-
Notifications
You must be signed in to change notification settings - Fork 128
Pseudo code of fusion of a convolution (or GEMM) operator and an epilogue function
Chao Liu edited this page Oct 21, 2021
·
5 revisions
template <typename InDesc,
typename WeiDesc,
typename OutDesc,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename CEpilogueFunctor>
__global__ void implicit_gemm_forward_convolution(InDesc in_desc,
WeiDesc wei_desc,
OutDesc out_desc,
const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
CEpilogueFunctor c_epilogue_functor)
{
// transform In/Wei/Out tensor descriptors to A/B/C matrix descriptors
auto a_desc = transform_tensor_descriptor(in_desc, ...);
auto b_desc = transform_tensor_descriptor(wei_desc, ...);
auto c_desc = transform_tensor_descriptor(out_desc, ...);
// instantiate GridwiseGEMM with parameters:
// A/B/C (generic) matrix descriptor types
// A/B/C datatype
// Block and warp level matrix tile sizes
// other low level optimization parameters
auto gridwise_gemm = GridwiseGEMM<decltype(a_desc),
decltype(b_desc),
decltype(c_desc),
InDataType,
WeiDataType,
OutDataType,
CEpilogueFunctor>{};
gridwise_gemm.Run(a_m_k_desc, b_n_k_desc, c_m_n_desc, p_in, p_wei, p_out, c_epilogue_functor);
}
template <typename AGridDesc,
typename BGridDesc,
typename CGridDesc,
typename CBlockMapping,
typename ADataType,
typename BDataType,
typename CDataType,
int MPeBlock,
int NperBlock,
int KPerBlock,
typename CEpilogueFunctor>
struct GridwiseGEMM
{
// utility function for Gridwise GEMM
__host__ __device__ static bool CheckValidity(...){...}
__host__ __device__ static int CalculateGridSize(...){...}
// GEMM function
__device__ Run(AGridDesc a_grid_desc,
BGridDesc b_grid_desc,
CGridDesc c_grid_decs,
CBlockMapping c_block_mapping,
Cconst ADataType* p_a_grid,
const BDataType* p_b_grid,
const CDataType* p_c_grid,
CEpilogueFunctor c_epilogue_functor)
{
// each block calculates its own starting M/N index, based on block-ID
auto block_m_n_begin_idx = c_block_mapping(get_block_id());
int block_m_begin = block_m_begin_idx[0];
int block_n_begin = block_m_begin_idx[1];
// define tensor descriptors for A/B block matrix tiles in shared memory
auto a_block_desc = make_naive_tensor_descriptor_packed(make_tupe(KPerBlock, MPeBlock));
auto b_block_desc = make_naive_tensor_descriptor_packed(make_tupe(KPerBlock, NPeBlock));
// allocate shared memory for A/B block matrix tile
__shared__ p_a_block[a_block_desc.GetElementSpaceSize()];
__shared__ p_b_block[b_block_desc.GetElementSpaceSize()];
// instatiate BlockwiseGEMM with parameters:
// A/B (generic) block matrix descriptor types
// A/B/C datatype
// Block and warp level matrix tile sizes
// other low level optimization parameters
auto blockwise_gemm = BlockwiseGEMM<decltype(a_block_desc),
decltype(b_block_desc),
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
....>{};
// get XDlops-compliant C thread tensor descriptor and register buffer
auto c_thread_desc = blockwise_gemm.GetCThreadTensorDescriptor();
auto c_thread_buffer = blockwise_gemm.GetCThreadBuffer();
// instatiate BlockwiseTensorSliceTransfer operator for A/B block matrix
// tile with parameters:
// Source/Destination matrix descriptor types, "Source" is in global
// memory, "Destination" is in shared memory Slicing window sizes
// Source/Destination matrix datatypes
// other low level optimization parameters
// construct BlockwiseTensorSliceTransfer operator with
// Starting coordinate of source slice windwow ({block_m_begin, 0}) in A
// global matrix (in global memory)
// Starting coordinate of destiation slice windwow ({0, 0}) in A block tile matrix
// (in shared memory)
auto a_blockwise_transfer = BlockwiseTensorSliceTransfer<decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<KPerBlock, MPerBlock>,
....>{{block_m_begin, 0}, {0, 0}};
auto b_blockwise_transfer = BlockwiseTensorSliceTransfer<decltype(b_grid_desc),
decltype(b_block_desc),
Sequence<KPerBlock, NPerBlock>,
....>{{block_n_begin, 0}, {0, 0}}
// Preload A/B matrix tile into shared memory
{
a_blockwise_transfer.RunRead(a_grid_desc, p_a_grid);
b_blockwise_transfer.RunRead(b_grid_desc, p_b_grid);
a_blockwise_transfer.RunWrite(a_block_desc, p_a_block);
b_blockwise_transfer.RunWrite(b_block_desc, p_b_block);
}
// 2-stage pipeline blockwise GEMM
for(int k = 0; k < K; k += KPerBlock)
{
// move source slicing windows on A/B global matrix to next tile
a_blockwise_transfer.MoveSrcSliceWindow(a_grid_desc, {KPerBlock, 0});
b_blockwise_transfer.MoveSrcSliceWindow(b_grid_desc, {KPerBlock, 0});
// Prefetch next A/B matrix tile into register buffer
a_blockwise_transfer.RunRead(a_grid_desc, p_a_grid);
b_blockwise_transfer.RunRead(b_grid_desc, p_b_grid);
block_sync_shared_memory();
// do GEMM on current A/B matrix tile
blockwise_gemm.Run(p_a_block, p_b_block, c_thread_buffer);
block_sync_shared_memory();
// write next A/B matix into shared memory
a_blockwise_transfer.RunWrite(a_block_desc, p_a_block);
b_blockwise_transfer.RunWrite(b_block_desc, p_b_block);
}
// GEMM tail loop
{
block_sync_shared_memory();
blockwise_gemm.Run(p_a_block, p_b_block, c_thread_buffer);
}
// epilogue functor on C thread buffer
static_for<...>{}(
[&](auto i) { c_thread_buffer(i) = c_epilogue_functor(c_thread_buffer[i]); });
// instantiate C thread tensor transfer operator
// Source/Destination tensor descriptor types, "Source" is in register,
// "Destination" is in global memory Slicing window sizes
// Source/Destination tensor datatypes
// other low level optimization parameters
// construct ThreadwiseTensorSliceTransfer operator with
// Starting coordinate of source slice windwow ({0, 0, ...})
// in C thread tensor (in register)
// Starting coordinate of destiation slice windwow ({0, 0, ...})
// in C global matrix (in global memory)
auto threadwise_transfer = ThreadwiseTensorSliceTransfer<...>{...};
threadwise_transfer.Run(c_thread_desc, c_thread_buffer, c_grid_desc, p_c_grid);
}
};