diff --git a/transformers4rec/torch/ranking_metric.py b/transformers4rec/torch/ranking_metric.py index 5495b98a4..a281dd83b 100644 --- a/transformers4rec/torch/ranking_metric.py +++ b/transformers4rec/torch/ranking_metric.py @@ -131,7 +131,7 @@ def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor) # Compute recalls at K num_relevant = torch.sum(labels, dim=-1) - rel_indices = (num_relevant != 0).nonzero().squeeze() + rel_indices = (num_relevant != 0).nonzero().squeeze(dim=1) rel_count = num_relevant[rel_indices] if rel_indices.shape[0] > 0: