From 2e8ef812096ccb949ccb900581cd5e33828d47de Mon Sep 17 00:00:00 2001 From: Ravin Kumar <16964978+mr-ravin@users.noreply.github.com> Date: Mon, 15 Jul 2024 21:25:10 +0530 Subject: [PATCH] APTx Function: Default value of gamma should be 0.5 when trainable=False (#222) In APTx activation function, the default value of gamma should be 0.5 when trainable=False --- neurodiffeq/networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neurodiffeq/networks.py b/neurodiffeq/networks.py index e7cf288..4b88b64 100644 --- a/neurodiffeq/networks.py +++ b/neurodiffeq/networks.py @@ -190,13 +190,13 @@ class APTx(nn.Module): :type trainable: bool """ - def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, trainable=False): + def __init__(self, alpha=1.0, beta=1.0, gamma=0.5, trainable=False): super(APTx, self).__init__() alpha = float(alpha) beta = float(beta) gamma = float(gamma) self.trainable = trainable - if trainable: + if self.trainable: self.alpha = nn.Parameter(torch.tensor(alpha)) self.beta = nn.Parameter(torch.tensor(beta)) self.gamma = nn.Parameter(torch.tensor(gamma)) @@ -206,4 +206,4 @@ def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, trainable=False): self.gamma = gamma def forward(self, x): - return (self.alpha + torch.nn.functional.tanh(self.beta*x))*self.gamma*x \ No newline at end of file + return (self.alpha + torch.nn.functional.tanh(self.beta*x))*self.gamma*x