diff --git a/cmake/graph-compiler.cmake b/cmake/graph-compiler.cmake index 30b9e5b6913d14..ef0217b527a3fe 100644 --- a/cmake/graph-compiler.cmake +++ b/cmake/graph-compiler.cmake @@ -12,13 +12,13 @@ if (NOT DEFINED GRAPH_COMPILER_LIBS) FetchContent_Declare( GC GIT_REPOSITORY https://github.com/intel/graph-compiler.git - GIT_TAG main + GIT_TAG xgniu/constant_weights_folding # zhicong/perf_test2 # yifei/mlp_benching_new FIND_PACKAGE_ARGS NAMES GraphCompiler ) set(GC_ENABLE_OPT OFF) set(GC_ENABLE_TEST OFF) - set(GC_ENABLE_DNNL OFF) + set(GC_ENABLE_DNNL_API OFF) set(GC_ENABLE_LEGACY OFF) set(GC_ENABLE_BINDINGS_PYTHON OFF) set(OV_BUILD_SHARED_LIBS_TMP ${BUILD_SHARED_LIBS}) @@ -31,6 +31,9 @@ if (NOT DEFINED GRAPH_COMPILER_LIBS) GcInterface GcJitWrapper GcCpuRuntime + # For some branches: + # MLIRCPURuntimeTransforms + # MLIRMicrokernelTransforms ) set_property(GLOBAL PROPERTY GRAPH_COMPILER_LIBS ${GRAPH_COMPILER_LIBS}) endif () diff --git a/src/common/transformations/src/transformations/mlir/convert.cpp b/src/common/transformations/src/transformations/mlir/convert.cpp index 60b86b263bfa14..085c097818abba 100644 --- a/src/common/transformations/src/transformations/mlir/convert.cpp +++ b/src/common/transformations/src/transformations/mlir/convert.cpp @@ -137,13 +137,73 @@ mlir::OwningOpRef ngraph_to_mlir(MLIRContext* context, // Affix target information attribute to the module to be used, at its discretion, // by the MLIR-compiler that consumes this module. auto tileSize = IntegerAttr::get(IntegerType::get(context, 32), 32); - auto key = StringAttr::get(context, "tile_size"); - DataLayoutEntryInterface entry = DataLayoutEntryAttr::get(context, key, tileSize); - TargetDeviceSpecInterface deviceSpec = TargetDeviceSpecAttr::get(context, ArrayRef(entry)); + auto tileSizeKey = StringAttr::get(context, "tile_size"); + DataLayoutEntryInterface tileSizeEntry = DataLayoutEntryAttr::get(context, tileSizeKey, tileSize); + + int numThreadsInt = 1; + if (char* numThreadsEnv = std::getenv("OMP_NUM_THREADS")) { + numThreadsInt = std::atoi(numThreadsEnv); + } + auto numThreads = IntegerAttr::get(IntegerType::get(context, 32), numThreadsInt); + auto numThreadsKey = StringAttr::get(context, "num_threads"); + DataLayoutEntryInterface numThreadsEntry = DataLayoutEntryAttr::get(context, numThreadsKey, numThreads); + + int L1CacheSizeInt = 49152; + if (char* L1CacheSizeEnv = std::getenv("L1_CACHE_SIZE")) { + L1CacheSizeInt = std::atoi(L1CacheSizeEnv); + } + auto L1CacheSize = IntegerAttr::get(IntegerType::get(context, 32), L1CacheSizeInt); + auto L1CacheSizeKey = StringAttr::get(context, "L1_cache_size_in_bytes"); + DataLayoutEntryInterface L1CacheSizeEntry = DataLayoutEntryAttr::get(context, L1CacheSizeKey, L1CacheSize); + + int L2CacheSizeInt = 2097152; + if (char* L2CacheSizeEnv = std::getenv("L2_CACHE_SIZE")) { + L2CacheSizeInt = std::atoi(L2CacheSizeEnv); + } + auto L2CacheSize = IntegerAttr::get(IntegerType::get(context, 32), L2CacheSizeInt); + auto L2CacheSizeKey = StringAttr::get(context, "L2_cache_size_in_bytes"); + DataLayoutEntryInterface L2CacheSizeEntry = DataLayoutEntryAttr::get(context, L2CacheSizeKey, L2CacheSize); + + int L3CacheSizeInt = 1966080; + if (char* L3CacheSizeEnv = std::getenv("L3_CACHE_SIZE")) { + L3CacheSizeInt = std::atoi(L3CacheSizeEnv); + } + auto L3CacheSize = IntegerAttr::get(IntegerType::get(context, 32), L3CacheSizeInt); + auto L3CacheSizeKey = StringAttr::get(context, "L3_cache_size_in_bytes"); + DataLayoutEntryInterface L3CacheSizeEntry = DataLayoutEntryAttr::get(context, L3CacheSizeKey, L3CacheSize); + + int maxVectorWidthInt = 512; + if (char* maxVectorWidthEnv = std::getenv("MAX_VECTOR_WIDTH")) { + maxVectorWidthInt = std::atoi(maxVectorWidthEnv); + } + auto maxVectorWidth = IntegerAttr::get(IntegerType::get(context, 32), maxVectorWidthInt); + auto maxVectorWidthKey = StringAttr::get(context, "max_vector_width"); + DataLayoutEntryInterface maxVectorWidthEntry = DataLayoutEntryAttr::get(context, maxVectorWidthKey, maxVectorWidth); + + TargetDeviceSpecInterface deviceSpec = TargetDeviceSpecAttr::get(context, + ArrayRef({tileSizeEntry, + numThreadsEntry, + L1CacheSizeEntry, + L2CacheSizeEntry, + L3CacheSizeEntry, + maxVectorWidthEntry})); auto deviceStr = StringAttr::get(context, "CPU"); auto sysSpec = TargetSystemSpecAttr::get(context, ArrayRef(std::pair(deviceStr, deviceSpec))); module.getOperation()->setAttr("#dlti.sys_spec", sysSpec); + std::vector compiletime_const_args_index; + for (size_t i = 0; i < inputs.size(); ++i) { + auto parent = inputs[i].get_node_shared_ptr(); + if (auto data_const = std::dynamic_pointer_cast(parent)) { + OPENVINO_MLIR_DEBUG_PRINT("Mark #" << i << " input as Constant tensor\n"); + compiletime_const_args_index.push_back(i); + } + } + func.getOperation()->setAttr("compiletime_const_args_index", + moduleBuilder.getI32ArrayAttr(compiletime_const_args_index)); + + func.getOperation()->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), UnitAttr::get(context)); + ConversionContext conversion_context(context, &block_builder); for (size_t i = 0; i < inputs.size(); ++i) { diff --git a/src/common/transformations/src/transformations/mlir/mlir_op.cpp b/src/common/transformations/src/transformations/mlir/mlir_op.cpp index afe777ef949e6f..78c1ece1fc91a2 100644 --- a/src/common/transformations/src/transformations/mlir/mlir_op.cpp +++ b/src/common/transformations/src/transformations/mlir/mlir_op.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" @@ -22,6 +23,9 @@ #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir-c/ExecutionEngine.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -241,6 +245,13 @@ struct MemRefDescriptor { } } + MemRefDescriptor(ov::mlir::CachedBuffer buffer) + : allocated(buffer.buffer), + aligned(buffer.buffer), + offset(0), + shape(buffer.shape), + strides(buffer.strides) {} + void* allocated; void* aligned; int64_t offset; @@ -267,6 +278,100 @@ namespace mlir { using namespace ::mlir; +static std::unordered_set executed_ops; + +void MLIREvaluate::set_folding_info() { + { + auto expectArgs = engine->lookup("__num_orig_args"); + if (!expectArgs) { + llvm::consumeError(expectArgs.takeError()); + return; + } + folding_info.num_orig_args = *reinterpret_cast(*expectArgs); + } + + { + auto expectFold = engine->lookupPacked(defaultFoldName); + if (!expectFold) { + llvm::consumeError(expectFold.takeError()); + return; + } + folding_info.fold_func = *expectFold; + } + + { + auto expectBufferIds = engine->lookup("__runtime_fold_buffer_ids"); + if (!expectBufferIds) { + llvm::consumeError(expectBufferIds.takeError()); + return; + } + auto raw = reinterpret_cast(*expectBufferIds); + folding_info.fold_buffer_ids = llvm::ArrayRef{raw + 1, raw[0]}; + } + + { + auto expectFold = engine->lookup("__fold_args"); + if (!expectFold) { + llvm::consumeError(expectFold.takeError()); + return; + } + auto raw = reinterpret_cast(*expectFold); + folding_info.fold_args = llvm::ArrayRef{raw + 1, raw[0]}; + } + + { + auto expect = engine->lookup("__compute_args"); + if (!expect) { + llvm::consumeError(expect.takeError()); + return; + } + auto raw = reinterpret_cast(*expect); + folding_info.compute_args = llvm::ArrayRef{raw + 1, raw[0]}; + } + + { + auto expect = engine->lookup("__folded_ranks"); + if (!expect) { + llvm::consumeError(expect.takeError()); + return; + } + auto raw = reinterpret_cast(*expect); + folding_info.folded_ranks = llvm::ArrayRef{raw, folding_info.fold_buffer_ids.size()}; + } + + { + auto expect = engine->lookup("__folded_shapes"); + if (!expect) { + llvm::consumeError(expect.takeError()); + return; + } + int32_t size = folding_info.fold_buffer_ids.size(); // element bytes of each buffer + for (auto r : folding_info.folded_ranks) { + size += r; + } + auto raw = reinterpret_cast(*expect); + llvm::ArrayRef folded_shapes = llvm::ArrayRef{raw, size}; + int pos = 0; + for (int i = 0; i < folding_info.folded_ranks.size(); ++i) { + std::vector shape(folded_shapes.begin() + pos, + folded_shapes.begin() + pos + folding_info.folded_ranks[i] + 1); + pos += folding_info.folded_ranks[i] + 1; + folding_info.folded_shapes.push_back(shape); + } + } + + for (auto id : folding_info.fold_buffer_ids) { + std::vector shape = folding_info.folded_shapes[id]; + size_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + shape.pop_back(); // delete the last which is bytes of element + std::vector strides(shape.size(), 1); + for (int i = strides.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + void* buffer = std::aligned_alloc(/*alignment*/ 64, size); + cached_const_buffers[id] = CachedBuffer{buffer, shape, strides}; + } +} MLIREvaluate::MLIREvaluate(OwningOpRef _module, MlirMode mode) : module(std::move(_module)) { @@ -291,16 +396,37 @@ MLIREvaluate::MLIREvaluate(OwningOpRef _module, MlirMode mode) : /*sizeLevel=*/0, // FIXME: HARDCODED /*targetMachine=*/nullptr); - mlir::ExecutionEngineOptions engineOptions; - engineOptions.transformer = optPipeline; // opt level looks to be overriden in lowerToLLVMIR, but is still used - // in `create` independently - engineOptions.llvmModuleBuilder = lowerToLLVMIR; - auto maybeEngine = mlir::ExecutionEngine::create(module.get(), engineOptions); - if (maybeEngine) { - engine = std::move(maybeEngine.get()); - } else { - llvm::errs() << "failed to construct an execution engine\n"; - abort(); + // mlir::ExecutionEngineOptions engineOptions; + // engineOptions.transformer = optPipeline; // opt level looks to be overriden in lowerToLLVMIR, but is still used + // // in `create` independently + // engineOptions.llvmModuleBuilder = lowerToLLVMIR; + // engineOptions.enableObjectDump = true; + // auto maybeEngine = mlir::ExecutionEngine::create(module.get(), engineOptions); + + int optLevel = 3; + const std::vector sharedLibPaths; + // sharedLibPaths = {"/home/xiaoguang/ov-gc/llvm-project/llvm-install/lib/libmlir_c_runner_utils.so", + // "/home/xiaoguang/ov-gc/llvm-project/llvm-install/lib/libmlir_runner_utils.so", + // "/home/xiaoguang/ov-gc/graph-compiler/build/lib/libGcCpuRuntime.so"}; + bool enableObjectDump = true; + llvm::SmallVector libPaths; + for (const std::string &path : sharedLibPaths) + libPaths.push_back({path.c_str(), path.length()}); + MlirExecutionEngine executionEngine = + mlirExecutionEngineCreate(wrap(module.get()), optLevel, libPaths.size(), + libPaths.data(), enableObjectDump); + if (mlirExecutionEngineIsNull(executionEngine)) + throw std::runtime_error("Failure while creating the ExecutionEngine."); + + engine = std::unique_ptr(static_cast(executionEngine.ptr)); + // engine->dumpToObjectFile("./dumped_ov.o"); + + set_folding_info(); +} + +MLIREvaluate::~MLIREvaluate() { + for (auto pair : cached_const_buffers) { + std::free(pair.second.buffer); } } @@ -333,36 +459,102 @@ NodePtr MLIROp::clone_with_new_inputs(const ov::OutputVector& new_args) const { } bool MLIROp::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { - std::vector memref_args; - for (size_t i = 0; i < inputs.size(); ++i) { - memref_args.push_back(MemRefDescriptor(inputs[i])); - } - for (size_t i = 0; i < outputs.size(); ++i) { - // TODO: Optimize by adding all dimensions to dimensions_map, not only dynamic - Shape target; - PartialShape expected = get_output_partial_shape(i); - for(size_t j = 0; j < expected.size(); ++j) { - auto dim = expected[j]; - if(dim.is_dynamic()) { - int input_index, dim_index; - std::tie(input_index, dim_index) = dimensions_map[i][j]; - target.push_back(inputs[input_index].get_shape()[dim_index]); - } else { - target.push_back(dim.get_length()); + OPENVINO_MLIR_DEBUG_PRINT("[ DEBUG ] input size: " << inputs.size() << ", output size: " << outputs.size() << "\n"); + if (engine->folding_info.fold_func == nullptr) { // No folding, call entry() directly + std::vector memref_args; + for (size_t i = 0; i < inputs.size(); ++i) { + memref_args.push_back(MemRefDescriptor(inputs[i])); + } + for (size_t i = 0; i < outputs.size(); ++i) { + // TODO: Optimize by adding all dimensions to dimensions_map, not only dynamic + Shape target; + PartialShape expected = get_output_partial_shape(i); + for (size_t j = 0; j < expected.size(); ++j) { + auto dim = expected[j]; + if (dim.is_dynamic()) { + int input_index, dim_index; + std::tie(input_index, dim_index) = dimensions_map[i][j]; + target.push_back(inputs[input_index].get_shape()[dim_index]); + } else { + target.push_back(dim.get_length()); + } } + // std::cerr << "[ DEBUG ] Set outputs[" << i << "].shape(" << target << ")\n"; + outputs[i].set_shape(target); + memref_args.push_back(MemRefDescriptor(outputs[i])); } - //std::cerr << "[ DEBUG ] Set outputs[" << i << "].shape(" << target << ")\n"; - outputs[i].set_shape(target); - memref_args.push_back(MemRefDescriptor(outputs[i])); - } - std::vector args; - std::for_each(memref_args.begin(), memref_args.end(), [&args](MemRefDescriptor& x) { - x.append_to_packed_args(args); - }); + std::vector args; + std::for_each(memref_args.begin(), memref_args.end(), [&args](MemRefDescriptor& x) { + x.append_to_packed_args(args); + }); + + OPENVINO_MLIR_DEBUG_PRINT("[ DEBUG ] Call entry func directly\n"); + return engine->invoke_packed(args); + } else { // call fold() first, then call entry() + if (executed_ops.count(this) == 0) { // Call fold() + std::vector memref_args; + // Args of fold(): {constant inputs, folded inputs}. + for (auto id : engine->folding_info.fold_args) { + if (id < engine->folding_info.num_orig_args) { + memref_args.push_back(MemRefDescriptor(inputs[id])); + } else { + int64_t buffer_id = id - engine->folding_info.num_orig_args; + assert(engine->cached_const_buffers.find(buffer_id) != engine->cached_const_buffers.end()); + memref_args.push_back(MemRefDescriptor(engine->cached_const_buffers[buffer_id])); + } + } + std::vector args; + std::for_each(memref_args.begin(), memref_args.end(), [&args](MemRefDescriptor& x) { + x.append_to_packed_args(args); + }); + OPENVINO_MLIR_DEBUG_PRINT("[ DEBUG ] First executon, call fold func\n"); + engine->folding_info.fold_func(args.data()); + + // TODO: Find a better way to check if the op has executed. + // This is a const function and can not modify member attributes directly. + executed_ops.insert(this); + } + // call entry() + std::vector memref_args; + // Args of entry(): {non-constant inputs, outputs, folded inputs}. + for (auto id : engine->folding_info.compute_args) { + // num_orig_args = inputs.size() + outputs.size() + // if (id < engine->folding_info.num_orig_args) { + if (id < inputs.size()) { // non-constant input + memref_args.push_back(MemRefDescriptor(inputs[id])); + } else if (id < engine->folding_info.num_orig_args) { // output + int i = id - inputs.size(); // output id + Shape target; + PartialShape expected = get_output_partial_shape(i); + for (size_t j = 0; j < expected.size(); ++j) { + auto dim = expected[j]; + if (dim.is_dynamic()) { + int input_index, dim_index; + std::tie(input_index, dim_index) = dimensions_map[i][j]; + target.push_back(inputs[input_index].get_shape()[dim_index]); + } else { + target.push_back(dim.get_length()); + } + } + // std::cerr << "[ DEBUG ] Set outputs[" << i << "].shape(" << target << ")\n"; + outputs[i].set_shape(target); + memref_args.push_back(MemRefDescriptor(outputs[i])); + } else { // folded input + int64_t buffer_id = id - engine->folding_info.num_orig_args; + assert(engine->cached_const_buffers.find(buffer_id) != engine->cached_const_buffers.end()); + memref_args.push_back(MemRefDescriptor(engine->cached_const_buffers[buffer_id])); + } + } - //std::cerr << "[ INFO ] Running kernel in MLIROp::evaluate\n"; - return engine->invoke_packed(args); + std::vector args; + std::for_each(memref_args.begin(), memref_args.end(), [&args](MemRefDescriptor& x) { + x.append_to_packed_args(args); + }); + OPENVINO_MLIR_DEBUG_PRINT("[ DEBUG ] Call entry func\n"); + // std::cerr << "[ INFO ] Running kernel in MLIROp::evaluate\n"; + return engine->invoke_packed(args); + } } bool MLIROp::has_evaluate() const { diff --git a/src/common/transformations/src/transformations/mlir/mlir_op.hpp b/src/common/transformations/src/transformations/mlir/mlir_op.hpp index a0ffc58713c30e..77d420fa28bca5 100644 --- a/src/common/transformations/src/transformations/mlir/mlir_op.hpp +++ b/src/common/transformations/src/transformations/mlir/mlir_op.hpp @@ -29,15 +29,37 @@ enum MlirMode { MLIR_MODE_DEFAULT, }; +using JitModuleFuncT = void (*)(void**); +static const char defaultFoldName[] = "runtime_fold"; + +struct FoldingInfo { + int32_t num_orig_args; + llvm::ArrayRef fold_args; + llvm::ArrayRef compute_args; + llvm::ArrayRef fold_buffer_ids; + llvm::ArrayRef folded_ranks; + std::vector> folded_shapes; + JitModuleFuncT fold_func = nullptr; +}; + +struct CachedBuffer { + void* buffer; + std::vector shape; + std::vector strides; +}; class MLIREvaluate { OwningOpRef module; // FIXME: needs to be kept? std::unique_ptr engine; + void set_folding_info(); public: MLIREvaluate(OwningOpRef _module, MlirMode mode); + ~MLIREvaluate(); bool invoke_packed(std::vector& args); + FoldingInfo folding_info; + std::unordered_map cached_const_buffers; }; diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index b8c50db1ed4f31..eee5b6a10929b7 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -20,30 +20,42 @@ struct ConvertMatMul { void operator()(ConversionContext& context, NodePtr node) { auto loc = createLocation(context.context, node); auto& builder = context.builder(); - // TODO: Support broadcasts + + auto matmul_node = std::dynamic_pointer_cast(node); + assert(matmul_node); + bool isTransposedA = matmul_node->get_transpose_a(); + bool isTransposedB = matmul_node->get_transpose_b(); + assert(!(isTransposedA && isTransposedB)); + const auto inputs = context.getInputs(node); + mlir::SmallVector ins{inputs[0]}; + + if (isTransposedB) { + auto shape = node->get_input_partial_shape(1); + auto transposedShape = ov::PartialShape({shape[1], shape[0]}); + auto transposedType = importTensor(context.context, transposedShape, node->get_input_element_type(1)); + mlir::SmallVector dynamicDims = context.get_dynamic_dimension_values(transposedShape);; + auto empty = builder.create(loc, transposedType, dynamicDims); + auto transposeOp = builder.create( + loc, inputs[1], empty, mlir::SmallVector{1, 0}); + ins.push_back(transposeOp.getResult()[0]); + } else { + ins.push_back(inputs[1]); + } + const auto ov_output_element_type = node->get_output_element_type(0); const auto ov_output_shape = node->get_output_partial_shape(0); + // TODO: Support broadcasts auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type); auto dynamic_dimensions = context.get_dynamic_dimension_values(ov_output_shape); auto empty = builder.create(loc, outType, dynamic_dimensions); auto zero = getConstant(builder, ov_output_element_type, 0); auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); - - mlir::SmallVector ins{inputs[0], inputs[1]}; mlir::SmallVector outs{fill.getResult(0)}; - auto matmul_node = std::dynamic_pointer_cast(node); - assert(matmul_node); - bool isTransposedA = matmul_node->get_transpose_a(); - bool isTransposedB = matmul_node->get_transpose_b(); - assert(!(isTransposedA && isTransposedB)); - Operation* matmul; if (isTransposedA) { matmul = builder.create(loc, ins, outs); - } else if (isTransposedB) { - matmul = builder.create(loc, ins, outs); } else { matmul = builder.create(loc, ins, outs); } diff --git a/src/common/transformations/src/transformations/mlir/op/relu.cpp b/src/common/transformations/src/transformations/mlir/op/relu.cpp index 6f7157f9bd4bd4..364785979a2611 100644 --- a/src/common/transformations/src/transformations/mlir/op/relu.cpp +++ b/src/common/transformations/src/transformations/mlir/op/relu.cpp @@ -25,10 +25,16 @@ struct ConvertRelu { auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type); auto dynamic_dimensions = context.get_dynamic_dimension_values(ov_output_shape); auto empty = builder.create(loc, outType, dynamic_dimensions); - auto zero = getConstant(builder, ov_output_element_type, 0); - auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); + auto denseAttr = DenseElementsAttr::get( + outType, + builder.getFloatAttr(importPrecision(builder.getContext(), ov_output_element_type), 0.0)); + auto zeros = builder.create(loc, denseAttr); auto relu = - builder.create(loc, mlir::ValueRange{input, fill.getResult(0)}, mlir::ValueRange{empty}); + builder.create(loc, mlir::ValueRange{input, zeros.getResult()}, mlir::ValueRange{empty}); + // auto zero = getConstant(builder, ov_output_element_type, 0); + // auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); + // auto relu = + // builder.create(loc, mlir::ValueRange{input, fill.getResult(0)}, mlir::ValueRange{empty}); context.addOutputs(node, relu); } }; diff --git a/tools/mlir_bench/MLIR_MLP_BENCH_bf16_128_1024.mlir b/tools/mlir_bench/MLIR_MLP_BENCH_bf16_128_1024.mlir new file mode 100644 index 00000000000000..5759e67ba3645d --- /dev/null +++ b/tools/mlir_bench/MLIR_MLP_BENCH_bf16_128_1024.mlir @@ -0,0 +1,21 @@ +module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 4 : i32>, #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, #dlti.dl_entry<"L3_cache_size_in_bytes", 1966080 : i32>, #dlti.dl_entry<"max_vector_width", 512 : i32>>>} { + func.func @entry(%arg0: memref<128x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<128x1024xbf16>, %arg3: memref<128x1024xbf16>) attributes {compiletime_const_args_index = [1 : i32, 2 : i32]} { + %0 = bufferization.to_tensor %arg0 restrict : memref<128x1024xbf16> + %1 = bufferization.to_tensor %arg1 restrict : memref<1024x1024xbf16> + %2 = bufferization.to_tensor %arg2 restrict : memref<128x1024xbf16> + %3 = tensor.empty() : tensor<1024x1024xbf16> + %transposed = linalg.transpose ins(%1 : tensor<1024x1024xbf16>) outs(%3 : tensor<1024x1024xbf16>) permutation = [1, 0] + %4 = tensor.empty() : tensor<128x1024xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %5 = linalg.fill ins(%cst : bf16) outs(%4 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + %6 = linalg.matmul ins(%0, %transposed : tensor<128x1024xbf16>, tensor<1024x1024xbf16>) outs(%5 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + %7 = tensor.empty() : tensor<128x1024xbf16> + %8 = linalg.add ins(%6, %2 : tensor<128x1024xbf16>, tensor<128x1024xbf16>) outs(%7 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + %9 = tensor.empty() : tensor<128x1024xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %10 = linalg.fill ins(%cst_0 : bf16) outs(%9 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + %11 = linalg.max ins(%8, %10 : tensor<128x1024xbf16>, tensor<128x1024xbf16>) outs(%9 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + bufferization.materialize_in_destination %11 in restrict writable %arg3 : (tensor<128x1024xbf16>, memref<128x1024xbf16>) -> () + return + } +} diff --git a/tools/mlir_bench/MLIR_MLP_BENCH_f32_128_1024.mlir b/tools/mlir_bench/MLIR_MLP_BENCH_f32_128_1024.mlir new file mode 100644 index 00000000000000..ec144f84af44c3 --- /dev/null +++ b/tools/mlir_bench/MLIR_MLP_BENCH_f32_128_1024.mlir @@ -0,0 +1,21 @@ +module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 4 : i32>, #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, #dlti.dl_entry<"L3_cache_size_in_bytes", 1966080 : i32>, #dlti.dl_entry<"max_vector_width", 512 : i32>>>} { + func.func @entry(%arg0: memref<128x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<128x1024xf32>, %arg3: memref<128x1024xf32>) attributes {compiletime_const_args_index = [1 : i32, 2 : i32]} { + %0 = bufferization.to_tensor %arg0 restrict : memref<128x1024xf32> + %1 = bufferization.to_tensor %arg1 restrict : memref<1024x1024xf32> + %2 = bufferization.to_tensor %arg2 restrict : memref<128x1024xf32> + %3 = tensor.empty() : tensor<1024x1024xf32> + %transposed = linalg.transpose ins(%1 : tensor<1024x1024xf32>) outs(%3 : tensor<1024x1024xf32>) permutation = [1, 0] + %4 = tensor.empty() : tensor<128x1024xf32> + %cst = arith.constant 0.000000e+00 : f32 + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + %6 = linalg.matmul ins(%0, %transposed : tensor<128x1024xf32>, tensor<1024x1024xf32>) outs(%5 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + %7 = tensor.empty() : tensor<128x1024xf32> + %8 = linalg.add ins(%6, %2 : tensor<128x1024xf32>, tensor<128x1024xf32>) outs(%7 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + %9 = tensor.empty() : tensor<128x1024xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %10 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + %11 = linalg.max ins(%8, %10 : tensor<128x1024xf32>, tensor<128x1024xf32>) outs(%9 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + bufferization.materialize_in_destination %11 in restrict writable %arg3 : (tensor<128x1024xf32>, memref<128x1024xf32>) -> () + return + } +} diff --git a/tools/mlir_bench/mlp_bench.sh b/tools/mlir_bench/mlp_bench.sh index 2317607d20c662..e59e6a731e6450 100755 --- a/tools/mlir_bench/mlp_bench.sh +++ b/tools/mlir_bench/mlp_bench.sh @@ -6,17 +6,18 @@ # Runs OV MLP benchmarks. die_syntax() { - echo "Syntax: $0 [-t (f32|f16|bf16|...)] [-b (mlp)] [-D] [-l 3]" + echo "Syntax: $0 [-t (f32|f16|bf16|...)] [-b (mlp)] [-D] [-l 3] [-n 1]" echo "" echo " -t: Optional data type" echo " -b: Optional baseline model" echo " -l: Optional number of layers (def:3)" echo " -D: Set model shapes to dynamic" + echo " -n: Set number of threads (default: 1)" exit 1 } # Cmd-line opts -while getopts "t:l:b:D" arg; do +while getopts "t:l:b:D:n:" arg; do case ${arg} in t) DATA_TYPE=${OPTARG} @@ -30,6 +31,9 @@ while getopts "t:l:b:D" arg; do D) IS_DYNAMIC=true ;; + n) + NUM_THREADS=${OPTARG} + ;; ?) echo "Invalid option: ${OPTARG}" die_syntax @@ -38,14 +42,28 @@ while getopts "t:l:b:D" arg; do done if [ ! $NUM_LAYERS ]; then - NUM_LAYERS=3 + NUM_LAYERS=1 +fi + +if [ ! $NUM_THREADS ]; then + NUM_THREADS=1 +fi + +if [ $NUM_THREADS == 1 ]; then + NUMA_CTL="numactl --physcpubind=4 --membind=0" +elif [ $NUM_THREADS == 4 ]; then + NUMA_CTL="numactl --physcpubind=4-7 --membind=0" +elif [ $NUM_THREADS == 8 ]; then + NUMA_CTL="numactl --physcpubind=4-11 --membind=0" +else + NUMA_CTL="" fi OV_ROOT=$(git rev-parse --show-toplevel) BENCH_ROOT=$(realpath "${OV_ROOT}/tools/mlir_bench") MODEL_GEN=$(realpath "${BENCH_ROOT}/ov_model_gen.py") -BENCH_RUNNER=benchmark_app +BENCH_RUNNER=${OV_ROOT}/bin/intel64/Release/benchmark_app # Initial validation. if ! [ -d "${OV_ROOT}" ]; then @@ -70,14 +88,13 @@ if [ "${BASELINE_MODEL}" ] && [ "${IS_DYNAMIC}" ]; then fi # Kernel config. -#LAYERS=( 1024 2048 4096 8192 ) -#MINI_BATCHES=( 128 256 512 ) +# LAYERS=( 1024 2048 4096 8192 ) +# MINI_BATCHES=( 128 256 512 ) LAYERS=( 1024 ) -MINI_BATCHES=( 256 ) +MINI_BATCHES=( 128 ) if [ ! "${DATA_TYPE}" ]; then DATA_TYPE="f32" fi -MODEL_NAME="MLIR_MLP_BENCH.xml" echo "Result type: time [ms] - NUM LAYERS: ${NUM_LAYERS}" for MB in "${MINI_BATCHES[@]}"; do @@ -96,10 +113,12 @@ for MB in "${MINI_BATCHES[@]}"; do MODEL_CONFIG=(-l="${MODEL_STRING}") fi echo "MODEL_CONFIG=${MODEL_CONFIG}" + MODEL_NAME="MLIR_MLP_BENCH_${DATA_TYPE}_${MB}_${LAYER}.xml" GEN_FLAGS=(-t ${DATA_TYPE} -n ${MODEL_NAME}) if [ "${IS_DYNAMIC}" ]; then GEN_FLAGS+=(--dynamic) fi + # echo "Gen Model cmd: python3 ${MODEL_GEN} ${MODEL_CONFIG[@]} ${GEN_FLAGS[@]}" python3 ${MODEL_GEN} "${MODEL_CONFIG[@]}" "${GEN_FLAGS[@]}" if [ $? != 0 ]; then echo "Failed to generate model" @@ -114,10 +133,11 @@ for MB in "${MINI_BATCHES[@]}"; do if [ "${IS_DYNAMIC}" ]; then DATA_SHAPE=(-data_shape [${MB},${LAYER}]) fi - # Benchmark config. Disable parallelism. - PERF_FLAGS="-niter 1000 -hint none -nstreams 1 -nthreads 1" + # Benchmark config. Enable openmp parallelism. + PERF_FLAGS="-niter 100 -hint none -nstreams 1 -nthreads ${NUM_THREADS}" BENCH_FLAGS="-m ${MODEL_NAME} -d CPU -ip ${PRECISION} -infer_precision ${DATA_TYPE} ${DATA_SHAPE[@]} ${PERF_FLAGS}" - ${BENCH_RUNNER} ${BENCH_FLAGS} 2>/dev/null | \ + echo "Bench cmd: OMP_NUM_THREADS=${NUM_THREADS} ${NUMA_CTL} ${BENCH_RUNNER} ${BENCH_FLAGS}" + OMP_NUM_THREADS=${NUM_THREADS} ${NUMA_CTL} ${BENCH_RUNNER} ${BENCH_FLAGS} 2>/dev/null | \ sed -nE "s/.*\[ INFO \]\s*Median:\s*([0-9.]+).*/\\1/p" done done diff --git a/tools/mlir_bench/ov_model_gen.py b/tools/mlir_bench/ov_model_gen.py index f233bac14edd2d..2636a9139bc299 100644 --- a/tools/mlir_bench/ov_model_gen.py +++ b/tools/mlir_bench/ov_model_gen.py @@ -178,12 +178,14 @@ def __init__(self, sizes_mnk, type=None, layers=3): super(BaselineMLP, self).__init__() m = sizes_mnk[0] n = sizes_mnk[1] + k = sizes_mnk[2] + self.weight = torch.empty((k, n), dtype=type).data.normal_(0, 0.01) self.bias = torch.empty((m, n), dtype=type).data.fill_(0.01) self.relu = nn.ReLU() self.layers = layers - def forward(self, a, b): + def forward(self, a): for _ in range(0,self.layers): - c = torch.matmul(a, b) + c = torch.matmul(a, self.weight) c = torch.add(c, self.bias) a = self.relu(c) return a @@ -201,8 +203,9 @@ def baseline_MLP(model_desc: str, data_type: str, is_dynamic: bool) -> tuple[nn. n = input_shapes[1] k = input_shapes[2] ov_type = get_ov_type(data_type) - inputs = [(ov.PartialShape([m, k]), ov_type), (ov.PartialShape([k, n]), ov_type)] - return (mlp, inputs) + inputs = [(ov.PartialShape([m, k]), ov_type)] + example_inputs = [torch.empty((m, k), dtype=get_torch_type(data_type)).data.normal_(0, 0.01)] + return (mlp, inputs, example_inputs) def generate_baseline_model(model_desc: str, data_type: str, file_name: str, is_dynamic: bool = False): @@ -213,8 +216,8 @@ def generate_baseline_model(model_desc: str, data_type: str, file_name: str, is_ else: assert False, f"Unsupported baseline model data type {model_name}" - ov_model = ov.convert_model(baseline_tuple[0], input=baseline_tuple[1]) - ov.save_model(ov_model, f"{file_name}") + ov_model = ov.convert_model(baseline_tuple[0], example_input=baseline_tuple[2], input=baseline_tuple[1]) + ov.save_model(ov_model, f"{file_name}", compress_to_fp16=False) return ov_model diff --git a/tools/mlir_bench/run_torch_linear.py b/tools/mlir_bench/run_torch_linear.py new file mode 100644 index 00000000000000..e93a8e6241d5ed --- /dev/null +++ b/tools/mlir_bench/run_torch_linear.py @@ -0,0 +1,137 @@ +import openvino as ov +import torch +import torch.nn as nn +from ml_dtypes import bfloat16 + +# input: (128, 1024). weight: (1024, 512). bias: (512). output: (128, 512) +batch_size = 128 +in_features = 1024 +out_features = 512 + +use_bf16 = True +dtype = torch.bfloat16 if use_bf16 else torch.float + +# class ToyNet(nn.Module): +# def __init__(self, in_features, out_features): +# super(ToyNet, self).__init__() +# # input: (N, in_feature), weight: (out_feature, in_feature), bias: (out_feature). +# # out = x * W^T + b +# self.fc = nn.Linear(in_features, out_features) + +# def forward(self, x): +# out = self.fc(x) +# return out + +# model = ToyNet(in_features, out_features) + + + +class ToyNet(nn.Module): + def __init__(self, in_features, out_features): + super(ToyNet, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype), requires_grad=False) + self.bias = nn.Parameter(torch.randn(out_features, dtype=dtype), requires_grad=False) + def forward(self, x): + out = x @ self.weight + self.bias + out = torch.relu(out) + return out + +model = ToyNet(in_features, out_features) + + + + +#class ToyNet(nn.Module): +# def __init__(self, in_features, out_features): +# super(ToyNet, self).__init__() +# # input: (N, in_feature), weight: (out_feature, in_feature), bias: (out_feature). +# # out = x * W^T + b +# self.fc1 = nn.Linear(in_features, in_features) +# self.fc2 = nn.Linear(in_features, out_features) +# self.fc3 = nn.Linear(out_features, out_features) +# +# def forward(self, x): +# out = torch.relu(self.fc1(x)) +# out = torch.relu(self.fc2(out)) +# out = torch.relu(self.fc3(out)) +# return out + +# # (128x64, 64x64) -> (128x64). (128x64, 64x64) -> (128x64). Two same matmuls. +#model = ToyNet(in_features, out_features) + +# in_features = 16 +# out_features = 64 +# # (128x16, 16x16) -> (128x16). (128x16, 16x64) -> (128x64). Two different matmuls. +# model = ToyNet(in_features, out_features) + + +# load PyTorch model into memory +# model = torch.hub.load("pytorch/vision", "shufflenet_v2_x1_0", weights="DEFAULT") + +# convert the model into OpenVINO model +example = torch.randn(batch_size, in_features, dtype=dtype) +print("===== Convert model start =====") +ov_model = ov.convert_model(model, example_input=(example,)) +ov_model.reshape([batch_size, in_features]) +print(ov_model) +print("===== Convert model finish =====") + +# print(ov_model.is_dynamic()) +# print(ov_model.get_ordered_ops()) +# weight = ov_model.get_ordered_ops()[1] +# print(weight.get_output_tensor(0)) +# print(weight.get_output_tensor(0).shape) + +# print("Compress weights start") +# ov_model = compress_weights(ov_model) +# print(ov_model) +# print("Compress weights finish") + +# openvino.save_model(model: ov::Model, output_model: object, compress_to_fp16: bool = True) +print("save model start") +saved_path = './toy-net.xml' +ov.save_model(ov_model, saved_path, compress_to_fp16=False) +print("save model finish") + +# compile the model for CPU device +core = ov.Core() +device_name = 'CPU' + +# Find 'EXPORT_IMPORT' capability in supported capabilities +#caching_supported = 'EXPORT_IMPORT' in core.get_property(device_name, 'OPTIMIZATION_CAPABILITIES') +#print("===== On device: ", device_name, ", model caching supported: ", caching_supported, " =====") + + +print("===== Compile model start =====") +compiled_model = core.compile_model(ov_model, device_name) +print(compiled_model) +print("===== Compile model finish =====") + +if use_bf16: + example_np = example.view(dtype=torch.uint16).numpy().view(bfloat16) +else: + example_np = example.numpy() + +# infer the model on random data +print("===== #1 run start =====") + +output = compiled_model({0: example_np}) +print("===== #1 run finish =====") + +print("===== #2 run start =====") +output = compiled_model({0: example_np}) +print("===== #2 run finish =====") + + +#model_from_read = core.read_model(saved_path) +#print("Read model:\n", model_from_read) +#compiled_model_from_read = core.compile_model(model=model_from_read, device_name=device_name) +#print("Read model compiled:\n", compiled_model_from_read) + +#print("===== #1 run start =====") +#output = compiled_model_from_read({0: example.numpy()}) +#print("===== #1 run finish =====") + +# /home/xiaoguang/ov-llm/openvino-slyalin-mlir/bin/intel64/Release/benchmark_app -m toy-net.xml -d CPU -ip f32 -infer_precision f32 -niter 10 -hint none -nstreams 1 -nthreads 16 diff --git a/tools/mlir_bench/run_torch_relu.py b/tools/mlir_bench/run_torch_relu.py new file mode 100644 index 00000000000000..de8c4dbd5be5ff --- /dev/null +++ b/tools/mlir_bench/run_torch_relu.py @@ -0,0 +1,60 @@ +import openvino as ov +import torch +import torch.nn as nn +from ml_dtypes import bfloat16 + +# input: (128, 1024, 512) +A = 128 +B = 1 +C = 512 + +use_bf16 = True +dtype = torch.bfloat16 if use_bf16 else torch.float + +class ToyNet(nn.Module): + def __init__(self): + super(ToyNet, self).__init__() + + def forward(self, x): + out = torch.relu(x) + return out + +model = ToyNet() + + +# convert the model into OpenVINO model +example = torch.randn(A, B, C, dtype=dtype) +print("===== Convert model start =====") +ov_model = ov.convert_model(model, example_input=(example,)) +ov_model.reshape([A, B, C]) +print(ov_model) +print("===== Convert model finish =====") + +print("save model start") +saved_path = './toy-net.xml' +ov.save_model(ov_model, saved_path, compress_to_fp16=False) +print("save model finish") + +# compile the model for CPU device +core = ov.Core() +device_name = 'CPU' + +print("===== Compile model start =====") +compiled_model = core.compile_model(ov_model, device_name) +print(compiled_model) +print("===== Compile model finish =====") + +if use_bf16: + example_np = example.view(dtype=torch.uint16).numpy().view(bfloat16) +else: + example_np = example.numpy() + +# infer the model on random data +print("===== #1 run start =====") + +output = compiled_model({0: example_np}) +print("===== #1 run finish =====") + +print("===== #2 run start =====") +output = compiled_model({0: example_np}) +print("===== #2 run finish =====") diff --git a/tools/mlir_bench/run_xml_model_accuracy.py b/tools/mlir_bench/run_xml_model_accuracy.py new file mode 100644 index 00000000000000..17c8b5e9046702 --- /dev/null +++ b/tools/mlir_bench/run_xml_model_accuracy.py @@ -0,0 +1,61 @@ +import openvino as ov +import numpy as np +import ml_dtypes +import torch +import os + +xml_file_name = "MLIR_MLP_BENCH_f32_128_1024.xml" +# xml_file_name = "MLIR_MLP_BENCH_bf16_128_1024.xml" + +# compile the model for CPU device +core = ov.Core() +device_name = 'CPU' + +model = core.read_model(xml_file_name) +# model.reshape([batch_size, in_features]) +print("Ops in the model: ", model.get_ordered_ops()) + +# only support single-input model: https://docs.openvino.ai/2024/api/ie_python_api/_autosummary/openvino.runtime.CompiledModel.html#openvino.runtime.CompiledModel.__call__ +for param in model.get_parameters(): + shape = param.get_output_shape(0) + type = param.get_output_element_type(0).get_type_name() + if type == "f32": + input = torch.randn(*shape, dtype=torch.float32).numpy() + elif type == "bf16": + input = torch.randn(*shape, dtype=torch.bfloat16).view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + else: + print("Unsupported data type: ", type) + exit() + + +os.environ['OV_MLIR'] = '0' +print("MLIR Enabled: ", os.getenv('OV_MLIR')) +compiled_model = core.compile_model(model=model, device_name=device_name) +print(compiled_model.get_runtime_model().get_ordered_ops()) + +print("===== #1 run start =====") +output_ov = compiled_model({0: input}) +print("===== #1 run finish =====") + + +os.environ['OV_MLIR'] = '1' +os.environ['OV_MLIR_DEBUG'] = '1' +print("MLIR Enabled: ", os.getenv('OV_MLIR')) +compiled_model = core.compile_model(model=model, device_name=device_name) +print(compiled_model.get_runtime_model().get_ordered_ops()) + +print("===== #1 run start =====") +output_gc_1 = compiled_model({0: input}) +print("===== #1 run finish =====") + +print("===== #2 run start =====") +output_gc_2 = compiled_model({0: input}) +print("===== #2 run finish =====") + +# Default 'rtol=1e-05, atol=1e-08' will report error. +if type == "f32": + tol = 1e-5 +elif type == "bf16": + tol = 1e-3 +print("OV output vs GC #1 output close: ", np.allclose(output_ov[0].astype(np.float32), output_gc_1[0].astype(np.float32), rtol=tol, atol=tol)) +print("OV output vs GC #2 output close: ", np.allclose(output_ov[0].astype(np.float32), output_gc_2[0].astype(np.float32), rtol=tol, atol=tol))