Skip to content

Commit

Permalink
Use reference counting on factories
Browse files Browse the repository at this point in the history
Previously the factories used by ur_ldrddi (used when there are multiple
backends) would add newly created objects to a map, but never release
them. This patch adds reference counting semantics to the allocation,
retention and release methods.

A lot of changes were also made to fix use-after-free issues,
specifically:
* The `release` functions now no longer use the handle after freeing
  it.
* `urDeviceTest` no longer frees devices that it dosen't own.
* Some tests for reference counting now explicitly retain an extra
  copy before releasing them.

No tests were added; this should be covered by tests for urThingRetain.

Closes: oneapi-src#1784 .
  • Loading branch information
RossBrunton committed Nov 27, 2024
1 parent a07352d commit 6083ba0
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 9 deletions.
16 changes: 11 additions & 5 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,17 @@ namespace ur_loader
%endif
%endif
## Before we can re-enable the releases we will need ref-counted object_t.
## See unified-runtime github issue #1784
##%if item['release']:
##// release loader handle
##${item['factory']}.release( ${item['name']} );
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
%if 'factory' in item and '_exp_image_' not in item['factory']:
%if item['release']:
// release loader handle
context->factories.${item['factory']}.release( ${item['name']} );
%endif
%if item['retain']:
// increment refcount of handle
context->factories.${item['factory']}.retain( ${item['name']} );
%endif
%endif
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
try
{
Expand Down
29 changes: 25 additions & 4 deletions source/common/ur_singleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
#ifndef UR_SINGLETON_H
#define UR_SINGLETON_H 1

#include <cassert>
#include <memory>
#include <mutex>
#include <unordered_map>

//////////////////////////////////////////////////////////////////////////
/// a abstract factory for creation of singleton objects
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
struct entry_t {
std::unique_ptr<singleton_tn> ptr;
size_t ref_count;
};

protected:
using singleton_t = singleton_tn;
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
size_t, key_tn>::type;

using ptr_t = std::unique_ptr<singleton_t>;
using map_t = std::unordered_map<key_t, ptr_t>;
using map_t = std::unordered_map<key_t, entry_t>;

std::mutex mut; ///< lock for thread-safety
map_t map; ///< single instance of singleton for each unique key
Expand Down Expand Up @@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
if (map.end() == iter) {
auto ptr =
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
iter = map.emplace(key, std::move(ptr)).first;
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
} else {
iter->second.ref_count++;
}
return iter->second.get();
return iter->second.ptr.get();
}

void retain(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
auto iter = map.find(getKey(key));
assert(iter != map.end());
iter->second.ref_count++;
}

//////////////////////////////////////////////////////////////////////////
/// once the key is no longer valid, release the singleton
void release(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
map.erase(getKey(key));
auto iter = map.find(getKey(key));
assert(iter != map.end());
if (iter->second.ref_count == 0) {
map.erase(iter);
} else {
iter->second.ref_count--;
}
}

void clear() {
Expand Down
85 changes: 85 additions & 0 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
// forward to device-platform
result = pfnAdapterRelease(hAdapter);

// release loader handle
context->factories.ur_adapter_factory.release(hAdapter);

return result;
}

Expand All @@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
// forward to device-platform
result = pfnAdapterRetain(hAdapter);

// increment refcount of handle
context->factories.ur_adapter_factory.retain(hAdapter);

return result;
}

Expand Down Expand Up @@ -614,6 +620,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
// forward to device-platform
result = pfnRetain(hDevice);

// increment refcount of handle
context->factories.ur_device_factory.retain(hDevice);

return result;
}

Expand All @@ -640,6 +649,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
// forward to device-platform
result = pfnRelease(hDevice);

// release loader handle
context->factories.ur_device_factory.release(hDevice);

return result;
}

Expand Down Expand Up @@ -910,6 +922,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
// forward to device-platform
result = pfnRetain(hContext);

// increment refcount of handle
context->factories.ur_context_factory.retain(hContext);

return result;
}

Expand All @@ -936,6 +951,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
// forward to device-platform
result = pfnRelease(hContext);

// release loader handle
context->factories.ur_context_factory.release(hContext);

return result;
}

Expand Down Expand Up @@ -1238,6 +1256,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
// forward to device-platform
result = pfnRetain(hMem);

// increment refcount of handle
context->factories.ur_mem_factory.retain(hMem);

return result;
}

