Skip to content

Commit

Permalink
Merge pull request #1952 from kbenzie/benie/bounds-checking-off-by-de…
Browse files Browse the repository at this point in the history
…fault

Make USM parameter bounds checking configurable
  • Loading branch information
omarahmed1111 authored Aug 13, 2024
2 parents 26f1dfc + 216d30e commit 6c98e0e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 68 deletions.
2 changes: 2 additions & 0 deletions scripts/core/INTRO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ Layers currently included with the runtime are as follows:
- Description
* - UR_LAYER_PARAMETER_VALIDATION
- Enables non-adapter-specific parameter validation (e.g. checking for null values).
* - UR_LAYER_BOUNDS_CHECKING
- Enables non-adapter-specific bounds checking of USM allocations for enqueued commands. Automatically enables UR_LAYER_PARAMETER_VALIDATION.
* - UR_LAYER_LEAK_CHECKING
- Performs some leak checking for API calls involving object creation/destruction.
* - UR_LAYER_LIFETIME_VALIDATION
Expand Down
14 changes: 13 additions & 1 deletion scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,16 @@ namespace ur_validation_layer
{
%for key, values in sorted_param_checks:
%for val in values:
if( ${val} )
%if 'boundsError' in val:
if ( getContext()->enableBoundsChecking ) {
if ( ${val} ) {
return ${key};
}
}
%else:
if ( ${val} )
return ${key};
%endif

%endfor
%endfor
Expand Down Expand Up @@ -178,9 +186,13 @@ namespace ur_validation_layer

if (enabledLayerNames.count(nameFullValidation)) {
enableParameterValidation = true;
enableBoundsChecking = true;
enableLeakChecking = true;
enableLifetimeValidation = true;
} else {
if (enabledLayerNames.count(nameBoundsChecking)) {
enableBoundsChecking = true;
}
if (enabledLayerNames.count(nameParameterValidation)) {
enableParameterValidation = true;
}
Expand Down
180 changes: 114 additions & 66 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4822,9 +4822,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -4902,9 +4904,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5033,9 +5037,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5168,9 +5174,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBuffer, bufferOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5248,14 +5256,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hBufferSrc, srcOffset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBufferSrc, srcOffset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (auto boundsError = bounds(hBufferDst, dstOffset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBufferDst, dstOffset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5383,14 +5395,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = bounds(hBufferSrc, srcOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBufferSrc, srcOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (auto boundsError = bounds(hBufferDst, dstOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBufferDst, dstOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5492,9 +5508,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5579,9 +5597,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = boundsImage(hImage, origin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = boundsImage(hImage, origin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5667,9 +5687,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = boundsImage(hImage, origin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = boundsImage(hImage, origin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5756,14 +5778,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = boundsImage(hImageSrc, srcOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = boundsImage(hImageSrc, srcOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (auto boundsError = boundsImage(hImageDst, dstOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = boundsImage(hImageDst, dstOrigin, region);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -5850,9 +5876,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hBuffer, offset, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -6012,9 +6040,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hQueue, pMem, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pMem, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -6089,14 +6119,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hQueue, pDst, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pDst, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (auto boundsError = bounds(hQueue, pSrc, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pSrc, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -6169,9 +6203,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hQueue, pMem, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pMem, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -6230,9 +6266,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise(
return UR_RESULT_ERROR_INVALID_SIZE;
}

if (auto boundsError = bounds(hQueue, pMem, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pMem, 0, size);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}
}

Expand Down Expand Up @@ -6332,9 +6370,11 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hQueue, pMem, 0, pitch * height);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pMem, 0, pitch * height);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -6431,14 +6471,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST;
}

if (auto boundsError = bounds(hQueue, pDst, 0, dstPitch * height);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pDst, 0, dstPitch * height);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (auto boundsError = bounds(hQueue, pSrc, 0, srcPitch * height);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
if (getContext()->enableBoundsChecking) {
if (auto boundsError = bounds(hQueue, pSrc, 0, srcPitch * height);
boundsError != UR_RESULT_SUCCESS) {
return boundsError;
}
}

if (phEventWaitList != NULL && numEventsInWaitList > 0) {
Expand Down Expand Up @@ -10997,9 +11041,13 @@ ur_result_t context_t::init(ur_dditable_t *dditable,

if (enabledLayerNames.count(nameFullValidation)) {
enableParameterValidation = true;
enableBoundsChecking = true;
enableLeakChecking = true;
enableLifetimeValidation = true;
} else {
if (enabledLayerNames.count(nameBoundsChecking)) {
enableBoundsChecking = true;
}
if (enabledLayerNames.count(nameParameterValidation)) {
enableParameterValidation = true;
}
Expand Down
5 changes: 4 additions & 1 deletion source/loader/layers/validation/ur_validation_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class __urdlllocal context_t : public proxy_layer_context_t,
public AtomicSingleton<context_t> {
public:
bool enableParameterValidation = false;
bool enableBoundsChecking = false;
bool enableLeakChecking = false;
bool enableLifetimeValidation = false;
logger::Logger logger;
Expand All @@ -35,7 +36,7 @@ class __urdlllocal context_t : public proxy_layer_context_t,

static std::vector<std::string> getNames() {
return {nameFullValidation, nameParameterValidation, nameLeakChecking,
nameLifetimeValidation};
nameBoundsChecking, nameLifetimeValidation};
}
ur_result_t init(ur_dditable_t *dditable,
const std::set<std::string> &enabledLayerNames,
Expand All @@ -49,6 +50,8 @@ class __urdlllocal context_t : public proxy_layer_context_t,
"UR_LAYER_FULL_VALIDATION";
inline static const std::string nameParameterValidation =
"UR_LAYER_PARAMETER_VALIDATION";
inline static const std::string nameBoundsChecking =
"UR_LAYER_BOUNDS_CHECKING";
inline static const std::string nameLeakChecking = "UR_LAYER_LEAK_CHECKING";
inline static const std::string nameLifetimeValidation =
"UR_LAYER_LIFETIME_VALIDATION";
Expand Down

0 comments on commit 6c98e0e

Please sign in to comment.