Skip to content

Commit

Permalink
Add a CMake option to enable TOSA. Default to ON. (#4021)
Browse files Browse the repository at this point in the history
Fixes #4019.

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Feb 12, 2025
1 parent ddc180f commit c9694c6
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 24 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directo
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON)

option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON)

if(TORCH_MLIR_ENABLE_REFBACKEND)
add_definitions(-DTORCH_MLIR_ENABLE_REFBACKEND)
endif()

set(TORCH_MLIR_TABLEGEN_FLAGS "")

option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_STABLEHLO")
endif()
# It is possible that both stablehlo and torch_mlir projects are used in some compiler project.
# In this case, we don't want to use stablehlo that is downloaded by torch_mlir (in external/stablehlo)
Expand All @@ -50,6 +54,12 @@ endif()
# stablehlo targets AND includes available (for example with `add_subdirectory` and `include_directories`).
option(TORCH_MLIR_USE_EXTERNAL_STABLEHLO "Use stablehlo from top level project" OFF)

option(TORCH_MLIR_ENABLE_TOSA "Add TOSA support" ON)
if(TORCH_MLIR_ENABLE_TOSA)
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_TOSA")
endif()

option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)

# PyTorch native extension gate. If OFF, then no features which depend on
Expand Down
10 changes: 5 additions & 5 deletions include/torch-mlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
add_subdirectory(TorchOnnxToTorch)

set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()



mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS})

add_public_tablegen_target(TorchMLIRConversionPassIncGen)

add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTensorPass()";
}

#ifdef TORCH_MLIR_ENABLE_TOSA
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
Expand All @@ -122,6 +123,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
}];
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}
#endif

def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()

mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS})

add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen)

add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc)
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ namespace TorchConversion {
/// linalg-on-tensors backend contract.
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

// Do not register the TOSA options if the TOSA target is disabled
#ifdef TORCH_MLIR_ENABLE_TOSA
/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);

std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
#endif // TORCH_MLIR_ENABLE_TOSA

// Do not register the stablehlo options if the stablehlo target is disabled
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct StablehloBackendPipelineOptions
Expand Down Expand Up @@ -57,7 +62,7 @@ createFinalizingBackendTypeConversionForStablehloPass();

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
#endif
#endif // TORCH_MLIR_ENABLE_STABLEHLO

std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

Expand All @@ -77,8 +82,6 @@ createConvertCustomQuantOpPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();

std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();

} // namespace TorchConversion

/// Registers all Torch transformation passes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
}

#ifdef TORCH_MLIR_ENABLE_TOSA
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
}
#endif

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
Expand Down
7 changes: 5 additions & 2 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ set(LinkedLibs
MLIRSCFDialect
MLIRTensorDialect
MLIRTensorInferTypeOpInterfaceImpl
MLIRTosaDialect
MLIRSupport

# Dialects.
Expand All @@ -33,7 +32,11 @@ set(LinkedLibs
)

if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses)
list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses)
endif()

if(TORCH_MLIR_ENABLE_TOSA)
list(APPEND LinkedLibs MLIRTosaDialect)
endif()

if(TORCH_MLIR_ENABLE_REFBACKEND)
Expand Down
8 changes: 6 additions & 2 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ add_subdirectory(TorchToArith)
add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToTensor)
add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_TOSA)
add_subdirectory(TorchToTosa)
endif()
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo)
endif()
Expand All @@ -16,13 +18,15 @@ set(linked_libs TorchMLIRTorchToArith
TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToTensor
TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
endif()
if(TORCH_MLIR_ENABLE_TOSA)
list(APPEND linked_libs TorchMLIRTorchToTosa)
endif()

add_mlir_library(TorchMLIRConversionPasses
Passes.cpp
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"

#ifdef TORCH_MLIR_ENABLE_TOSA
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#endif // TORCH_MLIR_ENABLE_TOSA

//===----------------------------------------------------------------------===//
// Pass registration
Expand Down
14 changes: 10 additions & 4 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "stablehlo/transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"

#ifdef TORCH_MLIR_ENABLE_TOSA
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
using namespace mlir::tosa;
#endif

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::tosa;

//===----------------------------------------------------------------------===//
// Pass registration
Expand All @@ -46,12 +49,13 @@ void mlir::torch::registerTorchConversionPasses() {
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

#ifdef TORCH_MLIR_ENABLE_TOSA
mlir::PassPipelineRegistration<>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline);
#endif
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::PassPipelineRegistration<
TorchConversion::StablehloBackendPipelineOptions>(
Expand Down Expand Up @@ -107,6 +111,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
}

#ifdef TORCH_MLIR_ENABLE_TOSA
void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
Expand All @@ -130,6 +135,7 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
// correct form.
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
}
#endif

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifdef TORCH_MLIR_ENABLE_TOSA
#include "PassDetail.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -63,3 +63,4 @@ std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() {
return std::make_unique<VerifyTosaBackendContractPass>();
}
#endif // TORCH_MLIR_ENABLE_TOSA
10 changes: 8 additions & 2 deletions lib/InitAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Dialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
Expand All @@ -36,6 +35,10 @@
#include "stablehlo/transforms/Passes.h"
#endif

#ifdef TORCH_MLIR_ENABLE_TOSA
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#endif

void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>();
Expand All @@ -54,7 +57,10 @@ void mlir::torch::registerOptionalInputDialects(
registry.insert<complex::ComplexDialect, linalg::LinalgDialect,
memref::MemRefDialect, ml_program::MLProgramDialect,
scf::SCFDialect, sparse_tensor::SparseTensorDialect,
tensor::TensorDialect, tosa::TosaDialect>();
tensor::TensorDialect>();
#ifdef TORCH_MLIR_ENABLE_TOSA
registry.insert<tosa::TosaDialect>();
#endif
}

void mlir::torch::registerAllPasses() {
Expand Down

0 comments on commit c9694c6

Please sign in to comment.