21 lines
526 B
Python
21 lines
526 B
Python
import torch
|
|
from torch.autograd import Function
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
class _trunc_exp(Function):
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float)
|
|
def forward(ctx, x):
|
|
ctx.save_for_backward(x)
|
|
return torch.exp(x)
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, g):
|
|
x = ctx.saved_tensors[0]
|
|
return g * torch.exp(x.clamp(max=15))
|
|
|
|
trunc_exp = _trunc_exp.apply
|
|
|
|
def biased_softplus(x, bias=0):
|
|
return torch.nn.functional.softplus(x - bias) |