Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stride issues in flash_attn_interface #58

Open
wants to merge 2 commits into
base: flash_attention_for_rocm
Choose a base branch
from

Conversation

clintg6
Copy link

@clintg6 clintg6 commented May 31, 2024

What:

Ensures tensors are contiguous in memory with matching strides during the backward pass.

Fixes #40

Why:

Multiple users/customers have been facing issues while training with the axolotl package see #40

maybe_contiguous fails to adequately check if tensors are contiguous in certain packing scenarios

Changes:

  • flash_attn_interface.py
    • _flash_attn_backward: stride check, contiguous check
    • _flash_attn_varlen_backward: contiguous check

Testing:

  • Manual testing: Successfully fine tuned Phi-2, StableLM, and TinyLlama in axolotl
  • Automated testing: Benchmarking scripts ran as expected

@clintg6 clintg6 self-assigned this May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Issue]: Expected dout_seq_stride == out_seq_stride to be true, but got false
1 participant