diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 25c14c9bf..9a4acb4ee 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -43,6 +43,7 @@ using namespace mlir; #define BLOCK_SIZE_X_ATTR "BlockSize.x" #define BLOCK_SIZE_Y_ATTR "BlockSize.y" #define BLOCK_SIZE_Z_ATTR "BlockSize.z" +#define SHARED_MEMORY_SIZE "DynamicSharedMemorySize" #define ARG_RANKS_ATTR "arg_ranks" #define CALL_CONVENTION_ATTR "call_convention" @@ -92,6 +93,11 @@ struct PTXImpl { CUfunction func; auto status_func = ptx_compiler->GetOrCreateFunction( func, kernel_info.kernel_name, kernel_info.file_name); + size_t max_shared_mem = 48 << 10; + if (shared_size > max_shared_mem) { + cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_size); + } BRT_ENFORCE(status_func.IsOK(), status_func.ErrorMessage()); device2func.emplace(device_id, func); return func; @@ -170,11 +176,17 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) ranks.push_back(GetRankFromOpArgIndex(info_, i)); } } + int64_t dynamic_shm_size = 0; + if (info.GetOperation()->hasAttrOfType(SHARED_MEMORY_SIZE)) { + dynamic_shm_size = info.GetOperation() + ->getAttrOfType(SHARED_MEMORY_SIZE) + .getInt(); + } auto num_arg = GetOpArgNum(info_); impl_->grid = dim3(gx, gy, gz); impl_->block = dim3(bx, by, bz); - impl_->shared_size = 0; + impl_->shared_size = dynamic_shm_size; impl_->arg_reserve_size = 3; // initial 3 for grid/block/shared_size // store tensor meta