Skip to content

Pseudo code of fusion of a convolution (or GEMM) operator and an epilogue function

Chao Liu edited this page Oct 21, 2021 · 5 revisions

Transform a forward convolution operator into a GEMM operator

template <typename InDesc,
          typename WeiDesc,
          typename OutDesc,
          typename InDataType,
          typename WeiDataType,
          typename OutDataType,
          typename CEpilogueFunctor>
__global__ void implicit_gemm_forward_convolution(InDesc in_desc,
                                                  WeiDesc wei_desc,
                                                  OutDesc out_desc,
                                                  const InDataType* p_in,
                                                  const WeiDataType* p_wei,
                                                  OutDataType* p_out,
                                                  CEpilogueFunctor c_epilogue_functor)
{
    // transform In/Wei/Out tensor descriptors to A/B/C matrix descriptors
    auto a_desc = transform_tensor_descriptor(in_desc, ...);
    auto b_desc = transform_tensor_descriptor(wei_desc, ...);
    auto c_desc = transform_tensor_descriptor(out_desc, ...);

    // instantiate GridwiseGEMM with parameters:
    //    A/B/C (generic) matrix descriptor types
    //    A/B/C datatype
    //    Block and warp level matrix tile sizes
    //    other low level optimization parameters
    auto gridwise_gemm = GridwiseGEMM<decltype(a_desc),
                                      decltype(b_desc),
                                      decltype(c_desc),
                                      InDataType,
                                      WeiDataType,
                                      OutDataType,
                                      CEpilogueFunctor>{};

    gridwise_gemm.Run(a_m_k_desc, b_n_k_desc, c_m_n_desc, p_in, p_wei, p_out, c_epilogue_functor);
}

Execute a GEMM operator on GPU, with an epilogue element-wise operator applied to C matrix

template <typename AGridDesc,
          typename BGridDesc,
          typename CGridDesc,
          typename CBlockMapping,
          typename ADataType,
          typename BDataType,
          typename CDataType,
          int MPeBlock,
          int NperBlock,
          int KPerBlock,
          typename CEpilogueFunctor>
struct GridwiseGEMM
{
    // utility function for Gridwise GEMM
    __host__ __device__ static bool CheckValidity(...){...}

    __host__ __device__ static int CalculateGridSize(...){...}

