Skip to content

Commit

Permalink
Use reference counting on factories
Browse files Browse the repository at this point in the history
Previously the factories used by ur_ldrddi (used when there are multiple
backends) would add newly created objects to a map, but never release
them. This patch adds reference counting semantics to the allocation,
retention and release methods.

A lot of changes were also made to fix use-after-free issues,
specifically:
* The `release` functions now no longer use the handle after freeing
  it.
* `urDeviceTest` no longer frees devices that it dosen't own.
* Some tests for reference counting now explicitly retain an extra
  copy before releasing them.

No tests were added; this should be covered by tests for urThingRetain.

Closes: oneapi-src#1784 .
  • Loading branch information
RossBrunton committed Nov 14, 2024
1 parent f9f71f1 commit ccab45f
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 60 deletions.
16 changes: 11 additions & 5 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,17 @@ namespace ur_loader
%endif
%endif
## Before we can re-enable the releases we will need ref-counted object_t.
## See unified-runtime github issue #1784
##%if item['release']:
##// release loader handle
##${item['factory']}.release( ${item['name']} );
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
%if 'factory' in item and '_exp_image_' not in item['factory']:
%if item['release']:
// release loader handle
context->factories.${item['factory']}.release( ${item['name']} );
%endif
%if item['retain']:
// increment refcount of handle
context->factories.${item['factory']}.retain( ${item['name']} );
%endif
%endif
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
try
{
Expand Down
20 changes: 14 additions & 6 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ namespace ur_validation_layer
%endif
%endfor

%for tp in tracked_params:
<%
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

%for tp in tracked_params:
Expand All @@ -114,15 +127,10 @@ namespace ur_validation_layer
}
}
%elif func_name in tp_handle_funcs['retain']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

Expand Down
29 changes: 25 additions & 4 deletions source/common/ur_singleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
#ifndef UR_SINGLETON_H
#define UR_SINGLETON_H 1

#include <cassert>
#include <memory>
#include <mutex>
#include <unordered_map>

//////////////////////////////////////////////////////////////////////////
/// a abstract factory for creation of singleton objects
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
struct entry_t {
std::unique_ptr<singleton_tn> ptr;
size_t ref_count;
};

protected:
using singleton_t = singleton_tn;
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
size_t, key_tn>::type;

using ptr_t = std::unique_ptr<singleton_t>;
using map_t = std::unordered_map<key_t, ptr_t>;
using map_t = std::unordered_map<key_t, entry_t>;

std::mutex mut; ///< lock for thread-safety
map_t map; ///< single instance of singleton for each unique key
Expand Down Expand Up @@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
if (map.end() == iter) {
auto ptr =
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
iter = map.emplace(key, std::move(ptr)).first;
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
} else {
iter->second.ref_count++;
}
return iter->second.get();
return iter->second.ptr.get();
}

void retain(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
auto iter = map.find(getKey(key));
assert(iter != map.end());
iter->second.ref_count++;
}

//////////////////////////////////////////////////////////////////////////
/// once the key is no longer valid, release the singleton
void release(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
map.erase(getKey(key));
auto iter = map.find(getKey(key));
assert(iter != map.end());
if (iter->second.ref_count == 0) {
map.erase(iter);
} else {
iter->second.ref_count--;
}
}

void clear() {
Expand Down
88 changes: 44 additions & 44 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
}
}

ur_result_t result = pfnAdapterRelease(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hAdapter, true);
}

ur_result_t result = pfnAdapterRelease(hAdapter);

return result;
}

Expand All @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(

ur_result_t result = pfnAdapterRetain(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hAdapter, true);
}

Expand Down Expand Up @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(

ur_result_t result = pfnRetain(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hDevice, false);
}

Expand All @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
}
}

ur_result_t result = pfnRelease(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hDevice, false);
}

ur_result_t result = pfnRelease(hDevice);

return result;
}

Expand Down Expand Up @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(

ur_result_t result = pfnRetain(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hContext, false);
}

Expand All @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
}
}

ur_result_t result = pfnRelease(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hContext, false);
}

ur_result_t result = pfnRelease(hContext);

return result;
}

Expand Down Expand Up @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(

ur_result_t result = pfnRetain(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hMem, false);
}

Expand All @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
}
}

ur_result_t result = pfnRelease(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hMem, false);
}

ur_result_t result = pfnRelease(hMem);

return result;
}

Expand Down Expand Up @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(

ur_result_t result = pfnRetain(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hSampler, false);
}

Expand All @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
}
}

ur_result_t result = pfnRelease(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hSampler, false);
}

ur_result_t result = pfnRelease(hSampler);

return result;
}

Expand Down Expand Up @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(

ur_result_t result = pfnPoolRetain(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(pPool, false);
}

Expand All @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
}
}

ur_result_t result = pfnPoolRelease(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(pPool, false);
}

ur_result_t result = pfnPoolRelease(pPool);

return result;
}

Expand Down Expand Up @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(

ur_result_t result = pfnRetain(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
}

Expand All @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
}
}

ur_result_t result = pfnRelease(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
}

ur_result_t result = pfnRelease(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(

ur_result_t result = pfnRetain(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hProgram, false);
}

Expand All @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
}
}

ur_result_t result = pfnRelease(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hProgram, false);
}

ur_result_t result = pfnRelease(hProgram);

return result;
}

Expand Down Expand Up @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

ur_result_t result = pfnRetain(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hKernel, false);
}

Expand All @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
}
}

ur_result_t result = pfnRelease(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hKernel, false);
}

ur_result_t result = pfnRelease(hKernel);

return result;
}

Expand Down Expand Up @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(

ur_result_t result = pfnRetain(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hQueue, false);
}

Expand All @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
}
}

ur_result_t result = pfnRelease(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hQueue, false);
}

ur_result_t result = pfnRelease(hQueue);

return result;
}

Expand Down Expand Up @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(

ur_result_t result = pfnRetain(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hEvent, false);
}

Expand All @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
}
}

ur_result_t result = pfnRelease(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hEvent, false);
}

ur_result_t result = pfnRelease(hEvent);

return result;
}

Expand Down
Loading

0 comments on commit ccab45f

Please sign in to comment.