Skip to content

Commit

Permalink
Fix "use after release" issues
Browse files Browse the repository at this point in the history
In some cases, we use handles after releasing them, or incorrectly
release handles we shouldn't. This doesn't cause any issues currently,
but will when we start using reference counting in the loader.
  • Loading branch information
RossBrunton committed Nov 26, 2024
1 parent 49a3d6c commit a07352d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 51 deletions.
20 changes: 14 additions & 6 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ namespace ur_validation_layer
%endif
%endfor

%for tp in tracked_params:
<%
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

%for tp in tracked_params:
Expand All @@ -114,15 +127,10 @@ namespace ur_validation_layer
}
}
%elif func_name in tp_handle_funcs['retain']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

Expand Down
88 changes: 44 additions & 44 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
}
}

ur_result_t result = pfnAdapterRelease(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hAdapter, true);
}

ur_result_t result = pfnAdapterRelease(hAdapter);

return result;
}

Expand All @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(

ur_result_t result = pfnAdapterRetain(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hAdapter, true);
}

Expand Down Expand Up @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(

ur_result_t result = pfnRetain(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hDevice, false);
}

Expand All @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
}
}

ur_result_t result = pfnRelease(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hDevice, false);
}

ur_result_t result = pfnRelease(hDevice);

return result;
}

Expand Down Expand Up @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(

ur_result_t result = pfnRetain(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hContext, false);
}

Expand All @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
}
}

ur_result_t result = pfnRelease(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hContext, false);
}

ur_result_t result = pfnRelease(hContext);

return result;
}

Expand Down Expand Up @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(

ur_result_t result = pfnRetain(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hMem, false);
}

Expand All @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
}
}

ur_result_t result = pfnRelease(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hMem, false);
}

ur_result_t result = pfnRelease(hMem);

return result;
}

Expand Down Expand Up @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(

ur_result_t result = pfnRetain(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hSampler, false);
}

Expand All @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
}
}

ur_result_t result = pfnRelease(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hSampler, false);
}

ur_result_t result = pfnRelease(hSampler);

return result;
}

Expand Down Expand Up @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(

ur_result_t result = pfnPoolRetain(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(pPool, false);
}

Expand All @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
}
}

ur_result_t result = pfnPoolRelease(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(pPool, false);
}

ur_result_t result = pfnPoolRelease(pPool);

return result;
}

Expand Down Expand Up @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(

ur_result_t result = pfnRetain(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
}

Expand All @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
}
}

ur_result_t result = pfnRelease(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
}

ur_result_t result = pfnRelease(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(

ur_result_t result = pfnRetain(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hProgram, false);
}

Expand All @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
}
}

ur_result_t result = pfnRelease(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hProgram, false);
}

ur_result_t result = pfnRelease(hProgram);

return result;
}

Expand Down Expand Up @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

ur_result_t result = pfnRetain(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hKernel, false);
}

Expand All @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
}
}

ur_result_t result = pfnRelease(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hKernel, false);
}

ur_result_t result = pfnRelease(hKernel);

return result;
}

Expand Down Expand Up @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(

ur_result_t result = pfnRetain(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hQueue, false);
}

Expand All @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
}
}

ur_result_t result = pfnRelease(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hQueue, false);
}

ur_result_t result = pfnRelease(hQueue);

return result;
}

Expand Down Expand Up @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(

ur_result_t result = pfnRetain(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hEvent, false);
}

Expand All @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
}
}

ur_result_t result = pfnRelease(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hEvent, false);
}

ur_result_t result = pfnRelease(hEvent);

return result;
}

Expand Down
1 change: 1 addition & 0 deletions test/conformance/adapter/urAdapterRelease.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest {

TEST_F(urAdapterReleaseTest, Success) {
uint32_t referenceCountBefore = 0;
ASSERT_SUCCESS(urAdapterRetain(adapter));

ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT,
sizeof(referenceCountBefore),
Expand Down
2 changes: 2 additions & 0 deletions test/conformance/device/urDeviceRelease.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {};

TEST_F(urDeviceReleaseTest, Success) {
for (auto device : devices) {
ASSERT_SUCCESS(urDeviceRetain(device));

uint32_t prevRefCount = 0;
ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount));

Expand Down
1 change: 0 additions & 1 deletion test/conformance/testing/include/uur/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ struct urDeviceTest : urPlatformTest,
}

void TearDown() override {
EXPECT_SUCCESS(urDeviceRelease(device));
UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::TearDown());
}

Expand Down

0 comments on commit a07352d

Please sign in to comment.