From 05587c58720bd307edbc69c49837c6c7946c56e5 Mon Sep 17 00:00:00 2001 From: Michael Kenzel Date: Fri, 4 Aug 2023 23:30:06 +0200 Subject: [PATCH] Add dynamic shared memory allocation --- src/thorin/be/c/c.cpp | 36 +++++++++++++++++++++++++--------- src/thorin/be/codegen.h | 1 + src/thorin/be/llvm/llvm.cpp | 28 +++++++++++++++++++------- src/thorin/be/llvm/llvm.h | 1 + src/thorin/be/llvm/runtime.cpp | 13 ++++++++---- src/thorin/be/llvm/runtime.h | 1 + src/thorin/be/llvm/runtime.inc | 2 +- 7 files changed, 61 insertions(+), 21 deletions(-) diff --git a/src/thorin/be/c/c.cpp b/src/thorin/be/c/c.cpp index 3f77582fd..3ad2bebe4 100644 --- a/src/thorin/be/c/c.cpp +++ b/src/thorin/be/c/c.cpp @@ -510,6 +510,9 @@ void CCodeGen::emit_module() { stream_.fmt("__device__ inline int blockDim_{}() {{ return blockDim.{}; }}\n", x, x); stream_.fmt("__device__ inline int gridDim_{}() {{ return gridDim.{}; }}\n", x, x); } + + stream_.fmt("\n" + "extern __shared__ unsigned char __dynamic_smem[];\n"); } stream_.endl() << func_impls_.str(); @@ -742,16 +745,25 @@ void CCodeGen::emit_epilogue(Continuation* cont) { world().edef(body->arg(1), "reserve_shared: couldn't extract memory size"); auto ret_cont = body->arg(2)->as_nom(); - auto elem_type = ret_cont->param(1)->type()->as()->pointee()->as()->elem_type(); - func_impls_.fmt("{}{} {}_reserved[{}];\n", - addr_space_prefix(AddrSpace::Shared), convert(elem_type), - cont->unique_name(), emit_constant(body->arg(1))); - if (lang_ == Lang::HLS && !hls_top_scope) { - func_impls_.fmt("#pragma HLS dependence variable={}_reserved inter false\n", cont->unique_name()); - func_impls_.fmt("#pragma HLS data_pack variable={}_reserved\n", cont->unique_name()); - func_impls_<< "#if defined( __VITIS_HLS__ )\n __attribute__((packed))\n #endif\n"; + + if (body->arg(1)->as()->ps32_value().data() == 0) { + auto ptr_type = ret_cont->param(1)->type()->as(); + bb.tail.fmt("p_{} = ({})__dynamic_smem;\n", ret_cont->param(1)->unique_name(), convert(ptr_type)); + } + else { + auto elem_type = ret_cont->param(1)->type()->as()->pointee()->as()->elem_type(); + + func_impls_.fmt("{}{} {}_reserved[{}];\n", + addr_space_prefix(AddrSpace::Shared), convert(elem_type), + cont->unique_name(), emit_constant(body->arg(1))); + + if (lang_ == Lang::HLS && !hls_top_scope) { + func_impls_.fmt("#pragma HLS dependence variable={}_reserved inter false\n", cont->unique_name()); + func_impls_.fmt("#pragma HLS data_pack variable={}_reserved\n", cont->unique_name()); + func_impls_<< "#if defined( __VITIS_HLS__ )\n __attribute__((packed))\n #endif\n"; + } + bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name()); } - bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name()); bb.tail.fmt("goto {};", label_name(ret_cont)); } else if (callee->intrinsic() == Intrinsic::Pipeline) { assert((lang_ == Lang::OpenCL || lang_ == Lang::HLS) && "pipelining not supported on this backend"); @@ -1435,6 +1447,12 @@ std::string CCodeGen::emit_fun_head(Continuation* cont, bool is_proto) { } needs_comma = true; } + + if (lang_ == Lang::OpenCL) { + if (needs_comma) s.fmt(", "); + s.fmt("__local unsigned char* __dynamic_smem"); + } + s << ")"; return s.str(); } diff --git a/src/thorin/be/codegen.h b/src/thorin/be/codegen.h index cd7c5689f..cbd0056b6 100644 --- a/src/thorin/be/codegen.h +++ b/src/thorin/be/codegen.h @@ -32,6 +32,7 @@ struct LaunchArgs { Device, Space, Config, + SMem, Body, Return, Num diff --git a/src/thorin/be/llvm/llvm.cpp b/src/thorin/be/llvm/llvm.cpp index 6f4e83191..76a59d769 100644 --- a/src/thorin/be/llvm/llvm.cpp +++ b/src/thorin/be/llvm/llvm.cpp @@ -1306,6 +1306,16 @@ Continuation* CodeGen::emit_reserve(llvm::IRBuilder<>&, const Continuation* cont THORIN_UNREACHABLE; } +llvm::GlobalVariable* CodeGen::emit_dynamic_shared_memory_allocation() { + static constexpr auto name = "__dynamic_smem"; + if (auto found = module().getGlobalVariable(name)) + return found; + auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(context()), 0); + auto global = new llvm::GlobalVariable(module(), type, false, llvm::GlobalValue::ExternalLinkage, nullptr, name, nullptr, llvm::GlobalVariable::NotThreadLocal, 3); + global->setAlignment(llvm::Align(16)); + return global; +} + Continuation* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Continuation* continuation, bool init_undef) { assert(continuation->has_body()); auto body = continuation->body(); @@ -1315,13 +1325,17 @@ Continuation* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const C auto num_elems = body->arg(1)->as()->ps32_value(); auto cont = body->arg(2)->as_nom(); auto type = convert(cont->param(1)->type()); - // construct array type - auto elem_type = cont->param(1)->type()->as()->pointee()->as()->elem_type(); - auto smem_type = this->convert(continuation->world().definite_array_type(elem_type, num_elems)); - auto name = continuation->unique_name(); - // NVVM doesn't allow '.' in global identifier - std::replace(name.begin(), name.end(), '.', '_'); - auto global = emit_global_variable(smem_type, name, 3, init_undef); + auto global = [&] { + if (num_elems.data() == 0) + return emit_dynamic_shared_memory_allocation(); + // construct array type + auto elem_type = cont->param(1)->type()->as()->pointee()->as()->elem_type(); + auto smem_type = this->convert(continuation->world().definite_array_type(elem_type, num_elems)); + auto name = continuation->unique_name(); + // NVVM doesn't allow '.' in global identifier + std::replace(name.begin(), name.end(), '.', '_'); + return emit_global_variable(smem_type, name, 3, init_undef); + }(); auto call = irbuilder.CreatePointerCast(global, type); emit_phi_arg(irbuilder, cont->param(1), call); return cont; diff --git a/src/thorin/be/llvm/llvm.h b/src/thorin/be/llvm/llvm.h index 43ceb5a75..611166e72 100644 --- a/src/thorin/be/llvm/llvm.h +++ b/src/thorin/be/llvm/llvm.h @@ -83,6 +83,7 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter&, const Continuation*); Continuation* emit_reserve_shared(llvm::IRBuilder<>&, const Continuation*, bool=false); + llvm::GlobalVariable* emit_dynamic_shared_memory_allocation(); virtual std::string get_alloc_name() const = 0; llvm::BasicBlock* cont2bb(Continuation* cont) { return cont2bb_[cont].first; } diff --git a/src/thorin/be/llvm/runtime.cpp b/src/thorin/be/llvm/runtime.cpp index 6b256002a..a3217c618 100644 --- a/src/thorin/be/llvm/runtime.cpp +++ b/src/thorin/be/llvm/runtime.cpp @@ -64,11 +64,13 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil assert(continuation->has_body()); auto body = continuation->body(); // to-target is the desired kernel call - // target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars) + // target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), smem, body, return, free_vars) auto target = body->callee()->as_nom(); assert_unused(target->is_intrinsic()); assert(body->num_args() >= LaunchArgs::Num && "required arguments are missing"); + auto& world = continuation->world(); + // arguments auto target_device_id = code_gen.emit(body->arg(LaunchArgs::Device)); auto target_platform = builder.getInt32(platform); @@ -76,9 +78,10 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil auto it_space = body->arg(LaunchArgs::Space); auto it_config = body->arg(LaunchArgs::Config); - auto kernel = body->arg(LaunchArgs::Body)->as()->init()->as(); - auto& world = continuation->world(); + auto smem = code_gen.emit(body->arg(LaunchArgs::SMem)); + + auto kernel = body->arg(LaunchArgs::Body)->as()->init()->as(); auto kernel_name = builder.CreateGlobalStringPtr(kernel->name() == "hls_top" ? kernel->name() : kernel->unique_name()); auto file_name = builder.CreateGlobalStringPtr(world.name() + ext); const size_t num_kernel_args = body->num_args() - LaunchArgs::Num; @@ -181,6 +184,7 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil launch_kernel(code_gen, builder, target_device, file_name, kernel_name, grid_size, block_size, + smem, args, sizes, aligns, allocs, types, builder.getInt32(num_kernel_args)); @@ -191,10 +195,11 @@ llvm::Value* Runtime::launch_kernel( CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* device, llvm::Value* file, llvm::Value* kernel, llvm::Value* grid, llvm::Value* block, + llvm::Value* smem, llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types, llvm::Value* num_args) { - llvm::Value* launch_args[] = { device, file, kernel, grid, block, args, sizes, aligns, allocs, types, num_args }; + llvm::Value* launch_args[] = { device, file, kernel, grid, block, smem, args, sizes, aligns, allocs, types, num_args }; return builder.CreateCall(get(code_gen, "anydsl_launch_kernel"), launch_args); } diff --git a/src/thorin/be/llvm/runtime.h b/src/thorin/be/llvm/runtime.h index 4ebf5f5ef..2767440d9 100644 --- a/src/thorin/be/llvm/runtime.h +++ b/src/thorin/be/llvm/runtime.h @@ -30,6 +30,7 @@ class Runtime { CodeGen&, llvm::IRBuilder<>&, llvm::Value* device, llvm::Value* file, llvm::Value* kernel, llvm::Value* grid, llvm::Value* block, + llvm::Value* smem, llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types, llvm::Value* num_args); diff --git a/src/thorin/be/llvm/runtime.inc b/src/thorin/be/llvm/runtime.inc index 44ee7f9f8..fe4ab4144 100644 --- a/src/thorin/be/llvm/runtime.inc +++ b/src/thorin/be/llvm/runtime.inc @@ -6,7 +6,7 @@ namespace thorin { declare noalias ptr @anydsl_alloc(i32, i64); declare noalias ptr @anydsl_alloc_unified(i32, i64); declare void @anydsl_release(i32, ptr); - declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i32); + declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, i32, ptr, ptr, ptr, ptr, ptr, i32); declare void @anydsl_parallel_for(i32, i32, i32, ptr, ptr); declare void @anydsl_fibers_spawn(i32, i32, i32, ptr, ptr); declare i32 @anydsl_spawn_thread(ptr, ptr);