Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Why use column-major tv layout to encode the mma? #1226

Closed
mammoth831 opened this issue Dec 3, 2023 · 2 comments
Closed

[QST] Why use column-major tv layout to encode the mma? #1226

mammoth831 opened this issue Dec 3, 2023 · 2 comments

Comments

@mammoth831
Copy link

CuTe uses column-major to encode the TV layout of mma’s multiplicand, the doc explains that
Since CuTe layouts return indices rather than coordinates, we choose a column-major encoding of the (m,n) coordinates:.

But it seems it leads to wrong results for a row-major matrix.

e.g. for multiplicand A of this mma inst:

template <>
struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
using ALayout = SM80_16x8_Row;
using BLayout = SM80_8x8_Row;
using CLayout = SM80_16x8_Row;
};

Then I use the following code to read values from global memory directly.

#include <cute/tensor.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/copy_atom.hpp>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>

using namespace cute;

template <class TiledMma, class T>
__global__ void mma_test(const T* A) {
        TiledMma tiled_mma;
        using GmemCopyAtom = Copy_Atom<DefaultCopy, T>;
        auto inst_m = size<0>(typename TiledMma::AtomShape_MNK{});
        auto inst_k = size<2>(typename TiledMma::AtomShape_MNK{});

        auto gA = make_tensor(make_gmem_ptr(A), make_shape(inst_m, inst_k));
        auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
        Tensor tCrA  = thr_mma.partition_fragment_A(gA);

        auto thr_copy_A       = make_tiled_copy_A(GmemCopyAtom{}, tiled_mma).get_thread_slice(threadIdx.x);
        Tensor tCgA           = thr_copy_A.partition_S(gA);
        Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

        clear(tCrA);
        copy(tCgA, tCrA_copy_view);
        if (thread0()) {
          for (int i = 0; i < 4; ++i)
            printf("t0v%d: %f \n", i, float(tCrA_copy_view(i)));
        }
}

int main() {
        using TiledMma = TiledMMA<MMA_Atom<SM80_16x8x8_F16F16F16F16_TN>>;
        auto inst_m = size<0>(TiledMma::AtomShape_MNK{});
        auto inst_k = size<2>(TiledMma::AtomShape_MNK{});
        thrust::host_vector<half_t> h_A(inst_m * inst_k);

        for (int i = 0; i < inst_m * inst_k; ++i) h_A[i] = half_t(i);
        thrust::device_vector<half_t> d_A = h_A;

        mma_test<TiledMma><<<1, 32>>>(thrust::raw_pointer_cast(d_A.data()));

        printf("Multiplicand A(16x8):\n");
        for (int i = 0; i < inst_m; ++i) {
          for (int j = 0; j < inst_k; ++j) {
            // row-major matrix
            printf("%3.1f ", float(h_A[i * inst_k + j]));
          }
          printf("\n");
        }
}

Since I fill the matrix using naturally increasing elements and A is a row-major matrix, I think the thread0 should read 0, 1, 64, 65(as PTX ref:https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k8)

But the output is:

t0v0: 0.000000
t0v1: 16.000000
t0v2: 8.000000
t0v3: 24.000000

Shall we use row-major tv layout to encode the row-major multiplicand? Or is there something wrong in my code?

@ccecka
Copy link

ccecka commented Dec 3, 2023

The TV Layout in the MMAs are not "row-major" or "col-major", they describe the partitioning patterns of each instruction and can be applied to any Tensor with any Layout.

If you would like to use row-major data, then you should use row-major data:

Tensor gA = make_tensor(make_gmem_ptr(A), make_shape(inst_m, inst_k), make_stride(inst_k, Int<1>{}));
// or
Tensor gA = make_tensor(make_gmem_ptr(A), make_shape(inst_m, inst_k), GenRowMajor{});

@mammoth831
Copy link
Author

The TV Layout in the MMAs are not "row-major" or "col-major", they describe the partitioning patterns of each instruction and can be applied to any Tensor with any Layout.

If you would like to use row-major data, then you should use row-major data:

Tensor gA = make_tensor(make_gmem_ptr(A), make_shape(inst_m, inst_k), make_stride(inst_k, Int<1>{}));
// or
Tensor gA = make_tensor(make_gmem_ptr(A), make_shape(inst_m, inst_k), GenRowMajor{});

Thank you! It solved my problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants