diff --git a/sgemm/sgemm_wmma_tf32_stage.cu b/sgemm/sgemm_wmma_tf32_stage.cu index 06ea39cd..e0cfa479 100644 --- a/sgemm/sgemm_wmma_tf32_stage.cu +++ b/sgemm/sgemm_wmma_tf32_stage.cu @@ -283,6 +283,27 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stage2( constexpr int NUM_THREADS= ( WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + // constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128 + // constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128 + // constexpr int BK = WMMA_K; // 8 + // constexpr int OFFSET=0; + + // int dev_id = 0; + // cudaGetDevice(&dev_id); + // cudaDeviceProp dev_prop; + // cudaGetDeviceProperties(&dev_prop, dev_id); + // int smem_max_size = (K_STAGE * BM * (BK+OFFSET) * sizeof(float) + + // K_STAGE * BK * (BN+OFFSET) * sizeof(float)); + // smem_max_size = (smem_max_size < dev_prop.sharedMemPerMultiprocessor ? + // smem_max_size : dev_prop.sharedMemPerMultiprocessor); + + // cudaFuncSetAttribute( + // sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel< + // WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, + // WARP_TILE_M, WARP_TILE_N, K_STAGE, 0>, + // cudaFuncAttributeMaxDynamicSharedMemorySize, + // smem_max_size); + dim3 block(NUM_THREADS); dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));