Skip to content

Commit

Permalink
[MLIR][DLTI] Add DLTI attr to MLIR-generated modules
Browse files Browse the repository at this point in the history
This change serves as a PoC of OV being able to communicate
hints to the MLIR-compiler that is responible for the subgraph.
As a PoC we just pass the magic number 32 as a tile size hint.

Later changes can incorporate OV-based logic for deriving the values for
these hints.
  • Loading branch information
rolfmorel committed Aug 8, 2024
1 parent f5a4471 commit a1a0774
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/common/transformations/src/transformations/mlir/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Passes.h"
Expand Down Expand Up @@ -133,6 +134,16 @@ mlir::OwningOpRef<mlir::ModuleOp> ngraph_to_mlir(MLIRContext* context,
auto func = moduleBuilder.create<mlir::func::FuncOp>(funcLoc, "entry", funcType);
auto block_builder = mlir::OpBuilder::atBlockBegin(func.addEntryBlock() /* TODO: Add logger here */);

// Affix target information attribute to the module to be used, at its discretion,
// by the MLIR-compiler that consumes this module.
auto tileSize = IntegerAttr::get(IntegerType::get(context, 32), 32);
auto key = StringAttr::get(context, "tile_size");
DataLayoutEntryInterface entry = DataLayoutEntryAttr::get(context, key, tileSize);
TargetDeviceSpecInterface deviceSpec = TargetDeviceSpecAttr::get(context, ArrayRef(entry));
auto deviceStr = StringAttr::get(context, "CPU");
auto sysSpec = TargetSystemSpecAttr::get(context, ArrayRef(std::pair(deviceStr, deviceSpec)));
module.getOperation()->setAttr("#dlti.sys_spec", sysSpec);

ConversionContext conversion_context(context, &block_builder);

for (size_t i = 0; i < inputs.size(); ++i) {
Expand Down Expand Up @@ -299,6 +310,7 @@ void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context, MlirMode
}

void loadDialects(MLIRContext* context) {
context->loadDialect<mlir::DLTIDialect>();
context->loadDialect<mlir::func::FuncDialect>();
context->loadDialect<mlir::linalg::LinalgDialect>();
context->loadDialect<mlir::bufferization::BufferizationDialect>();
Expand Down

0 comments on commit a1a0774

Please sign in to comment.