Skip to content

Commit

Permalink
Fix LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 1, 2024
1 parent f44fe35 commit 609a164
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/brevitas/core/stats/stats_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor:
stats_input = self.first_tracked_param(x)
if self.extra_tracked_params_list is not None:
stats_input = self.first_tracked_param(None)
for extra_tracked_param in self.extra_tracked_params_list:
stats_input = extra_tracked_param(stats_input)
else:
stats_input = self.first_tracked_param(x)
out = self.stats(stats_input)
return out

0 comments on commit 609a164

Please sign in to comment.