Skip to content

Commit

Permalink
Add dynamic shared memory allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-kenzel committed Aug 9, 2023
1 parent d360324 commit 05587c5
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 21 deletions.
36 changes: 27 additions & 9 deletions src/thorin/be/c/c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ void CCodeGen::emit_module() {
stream_.fmt("__device__ inline int blockDim_{}() {{ return blockDim.{}; }}\n", x, x);
stream_.fmt("__device__ inline int gridDim_{}() {{ return gridDim.{}; }}\n", x, x);
}

stream_.fmt("\n"
"extern __shared__ unsigned char __dynamic_smem[];\n");
}

stream_.endl() << func_impls_.str();
Expand Down Expand Up @@ -742,16 +745,25 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");

auto ret_cont = body->arg(2)->as_nom<Continuation>();
auto elem_type = ret_cont->param(1)->type()->as<PtrType>()->pointee()->as<ArrayType>()->elem_type();
func_impls_.fmt("{}{} {}_reserved[{}];\n",
addr_space_prefix(AddrSpace::Shared), convert(elem_type),
cont->unique_name(), emit_constant(body->arg(1)));
if (lang_ == Lang::HLS && !hls_top_scope) {
func_impls_.fmt("#pragma HLS dependence variable={}_reserved inter false\n", cont->unique_name());
func_impls_.fmt("#pragma HLS data_pack variable={}_reserved\n", cont->unique_name());
func_impls_<< "#if defined( __VITIS_HLS__ )\n __attribute__((packed))\n #endif\n";

if (body->arg(1)->as<PrimLit>()->ps32_value().data() == 0) {
auto ptr_type = ret_cont->param(1)->type()->as<PtrType>();
bb.tail.fmt("p_{} = ({})__dynamic_smem;\n", ret_cont->param(1)->unique_name(), convert(ptr_type));
}
else {
auto elem_type = ret_cont->param(1)->type()->as<PtrType>()->pointee()->as<ArrayType>()->elem_type();

func_impls_.fmt("{}{} {}_reserved[{}];\n",
addr_space_prefix(AddrSpace::Shared), convert(elem_type),
cont->unique_name(), emit_constant(body->arg(1)));

if (lang_ == Lang::HLS && !hls_top_scope) {
func_impls_.fmt("#pragma HLS dependence variable={}_reserved inter false\n", cont->unique_name());
func_impls_.fmt("#pragma HLS data_pack variable={}_reserved\n", cont->unique_name());
func_impls_<< "#if defined( __VITIS_HLS__ )\n __attribute__((packed))\n #endif\n";
}
bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name());
}
bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name());
bb.tail.fmt("goto {};", label_name(ret_cont));
} else if (callee->intrinsic() == Intrinsic::Pipeline) {
assert((lang_ == Lang::OpenCL || lang_ == Lang::HLS) && "pipelining not supported on this backend");
Expand Down Expand Up @@ -1435,6 +1447,12 @@ std::string CCodeGen::emit_fun_head(Continuation* cont, bool is_proto) {
}
needs_comma = true;
}

if (lang_ == Lang::OpenCL) {
if (needs_comma) s.fmt(", ");
s.fmt("__local unsigned char* __dynamic_smem");
}

s << ")";
return s.str();
}
Expand Down
1 change: 1 addition & 0 deletions src/thorin/be/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct LaunchArgs {
Device,
Space,
Config,
SMem,
Body,
Return,
Num
Expand Down
28 changes: 21 additions & 7 deletions src/thorin/be/llvm/llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,16 @@ Continuation* CodeGen::emit_reserve(llvm::IRBuilder<>&, const Continuation* cont
THORIN_UNREACHABLE;
}

llvm::GlobalVariable* CodeGen::emit_dynamic_shared_memory_allocation() {
static constexpr auto name = "__dynamic_smem";
if (auto found = module().getGlobalVariable(name))
return found;
auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(context()), 0);
auto global = new llvm::GlobalVariable(module(), type, false, llvm::GlobalValue::ExternalLinkage, nullptr, name, nullptr, llvm::GlobalVariable::NotThreadLocal, 3);
global->setAlignment(llvm::Align(16));
return global;
}