    // GEMM function
    __device__ Run(AGridDesc a_grid_desc,
                   BGridDesc b_grid_desc,
                   CGridDesc c_grid_decs,
                   CBlockMapping c_block_mapping,
                   Cconst ADataType* p_a_grid,
                   const BDataType* p_b_grid,
                   const CDataType* p_c_grid,
                   CEpilogueFunctor c_epilogue_functor)
    {
        // each block calculates its own starting M/N index, based on block-ID
        auto block_m_n_begin_idx = c_block_mapping(get_block_id());
        int block_m_begin        = block_m_begin_idx[0];
        int block_n_begin        = block_m_begin_idx[1];

        // define tensor descriptors for A/B block matrix tiles in shared memory
        auto a_block_desc = make_naive_tensor_descriptor_packed(make_tupe(KPerBlock, MPeBlock));
        auto b_block_desc = make_naive_tensor_descriptor_packed(make_tupe(KPerBlock, NPeBlock));

        // allocate shared memory for A/B block matrix tile
        __shared__ p_a_block[a_block_desc.GetElementSpaceSize()];
        __shared__ p_b_block[b_block_desc.GetElementSpaceSize()];

        // instatiate BlockwiseGEMM with parameters:
        //    A/B (generic) block matrix descriptor types
        //    A/B/C datatype
        //    Block and warp level matrix tile sizes
        //    other low level optimization parameters
        auto blockwise_gemm = BlockwiseGEMM<decltype(a_block_desc),
                                            decltype(b_block_desc),
                                            ADataType,
                                            BDataType,
                                            CDataType,
                                            MPerBlock,
                                            NPerBlock,
                                            KPerBlock,
                                            ....>{};

        // get XDlops-compliant C thread tensor descriptor and register buffer
        auto c_thread_desc   = blockwise_gemm.GetCThreadTensorDescriptor();
        auto c_thread_buffer = blockwise_gemm.GetCThreadBuffer();

        // instatiate BlockwiseTensorSliceTransfer operator for A/B block matrix
        // tile with parameters:
        //     Source/Destination matrix descriptor types, "Source" is in global
        //         memory, "Destination" is in shared memory Slicing window sizes
        //     Source/Destination matrix datatypes
        //     other low level optimization parameters
        // construct BlockwiseTensorSliceTransfer operator with
        //     Starting coordinate of source slice windwow ({block_m_begin, 0}) in A
        //         global matrix (in global memory)
        //     Starting coordinate of destiation slice windwow ({0, 0}) in A block tile matrix
        //         (in shared memory)
        auto a_blockwise_transfer = BlockwiseTensorSliceTransfer<decltype(a_grid_desc),
                                                                 decltype(a_block_desc),
                                                                 Sequence<KPerBlock, MPerBlock>,
                                                                 ....>{{block_m_begin, 0}, {0, 0}};

        auto b_blockwise_transfer = BlockwiseTensorSliceTransfer<decltype(b_grid_desc),
                                                                 decltype(b_block_desc),
                                                                 Sequence<KPerBlock, NPerBlock>,
                                                                 ....>{{block_n_begin, 0}, {0, 0}}

        // Preload A/B matrix tile into shared memory
        {
            a_blockwise_transfer.RunRead(a_grid_desc, p_a_grid);
            b_blockwise_transfer.RunRead(b_grid_desc, p_b_grid);

            a_blockwise_transfer.RunWrite(a_block_desc, p_a_block);
            b_blockwise_transfer.RunWrite(b_block_desc, p_b_block);
        }

        // 2-stage pipeline blockwise GEMM
        for(int k = 0; k < K; k += KPerBlock)
        {
            // move source slicing windows on A/B global matrix to next tile
            a_blockwise_transfer.MoveSrcSliceWindow(a_grid_desc, {KPerBlock, 0});
            b_blockwise_transfer.MoveSrcSliceWindow(b_grid_desc, {KPerBlock, 0});

            // Prefetch next A/B matrix tile into register buffer
            a_blockwise_transfer.RunRead(a_grid_desc, p_a_grid);
            b_blockwise_transfer.RunRead(b_grid_desc, p_b_grid);

            block_sync_shared_memory();

            // do GEMM on current A/B matrix tile
            blockwise_gemm.Run(p_a_block, p_b_block, c_thread_buffer);

            block_sync_shared_memory();

            // write next A/B matix into shared memory
            a_blockwise_transfer.RunWrite(a_block_desc, p_a_block);
            b_blockwise_transfer.RunWrite(b_block_desc, p_b_block);
        }

        // GEMM tail loop
        {
            block_sync_shared_memory();

            blockwise_gemm.Run(p_a_block, p_b_block, c_thread_buffer);
        }

        // epilogue functor on C thread buffer
        static_for<...>{}(
            [&](auto i) { c_thread_buffer(i) = c_epilogue_functor(c_thread_buffer[i]); });

        // instantiate C thread tensor transfer operator
        //     Source/Destination tensor descriptor types, "Source" is in register,
        //         "Destination" is in global memory Slicing window sizes
        //     Source/Destination tensor datatypes
        //     other low level optimization parameters
        // construct ThreadwiseTensorSliceTransfer operator with
        //     Starting coordinate of source slice windwow ({0, 0, ...})
        //         in C thread tensor (in register)
        //     Starting coordinate of destiation slice windwow ({0, 0, ...})
        //         in C global matrix (in global memory)
        auto threadwise_transfer = ThreadwiseTensorSliceTransfer<...>{...};

        threadwise_transfer.Run(c_thread_desc, c_thread_buffer, c_grid_desc, p_c_grid);
    }
};