Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Ulysses DistributedAttention compatibility (#5525)
The `DistributedAttention` in DeepSpeed-Ulysses has a compatibility with the training code in [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/model/transformer.py#L811) because it only takes sequential sequences as input parameters. However, this is not compatible with the frequently used scenarios of specifying parameters, such as the following scenario when using Flash Attention: ```python ulysses_attn = DistributedAttention(local_attention=flash_attn_func, sequence_process_group=None, scatter_idx=2, gather_idx=1) attn_output = ulysses_attn( query_states, key_states, value_states, dropout, softmax_scale, causal=causal, ) ``` Therefore, the `**kwargs` parameter has been added to increase compatibility with more local attention, while making minimal code modifications. Co-authored-by: Kwen-Chen <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]>
- Loading branch information