Continuation* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Continuation* continuation, bool init_undef) {
assert(continuation->has_body());
auto body = continuation->body();
Expand All @@ -1315,13 +1325,17 @@ Continuation* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const C
auto num_elems = body->arg(1)->as<PrimLit>()->ps32_value();
auto cont = body->arg(2)->as_nom<Continuation>();
auto type = convert(cont->param(1)->type());
// construct array type
auto elem_type = cont->param(1)->type()->as<PtrType>()->pointee()->as<ArrayType>()->elem_type();
auto smem_type = this->convert(continuation->world().definite_array_type(elem_type, num_elems));
auto name = continuation->unique_name();
// NVVM doesn't allow '.' in global identifier
std::replace(name.begin(), name.end(), '.', '_');
auto global = emit_global_variable(smem_type, name, 3, init_undef);
auto global = [&] {
if (num_elems.data() == 0)
return emit_dynamic_shared_memory_allocation();
// construct array type
auto elem_type = cont->param(1)->type()->as<PtrType>()->pointee()->as<ArrayType>()->elem_type();
auto smem_type = this->convert(continuation->world().definite_array_type(elem_type, num_elems));
auto name = continuation->unique_name();
// NVVM doesn't allow '.' in global identifier
std::replace(name.begin(), name.end(), '.', '_');
return emit_global_variable(smem_type, name, 3, init_undef);
}();
auto call = irbuilder.CreatePointerCast(global, type);
emit_phi_arg(irbuilder, cont->param(1), call);
return cont;
Expand Down
1 change: 1 addition & 0 deletions src/thorin/be/llvm/llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter<llvm::Value*, llv

virtual Continuation* emit_reserve(llvm::IRBuilder<>&, const Continuation*);
Continuation* emit_reserve_shared(llvm::IRBuilder<>&, const Continuation*, bool=false);
llvm::GlobalVariable* emit_dynamic_shared_memory_allocation();

virtual std::string get_alloc_name() const = 0;
llvm::BasicBlock* cont2bb(Continuation* cont) { return cont2bb_[cont].first; }
Expand Down
13 changes: 9 additions & 4 deletions src/thorin/be/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,24 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil
assert(continuation->has_body());
auto body = continuation->body();
// to-target is the desired kernel call
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars)
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), smem, body, return, free_vars)
auto target = body->callee()->as_nom<Continuation>();
assert_unused(target->is_intrinsic());
assert(body->num_args() >= LaunchArgs::Num && "required arguments are missing");

auto& world = continuation->world();

// arguments
auto target_device_id = code_gen.emit(body->arg(LaunchArgs::Device));
auto target_platform = builder.getInt32(platform);
auto target_device = builder.CreateOr(target_platform, builder.CreateShl(target_device_id, builder.getInt32(4)));

auto it_space = body->arg(LaunchArgs::Space);
auto it_config = body->arg(LaunchArgs::Config);
auto kernel = body->arg(LaunchArgs::Body)->as<Global>()->init()->as<Continuation>();

auto& world = continuation->world();
auto smem = code_gen.emit(body->arg(LaunchArgs::SMem));

auto kernel = body->arg(LaunchArgs::Body)->as<Global>()->init()->as<Continuation>();
auto kernel_name = builder.CreateGlobalStringPtr(kernel->name() == "hls_top" ? kernel->name() : kernel->unique_name());
auto file_name = builder.CreateGlobalStringPtr(world.name() + ext);
const size_t num_kernel_args = body->num_args() - LaunchArgs::Num;
Expand Down Expand Up @@ -181,6 +184,7 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil
launch_kernel(code_gen, builder, target_device,
file_name, kernel_name,
grid_size, block_size,
smem,
args, sizes, aligns, allocs, types,
builder.getInt32(num_kernel_args));

Expand All @@ -191,10 +195,11 @@ llvm::Value* Runtime::launch_kernel(
CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* device,
llvm::Value* file, llvm::Value* kernel,
llvm::Value* grid, llvm::Value* block,
llvm::Value* smem,
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
llvm::Value* num_args)
{
llvm::Value* launch_args[] = { device, file, kernel, grid, block, args, sizes, aligns, allocs, types, num_args };
llvm::Value* launch_args[] = { device, file, kernel, grid, block, smem, args, sizes, aligns, allocs, types, num_args };
return builder.CreateCall(get(code_gen, "anydsl_launch_kernel"), launch_args);
}

Expand Down
1 change: 1 addition & 0 deletions src/thorin/be/llvm/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Runtime {
CodeGen&, llvm::IRBuilder<>&, llvm::Value* device,
llvm::Value* file, llvm::Value* kernel,
llvm::Value* grid, llvm::Value* block,
llvm::Value* smem,
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
llvm::Value* num_args);

Expand Down
2 changes: 1 addition & 1 deletion src/thorin/be/llvm/runtime.inc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace thorin {
declare noalias ptr @anydsl_alloc(i32, i64);
declare noalias ptr @anydsl_alloc_unified(i32, i64);
declare void @anydsl_release(i32, ptr);
declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i32);
declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, i32, ptr, ptr, ptr, ptr, ptr, i32);
declare void @anydsl_parallel_for(i32, i32, i32, ptr, ptr);
declare void @anydsl_fibers_spawn(i32, i32, i32, ptr, ptr);
declare i32 @anydsl_spawn_thread(ptr, ptr);
Expand Down

0 comments on commit 05587c5

Please sign in to comment.