Files
Magic123/activation.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

21 lines
526 B
Python

import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))
trunc_exp = _trunc_exp.apply
def biased_softplus(x, bias=0):
return torch.nn.functional.softplus(x - bias)