From 6083ba0ca1ca75734a006315c5225d53f6bb3247 Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Wed, 4 Sep 2024 17:25:11 +0100 Subject: [PATCH] Use reference counting on factories Previously the factories used by ur_ldrddi (used when there are multiple backends) would add newly created objects to a map, but never release them. This patch adds reference counting semantics to the allocation, retention and release methods. A lot of changes were also made to fix use-after-free issues, specifically: * The `release` functions now no longer use the handle after freeing it. * `urDeviceTest` no longer frees devices that it dosen't own. * Some tests for reference counting now explicitly retain an extra copy before releasing them. No tests were added; this should be covered by tests for urThingRetain. Closes: #1784 . --- scripts/templates/ldrddi.cpp.mako | 16 ++++-- source/common/ur_singleton.hpp | 29 +++++++++-- source/loader/ur_ldrddi.cpp | 85 +++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 9 deletions(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 9c797a0ec3..1b7d19fa67 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -273,11 +273,17 @@ namespace ur_loader %endif %endif - ## Before we can re-enable the releases we will need ref-counted object_t. - ## See unified-runtime github issue #1784 - ##%if item['release']: - ##// release loader handle - ##${item['factory']}.release( ${item['name']} ); + ## Possibly handle release/retain ref counting - there are no ur_exp-image factories + %if 'factory' in item and '_exp_image_' not in item['factory']: + %if item['release']: + // release loader handle + context->factories.${item['factory']}.release( ${item['name']} ); + %endif + %if item['retain']: + // increment refcount of handle + context->factories.${item['factory']}.retain( ${item['name']} ); + %endif + %endif %if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { diff --git a/source/common/ur_singleton.hpp b/source/common/ur_singleton.hpp index b469c8b8a7..057d58c067 100644 --- a/source/common/ur_singleton.hpp +++ b/source/common/ur_singleton.hpp @@ -11,6 +11,7 @@ #ifndef UR_SINGLETON_H #define UR_SINGLETON_H 1 +#include #include #include #include @@ -18,13 +19,18 @@ ////////////////////////////////////////////////////////////////////////// /// a abstract factory for creation of singleton objects template class singleton_factory_t { + struct entry_t { + std::unique_ptr ptr; + size_t ref_count; + }; + protected: using singleton_t = singleton_tn; using key_t = typename std::conditional::value, size_t, key_tn>::type; using ptr_t = std::unique_ptr; - using map_t = std::unordered_map; + using map_t = std::unordered_map; std::mutex mut; ///< lock for thread-safety map_t map; ///< single instance of singleton for each unique key @@ -60,16 +66,31 @@ template class singleton_factory_t { if (map.end() == iter) { auto ptr = std::make_unique(std::forward(params)...); - iter = map.emplace(key, std::move(ptr)).first; + iter = map.emplace(key, entry_t{std::move(ptr), 0}).first; + } else { + iter->second.ref_count++; } - return iter->second.get(); + return iter->second.ptr.get(); + } + + void retain(key_tn key) { + std::lock_guard lk(mut); + auto iter = map.find(getKey(key)); + assert(iter != map.end()); + iter->second.ref_count++; } ////////////////////////////////////////////////////////////////////////// /// once the key is no longer valid, release the singleton void release(key_tn key) { std::lock_guard lk(mut); - map.erase(getKey(key)); + auto iter = map.find(getKey(key)); + assert(iter != map.end()); + if (iter->second.ref_count == 0) { + map.erase(iter); + } else { + iter->second.ref_count--; + } } void clear() { diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 86a6ad95a0..831f62f76d 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( // forward to device-platform result = pfnAdapterRelease(hAdapter); + // release loader handle + context->factories.ur_adapter_factory.release(hAdapter); + return result; } @@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( // forward to device-platform result = pfnAdapterRetain(hAdapter); + // increment refcount of handle + context->factories.ur_adapter_factory.retain(hAdapter); + return result; } @@ -614,6 +620,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( // forward to device-platform result = pfnRetain(hDevice); + // increment refcount of handle + context->factories.ur_device_factory.retain(hDevice); + return result; } @@ -640,6 +649,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( // forward to device-platform result = pfnRelease(hDevice); + // release loader handle + context->factories.ur_device_factory.release(hDevice); + return result; } @@ -910,6 +922,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( // forward to device-platform result = pfnRetain(hContext); + // increment refcount of handle + context->factories.ur_context_factory.retain(hContext); + return result; } @@ -936,6 +951,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( // forward to device-platform result = pfnRelease(hContext); + // release loader handle + context->factories.ur_context_factory.release(hContext); + return result; } @@ -1238,6 +1256,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( // forward to device-platform result = pfnRetain(hMem); + // increment refcount of handle + context->factories.ur_mem_factory.retain(hMem); + return result; } @@ -1264,6 +1285,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( // forward to device-platform result = pfnRelease(hMem); + // release loader handle + context->factories.ur_mem_factory.release(hMem); + return result; } @@ -1615,6 +1639,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( // forward to device-platform result = pfnRetain(hSampler); + // increment refcount of handle + context->factories.ur_sampler_factory.retain(hSampler); + return result; } @@ -1641,6 +1668,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( // forward to device-platform result = pfnRelease(hSampler); + // release loader handle + context->factories.ur_sampler_factory.release(hSampler); + return result; } @@ -2074,6 +2104,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( // forward to device-platform result = pfnPoolRetain(pPool); + // increment refcount of handle + context->factories.ur_usm_pool_factory.retain(pPool); + return result; } @@ -2099,6 +2132,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( // forward to device-platform result = pfnPoolRelease(pPool); + // release loader handle + context->factories.ur_usm_pool_factory.release(pPool); + return result; } @@ -2484,6 +2520,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( // forward to device-platform result = pfnRetain(hPhysicalMem); + // increment refcount of handle + context->factories.ur_physical_mem_factory.retain(hPhysicalMem); + return result; } @@ -2512,6 +2551,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( // forward to device-platform result = pfnRelease(hPhysicalMem); + // release loader handle + context->factories.ur_physical_mem_factory.release(hPhysicalMem); + return result; } @@ -2759,6 +2801,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( // forward to device-platform result = pfnRetain(hProgram); + // increment refcount of handle + context->factories.ur_program_factory.retain(hProgram); + return result; } @@ -2785,6 +2830,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( // forward to device-platform result = pfnRelease(hProgram); + // release loader handle + context->factories.ur_program_factory.release(hProgram); + return result; } @@ -3382,6 +3430,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( // forward to device-platform result = pfnRetain(hKernel); + // increment refcount of handle + context->factories.ur_kernel_factory.retain(hKernel); + return result; } @@ -3408,6 +3459,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( // forward to device-platform result = pfnRelease(hKernel); + // release loader handle + context->factories.ur_kernel_factory.release(hKernel); + return result; } @@ -3858,6 +3912,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( // forward to device-platform result = pfnRetain(hQueue); + // increment refcount of handle + context->factories.ur_queue_factory.retain(hQueue); + return result; } @@ -3884,6 +3941,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( // forward to device-platform result = pfnRelease(hQueue); + // release loader handle + context->factories.ur_queue_factory.release(hQueue); + return result; } @@ -4188,6 +4248,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( // forward to device-platform result = pfnRetain(hEvent); + // increment refcount of handle + context->factories.ur_event_factory.retain(hEvent); + return result; } @@ -4213,6 +4276,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( // forward to device-platform result = pfnRelease(hEvent); + // release loader handle + context->factories.ur_event_factory.release(hEvent); + return result; } @@ -6745,6 +6811,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp( // forward to device-platform result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem); + // release loader handle + context->factories.ur_exp_external_mem_factory.release(hExternalMem); + return result; } @@ -6835,6 +6904,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp( result = pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore); + // release loader handle + context->factories.ur_exp_external_semaphore_factory.release( + hExternalSemaphore); + return result; } @@ -7062,6 +7135,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( // forward to device-platform result = pfnRetainExp(hCommandBuffer); + // increment refcount of handle + context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer); + return result; } @@ -7092,6 +7168,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( // forward to device-platform result = pfnReleaseExp(hCommandBuffer); + // release loader handle + context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer); + return result; } @@ -8408,6 +8487,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( // forward to device-platform result = pfnRetainCommandExp(hCommand); + // increment refcount of handle + context->factories.ur_exp_command_buffer_command_factory.retain(hCommand); + return result; } @@ -8439,6 +8521,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( // forward to device-platform result = pfnReleaseCommandExp(hCommand); + // release loader handle + context->factories.ur_exp_command_buffer_command_factory.release(hCommand); + return result; }