Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] TiledCopy #1358

Closed
jeromeku opened this issue Feb 24, 2024 · 2 comments
Closed

[QST] TiledCopy #1358

jeromeku opened this issue Feb 24, 2024 · 2 comments

Comments

@jeromeku
Copy link
Contributor

What is your question?
Getting CopyAtom, src / dst layout doesn't vectorize into registers when trying to implement the following tiled copy:

 using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
 using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
 using g2s_copy_atom = Copy_Atom<g2s_copy_traits, T>;

 using G2SCopyA =
      decltype(make_tiled_copy(g2s_copy_atom{},
                               make_layout(make_shape(Int<16>{}, Int<2>{}),
                                           make_stride(Int<2>{}, Int<1>{})),
                               make_layout(make_shape(Int<1>{}, Int<8>{}))));
Tensor gA = make_tensor(make_gmem_ptr(A), make_layout(make_shape(Int<M>{}, Int<K>{}), make_stride(Int<K>{}, 1)));
Tensor sA = make_tensor(make_smem_ptr(smemA), make_layout(make_shape(Int<M>{}, Int<K>{}), make_stride(Int<K>{}, 1)));

G2SCopyA g2s_tiled_copy_a;
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
auto tAgA_copy = g2s_thr_copy_a.partition_S(gA);
auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); 

cute::copy(g2s_tiled_copy_a, tAgA_copy, tAsA_copy);

In the above, M and K are multiples of 16 and T = cutlass::half_t. smemA is a static shared T array of len MxK.

My understanding is that the tiled_copy I've defined is using 32 threads (16 x 2) and that each thread is copying 8 elements (1 x 8) such that each copy tile is shape 16 x 16, and the intent is to have each thread do a vectorized copy from global to shared memory. Where am I going wrong?

@ccecka
Copy link

ccecka commented Feb 24, 2024

Looks fine. Try using static Int<1>{} rather than dynamic 1 in the data layouts.

@jeromeku
Copy link
Contributor Author

@ccecka

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants