Skip to content

Commit

Permalink
fixed APTx activation (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat authored May 14, 2024
1 parent a722d79 commit 0ee81fd
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
34 changes: 34 additions & 0 deletions neurodiffeq/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,37 @@ def __init__(self, beta=1.0, trainable=False):

def forward(self, x):
return x * torch.sigmoid(self.beta * x)

class APTx(nn.Module):
r"""The APTx (Alpha Plus Tanh Times) activation function: :math:`\mathrm{APTx}(x)= (\alpha + \tanh{(\beta x)}) \gamma x`
behaves similar to the MISH activation function, but requires lesser mathematical operations to
compute. The lesser computational requirements of APTx does speed up the
model training, and thus also reduces the hardware requirement for the deep
learning model
:param alpha: The :math:`\alpha` parameter in the APTx activation.
:type alpha: float
:param beta: The :math:`\beta` parameter in the APTx activation.
:type beta: float
:param gamma: The :math:`\gamma` parameter in the APTx activation.
:type gamma: float
:param trainable: Whether scalar :math:`\beta` can be trained
:type trainable: bool
"""

def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, trainable=False):
super(APTx, self).__init__()
alpha = float(alpha)
beta = float(beta)
gamma = float(gamma)
self.trainable = trainable
if trainable:
self.alpha = nn.Parameter(torch.tensor(alpha))
self.beta = nn.Parameter(torch.tensor(beta))
self.gamma = nn.Parameter(torch.tensor(gamma))
else:
self.alpha = alpha
self.beta = beta
self.gamma = gamma

def forward(self, x):
return (self.alpha + torch.nn.functional.tanh(self.beta*x))*self.gamma*x
19 changes: 19 additions & 0 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from neurodiffeq.networks import MonomialNN
from neurodiffeq.networks import SinActv
from neurodiffeq.networks import Swish
from neurodiffeq.networks import APTx

MAGIC = 42
torch.manual_seed(MAGIC)
Expand Down Expand Up @@ -147,3 +148,21 @@ def test_swish():
assert len(list(f.parameters())) == 1
assert list(f.parameters())[0].shape == ()
assert torch.isclose(f(x), x * torch.sigmoid(beta * x)).all()



def test_APTx():
x = torch.rand(10, 5)

f = APTx()
print(list(f.parameters()))
assert len(list(f.parameters())) == 0
assert torch.isclose(f(x), (1 + torch.nn.Tanh()(x))*x ).all()

alpha = 1.0
beta = 1.0
gamma = 0.5
f = APTx(alpha,beta,gamma, trainable=True)
assert len(list(f.parameters())) == 3
assert list(f.parameters())[0].shape == ()
assert torch.isclose(f(x), (alpha + torch.nn.Tanh()(beta*x))*gamma*x ).all()

0 comments on commit 0ee81fd

Please sign in to comment.