first commit
This commit is contained in:
1
freqencoder/__init__.py
Normal file
1
freqencoder/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .freq import FreqEncoder
|
||||
42
freqencoder/backend.py
Normal file
42
freqencoder/backend.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-use_fast_math'
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_freqencoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'freqencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
||||
77
freqencoder/freq.py
Normal file
77
freqencoder/freq.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _freqencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
|
||||
class _freq_encoder(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
||||
def forward(ctx, inputs, degree, output_dim):
|
||||
# inputs: [B, input_dim], float
|
||||
# RETURN: [B, F], float
|
||||
|
||||
if not inputs.is_cuda: inputs = inputs.cuda()
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
B, input_dim = inputs.shape # batch size, coord dim
|
||||
|
||||
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
||||
|
||||
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
||||
|
||||
ctx.save_for_backward(inputs, outputs)
|
||||
ctx.dims = [B, input_dim, degree, output_dim]
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
# grad: [B, C * C]
|
||||
|
||||
grad = grad.contiguous()
|
||||
inputs, outputs = ctx.saved_tensors
|
||||
B, input_dim, degree, output_dim = ctx.dims
|
||||
|
||||
grad_inputs = torch.zeros_like(inputs)
|
||||
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
||||
|
||||
return grad_inputs, None, None
|
||||
|
||||
|
||||
freq_encode = _freq_encoder.apply
|
||||
|
||||
|
||||
class FreqEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, degree=4):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.degree = degree
|
||||
self.output_dim = input_dim + input_dim * 2 * degree
|
||||
|
||||
def __repr__(self):
|
||||
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
# inputs: [..., input_dim]
|
||||
# return: [..., ]
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.reshape(-1, self.input_dim)
|
||||
|
||||
outputs = freq_encode(inputs, self.degree, self.output_dim)
|
||||
|
||||
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
||||
|
||||
return outputs
|
||||
52
freqencoder/setup.py
Normal file
52
freqencoder/setup.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-use_fast_math'
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='freqencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_freqencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'freqencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
||||
8
freqencoder/src/bindings.cpp
Normal file
8
freqencoder/src/bindings.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "freqencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
|
||||
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
|
||||
}
|
||||
129
freqencoder/src/freqencoder.cu
Normal file
129
freqencoder/src/freqencoder.cu
Normal file
@@ -0,0 +1,129 @@
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
// inputs: [B, D]
|
||||
// outputs: [B, C], C = D + D * deg * 2
|
||||
__global__ void kernel_freq(
|
||||
const float * __restrict__ inputs,
|
||||
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
||||
float * outputs
|
||||
) {
|
||||
// parallel on per-element
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * C) return;
|
||||
|
||||
// get index
|
||||
const uint32_t b = t / C;
|
||||
const uint32_t c = t - b * C; // t % C;
|
||||
|
||||
// locate
|
||||
inputs += b * D;
|
||||
outputs += t;
|
||||
|
||||
// write self
|
||||
if (c < D) {
|
||||
outputs[0] = inputs[c];
|
||||
// write freq
|
||||
} else {
|
||||
const uint32_t col = c / D - 1;
|
||||
const uint32_t d = c % D;
|
||||
const uint32_t freq = col / 2;
|
||||
const float phase_shift = (col % 2) * (PI() / 2);
|
||||
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
|
||||
}
|
||||
}
|
||||
|
||||
// grad: [B, C], C = D + D * deg * 2
|
||||
// outputs: [B, C]
|
||||
// grad_inputs: [B, D]
|
||||
__global__ void kernel_freq_backward(
|
||||
const float * __restrict__ grad,
|
||||
const float * __restrict__ outputs,
|
||||
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
||||
float * grad_inputs
|
||||
) {
|
||||
// parallel on per-element
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * D) return;
|
||||
|
||||
const uint32_t b = t / D;
|
||||
const uint32_t d = t - b * D; // t % D;
|
||||
|
||||
// locate
|
||||
grad += b * C;
|
||||
outputs += b * C;
|
||||
grad_inputs += t;
|
||||
|
||||
// register
|
||||
float result = grad[d];
|
||||
grad += D;
|
||||
outputs += D;
|
||||
|
||||
for (uint32_t f = 0; f < deg; f++) {
|
||||
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
|
||||
grad += 2 * D;
|
||||
outputs += 2 * D;
|
||||
}
|
||||
|
||||
// write
|
||||
grad_inputs[0] = result;
|
||||
}
|
||||
|
||||
|
||||
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(outputs);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
|
||||
static constexpr uint32_t N_THREADS = 128;
|
||||
|
||||
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
|
||||
}
|
||||
|
||||
|
||||
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(outputs);
|
||||
CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
static constexpr uint32_t N_THREADS = 128;
|
||||
|
||||
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
|
||||
}
|
||||
10
freqencoder/src/freqencoder.h
Normal file
10
freqencoder/src/freqencoder.h
Normal file
@@ -0,0 +1,10 @@
|
||||
# pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
||||
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
|
||||
|
||||
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
||||
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
|
||||
Reference in New Issue
Block a user