Skip to content

Commit

Permalink
mypy, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Ilia Kulikov committed Jan 8, 2025
1 parent a19560b commit 994785c
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/fairseq2/recipes/lm/preference_finetune/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Mapping, cast, final
from typing import Mapping, cast, final

import torch
import torch.distributed
Expand Down Expand Up @@ -76,6 +76,11 @@ def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]:
rejected_input_batch, rejected_target_batch = as_auto_regressive_input(
rejected_batch
)
if (
chosen_target_batch.target_mask is None
or rejected_target_batch.target_mask is None
):
raise RuntimeError("target_mask attributes must exist for DPO loss")

chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch))
rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch))
Expand Down Expand Up @@ -116,7 +121,7 @@ def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]:
)
else:
raise RuntimeError(
f"Reference model is not initialized and data batch does not provide reference score, but at least one must exist."
"Reference model is not initialized and data batch does not provide reference score, but at least one must exist."
)

if self._length_normalization:
Expand Down

0 comments on commit 994785c

Please sign in to comment.