Skip to content

Commit

Permalink
bugfix: use actual sm count for num_sm90_ctas (#762)
Browse files Browse the repository at this point in the history
  • Loading branch information
LLLLKKKK authored Jan 29, 2025
1 parent 0e25eb2 commit e5a3bef
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,11 @@ inline cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_si
cta_tile_q = 192;
}

const int num_sm90_ctas = 132; // for sm90, the num_ctas is fixed
int device = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int num_sm90_ctas = 0;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm90_ctas, cudaDevAttrMultiProcessorCount, device));

CTACostHeap cta_cost_heap(num_sm90_ctas);
std::vector<std::vector<IdType>> cta_qo_tile_indices(num_sm90_ctas, std::vector<IdType>()),
Expand Down

0 comments on commit e5a3bef

Please sign in to comment.