Skip to content

Commit

Permalink
Update sgemm_wmma_tf32_stage.cu (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth authored Oct 14, 2024
1 parent cf4f9d7 commit ba4998d
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions sgemm/sgemm_wmma_tf32_stage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,27 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stage2(
constexpr int NUM_THREADS= (
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256

// constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128
// constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128
// constexpr int BK = WMMA_K; // 8
// constexpr int OFFSET=0;

// int dev_id = 0;
// cudaGetDevice(&dev_id);
// cudaDeviceProp dev_prop;
// cudaGetDeviceProperties(&dev_prop, dev_id);
// int smem_max_size = (K_STAGE * BM * (BK+OFFSET) * sizeof(float) +
// K_STAGE * BK * (BN+OFFSET) * sizeof(float));
// smem_max_size = (smem_max_size < dev_prop.sharedMemPerMultiprocessor ?
// smem_max_size : dev_prop.sharedMemPerMultiprocessor);

// cudaFuncSetAttribute(
// sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel<
// WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N,
// WARP_TILE_M, WARP_TILE_N, K_STAGE, 0>,
// cudaFuncAttributeMaxDynamicSharedMemorySize,
// smem_max_size);

dim3 block(NUM_THREADS);
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N),
div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));
Expand Down

0 comments on commit ba4998d

Please sign in to comment.