diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index be549328e..49bf62a82 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -115,9 +115,11 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor: - stats_input = self.first_tracked_param(x) if self.extra_tracked_params_list is not None: + stats_input = self.first_tracked_param(None) for extra_tracked_param in self.extra_tracked_params_list: stats_input = extra_tracked_param(stats_input) + else: + stats_input = self.first_tracked_param(x) out = self.stats(stats_input) return out