From ad076d39c3a8374f1bee14d0357a3edc8d243c44 Mon Sep 17 00:00:00 2001 From: Lisanna Dettwyler Date: Tue, 3 Sep 2024 15:41:45 +0000 Subject: [PATCH] Add zelReloadDrivers(flags) API Provides a means to re-initialize all of the drivers' library handles and DDI tables. The value of flags must match what was provided to zeInit(flags). Signed-off-by: Lisanna Dettwyler --- doc/loader_api.md | 6 + include/loader/ze_loader.h | 4 + source/lib/ze_lib.cpp | 16 ++ source/loader/ze_loader_api.cpp | 252 ++++++++++++++++++++++++++++++++ source/loader/ze_loader_api.h | 5 + test/CMakeLists.txt | 6 +- test/loader_api.cpp | 26 ++++ 7 files changed, 313 insertions(+), 2 deletions(-) diff --git a/doc/loader_api.md b/doc/loader_api.md index 2da1a96f..8701e56f 100644 --- a/doc/loader_api.md +++ b/doc/loader_api.md @@ -21,6 +21,12 @@ There are currently 3 versioned components assigned the following name strings: - `"validation layer"` - `"loader"` +### zelReloadDrivers + +Close, reload, and re-initialize through zeInit all driver libraries currently loaded. + +- __flags__ init flags that will be passed to each driver's implementation of zeInit, it should match what was previously provided at the first zeInit. + ### zelLoaderTranslateHandle diff --git a/include/loader/ze_loader.h b/include/loader/ze_loader.h index 2d5b75d2..c57da3c6 100644 --- a/include/loader/ze_loader.h +++ b/include/loader/ze_loader.h @@ -39,6 +39,10 @@ zelLoaderGetVersions( size_t *num_elems, //Pointer to num versions to get. zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned +ZE_APIEXPORT ze_result_t ZE_APICALL +zelReloadDrivers( + ze_init_flags_t flags); //Init flags, should match flags used in zeInit + typedef enum _zel_handle_type_t { ZEL_HANDLE_DRIVER, ZEL_HANDLE_DEVICE, diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index 2f93997b..4ea00036 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -171,6 +171,22 @@ zelLoaderGetVersions( #endif } +ze_result_t ZE_APICALL +zelReloadDrivers( + ze_init_flags_t flags) +{ +#ifdef DYNAMIC_LOAD_LOADER + if(nullptr == ze_lib::context->loader) + return ZE_RESULT_ERROR; + typedef ze_result_t (ZE_APICALL *zelReloadDriver_t)(ze_driver_handle_t hDriver); + auto reloadDrivers = reinterpret_cast( + GET_FUNCTION_PTR(ze_lib::context->loader, "zelReloadDriversInternal") ); + return reloadDrivers(flags); +#else + return zelReloadDriversInternal(flags); +#endif +} + ze_result_t ZE_APICALL zelLoaderTranslateHandle( diff --git a/source/loader/ze_loader_api.cpp b/source/loader/ze_loader_api.cpp index d86d92ca..fc9d673e 100644 --- a/source/loader/ze_loader_api.cpp +++ b/source/loader/ze_loader_api.cpp @@ -73,6 +73,258 @@ zelLoaderGetVersionsInternal( return ZE_RESULT_SUCCESS; } +ZE_DLLEXPORT ze_result_t ZE_APICALL +zelReloadDriversInternal( + ze_init_flags_t flags) +{ + for( auto& drv : loader::context->zeDrivers ) { + if(drv.initStatus != ZE_RESULT_SUCCESS) + continue; + + if (drv.handle) { + auto free_result = FREE_DRIVER_LIBRARY( drv.handle ); + auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result); + if (failure) + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + drv.handle = LOAD_DRIVER_LIBRARY( drv.name.c_str() ); + if (NULL == drv.handle) + return ZE_RESULT_ERROR_UNINITIALIZED; + + auto zeGetGlobalProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetGlobalProcAddrTable") ); + if (!zeGetGlobalProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetGlobalProcAddrTableResult = zeGetGlobalProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Global); + if (zeGetGlobalProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetGlobalProcAddrTableResult; + + auto zeGetRTASBuilderExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderExpProcAddrTable") ); + if (!zeGetRTASBuilderExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetRTASBuilderExpProcAddrTableResult = zeGetRTASBuilderExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASBuilderExp); + if (zeGetRTASBuilderExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetRTASBuilderExpProcAddrTableResult; + + auto zeGetRTASParallelOperationExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationExpProcAddrTable") ); + if (!zeGetRTASParallelOperationExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetRTASParallelOperationExpProcAddrTableResult = zeGetRTASParallelOperationExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASParallelOperationExp); + if (zeGetRTASParallelOperationExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetRTASParallelOperationExpProcAddrTableResult; + + auto zeGetDriverProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetDriverProcAddrTable") ); + if (!zeGetDriverProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetDriverProcAddrTableResult = zeGetDriverProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Driver); + if (zeGetDriverProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetDriverProcAddrTableResult; + + auto zeGetDriverExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetDriverExpProcAddrTable") ); + if (!zeGetDriverExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetDriverExpProcAddrTableResult = zeGetDriverExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DriverExp); + if (zeGetDriverExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetDriverExpProcAddrTableResult; + + auto zeGetDeviceProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetDeviceProcAddrTable") ); + if (!zeGetDeviceProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetDeviceProcAddrTableResult = zeGetDeviceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Device); + if (zeGetDeviceProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetDeviceProcAddrTableResult; + + auto zeGetDeviceExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetDeviceExpProcAddrTable") ); + if (!zeGetDeviceExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetDeviceExpProcAddrTableResult = zeGetDeviceExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DeviceExp); + if (zeGetDeviceExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetDeviceExpProcAddrTableResult; + + auto zeGetContextProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetContextProcAddrTable") ); + if (!zeGetContextProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetContextProcAddrTableResult = zeGetContextProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Context); + if (zeGetContextProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetContextProcAddrTableResult; + + auto zeGetCommandQueueProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetCommandQueueProcAddrTable") ); + if (!zeGetCommandQueueProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetCommandQueueProcAddrTableResult = zeGetCommandQueueProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandQueue); + if (zeGetCommandQueueProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetCommandQueueProcAddrTableResult; + + auto zeGetCommandListProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetCommandListProcAddrTable") ); + if (!zeGetCommandListProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetCommandListProcAddrTableResult = zeGetCommandListProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandList); + if (zeGetCommandListProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetCommandListProcAddrTableResult; + + auto zeGetCommandListExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetCommandListExpProcAddrTable") ); + if (!zeGetCommandListExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetCommandListExpProcAddrTableResult = zeGetCommandListExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandListExp); + if (zeGetCommandListExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetCommandListExpProcAddrTableResult; + + auto zeGetEventProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetEventProcAddrTable") ); + if (!zeGetEventProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetEventProcAddrTableResult = zeGetEventProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Event); + if (zeGetEventProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetEventProcAddrTableResult; + + auto zeGetEventExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetEventExpProcAddrTable") ); + if (!zeGetEventExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetEventExpProcAddrTableResult = zeGetEventExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventExp); + if (zeGetEventExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetEventExpProcAddrTableResult; + + auto zeGetEventPoolProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetEventPoolProcAddrTable") ); + if (!zeGetEventPoolProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetEventPoolProcAddrTableResult = zeGetEventPoolProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventPool); + if (zeGetEventPoolProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetEventPoolProcAddrTableResult; + + auto zeGetFenceProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetFenceProcAddrTable") ); + if (!zeGetFenceProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetFenceProcAddrTableResult = zeGetFenceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Fence); + if (zeGetFenceProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetFenceProcAddrTableResult; + + auto zeGetImageProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetImageProcAddrTable") ); + if (!zeGetImageProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetImageProcAddrTableResult = zeGetImageProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Image); + if (zeGetImageProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetImageProcAddrTableResult; + + auto zeGetImageExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetImageExpProcAddrTable") ); + if (!zeGetImageExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetImageExpProcAddrTableResult = zeGetImageExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ImageExp); + if (zeGetImageExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetImageExpProcAddrTableResult; + + auto zeGetKernelProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetKernelProcAddrTable") ); + if (!zeGetKernelProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetKernelProcAddrTableResult = zeGetKernelProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Kernel); + if (zeGetKernelProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetKernelProcAddrTableResult; + + auto zeGetKernelExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetKernelExpProcAddrTable") ); + if (!zeGetKernelExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetKernelExpProcAddrTableResult = zeGetKernelExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.KernelExp); + if (zeGetKernelExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetKernelExpProcAddrTableResult; + + auto zeGetMemProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetMemProcAddrTable") ); + if (!zeGetMemProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetMemProcAddrTableResult = zeGetMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Mem); + if (zeGetMemProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetMemProcAddrTableResult; + + auto zeGetMemExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetMemExpProcAddrTable") ); + if (!zeGetMemExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetMemExpProcAddrTableResult = zeGetMemExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.MemExp); + if (zeGetMemExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetMemExpProcAddrTableResult; + + auto zeGetModuleProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetModuleProcAddrTable") ); + if (!zeGetModuleProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetModuleProcAddrTableResult = zeGetModuleProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Module); + if (zeGetModuleProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetModuleProcAddrTableResult; + + auto zeGetModuleBuildLogProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetModuleBuildLogProcAddrTable") ); + if (!zeGetModuleBuildLogProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetModuleBuildLogProcAddrTableResult = zeGetModuleBuildLogProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ModuleBuildLog); + if (zeGetModuleBuildLogProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetModuleBuildLogProcAddrTableResult; + + auto zeGetPhysicalMemProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetPhysicalMemProcAddrTable") ); + if (!zeGetPhysicalMemProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetPhysicalMemProcAddrTableResult = zeGetPhysicalMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.PhysicalMem); + if (zeGetPhysicalMemProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetPhysicalMemProcAddrTableResult; + + auto zeGetSamplerProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetSamplerProcAddrTable") ); + if (!zeGetSamplerProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetSamplerProcAddrTableResult = zeGetSamplerProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Sampler); + if (zeGetSamplerProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetSamplerProcAddrTableResult; + + auto zeGetVirtualMemProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetVirtualMemProcAddrTable") ); + if (!zeGetVirtualMemProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetVirtualMemProcAddrTableResult = zeGetVirtualMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.VirtualMem); + if (zeGetVirtualMemProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetVirtualMemProcAddrTableResult; + + auto zeGetFabricEdgeExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetFabricEdgeExpProcAddrTable") ); + if (!zeGetFabricEdgeExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetFabricEdgeExpProcAddrTableResult = zeGetFabricEdgeExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricEdgeExp); + if (zeGetFabricEdgeExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetFabricEdgeExpProcAddrTableResult; + + auto zeGetFabricVertexExpProcAddrTable = reinterpret_cast( + GET_FUNCTION_PTR( drv.handle, "zeGetFabricVertexExpProcAddrTable") ); + if (!zeGetFabricVertexExpProcAddrTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + auto zeGetFabricVertexExpProcAddrTableResult = zeGetFabricVertexExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricVertexExp); + if (zeGetFabricVertexExpProcAddrTableResult != ZE_RESULT_SUCCESS) + return zeGetFabricVertexExpProcAddrTableResult; + + auto initResult = drv.dditable.ze.Global.pfnInit(flags); + // Bail out if any drivers that previously succeeded fail + if (initResult != ZE_RESULT_SUCCESS) + return initResult; + } + + return ZE_RESULT_SUCCESS; +} + ZE_DLLEXPORT ze_result_t ZE_APICALL zelLoaderTranslateHandleInternal( diff --git a/source/loader/ze_loader_api.h b/source/loader/ze_loader_api.h index 590f1432..686bcd6d 100644 --- a/source/loader/ze_loader_api.h +++ b/source/loader/ze_loader_api.h @@ -68,6 +68,11 @@ zelLoaderGetVersionsInternal( zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned +ZE_DLLEXPORT ze_result_t ZE_APICALL +zelReloadDriversInternal( + ze_init_flags_t flags); + + ZE_DLLEXPORT ze_result_t ZE_APICALL zelLoaderTranslateHandleInternal( zel_handle_type_t handleType, //Handle type diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5b1f0e27..59789b43 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,5 +18,7 @@ if(MSVC) target_compile_options(tests PRIVATE "/MD$<$:d>") endif() -add_test(NAME tests COMMAND tests) -set_property(TEST tests PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") \ No newline at end of file +add_test(NAME tests_api_version COMMAND tests --gtest_filter=LoaderAPI.GivenLevelZeroLoaderPresentWhenCallingzeGetLoaderVersionsAPIThenValidVersionIsReturned) +set_property(TEST tests_api_version PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") +add_test(NAME tests_api_reload COMMAND tests --gtest_filter=LoaderAPI.GivenInitWhenCallingzelReloadDriversThenDriversStillWork) +set_property(TEST tests_api_reload PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") diff --git a/test/loader_api.cpp b/test/loader_api.cpp index 4fdf8e19..32cafb1b 100644 --- a/test/loader_api.cpp +++ b/test/loader_api.cpp @@ -42,4 +42,30 @@ TEST( } } +TEST( + LoaderAPI, + GivenInitWhenCallingzelReloadDriversThenDriversStillWork +) { + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0)); + + uint32_t count = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, nullptr)); + EXPECT_GT(count, 0); + + std::vector hDrivers(count); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, hDrivers.data())); + + for (auto &driver : hDrivers) { + ze_driver_properties_t driverProperties; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties)); + } + + EXPECT_EQ(ZE_RESULT_SUCCESS, zelReloadDrivers(0)); + + for (auto &driver : hDrivers) { + ze_driver_properties_t driverProperties; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties)); + } +} + } // namespace