Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Add support for two-way type-cast from dpct::kernel_library and dpct::kernel_function to uint64_t conversion #2606

Merged
merged 8 commits into from
Jan 10, 2025
11 changes: 7 additions & 4 deletions clang/runtime/dpct-rt/include/dpct/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,10 @@ class kernel_library {
public:
constexpr kernel_library() : ptr{nullptr} {}
constexpr kernel_library(void *ptr) : ptr{ptr} {}
kernel_library(uint64_t addr) : ptr(reinterpret_cast<void *>(addr)) {}

operator void *() const { return ptr; }
explicit operator uint64_t() const { return reinterpret_cast<uint64_t>(ptr); }

private:
void *ptr;
Expand Down Expand Up @@ -393,15 +395,16 @@ class kernel_function {
public:
constexpr kernel_function() : ptr{nullptr} {}
constexpr kernel_function(dpct::kernel_functor ptr) : ptr{ptr} {}
kernel_function(uint64_t addr) : ptr(reinterpret_cast<dpct::kernel_functor>(addr)) {}

operator void *() const { return ((void *)ptr); }

void operator()(sycl::queue &q, const sycl::nd_range<3> &range,
unsigned int a, void **args, void **extra) {
unsigned int a, void **args, void **extra) const {
ptr(q, range, a, args, extra);
}

explicit operator uint64_t() const { return (uint64_t)this; }
explicit operator uint64_t() const { return reinterpret_cast<uint64_t>(this); }
the-slow-one marked this conversation as resolved.
Show resolved Hide resolved

private:
dpct::kernel_functor ptr;
Expand All @@ -411,7 +414,7 @@ class kernel_function {
/// \param [in] library Handle to the kernel library.
/// \param [in] name Name of the kernel function.
static inline dpct::kernel_function
get_kernel_function(kernel_library &library, const std::string &name) {
get_kernel_function(const kernel_library &library, const std::string &name) {
#ifdef _WIN32
dpct::kernel_functor fn = reinterpret_cast<dpct::kernel_functor>(
GetProcAddress(static_cast<HMODULE>(static_cast<void *>(library)),
Expand All @@ -434,7 +437,7 @@ get_kernel_function(kernel_library &library, const std::string &name) {
/// function.
/// \param [in] kernelParams Array of pointers to kernel arguments.
/// \param [in] extra Extra arguments.
static inline void invoke_kernel_function(dpct::kernel_function &function,
static inline void invoke_kernel_function(const dpct::kernel_function &function,
sycl::queue &queue,
sycl::range<3> groupRange,
sycl::range<3> localRange,
Expand Down
26 changes: 19 additions & 7 deletions clang/test/dpct/kernel-function-typecast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@

typedef uint64_t u64;

// CHECK: u64 foo(dpct::kernel_function cuFunc, dpct::kernel_library cuMod) {
u64 foo(CUfunction cuFunc, CUmodule cuMod) {
// CHECK: cuFunc = dpct::get_kernel_function(cuMod, "kfoo");
cuModuleGetFunction(&cuFunc, cuMod, "kfoo");
u64 function = (u64)cuFunc;
// CHECK: void exec_kernel(dpct::kernel_function cuFunc, dpct::kernel_library cuMod, dpct::queue_ptr stream) {
void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
u64 mod;
u64 function;

return function;
}
// verify the conversion from dpct::kernel_library to uint64_t
mod = (u64)cuMod;

// verify the conversion from uint64_t to dpct::kernel_library
// CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, "kfoo");
cuModuleGetFunction(&cuFunc, (CUmodule)mod, "kfoo");

// verify the conversion from dpct::kernel_function to uint64_t
function = (u64)cuFunc;

void *config[] = {0};

// verify the conversion from uint64_t to dpct::kernel_function
// CHECK: dpct::invoke_kernel_function((dpct::kernel_function)function, *stream, sycl::range<3>(100, 100, 100), sycl::range<3>(100, 100, 100), 1024, NULL, config);
cuLaunchKernel((CUfunction)function, 100, 100, 100, 100, 100, 100, 1024, stream, NULL, config);
}
Loading