Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Could you please provide the latest sample code for using the latest parallel_nsa function? #2

Open
Kyfafyd opened this issue Feb 24, 2025 · 4 comments
Labels
enhancement New feature or request

Comments

@Kyfafyd
Copy link

Kyfafyd commented Feb 24, 2025

Proposal

R.T.

Rationale

No response

@Kyfafyd Kyfafyd added the enhancement New feature or request label Feb 24, 2025
@Kyfafyd Kyfafyd changed the title [RFC] Could you please provide the latest code for using parallel_nsa function? [RFC] Could you please provide the latest sample code for using the latest parallel_nsa function? Feb 24, 2025
@Kyfafyd
Copy link
Author

Kyfafyd commented Feb 24, 2025

Also, what about the usage for naive_nsa_with_compression and parallel_nsa_with_compression?

@Hanyuezhuohua
Copy link
Collaborator

Thank you for your interest in our project. The latest sample code for using the latest parallel_nsa function has been provided, including both selected and sliding attention. You can directly try it following our instructions.

For the nsa_with_compression function, we will further integrate the compressed branch and online top-k selection into our kernel. However, this function is still under development for the parallel version.

@Kyfafyd
Copy link
Author

Kyfafyd commented Feb 25, 2025

Thanks for your response!
I would like to learn if nsa kernel can be applied for attention with 48 heads, which is not a power of 2.
Additionally, how to provide the input for g_slc, g_swa, block_indices, and block_counts in the attention computation of a vision transformer?

@Hanyuezhuohua
Copy link
Collaborator

The NSA kernel can be applied to attention with 48 query heads and 3 key-value heads. Since the query heads are grouped by key-value head, each group consists of 16 query heads.

In addition, our new nsa_with_compression function allows you to obtain block indices. Specifically, g_slc and g_swa are computed by passing the input through an MLP module with a sigmoid activation. The block_counts parameter is a user-defined constant that controls the sparsity ratio.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants