Skip to content

Commit

Permalink
fix ndcg metrics in multitasks cases (#2720)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2720

Runs into some errors when using NDCG metrics with multiple tasks. This PR
- adds unit tests for the multitask cases
- fixes the ndcg implementations for multitask cases

Reviewed By: iamzainhuda, venkatrsrinivas

Differential Revision: D69057771

fbshipit-source-id: 7d33432f663d933bab17bce6e660d5350962d953
  • Loading branch information
khanhnp-meta authored and facebook-github-bot committed Feb 4, 2025
1 parent 6e60bbe commit 6e3d296
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 7 deletions.
12 changes: 6 additions & 6 deletions torchrec/metrics/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,12 @@ def _get_ndcg_states(
)

# Expand these to be [num_task, num_sessions, batch_size] for masking to handle later.
expanded_sorted_labels_by_labels = sorted_labels_by_labels.expand(
(num_tasks, num_sessions, batch_size)
)
expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.expand(
expanded_sorted_labels_by_labels = sorted_labels_by_labels.unsqueeze(1).expand(
(num_tasks, num_sessions, batch_size)
)
expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.unsqueeze(
1
).expand((num_tasks, num_sessions, batch_size))

# Make sure to correspondingly sort session IDs according to how we sorted labels above.
session_ids_by_sorted_labels = torch.gather(
Expand All @@ -188,10 +188,10 @@ def _get_ndcg_states(

# Figure out after sorting which example indices belong to which session.
sorted_session_ids_by_labels_mask = (
task_to_session_to_examples == session_ids_by_sorted_labels
task_to_session_to_examples == session_ids_by_sorted_labels.unsqueeze(1)
).long()
sorted_session_ids_by_predictions_mask = (
task_to_session_to_examples == session_ids_by_sorted_predictions
task_to_session_to_examples == session_ids_by_sorted_predictions.unsqueeze(1)
).long()

# Get the ranks (1, N] for each example in each session for every task.
Expand Down
140 changes: 139 additions & 1 deletion torchrec/metrics/tests/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, List

import torch
from torchrec.metrics.metrics_config import DefaultTaskInfo
from torchrec.metrics.metrics_config import DefaultTaskInfo, RecComputeMode

from torchrec.metrics.ndcg import NDCGMetric, SESSION_KEY
from torchrec.metrics.test_utils import RecTaskInfo
Expand All @@ -22,6 +22,27 @@
WORLD_SIZE = 4
BATCH_SIZE = 10

DefaultTaskInfo0 = RecTaskInfo(
name="DefaultTask0",
label_name="label",
prediction_name="prediction",
weight_name="weight",
)

DefaultTaskInfo1 = RecTaskInfo(
name="DefaultTask1",
label_name="label",
prediction_name="prediction",
weight_name="weight",
)

DefaultTaskInfo2 = RecTaskInfo(
name="DefaultTask2",
label_name="label",
prediction_name="prediction",
weight_name="weight",
)


def get_test_case_single_session_within_batch() -> Dict[str, torch.Tensor]:
return {
Expand Down Expand Up @@ -117,6 +138,41 @@ def get_test_case_scale_by_weights_tensor() -> Dict[str, torch.Tensor]:
}


def get_test_case_multitask() -> Dict[str, torch.Tensor]:
return {
"predictions": torch.tensor(
[
[0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3],
[0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3],
[0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3],
]
),
"session_ids": torch.tensor(
[
[1, 1, 1, 1, 1, 2, 2, 2],
[1, 1, 1, 1, 1, 2, 2, 2],
[1, 1, 1, 1, 1, 2, 2, 2],
]
),
"labels": torch.tensor(
[
[0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0],
]
),
"weights": torch.tensor(
[
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0],
]
),
"expected_ndcg_exp": torch.tensor([0.6748, 0.6748, 0.6748]),
"expected_ndcg_non_exp": torch.tensor([0.6463, 0.6463, 0.6463]),
}


class NDCGMetricValueTest(unittest.TestCase):
def generate_metric(
self,
Expand All @@ -130,6 +186,7 @@ def generate_metric(
remove_single_length_sessions: bool = False,
scale_by_weights_tensor: bool = False,
report_ndcg_as_decreasing_curve: bool = True,
compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION,
**kwargs: Dict[str, Any],
) -> NDCGMetric:
return NDCGMetric(
Expand All @@ -149,6 +206,7 @@ def generate_metric(
report_ndcg_as_decreasing_curve=report_ndcg_as_decreasing_curve,
# pyre-ignore[6]
k=k,
compute_mode=compute_mode,
# pyre-ignore[6]
**kwargs,
)
Expand Down Expand Up @@ -565,3 +623,83 @@ def test_case_report_as_increasing_ndcg_and_scale_by_weights_tensor(self) -> Non
equal_nan=True,
msg=f"Actual: {actual_metric}, Expected: {expected_metric}",
)

def test_multitask_non_exp(self) -> None:
"""
Test NDCG with multiple tasks.
"""
model_output = get_test_case_multitask()
metric = self.generate_metric(
world_size=WORLD_SIZE,
my_rank=0,
batch_size=BATCH_SIZE,
tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2],
exponential_gain=False,
session_key=SESSION_KEY,
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
)

metric.update(
predictions=model_output["predictions"],
labels=model_output["labels"],
weights=model_output["weights"],
required_inputs={SESSION_KEY: model_output["session_ids"]},
)
output = metric.compute()
actual_metric = torch.stack(
[
output[f"ndcg-{task.name}|lifetime_ndcg"]
for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2]
]
)
expected_metric = model_output["expected_ndcg_non_exp"]

torch.testing.assert_close(
actual_metric,
expected_metric,
atol=1e-4,
rtol=1e-4,
check_dtype=False,
equal_nan=True,
msg=f"Actual: {actual_metric}, Expected: {expected_metric}",
)

def test_multitask_exp(self) -> None:
"""
Test NDCG with multiple tasks.
"""
model_output = get_test_case_multitask()
metric = self.generate_metric(
world_size=WORLD_SIZE,
my_rank=0,
batch_size=BATCH_SIZE,
tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2],
exponential_gain=True,
session_key=SESSION_KEY,
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
)

metric.update(
predictions=model_output["predictions"],
labels=model_output["labels"],
weights=model_output["weights"],
required_inputs={SESSION_KEY: model_output["session_ids"]},
)
output = metric.compute()
actual_metric = torch.stack(
[
output[f"ndcg-{task.name}|lifetime_ndcg"]
for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2]
]
)
expected_metric = model_output["expected_ndcg_exp"]

torch.testing.assert_close(
actual_metric,
expected_metric,
atol=1e-4,
rtol=1e-4,
check_dtype=False,
equal_nan=True,
msg=f"Actual: {actual_metric}, Expected: {expected_metric}",
)

0 comments on commit 6e3d296

Please sign in to comment.