Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO trainer example #172

Open
sparsh35 opened this issue Oct 22, 2024 · 11 comments
Open

DPO trainer example #172

sparsh35 opened this issue Oct 22, 2024 · 11 comments

Comments

@sparsh35
Copy link
Contributor

Describe the bug
In trying DPO trainer example getting a bug with batch size and sharding , may be shard axis are not properly set or could be jax error as well , system used is V3 -32 , 4 hosts
To Reproduce
Steps to reproduce the behavior
Just run the dpo trainer examples

the error is this

outputs = self.layers(

File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/qwen2/modeling_qwen_flax.py", line 862, in call
output = pecs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function ring_attention at 0x7f2e71b196c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)' appropriately.
layer(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/qwen2/modeling_qwen_flax.py", line 502, in call
attn_outputs = self.self_attn(
File "/home/spars/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 568, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/home/spars/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 565, in rematted
y = fn(scope, *args)
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/qwen2/modeling_qwen_flax.py", line 382, in call
attentions = self.attention_performer(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/attention_module.py", line 529, in call
return self.ring_attention(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/attention_module.py", line 649, in ring_attention
attn_output = shard_map(
ValueError: shard_map applied to the function 'functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 32, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

  • args[0] of shape float32[1,1,28,128], where args[0] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'query', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1

  • args[1] of shape float32[1,1,28,128], where args[1] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'key', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1

  • args[2] of shape float32[1,1,28,128], where args[2] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'value', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1

  • args[3] of shape float32[1,1,1,1], where args[3] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), 'tp', 'sp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1

@sparsh35
Copy link
Contributor Author

Another issue , with ORPO trainer ,
TypeError: Can't instantiate abstract class ORPOTrainer with abstract methods _eval_epoch, _execute_eval_step, _execute_train_step, _finalize_training, _run_evaluation, _run_training_loop, _train_epoch

@sparsh35
Copy link
Contributor Author

So tried, the DPO Trainer test , python_test/trainers/dpo_test.py , but it is also getting stuck at
params={"params": model.shard_params(model.params)},
File "/home/spars/.local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 272, in params
raise ValueError(
ValueError: params cannot be accessed from model when the model is created with _do_init=False. You must call init_weights manually and store the params outside of the model and pass it explicitly where needed.
Traceback (most recent call last):

How can I initialize from initial , I already have a SFT model , is there a way to transfer the parameters from trained model to initialized model , and the logic for this I couldn't understand, why it was necessary ?

@erfanzar
Copy link
Owner

Hi
Yea i can fix that for u in next 2,3 hours ill give you a Dpo example with an example for transferring model

@erfanzar
Copy link
Owner

@sparsh35, I have added a tutorial for the DPO trainer, and I am sorry if I was slow on that; I was checking some other things.

https://github.com/erfanzar/EasyDeL/blob/main/notebooks/dpo-trainer.ipynb

@sparsh35
Copy link
Contributor Author

@erfanzar No problem at all , I know you are the sole maintainer of this project, things do get hectic.

@sparsh35 sparsh35 reopened this Oct 31, 2024
@sparsh35
Copy link
Contributor Author

Still getting issue , with dtypes,
Traceback (most recent call last):
File "/home/spars/train/dpotraining.py", line 161, in
main()
File "/home/spars/train/dpotraining.py", line 135, in main
output=dpo_trainer.train()
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1160, in train
output, run_exception = self._run_training_loop(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 920, in _run_training_loop
current_step, run_exception = self._train_epoch(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 993, in _train_epoch
loss, metrics, run_exception = self._execute_train_step(batch)
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1102, in _execute_train_step
self.model_state, dpo_out = self.sharded_train_step_function(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/func_utils/creators.py", line 271, in dpo_step
(__loss, (__chosen_rewards, __rejected_rewards)), grads = grad_fn(state.params)
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/func_utils/creators.py", line 254, in calculate_loss
losses = _loss_func(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/func_utils/loss_funcs.py", line 77, in _kto_pair_dpo_loss
chosen_kl = jax.lax.clamp(
TypeError: lax.clamp requires arguments to have the same dtypes, got int32, bfloat16, float32.

Using same code as given in example.

@erfanzar
Copy link
Owner

erfanzar commented Oct 31, 2024

@sparsh35 hello and thanks for re-opening issue, i have that fixed rn

@sparsh35
Copy link
Contributor Author

@erfanzar Thanks , i also modified code for hinge loss and kto also to check, I think there is another bug related to jax array sharding , this code maybe isn't compatible with multi-host, I can't debug this error , can you help
[1 1 1 ... 0 0 0]]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/spars/train/dpotraining.py", line 153, in
main()
File "/home/spars/train/dpotraining.py", line 127, in main
output=dpo_trainer.train()
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1160, in train
output, run_exception = self._run_training_loop(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 920, in _run_training_loop
current_step, run_exception = self._train_epoch(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 993, in _train_epoch
loss, metrics, run_exception = self._execute_train_step(batch)
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1102, in _execute_train_step
self.model_state, dpo_out = self.sharded_train_step_function(
ValueError: Passing non-trivial shardings for numpy inputs is not allowed. To fix this error, either specify a replicated sharding explicitly or use jax.experimental.multihost_utils.host_local_array_to_global_array(...) to convert your host local numpy inputs to a jax.Array which you can pass to pjit. If the numpy input is the same on each process, then you can use jax.make_array_from_callback(...) to create a jax.Array` which you can pass to pjit. Please see the jax.Array migration guide for more information https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. Got arg shape: (32, 1536), arg value: [[1 1 1 ... 0 0 0]

@erfanzar
Copy link
Owner

@sparsh35 thanks ill fix that one asap

@erfanzar
Copy link
Owner

can you try again and tell me if u still facing this issue

@sparsh35
Copy link
Contributor Author

sparsh35 commented Nov 1, 2024

No this is the error ,
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/spars/train/sa.py", line 124, in
output=dpo_trainer.train()
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1198, in train
output, run_exception = self._run_training_loop(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 951, in _run_training_loop
current_step, run_exception = self._train_epoch(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1024, in _train_epoch
loss, metrics, run_exception = self._execute_train_step(batch)
File "/home/spars/.local/lib/python3.10/site-packages/easydel/trainers/direct_preference_optimization_trainer/dpo_trainer.py", line 1140, in _execute_train_step
self.model_state, dpo_out = self.sharded_train_step_function(
File "/home/spars/.local/lib/python3.10/site-packages/jax/_src/array.py", line 1047, in _array_mlir_constant_handler
raise RuntimeError(
RuntimeError: Closing over jax.Array that spans non-addressable (non process local) devices is not allowed. Please pass such arrays as arguments to the function. Got jax.Array: bfloat16[3584,152064]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants