Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with LLama training with LoRA #567

Open
freQuensy23-coder opened this issue Oct 17, 2023 · 3 comments
Open

Problem with LLama training with LoRA #567

freQuensy23-coder opened this issue Oct 17, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@freQuensy23-coder
Copy link

🐛 Describe the bug

I've tried to train LLama model using the reward model. I created such a config

config = TRLConfig(
        train=TrainConfig(
            seq_length=4096,
            epochs=100,
            total_steps=10000,
            batch_size=4,
            checkpoint_interval=100,
            eval_interval=100,
            pipeline="PromptPipeline",
            trainer="AcceleratePPOTrainer",
        ),
        model=ModelConfig(model_path="meta-llama/Llama-2-7b-hf",
                          num_layers_unfrozen=-1,
                          peft_config=LoraConfig(
                            r=8,
                            task_type=TaskType.CAUSAL_LM,
                            lora_alpha=32,
                            lora_dropout=0.1,
                            )
                          ),
        tokenizer=TokenizerConfig(tokenizer_path="meta-llama/Llama-2-7b-hf", truncation_side="right"),
        optimizer=OptimizerConfig(
            name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
        ),
        scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)),
        method=PPOConfig(
            name="PPOConfig",
            num_rollouts=128,
            chunk_size=4,
            ppo_epochs=4,
            init_kl_coef=0.0004,
            target=None,
            horizon=10000,
            gamma=1,
            lam=0.95,
            cliprange=0.2,
            cliprange_value=0.2,
            vf_coef=1,
            scale_reward="ignored",
            ref_mean=None,
            ref_std=None,
            cliprange_reward=10,
            gen_kwargs=dict(
                max_new_tokens=64,
                top_k=0,
                top_p=1.0,
                do_sample=True,
            ),
        ),
    )

And training code

trlx.train(
    reward_fn=reward_function,
    prompts=df_train['query'].tolist(),
    eval_prompts=df_test['query'].tolist(),
    config=config,
)

But it caused error with mask head (???)

[RANK 0] Collecting rollouts
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 trlx.train(
      2     reward_fn=reward_function,
      3     prompts=df_train['query'].tolist(),
      4     eval_prompts=df_test['query'].tolist(),
      5     config=config,
      6 )

File ~/.conda/envs/trlx/lib/python3.10/site-packages/trlx/trlx.py:129, in train(model_path, reward_fn, dataset, samples, rewards, prompts, eval_prompts, metric_fn, config, stop_sequences)
    126 if config.train.resume_from_checkpoint and os.path.exists(config.train.resume_from_checkpoint):
    127     trainer.load(config.train.resume_from_checkpoint)
--> 129 trainer.learn()
    130 return trainer

File ~/.conda/envs/trlx/lib/python3.10/site-packages/trlx/trainer/accelerate_base_trainer.py:521, in AccelerateRLTrainer.learn(self)
    516 """
    517 Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`
    518 """
    519 logger.info("Starting training")
--> 521 self.prepare_learning()
    522 self.iter_count = 0
    523 self.nth_evaluation = 0

File ~/.conda/envs/trlx/lib/python3.10/site-packages/trlx/trainer/accelerate_ppo_trainer.py:234, in AcceleratePPOTrainer.prepare_learning(self)
    231 eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size)
    232 self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader)
--> 234 self.make_experience(self.config.method.num_rollouts)
    236 self.train_dataloader = self.create_train_dataloader()
    238 self.n_inner_epochs = self.config.method.ppo_epochs

File ~/.conda/envs/trlx/lib/python3.10/site-packages/trlx/trainer/accelerate_ppo_trainer.py:418, in AcceleratePPOTrainer.make_experience(self, num_rollouts, iter_count)
    416 position_ids.masked_fill_(attention_mask == 0, 1)
    417 with torch.no_grad():
--> 418     logits, *_, values = self.model(
    419         all_tokens, attention_mask=attention_mask, position_ids=position_ids
    420     )
    421     # TODO(dahoas): When hydra model works need to also support generation on hydra head
    422     if hasattr(self.model, "frozen_head") or self.model.peft_type:

File ~/.conda/envs/trlx/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/trlx/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/trlx/lib/python3.10/site-packages/trlx/models/modeling_ppo.py:329, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, attention_mask, past_key_values, position_ids, head_mask, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, ignore_peft_adapter)
    327         outputs = self.base_model.base_model(**forward_kwargs)
    328 else:
--> 329     outputs = self.base_model(**forward_kwargs)
    331 # TODO: Apply PEFT to value branch
    332 if self.num_value_layers_unfrozen > 0:

File ~/.conda/envs/trlx/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/trlx/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/trlx/lib/python3.10/site-packages/peft/peft_model.py:918, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, **kwargs)
    907             raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
    908         return self.base_model(
    909             input_ids=input_ids,
    910             attention_mask=attention_mask,
   (...)
    915             **kwargs,
    916         )
--> 918     return self.base_model(
    919         input_ids=input_ids,
    920         attention_mask=attention_mask,
    921         inputs_embeds=inputs_embeds,
    922         labels=labels,
    923         output_attentions=output_attentions,
    924         output_hidden_states=output_hidden_states,
    925         return_dict=return_dict,
    926         **kwargs,
    927     )
    929 batch_size = _get_batch_size(input_ids, inputs_embeds)
    930 if attention_mask is not None:
    931     # concat prompt attention mask

File ~/.conda/envs/trlx/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/trlx/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/trlx/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:94, in BaseTuner.forward(self, *args, **kwargs)
     93 def forward(self, *args: Any, **kwargs: Any):
---> 94     return self.model.forward(*args, **kwargs)

TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'head_mask'```






### Which trlX version are you using?

0.7.0

### Additional system and package information

Linux
@freQuensy23-coder freQuensy23-coder added the bug Something isn't working label Oct 17, 2023
@freQuensy23-coder
Copy link
Author

Same code is working for gpt2 model (it has head_mask param in forward method)

@freQuensy23-coder
Copy link
Author

After updating to a new branch, another problem with wandb logging appeared.

wandb: WARNING Step only supports monotonically increasing values, use define_metric to set a custom x axis. For details see: https://wandb.me/define-metric
wandb: WARNING (User provided step: 0 is less than current step: 64. Dropping entry: {'time/rollout_generate': 5.391999550163746, 'time/rollout_score': 0.23438506573438644, 'rollout_scores/mean': 0.08249974995851517, 'rollout_scores/std': 0.8703116632532328, 'rollout_scores/running_mean': 0.29262859688606113, 'rollout_scores/running_std': 0.9445023126900196, 'time/rollout_time': 11.604977615177631, 'policy/sqrt_kl': 0.0, 'policy/kl_per_token': 0.0, 'kl_ctl_value': 0.001, '_timestamp': 1697575727.0544078})

Thanks for maintaining this lib, I will try to figure out the code on my own and fix the problem for now

@maxreciprocate
Copy link
Collaborator

Hi @freQuensy23-coder! I use wandb==0.15.4 and with your config you gave here I don't see this warning:
https://wandb.ai/sorry/trlx/runs/y7ttdket/logs

There are a few accelerator.log calls in different places which should happen on the same tick, but they should be properly synchronized with the step argument. So not sure what could be the reason behind this warning 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants