first commit
This commit is contained in:
21
activation.py
Normal file
21
activation.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
class _trunc_exp(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float)
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return torch.exp(x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, g):
|
||||
x = ctx.saved_tensors[0]
|
||||
return g * torch.exp(x.clamp(max=15))
|
||||
|
||||
trunc_exp = _trunc_exp.apply
|
||||
|
||||
def biased_softplus(x, bias=0):
|
||||
return torch.nn.functional.softplus(x - bias)
|
||||
Reference in New Issue
Block a user