Skip to content

wmmae/wmma_extension

Repository files navigation

WMMA API Extension

This extension provides features for

  • mapping between memory and fragment (primitive functions)
  • operationf for vectors
    • loading a vector as a fragment
    • storing a fragment as a vector
  • C++ interface for mma instructions [detail]
  • Error Correction (TCEC) for SGEMM emulation [detail]
  • arithmetic operators for fragments (+, -, *, /, fma) [detail]
  • utils [detail]
  • etc

without using extra shared memory.

Important

Please specify an appropriate virtual architecture for real GPU. For instance, a program which is compiled with -arch=sm_70 will not work correctly on Ampere GPUs.

Requirements

  • CUDA (10.2 or later)
  • C++ (17 or later)

Supported architectures / fragment

  • sm_70: ((16, 16, 16), fp16/fp32)
  • sm_75: ((16, 16, 16), fp16/fp32)
  • sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
  • sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
  • sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (wgmma instruction is not supported yet)

Functions

Primitive functions

foreach

This function calculates the mapping of the memory and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach<decltype(frag_b)>(
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
            const auto m = mem_index % 16;
            const auto n = mem_index / 16;
            for (unsigned i = 0; i < fragment_index_count; i++)
                frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n * 16 + m]);
        });

foreach_ij

This function calculates the mapping of the matrix element position (i,j) and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach_ij<decltype(frag_b)>(
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) {
            for (unsigned f = 0; f < fragment_index_count; f++)
                frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j * 16 + i]);
        });

foreach_v

For matrix A/B

This function calculates the mapping of a given vector and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_b)>(
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
            for (unsigned i = 0; i < fragment_index_count; i++)
                frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]);
        });
// is equivalent to `load_vector`

For accumulator

nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major,
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
            for (unsigned i = 0; i < fragment_index_count; i++)
                vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]);
        });
// is equivalent to `store_vector`

map

This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
unsigned tid_list[2];
unsigned fid_list[2];
unsigned list_size;
mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);
for (unsigned k = 0; k < list_size; k++) {
  if ((threadIdx.x & 0x1f) == tid_list[k]) {
    frag_b.x[fid_list[k]] = 3.0f;
  }
}

Functions for vector

Sample

#include <mma.h>
#include <wmma_extension/wmma_extension.hpp>

__global__ void kernel() {
    nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> frag_a;
    nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
    nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;

    __shared__ float vec16[16];

    mtk::wmma::load_vector(frag_a, vec16);
    mtk::wmma::load_vector(frag_b, vec16);

    nvcuda::wmma::fill_fragment(frag_c, 0.0f);
    nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);

    mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);
}

Other functions

make_identity_matrix / add_eye

load_matrix

  • Arguments
    • dst_fragment : Destination fragment (accumulator)
    • alpha : diagonal element

fill_zero

  • Argument
    • dst_fragment : Destination fragment

Debugging functions

print_fragment

This function output the elements of a fragment.

  • Arguments
    • frag : Target fragment
    • name : printing name of fragment (char*, optional)

Publication

@inproceedings{ootomo_wmmae_2023,
  author = {Ootomo, Hiroyuki and Yokota, Rio},
  title = {Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},
  year = {2023},
  series = {HPC Asia '23}
}

LICENSE

MIT