From 635c81bef120e4c3198844088dc7626dc43be82a Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 26 Mar 2024 16:24:57 +0000 Subject: [PATCH 1/2] add dynamic shared mem support --- .../backends/cuda/providers/default/codegen/ptx.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 39c77d5a3..f4090eca8 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -43,6 +43,7 @@ using namespace mlir; #define BLOCK_SIZE_X_ATTR "BlockSize.x" #define BLOCK_SIZE_Y_ATTR "BlockSize.y" #define BLOCK_SIZE_Z_ATTR "BlockSize.z" +#define SHARED_MEMORY_SIZE "DynamicSharedMemorySize" #define ARG_RANKS_ATTR "arg_ranks" #define CALL_CONVENTION_ATTR "call_convention" @@ -92,6 +93,10 @@ struct PTXImpl { CUfunction func; auto status_func = ptx_compiler->GetOrCreateFunction( func, kernel_info.kernel_name, kernel_info.file_name); + size_t max_shared_mem = 48 << 10; + if (shared_size > max_shared_mem) { + cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_size); + } BRT_ENFORCE(status_func.IsOK(), status_func.ErrorMessage()); device2func.emplace(device_id, func); return func; @@ -170,11 +175,16 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) ranks.push_back(GetRankFromOpArgIndex(info_, i)); } } + int64_t dynamic_shm_size = 0; + if (info.GetOperation()->hasAttrOfType(SHARED_MEMORY_SIZE)) + { + dynamic_shm_size = info.GetOperation()->getAttrOfType(SHARED_MEMORY_SIZE).getInt(); + } auto num_arg = GetOpArgNum(info_); impl_->grid = dim3(gx, gy, gz); impl_->block = dim3(bx, by, bz); - impl_->shared_size = 0; + impl_->shared_size = dynamic_shm_size; impl_->arg_reserve_size = 3; // initial 3 for grid/block/shared_size // store tensor meta From 0caf93e650cb04bc9c198c56313b74170aa3db65 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 2 Apr 2024 12:38:22 +0000 Subject: [PATCH 2/2] clang formaty --- .../lib/backends/cuda/providers/default/codegen/ptx.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 62a1702b3..9a4acb4ee 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -95,7 +95,8 @@ struct PTXImpl { func, kernel_info.kernel_name, kernel_info.file_name); size_t max_shared_mem = 48 << 10; if (shared_size > max_shared_mem) { - cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_size); + cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_size); } BRT_ENFORCE(status_func.IsOK(), status_func.ErrorMessage()); device2func.emplace(device_id, func); @@ -176,9 +177,10 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) } } int64_t dynamic_shm_size = 0; - if (info.GetOperation()->hasAttrOfType(SHARED_MEMORY_SIZE)) - { - dynamic_shm_size = info.GetOperation()->getAttrOfType(SHARED_MEMORY_SIZE).getInt(); + if (info.GetOperation()->hasAttrOfType(SHARED_MEMORY_SIZE)) { + dynamic_shm_size = info.GetOperation() + ->getAttrOfType(SHARED_MEMORY_SIZE) + .getInt(); } auto num_arg = GetOpArgNum(info_);