Skip to content

Commit

Permalink
Merge pull request oneapi-src#2264 from Bensuo/ben/cmdbuf-local-arg-fix
Browse files Browse the repository at this point in the history
[CMDBUF] Fix incorrect handling of shared local mem args in CUDA/HIP
  • Loading branch information
callumfare authored Nov 6, 2024
2 parents f01741a + b7d78ba commit 5955bad
Show file tree
Hide file tree
Showing 9 changed files with 581 additions and 4 deletions.
2 changes: 1 addition & 1 deletion include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -8329,7 +8329,7 @@ typedef struct ur_exp_command_buffer_update_value_arg_desc_t {
///< ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC
const void *pNext; ///< [in][optional] pointer to extension-specific structure
uint32_t argIndex; ///< [in] Argument index.
uint32_t argSize; ///< [in] Argument size.
size_t argSize; ///< [in] Argument size.
const ur_kernel_arg_value_properties_t *pProperties; ///< [in][optional] Pointer to value properties.
const void *pNewValueArg; ///< [in][optional] Argument value representing matching kernel arg type to
///< set at argument index.
Expand Down
2 changes: 1 addition & 1 deletion scripts/core/exp-command-buffer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ members:
- type: uint32_t
name: argIndex
desc: "[in] Argument index."
- type: uint32_t
- type: size_t
name: argSize
desc: "[in] Argument size."
- type: "const ur_kernel_arg_value_properties_t *"
Expand Down
7 changes: 6 additions & 1 deletion source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,12 @@ updateKernelArguments(kernel_command_handle *Command,

ur_result_t Result = UR_RESULT_SUCCESS;
try {
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
// Local memory args are passed as value args with nullptr value
if (ArgValue) {
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
} else {
Kernel->setKernelLocalArg(ArgIndex, ArgSize);
}
} catch (ur_result_t Err) {
Result = Err;
return Result;
Expand Down
7 changes: 6 additions & 1 deletion source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,12 @@ updateKernelArguments(ur_exp_command_buffer_command_handle_t Command,
const void *ArgValue = ValueArgDesc.pNewValueArg;

try {
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
// Local memory args are passed as value args with nullptr value
if (ArgValue) {
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
} else {
Kernel->setKernelLocalArg(ArgIndex, ArgSize);
}
} catch (ur_result_t Err) {
return Err;
}
Expand Down
1 change: 1 addition & 0 deletions test/conformance/device_code/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/sequence.cpp)
add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/standard_types.cpp)
add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/subgroup.cpp)
add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/linker_error.cpp)
add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/saxpy_usm_local_mem.cpp)

set(KERNEL_HEADER ${UR_CONFORMANCE_DEVICE_BINARIES_DIR}/kernel_entry_points.h)
add_custom_command(OUTPUT ${KERNEL_HEADER}
Expand Down
30 changes: 30 additions & 0 deletions test/conformance/device_code/saxpy_usm_local_mem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2024 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <sycl/sycl.hpp>

int main() {
size_t array_size = 16;
size_t local_size = 4;
sycl::queue sycl_queue;
uint32_t *X = sycl::malloc_shared<uint32_t>(array_size, sycl_queue);
uint32_t *Y = sycl::malloc_shared<uint32_t>(array_size, sycl_queue);
uint32_t *Z = sycl::malloc_shared<uint32_t>(array_size, sycl_queue);
uint32_t A = 42;

sycl_queue.submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint32_t, 1> local_mem(local_size, cgh);
cgh.parallel_for<class saxpy_usm_local_mem>(
sycl::nd_range<1>{{array_size}, {local_size}},
[=](sycl::nd_item<1> itemId) {
auto i = itemId.get_global_linear_id();
auto local_id = itemId.get_local_linear_id();
local_mem[local_id] = i;
Z[i] = A * X[i] + Y[i] + local_mem[local_id] +
itemId.get_local_range(0);
});
});
return 0;
}
1 change: 1 addition & 0 deletions test/conformance/exp_command_buffer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ add_conformance_test_with_kernels_environment(exp_command_buffer
update/usm_saxpy_kernel_update.cpp
update/event_sync.cpp
update/kernel_event_sync.cpp
update/local_memory_update.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@
{{OPT}}KernelCommandEventSyncUpdateTest.TwoWaitEvents/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}KernelCommandEventSyncUpdateTest.InvalidWaitUpdate/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}KernelCommandEventSyncUpdateTest.InvalidSignalUpdate/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}LocalMemoryUpdateTest.UpdateParameters/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}LocalMemoryUpdateTest.UpdateParametersAndLocalSize/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}LocalMemoryMultiUpdateTest.UpdateParameters/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
{{OPT}}LocalMemoryMultiUpdateTest.UpdateWithoutBlocking/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}}
Loading

0 comments on commit 5955bad

Please sign in to comment.