Skip to content

Commit

Permalink
Training DPO with pre-computed reference scores (#958)
Browse files Browse the repository at this point in the history
* dpo with offline scores added

* new docstring for reference model

* mypy, lint

---------

Co-authored-by: Ilia Kulikov <[email protected]>
  • Loading branch information
uralik and Ilia Kulikov authored Jan 9, 2025
1 parent 713afda commit ec8a1ee
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 23 deletions.
24 changes: 23 additions & 1 deletion src/fairseq2/datasets/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class PreferenceOptimizationBatch:

chosen: SequenceBatch
rejected: SequenceBatch
reference_score_chosen: torch.Tensor | None
reference_score_rejected: torch.Tensor | None


class PreferenceOptimizationDataset(ABC):
Expand Down Expand Up @@ -232,6 +234,10 @@ def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]:
"indices_prompt": source_indices,
"indices_chosen": indices_chosen,
"indices_rejected": indices_rejected,
"reference_score_chosen": example.get("reference_score_chosen", None),
"reference_score_rejected": example.get(
"reference_score_rejected", None
),
"target_mask_chosen": target_mask_chosen,
"target_mask_rejected": target_mask_rejected,
"total_tokens": total_tokens,
Expand Down Expand Up @@ -324,7 +330,23 @@ def to_batch(example: dict[str, Any]) -> PreferenceOptimizationBatch:
example=example,
)

return PreferenceOptimizationBatch(batch_chosen, batch_rejected)
batch_reference_scores_chosen = None
if all(example["reference_score_chosen"]):
batch_reference_scores_chosen = torch.Tensor(
example["reference_score_chosen"]
).to(gang.device)
batch_reference_scores_rejected = None
if all(example["reference_score_rejected"]):
batch_reference_scores_rejected = torch.Tensor(
example["reference_score_rejected"]
).to(gang.device)

return PreferenceOptimizationBatch(
batch_chosen,
batch_rejected,
batch_reference_scores_chosen,
batch_reference_scores_rejected,
)

pipeline = builder.map(to_batch).and_return()

Expand Down
69 changes: 47 additions & 22 deletions src/fairseq2/recipes/lm/preference_finetune/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
class DpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]):
"""Represents the language model DPO-finetuning unit. Paper: https://arxiv.org/abs/2305.18290."""

_reference_model: Module
_reference_model: Module | None
_beta: float
_nll_scale: float
_metric_bag: DpoFinetuneMetricBag
Expand All @@ -53,7 +53,7 @@ class DpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]):
def __init__(
self,
model: Module,
reference_model: Module,
reference_model: Module | None,
gang: Gang,
beta: float = 0.1,
nll_scale: float = 1.0,
Expand All @@ -76,6 +76,11 @@ def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]:
rejected_input_batch, rejected_target_batch = as_auto_regressive_input(
rejected_batch
)
if (
chosen_target_batch.target_mask is None
or rejected_target_batch.target_mask is None
):
raise RuntimeError("target_mask attributes must exist for DPO loss")

chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch))
rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch))
Expand All @@ -87,18 +92,36 @@ def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]:
rejected_output, rejected_target_batch
)

with torch.no_grad():
ref_chosen_output = cast(
SequenceModelOutput, self._reference_model(chosen_batch)
if self._reference_model is not None:
with torch.no_grad():
ref_chosen_output = cast(
SequenceModelOutput, self._reference_model(chosen_batch)
)
ref_rejected_output = cast(
SequenceModelOutput, self._reference_model(rejected_batch)
)
ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg(
ref_chosen_output, chosen_target_batch
)
ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg(
ref_rejected_output, rejected_target_batch
)
elif (
batch.reference_score_chosen is not None
and batch.reference_score_rejected is not None
):
# reference scores must exist in the batch if reference model is None
ref_chosen_logps = batch.reference_score_chosen
ref_average_chosen_logps = (
ref_chosen_logps / chosen_target_batch.target_mask.sum(-1)
)
ref_rejected_output = cast(
SequenceModelOutput, self._reference_model(rejected_batch)
ref_rejected_logps = batch.reference_score_rejected
ref_average_rejected_logps = (
ref_rejected_logps / rejected_target_batch.target_mask.sum(-1)
)
ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg(
ref_chosen_output, chosen_target_batch
)
ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg(
ref_rejected_output, rejected_target_batch
else:
raise RuntimeError(
"Reference model is not initialized and data batch does not provide reference score, but at least one must exist."
)

if self._length_normalization:
Expand Down Expand Up @@ -206,8 +229,8 @@ class DpoConfig:
"""Holds the DPO configuration of a language model preference-finetuning task."""

# Reference Model
reference_model: AssetReference = "llama3_1_8b_instruct"
"""The name, path, or path to the asset card of the reference model."""
reference_model: AssetReference | None = "llama3_1_8b_instruct"
"""The name, path, or path to the asset card of the reference model. If reference_model is None, recipe expects to get reference log-probabilities for chosen and rejected targets as float values in the data example (fields `reference_score_rejected` and `reference_score_chosen`)."""

reference_dtype: DataType = torch.bfloat16
"""The data type of the reference model."""
Expand All @@ -230,14 +253,16 @@ class DpoConfig:
def create_dpo_unit(
config: DpoConfig, model: Module, root_gang: Gang, gangs: Mapping[str, Gang]
) -> DpoFinetuneUnit:
reference_model = _load_reference_model(
config.reference_model,
config.reference_dtype,
root_gang,
gangs,
config.reference_tensor_parallel_size,
log,
)
reference_model = None
if config.reference_model is not None:
reference_model = _load_reference_model(
config.reference_model,
config.reference_dtype,
root_gang,
gangs,
config.reference_tensor_parallel_size,
log,
)

dp_gang = gangs["dp"] # data

Expand Down

0 comments on commit ec8a1ee

Please sign in to comment.