From c8309abef95f361920f65fb0a68f43f2cfbe7698 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Tue, 9 Jul 2024 16:15:48 -0700 Subject: [PATCH] Fix DriverGet to handle failed drivers and avoid layer init during checks (#167) - During init, instrumentation may call driver get before uinitialized drivers can be removed. DriverGet has been updated to return the driver count only for valid drivers and update drivers that are not init to be skipped in subsequent DriverGet calls. - Dont run layer init during checks for driver init to avoid creating invalid layer ddi table calls. Signed-off-by: Neil R. Spruit --- scripts/templates/ldrddi.cpp.mako | 14 +++++++++-- source/lib/ze_lib.cpp | 6 ----- source/loader/ze_ldrddi.cpp | 14 +++++++++-- source/loader/ze_loader.cpp | 40 ------------------------------- 4 files changed, 24 insertions(+), 50 deletions(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 9fb05c8b..d43747cc 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -82,7 +82,13 @@ namespace loader uint32_t library_driver_handle_count = 0; result = drv.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &library_driver_handle_count, nullptr ); - if( ${X}_RESULT_SUCCESS != result ) break; + if( ${X}_RESULT_SUCCESS != result ) { + // If Get Drivers fails with Uninitialized, then update the driver init status to prevent reporting this driver in the next get call. + if (${X}_RESULT_ERROR_UNINITIALIZED == result) { + drv.initStatus = result; + } + continue; + } if( nullptr != ${obj['params'][1]['name']} && *${obj['params'][0]['name']} !=0) { @@ -109,8 +115,12 @@ namespace loader total_driver_handle_count += library_driver_handle_count; } - if( ${X}_RESULT_SUCCESS == result ) + // If the last driver get failed, but at least one driver succeeded, then return success with total count. + if( ${X}_RESULT_SUCCESS == result || total_driver_handle_count > 0) *${obj['params'][0]['name']} = total_driver_handle_count; + if (total_driver_handle_count > 0) { + result = ${X}_RESULT_SUCCESS; + } %else: %for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)): diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index 4e26d29a..ea58768b 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -109,12 +109,6 @@ namespace ze_lib // Check which drivers support the ze_driver_flag_t specified // No need to check if only initializing sysman result = zelLoaderDriverCheck(flags); - // reInit the ze ddi tables after verifying the zeInit() with dummy tables. - // This ensures the tracing and validation layers are pointing to the correct function pointers after init. - if( ZE_RESULT_SUCCESS == result ) - { - result = zeInit(); - } } if( ZE_RESULT_SUCCESS == result ) diff --git a/source/loader/ze_ldrddi.cpp b/source/loader/ze_ldrddi.cpp index d8b38212..44ea5a4b 100644 --- a/source/loader/ze_ldrddi.cpp +++ b/source/loader/ze_ldrddi.cpp @@ -89,7 +89,13 @@ namespace loader uint32_t library_driver_handle_count = 0; result = drv.dditable.ze.Driver.pfnGet( &library_driver_handle_count, nullptr ); - if( ZE_RESULT_SUCCESS != result ) break; + if( ZE_RESULT_SUCCESS != result ) { + // If Get Drivers fails with Uninitialized, then update the driver init status to prevent reporting this driver in the next get call. + if (ZE_RESULT_ERROR_UNINITIALIZED == result) { + drv.initStatus = result; + } + continue; + } if( nullptr != phDrivers && *pCount !=0) { @@ -116,8 +122,12 @@ namespace loader total_driver_handle_count += library_driver_handle_count; } - if( ZE_RESULT_SUCCESS == result ) + // If the last driver get failed, but at least one driver succeeded, then return success with total count. + if( ZE_RESULT_SUCCESS == result || total_driver_handle_count > 0) *pCount = total_driver_handle_count; + if (total_driver_handle_count > 0) { + result = ZE_RESULT_SUCCESS; + } return result; } diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index 8361bd10..dd4a61d5 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -200,46 +200,6 @@ namespace loader return ZE_RESULT_ERROR_UNINITIALIZED; } - if(nullptr != validationLayer) { - getTable = reinterpret_cast( - GET_FUNCTION_PTR(validationLayer, "zeGetGlobalProcAddrTable") ); - if(!getTable) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null with validation layer. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - getTableResult = getTable( version, &global); - if(getTableResult != ZE_RESULT_SUCCESS) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() with validation layer failed with "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - } - - if(nullptr != tracingLayer) { - getTable = reinterpret_cast( - GET_FUNCTION_PTR(tracingLayer, "zeGetGlobalProcAddrTable") ); - if(!getTable) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null with tracing layer. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - getTableResult = getTable( version, &global); - if(getTableResult != ZE_RESULT_SUCCESS) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() with tracing layer failed with "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - } - auto pfnInit = global.pfnInit; if(nullptr == pfnInit) { if (debugTraceEnabled) {