Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Remove template in ext wrapper and fuse functions #27511

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "intel_npu/config/config.hpp"
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

Expand Down Expand Up @@ -54,7 +54,7 @@ class DriverCompilerAdapter final : public ICompilerAdapter {
std::string serializeConfig(const Config& config, ze_graph_compiler_version_info_t compilerVersion) const;

std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;
std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;

ze_device_graph_properties_t _deviceGraphProperties = {};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

#include "intel_npu/common/igraph.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

class DriverGraph final : public IGraph {
public:
DriverGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
DriverGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
Expand All @@ -37,7 +37,7 @@ class DriverGraph final : public IGraph {
private:
bool release_blob(const Config& config);

std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;

Logger _logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "openvino/runtime/so_ptr.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

Expand All @@ -28,7 +28,7 @@ class PluginCompilerAdapter final : public ICompilerAdapter {
private:
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;

std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
ov::SoPtr<ICompiler> _compiler;

Logger _logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#include "intel_npu/icompiler.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "openvino/runtime/so_ptr.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

class PluginGraph final : public IGraph {
public:
PluginGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const ov::SoPtr<ICompiler>& compiler,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
Expand All @@ -38,7 +38,7 @@ class PluginGraph final : public IGraph {
~PluginGraph() override;

private:
std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;

const ov::SoPtr<ICompiler> _compiler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,148 +10,60 @@
#include <type_traits>
#include <utility>

#include "intel_npu/network_metadata.hpp"
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "intel_npu/utils/zero/zero_types.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"

namespace intel_npu {

#define NotSupportQuery(T) (T == ZE_GRAPH_EXT_VERSION_1_2)

// ext version == 1.3 && 1.4, support API (pfnQueryNetworkCreate, pfnQueryNetworkDestroy,
// pfnQueryNetworkGetSupportedLayers)
#define SupportAPIGraphQueryNetworkV1(T) (T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4)

// ext version >= 1.5, support API (pfnCreate2, pfnQueryNetworkCreate2, pfnQueryContextMemory)
#define SupportAPIGraphQueryNetworkV2(T) ((!NotSupportQuery(T) && !SupportAPIGraphQueryNetworkV1(T)))

// For ext version >= 1.5, pfnCreate2 api is avaible
#define NotSupportGraph2(T) \
(T == ZE_GRAPH_EXT_VERSION_1_2 || T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4)

// A bug inside the driver makes the "pfnGraphGetArgumentMetadata" call not safe for use prior to
// "ze_graph_dditable_ext_1_6_t".
// See: E#117498
#define NotSupportArgumentMetadata(T) \
(T == ZE_GRAPH_EXT_VERSION_1_2 || T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4 || \
T == ZE_GRAPH_EXT_VERSION_1_5)

#define UseCopyForNativeBinary(T) \
(T == ZE_GRAPH_EXT_VERSION_1_2 || T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4 || \
T == ZE_GRAPH_EXT_VERSION_1_5 || T == ZE_GRAPH_EXT_VERSION_1_6)
using SerializedIR = std::pair<size_t, std::shared_ptr<uint8_t>>;

/**
* Adapter to use CiD through ZeroAPI
*/
template <ze_graph_ext_version_t TableExtension>
class ZeGraphExtWrappers final : public ZeGraphExtWrappersInterface {
class ZeGraphExtWrappers {
public:
ZeGraphExtWrappers(const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct);
ZeGraphExtWrappers(const ZeGraphExtWrappers&) = delete;
ZeGraphExtWrappers& operator=(const ZeGraphExtWrappers&) = delete;
~ZeGraphExtWrappers();

std::unordered_set<std::string> queryGraph(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const override;
const std::string& buildFlags) const;
ze_graph_handle_t getGraphHandle(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
const uint32_t& flags) const override;
const uint32_t& flags) const;

ze_graph_handle_t getGraphHandle(const std::vector<uint8_t>& network) const override;
ze_graph_handle_t getGraphHandle(const std::vector<uint8_t>& network) const;

NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const override;
NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const;

_ze_result_t destroyGraph(ze_graph_handle_t graphHandle) override;
_ze_result_t destroyGraph(ze_graph_handle_t graphHandle);

void getGraphBinary(ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob,
const uint8_t*& blobPtr,
size_t& blobSize) const override;
size_t& blobSize) const;

void setGraphArgumentValue(ze_graph_handle_t graphHandle, uint32_t argi_, const void* argv) const override;
void setGraphArgumentValue(ze_graph_handle_t graphHandle, uint32_t argi_, const void* argv) const;

void initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const override;
void initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const;

private:
template <ze_graph_ext_version_t T = TableExtension, std::enable_if_t<!NotSupportQuery(T), bool> = true>
std::unordered_set<std::string> getQueryResultFromSupportedLayers(
ze_result_t result,
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<NotSupportArgumentMetadata(T), bool> = true>
void getMetadata(ze_graph_handle_t graphHandle,
uint32_t index,
std::vector<IODescriptor>& inputs,
std::vector<IODescriptor>& outputs) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<!NotSupportArgumentMetadata(T), bool> = true>
void getMetadata(ze_graph_handle_t graphHandle,
uint32_t index,
std::vector<IODescriptor>& inputs,
std::vector<IODescriptor>& outputs) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob,
const uint8_t*& blobPtr,
size_t& blobSize) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<!UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_handle_t graphHandle,
std::vector<uint8_t>& /* unusedBlob */,
const uint8_t*& blobPtr,
size_t& blobSize) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV2(T), bool> = true>
ze_result_t queryNetworkCreateV2(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;

// ext version >= 1.5, support API (pfnCreate2, pfnQueryNetworkCreate2, pfnQueryContextMemory)
template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV2(T), bool> = true>
std::unordered_set<std::string> queryImpl(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV1(T), bool> = true>
ze_result_t queryNetworkCreateV1(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;

// ext version == 1.3 && 1.4, support API (pfnQueryNetworkCreate, pfnQueryNetworkDestroy,
// pfnQueryNetworkGetSupportedLayers)
template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV1(T), bool> = true>
std::unordered_set<std::string> queryImpl(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const;

// For ext version < 1.3
template <ze_graph_ext_version_t T = TableExtension, typename std::enable_if_t<NotSupportQuery(T), bool> = true>
std::unordered_set<std::string> queryImpl(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const;

template <ze_graph_ext_version_t T = TableExtension, typename std::enable_if_t<NotSupportGraph2(T), bool> = true>
void createGraph(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
const uint32_t& flags,
ze_graph_handle_t* graph) const;

template <ze_graph_ext_version_t T = TableExtension, typename std::enable_if_t<!NotSupportGraph2(T), bool> = true>
void createGraph(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
const uint32_t& flags,
ze_graph_handle_t* graph) const;

void initialize_graph_through_command_list(ze_graph_handle_t graphHandle, const Config& config) const;

std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;
uint32_t _graphExtVersion;

Logger _logger;
};
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,7 @@ DriverCompilerAdapter::DriverCompilerAdapter(const std::shared_ptr<ZeroInitStruc

_logger.info("DriverCompilerAdapter creating adapter using graphExtVersion");

switch (graphExtVersion) {
case ZE_GRAPH_EXT_VERSION_1_3:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_3>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_4:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_4>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_5:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_5>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_6:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_6>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_7:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_7>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_8:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_8>>(_zeroInitStruct);
break;
default:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_2>>(_zeroInitStruct);
break;
}
_zeGraphExt = std::make_shared<ZeGraphExtWrappers>(_zeroInitStruct);

_logger.info("initialize DriverCompilerAdapter complete, using graphExtVersion: %d.%d",
ZE_MAJOR_VERSION(graphExtVersion),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace intel_npu {

DriverGraph::DriverGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
DriverGraph::DriverGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,7 @@ PluginCompilerAdapter::PluginCompilerAdapter(const std::shared_ptr<ZeroInitStruc

_logger.info("PluginCompilerAdapter creating adapter using graphExtVersion");

switch (graphExtVersion) {
case ZE_GRAPH_EXT_VERSION_1_3:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_3>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_4:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_4>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_5:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_5>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_6:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_6>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_7:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_7>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_8:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_8>>(_zeroInitStruct);
break;
default:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_2>>(_zeroInitStruct);
break;
}
_zeGraphExt = std::make_shared<ZeGraphExtWrappers>(_zeroInitStruct);

_logger.info("initialize PluginCompilerAdapter complete, using graphExtVersion: %d.%d",
ZE_MAJOR_VERSION(graphExtVersion),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace intel_npu {

PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const ov::SoPtr<ICompiler>& compiler,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
Expand Down
Loading
Loading