first commit

This commit is contained in:
Guocheng Qian
2023-08-02 19:51:43 -07:00
parent c2891c38cc
commit 13e18567fa
202 changed files with 43362 additions and 17 deletions

10
shencoder/src/shencoder.h Normal file
View File

@@ -0,0 +1,10 @@
# pragma once
#include <stdint.h>
#include <torch/torch.h>
// inputs: [B, D], float, in [-1, 1]
// outputs: [B, F], float
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx);
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);