Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Shapes of partition_fragment_{A,B} #1299

Closed
jeromeku opened this issue Jan 10, 2024 · 6 comments
Closed

[QST] Shapes of partition_fragment_{A,B} #1299

jeromeku opened this issue Jan 10, 2024 · 6 comments

Comments

@jeromeku
Copy link
Contributor

What is your question?

Assuming a TiledMma defined as:

  using MMA_Op = SM80_16x8x16_F32F16F16F32_TN;
  using MMA_Traits = MMA_Traits<MMA_Op>;
  using MMA_Atom = MMA_Atom<MMA_Traits>;

  constexpr int kNWarps = 4;
  using ThreadLayoutMNK = Layout<Shape<Int<4>, _1, _1>>;
  using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;

  using TiledMma = TiledMMA<MMA_Atom, ThreadLayoutMNK, ValLayoutMNK>;

where MMA_Op for SM80_16x8x16... is:

struct SM80_16x8x16_F32F16F16F32_TN
{
  using DRegisters = float[4];
  using ARegisters = uint32_t[4];
  using BRegisters = uint32_t[2];
  using CRegisters = float[4];
  ...
asm volatile(
      "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
      "{%0,  %1,  %2,  %3},"
      "{%4,  %5,  %6,  %7},"
      "{%8,  %9},"
      "{%10, %11, %12, %13};\n"
      : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
      :  "r"(a0),  "r"(a1),  "r"(a2),  "r"(a3),
         "r"(b0),  "r"(b1),
         "f"(c0),  "f"(c1),  "f"(c2),  "f"(c3));
}

Then I define the following A and B tensors:

 Tensor gA = make_tensor(make_gmem_ptr(Aptr), Shape<Int<128>, Int<32>>{}, Stride<Int<32>, Int<1>>{});
 Tensor gB = make_tensor(make_gmem_ptr(Bptr), Shape<Int<128>, Int<32>>{}, Stride<Int<32>, Int<1>>{});

When I get a thr_mma slice as such:

 TiledMMA tiled_mma;
  auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
  print(thr_mma, "thr_mma");
  Tensor tSrA = thr_mma.partition_fragment_A(gA);
  print(tSrA);
  Tensor tSrB = thr_mma.partition_fragment_B(gB);
  print(tSrB);

I get the following output

thr_mma:

thr_mma
TiledMMA
  TiledThr:  (_4,_1,_1):(_1,_0,_0)
  TiledVal:  (_1,_2,_1):(_0,_1,_0)
  TiledPerm: (_,_,_)
  TiledShape_MNK: (_64,_16,_16)
  ThrLayoutVMNK:  (_32,_4,_1,_1):(_1,_32,_0,_0)
MMA_Atom
  ThrID:         _32:_1
  LayoutA_TV:    ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
  LayoutB_TV:    ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
  LayoutC_TV:    ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

tSrA:

ptr[16b](0x7f1a20feefc0) o ((_2,_2,_2),_2,_2):((_1,_2,_4),_16,_8)

tSrB:

ptr[16b](0x7f1a20fef090) o ((_2,_2),_16,_2):((_1,_2),_8,_4)

Since gA is shape 128 x 32, this would require 4 tiles of the tiled_mma, each of which is 64 x 16. Thus tSrA makes since, since for the given MmaAtom, each thread is responsible for 2 x 2 x 2 = 8 fp16 A values, and the 2 x 2 = 4 refers to the required tiling.

For tSrB, how is the shape derived? The (2, 2) makes sense, since for the MmaAtom, each thread is responsible for 4 fp16 B values. However, having trouble reconciling how the 16 and 2 are derived. One could reason that 16 = 128 / 8 and 2 = 32 / 16 where 8 and 16 divisors comes from N and K of the MmaAtom and the 16 and 2 are the number of MmaAtoms needed to achieve the desired 128 x 32.

However, tSrA wouldn't make sense then, since applying the same logic would require tSrA to have shape ((2, 2, 2), 8, 2) since 8 = 128 / 16 and 2 = 32 / 16 where 16 comes from the MmaAtom M and K.

@ccecka
Copy link

ccecka commented Jan 10, 2024

Scratch.pdf

I've attached the output of

  using MMA_Op = SM80_16x8x16_F32F16F16F32_TN;
  using MMA_Traits = MMA_Traits<MMA_Op>;
  using MMA_Atom = MMA_Atom<MMA_Traits>;

  using ThreadLayoutMNK = Layout<Shape<Int<4>, _1, _1>>;
  using TiledMma = TiledMMA<MMA_Atom, ThreadLayoutMNK>;

  print_latex(TiledMma{});

(Note that a recent update removed the ValLayoutMNK parameter of the TiledMMA, see the new interface)

While the Atom is 16x8x16, a AtomLayout of 4x1x1 causes the TiledMma to be 64x8x16. This is what tile_shape(TiledMma{}) and tile_size<0>(TiledMma{}), tile_size<1>(TiledMma{}), tile_size<2>(TiledMma{}) also reflect.

For A, 128 / 64 = 2 and 32 / 16 = 2.
For B, 128 / 8 = 16 and 32 / 16 = 2.

@jeromeku
Copy link
Contributor Author

@ccecka

Thanks! That makes sense.

How to obtain the same TileMma with ValLayoutMNK = _1, _2, _1 (16 x 16 tile of B across 128 threads) with the new interface? If I use ThreadLayoutMNK = <_4, _2, _1>, I get a 16 x 16 tile of B but with 256 threads.

@ccecka
Copy link

ccecka commented Jan 10, 2024

Is

TiledMma<MMA_Atom, Layout<Shape<_2,_2,_1>>

what you're looking for?

That's 32 * 4 = 128 threads making a TiledMma that is 32x16x16

@jeromeku
Copy link
Contributor Author

That could work, but that doesn't give the same TV mapping as original.

Original, with ValLayoutMNK = <_1, _2, _1>
tiled_mma_val_layout.pdf

New, no ValLayoutMNK but with TiledMma<MMA_Atom, Layout<Shape<_2,_2,_1>>:
tiled_2x2x1.pdf

@ccecka
Copy link

ccecka commented Jan 11, 2024

Then you probably want

TiledMma<MMA_Atom,
         Layout<Shape<_4,_1,_1>>,
         Tile<_64,_16,_16>>;

To produce a 64x16x16 TiledMMA from the 64x8x16 TiledMMA.

(But it will have the same partitioning effect as without the Tile because the first mode is always the fragment for a single MMA.)

@jeromeku
Copy link
Contributor Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants