From ec8a1eec9bd562d1b16fcc0f91a4ea09426c736f Mon Sep 17 00:00:00 2001 From: Ilia Kulikov Date: Wed, 8 Jan 2025 16:22:23 -0800 Subject: [PATCH] Training DPO with pre-computed reference scores (#958) * dpo with offline scores added * new docstring for reference model * mypy, lint --------- Co-authored-by: Ilia Kulikov --- src/fairseq2/datasets/preference.py | 24 ++++++- .../recipes/lm/preference_finetune/dpo.py | 69 +++++++++++++------ 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/fairseq2/datasets/preference.py b/src/fairseq2/datasets/preference.py index 6cb1bcdbd..66e8794ba 100644 --- a/src/fairseq2/datasets/preference.py +++ b/src/fairseq2/datasets/preference.py @@ -48,6 +48,8 @@ class PreferenceOptimizationBatch: chosen: SequenceBatch rejected: SequenceBatch + reference_score_chosen: torch.Tensor | None + reference_score_rejected: torch.Tensor | None class PreferenceOptimizationDataset(ABC): @@ -232,6 +234,10 @@ def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: "indices_prompt": source_indices, "indices_chosen": indices_chosen, "indices_rejected": indices_rejected, + "reference_score_chosen": example.get("reference_score_chosen", None), + "reference_score_rejected": example.get( + "reference_score_rejected", None + ), "target_mask_chosen": target_mask_chosen, "target_mask_rejected": target_mask_rejected, "total_tokens": total_tokens, @@ -324,7 +330,23 @@ def to_batch(example: dict[str, Any]) -> PreferenceOptimizationBatch: example=example, ) - return PreferenceOptimizationBatch(batch_chosen, batch_rejected) + batch_reference_scores_chosen = None + if all(example["reference_score_chosen"]): + batch_reference_scores_chosen = torch.Tensor( + example["reference_score_chosen"] + ).to(gang.device) + batch_reference_scores_rejected = None + if all(example["reference_score_rejected"]): + batch_reference_scores_rejected = torch.Tensor( + example["reference_score_rejected"] + ).to(gang.device) + + return PreferenceOptimizationBatch( + batch_chosen, + batch_rejected, + batch_reference_scores_chosen, + batch_reference_scores_rejected, + ) pipeline = builder.map(to_batch).and_return() diff --git a/src/fairseq2/recipes/lm/preference_finetune/dpo.py b/src/fairseq2/recipes/lm/preference_finetune/dpo.py index 60604b76f..ed460d351 100644 --- a/src/fairseq2/recipes/lm/preference_finetune/dpo.py +++ b/src/fairseq2/recipes/lm/preference_finetune/dpo.py @@ -44,7 +44,7 @@ class DpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): """Represents the language model DPO-finetuning unit. Paper: https://arxiv.org/abs/2305.18290.""" - _reference_model: Module + _reference_model: Module | None _beta: float _nll_scale: float _metric_bag: DpoFinetuneMetricBag @@ -53,7 +53,7 @@ class DpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): def __init__( self, model: Module, - reference_model: Module, + reference_model: Module | None, gang: Gang, beta: float = 0.1, nll_scale: float = 1.0, @@ -76,6 +76,11 @@ def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]: rejected_input_batch, rejected_target_batch = as_auto_regressive_input( rejected_batch ) + if ( + chosen_target_batch.target_mask is None + or rejected_target_batch.target_mask is None + ): + raise RuntimeError("target_mask attributes must exist for DPO loss") chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) @@ -87,18 +92,36 @@ def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]: rejected_output, rejected_target_batch ) - with torch.no_grad(): - ref_chosen_output = cast( - SequenceModelOutput, self._reference_model(chosen_batch) + if self._reference_model is not None: + with torch.no_grad(): + ref_chosen_output = cast( + SequenceModelOutput, self._reference_model(chosen_batch) + ) + ref_rejected_output = cast( + SequenceModelOutput, self._reference_model(rejected_batch) + ) + ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg( + ref_chosen_output, chosen_target_batch + ) + ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg( + ref_rejected_output, rejected_target_batch + ) + elif ( + batch.reference_score_chosen is not None + and batch.reference_score_rejected is not None + ): + # reference scores must exist in the batch if reference model is None + ref_chosen_logps = batch.reference_score_chosen + ref_average_chosen_logps = ( + ref_chosen_logps / chosen_target_batch.target_mask.sum(-1) ) - ref_rejected_output = cast( - SequenceModelOutput, self._reference_model(rejected_batch) + ref_rejected_logps = batch.reference_score_rejected + ref_average_rejected_logps = ( + ref_rejected_logps / rejected_target_batch.target_mask.sum(-1) ) - ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg( - ref_chosen_output, chosen_target_batch - ) - ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg( - ref_rejected_output, rejected_target_batch + else: + raise RuntimeError( + "Reference model is not initialized and data batch does not provide reference score, but at least one must exist." ) if self._length_normalization: @@ -206,8 +229,8 @@ class DpoConfig: """Holds the DPO configuration of a language model preference-finetuning task.""" # Reference Model - reference_model: AssetReference = "llama3_1_8b_instruct" - """The name, path, or path to the asset card of the reference model.""" + reference_model: AssetReference | None = "llama3_1_8b_instruct" + """The name, path, or path to the asset card of the reference model. If reference_model is None, recipe expects to get reference log-probabilities for chosen and rejected targets as float values in the data example (fields `reference_score_rejected` and `reference_score_chosen`).""" reference_dtype: DataType = torch.bfloat16 """The data type of the reference model.""" @@ -230,14 +253,16 @@ class DpoConfig: def create_dpo_unit( config: DpoConfig, model: Module, root_gang: Gang, gangs: Mapping[str, Gang] ) -> DpoFinetuneUnit: - reference_model = _load_reference_model( - config.reference_model, - config.reference_dtype, - root_gang, - gangs, - config.reference_tensor_parallel_size, - log, - ) + reference_model = None + if config.reference_model is not None: + reference_model = _load_reference_model( + config.reference_model, + config.reference_dtype, + root_gang, + gangs, + config.reference_tensor_parallel_size, + log, + ) dp_gang = gangs["dp"] # data