Expand All @@ -1264,6 +1285,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
// forward to device-platform
result = pfnRelease(hMem);

// release loader handle
context->factories.ur_mem_factory.release(hMem);

return result;
}

Expand Down Expand Up @@ -1615,6 +1639,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
// forward to device-platform
result = pfnRetain(hSampler);

// increment refcount of handle
context->factories.ur_sampler_factory.retain(hSampler);

return result;
}

Expand All @@ -1641,6 +1668,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
// forward to device-platform
result = pfnRelease(hSampler);

// release loader handle
context->factories.ur_sampler_factory.release(hSampler);

return result;
}

Expand Down Expand Up @@ -2074,6 +2104,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
// forward to device-platform
result = pfnPoolRetain(pPool);

// increment refcount of handle
context->factories.ur_usm_pool_factory.retain(pPool);

return result;
}

Expand All @@ -2099,6 +2132,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
// forward to device-platform
result = pfnPoolRelease(pPool);

// release loader handle
context->factories.ur_usm_pool_factory.release(pPool);

return result;
}

Expand Down Expand Up @@ -2484,6 +2520,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
// forward to device-platform
result = pfnRetain(hPhysicalMem);

// increment refcount of handle
context->factories.ur_physical_mem_factory.retain(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2512,6 +2551,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
// forward to device-platform
result = pfnRelease(hPhysicalMem);

// release loader handle
context->factories.ur_physical_mem_factory.release(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2759,6 +2801,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
// forward to device-platform
result = pfnRetain(hProgram);

// increment refcount of handle
context->factories.ur_program_factory.retain(hProgram);

return result;
}

Expand All @@ -2785,6 +2830,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
// forward to device-platform
result = pfnRelease(hProgram);

// release loader handle
context->factories.ur_program_factory.release(hProgram);

return result;
}

Expand Down Expand Up @@ -3382,6 +3430,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
// forward to device-platform
result = pfnRetain(hKernel);

// increment refcount of handle
context->factories.ur_kernel_factory.retain(hKernel);

return result;
}

Expand All @@ -3408,6 +3459,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
// forward to device-platform
result = pfnRelease(hKernel);

// release loader handle
context->factories.ur_kernel_factory.release(hKernel);

return result;
}

Expand Down Expand Up @@ -3858,6 +3912,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
// forward to device-platform
result = pfnRetain(hQueue);

// increment refcount of handle
context->factories.ur_queue_factory.retain(hQueue);

return result;
}

Expand All @@ -3884,6 +3941,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
// forward to device-platform
result = pfnRelease(hQueue);

// release loader handle
context->factories.ur_queue_factory.release(hQueue);

return result;
}

Expand Down Expand Up @@ -4188,6 +4248,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
// forward to device-platform
result = pfnRetain(hEvent);

// increment refcount of handle
context->factories.ur_event_factory.retain(hEvent);

return result;
}

Expand All @@ -4213,6 +4276,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
// forward to device-platform
result = pfnRelease(hEvent);

// release loader handle
context->factories.ur_event_factory.release(hEvent);

return result;
}

Expand Down Expand Up @@ -6745,6 +6811,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp(
// forward to device-platform
result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem);

// release loader handle
context->factories.ur_exp_external_mem_factory.release(hExternalMem);

return result;
}

Expand Down Expand Up @@ -6835,6 +6904,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp(
result =
pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore);

// release loader handle
context->factories.ur_exp_external_semaphore_factory.release(
hExternalSemaphore);

return result;
}

Expand Down Expand Up @@ -7062,6 +7135,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp(
// forward to device-platform
result = pfnRetainExp(hCommandBuffer);

// increment refcount of handle
context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer);

return result;
}

Expand Down Expand Up @@ -7092,6 +7168,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp(
// forward to device-platform
result = pfnReleaseExp(hCommandBuffer);

// release loader handle
context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer);

return result;
}

Expand Down Expand Up @@ -8408,6 +8487,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp(
// forward to device-platform
result = pfnRetainCommandExp(hCommand);

// increment refcount of handle
context->factories.ur_exp_command_buffer_command_factory.retain(hCommand);

return result;
}

Expand Down Expand Up @@ -8439,6 +8521,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
// forward to device-platform
result = pfnReleaseCommandExp(hCommand);

// release loader handle
context->factories.ur_exp_command_buffer_command_factory.release(hCommand);

return result;
}

Expand Down

0 comments on commit 6083ba0

Please sign in to comment.