diff --git a/compiler/bindings/python/test/api/api_test.py b/compiler/bindings/python/test/api/api_test.py index 0d52bd3ba004..c12873a52288 100644 --- a/compiler/bindings/python/test/api/api_test.py +++ b/compiler/bindings/python/test/api/api_test.py @@ -52,6 +52,19 @@ def testFlagError(self): with self.assertRaises(ValueError): session.set_flags("--does-not-exist=1") + def testOptFlags(self): + session = Session() + flags = session.get_flags() + self.assertIn("--iree-opt-level=O0", flags) + self.assertIn("--iree-global-optimization-opt-level=O0", flags) + self.assertIn("--iree-opt-strip-assertions=false", flags) + + session.set_flags("--iree-opt-level=O2") + flags = session.get_flags() + self.assertIn("--iree-opt-level=O2", flags) + self.assertIn("--iree-global-optimization-opt-level=O0", flags) + self.assertIn("--iree-opt-strip-assertions=false", flags) + class DlInvocationTest(unittest.TestCase): def testCreate(self): session = Session() diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index d12dbc5c6d4b..7507ea9767e8 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -234,6 +234,7 @@ struct GlobalInit { // Our session options can optionally be bound to the global command-line // environment. If that is not the case, then these will be nullptr, and // they should be default initialized at the session level. + GlobalPipelineOptions *clGlobalPipelineOptions = nullptr; PluginManagerOptions *clPluginManagerOptions = nullptr; BindingOptions *clBindingOptions = nullptr; InputDialectOptions *clInputOptions = nullptr; @@ -278,6 +279,7 @@ void GlobalInit::registerCommandLineOptions() { mlir::tracing::DebugConfig::registerCLOptions(); // Bind session options to the command line environment. + clGlobalPipelineOptions = &GlobalPipelineOptions::FromFlags::get(); clPluginManagerOptions = &PluginManagerOptions::FromFlags::get(); clBindingOptions = &BindingOptions::FromFlags::get(); clInputOptions = &InputDialectOptions::FromFlags::get(); @@ -323,6 +325,7 @@ struct Session { if (failed(binder.parseArguments(argc, argv, callback))) { return new Error(std::move(errorMessage)); } + return nullptr; } @@ -387,6 +390,7 @@ struct Session { bool pluginsActivated = false; LogicalResult pluginActivationStatus{failure()}; + GlobalPipelineOptions pipelineOptions; BindingOptions bindingOptions; InputDialectOptions inputOptions; PreprocessingOptions preprocessingOptions; @@ -410,6 +414,7 @@ Session::Session(GlobalInit &globalInit) // Bootstrap session options from the cl environment, if enabled. if (globalInit.usesCommandLine) { debugConfig = mlir::tracing::DebugConfig::createFromCLOptions(); + pipelineOptions = *globalInit.clGlobalPipelineOptions; pluginManagerOptions = *globalInit.clPluginManagerOptions; bindingOptions = *globalInit.clBindingOptions; inputOptions = *globalInit.clInputOptions; @@ -430,6 +435,7 @@ Session::Session(GlobalInit &globalInit) // Register each options struct with the binder so we can manipulate // mnemonically via the API. + pipelineOptions.bindOptions(binder); bindingOptions.bindOptions(binder); preprocessingOptions.bindOptions(binder); inputOptions.bindOptions(binder); @@ -938,6 +944,17 @@ void Invocation::dumpCompilationPhase(IREEVMPipelinePhase phase, bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { auto passManager = createPassManager(); + + auto globalBinder = OptionsBinder::global(); + auto &binder = + session.globalInit.usesCommandLine ? globalBinder : session.binder; + GlobalOptimizationOptions highLevelOptimizationOptions = + session.highLevelOptimizationOptions; + + // Set optimization options on a copy of the sessions options. + highLevelOptimizationOptions.applyOptimization(binder, + session.pipelineOptions); + switch (pipeline) { case IREE_COMPILER_PIPELINE_STD: { IREEVMPipelinePhase compileFrom; @@ -962,7 +979,7 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { buildIREEVMTransformPassPipeline( session.targetRegistry, session.bindingOptions, session.inputOptions, - session.preprocessingOptions, session.highLevelOptimizationOptions, + session.preprocessingOptions, highLevelOptimizationOptions, session.schedulingOptions, session.halTargetOptions, session.vmTargetOptions, pipelineHooks, *passManager, compileFrom, compileTo); @@ -995,7 +1012,7 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { } buildIREEPrecompileTransformPassPipeline( session.targetRegistry, session.bindingOptions, session.inputOptions, - session.preprocessingOptions, session.highLevelOptimizationOptions, + session.preprocessingOptions, highLevelOptimizationOptions, session.schedulingOptions, session.halTargetOptions, pipelineHooks, *passManager, compileFrom, compileTo); break; diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 6fa65587b7a5..169d82754ab3 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Pipelines/Options.h" +#include "llvm/Passes/OptimizationLevel.h" IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::BindingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::InputDialectOptions); @@ -12,9 +13,21 @@ IREE_DEFINE_COMPILER_OPTION_FLAGS( mlir::iree_compiler::GlobalOptimizationOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::SchedulingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::PreprocessingOptions); +IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::GlobalPipelineOptions); namespace mlir::iree_compiler { +void GlobalPipelineOptions::bindOptions(OptionsBinder &binder) { + static llvm::cl::OptionCategory category( + "IREE global pipeline options controlling the entire compilation flow."); + + binder.opt( + "iree-opt-level", optLevel, + llvm::cl::desc("Global optimization level to apply to the entire " + "compilation flow."), + llvm::cl::cat(category)); +} + void BindingOptions::bindOptions(OptionsBinder &binder) { static llvm::cl::OptionCategory category( "IREE translation binding support options."); @@ -119,17 +132,33 @@ void PreprocessingOptions::bindOptions(OptionsBinder &binder) { llvm::cl::cat(category)); } +void GlobalOptimizationOptions::applyOptimization( + const OptionsBinder &binder, const GlobalPipelineOptions &globalLevel) { + binder.overrideDefault("iree-global-optimization-opt-level", optLevel, + globalLevel.optLevel); + binder.applyOptimization("iree-opt-strip-assertions", stripAssertions, + optLevel); +}; + void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { static llvm::cl::OptionCategory category( "IREE options for controlling global optimizations."); + binder.optimizationLevel( + "iree-global-optimization-opt-level", optLevel, + llvm::cl::desc("Optimization level for the this pipeline"), + llvm::cl::cat(category)); binder.opt( "iree-opt-aggressively-propagate-transposes", aggressiveTransposePropagation, + {init_at_opt(llvm::OptimizationLevel::O0, false), + init_at_opt(llvm::OptimizationLevel::O3, true)}, llvm::cl::desc( "Propagates transposes to named ops even when the resulting op will " "be a linalg.generic"), llvm::cl::cat(category)); binder.opt("iree-opt-outer-dim-concat", outerDimConcat, + {init_at_opt(llvm::OptimizationLevel::O0, false), + init_at_opt(llvm::OptimizationLevel::O1, true)}, llvm::cl::desc("Transposes all concatenations to happen" "along the outer most dimension."), llvm::cl::cat(category)); @@ -159,6 +188,8 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { "Reduces numeric precision to lower bit depths where possible."), llvm::cl::cat(category)); binder.opt("iree-opt-strip-assertions", stripAssertions, + {init_at_opt(llvm::OptimizationLevel::O0, false), + init_at_opt(llvm::OptimizationLevel::O1, true)}, llvm::cl::desc("Strips debug assertions after any useful " "information has been extracted."), llvm::cl::cat(category)); @@ -198,6 +229,8 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { binder.opt( "iree-opt-generalize-matmul", generalizeMatmul, + {init_at_opt(llvm::OptimizationLevel::O0, false), + init_at_opt(llvm::OptimizationLevel::O2, true)}, llvm::cl::desc("Convert named matmul ops to linalg generic ops during " "global optimization to enable better fusion."), llvm::cl::cat(category)); diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index e60c5d255874..5f72ac7513fc 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -11,6 +11,12 @@ namespace mlir::iree_compiler { +struct GlobalPipelineOptions { + llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; + void bindOptions(OptionsBinder &binder); + using FromFlags = OptionsFromFlags; +}; + struct BindingOptions { // Whether to include runtime support functions for the IREE native ABI. bool native = true; @@ -72,6 +78,7 @@ struct InputDialectOptions { // 2. Through a Transform dialect spec file. // 3. Through a PDL spec file. struct PreprocessingOptions { + llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; std::string preprocessingPassPipeline; std::string preprocessingTransformSpecFilename; std::string preprocessingPDLSpecFilename; @@ -82,6 +89,7 @@ struct PreprocessingOptions { // Options controlling high level optimizations. struct GlobalOptimizationOptions { + llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; // Maximum byte size increase allowed for constant expr hoisting policy to // allow hoisting. The threshold is 1MB by default. int64_t constExprMaxSizeIncreaseThreshold = 1024 * 1024; @@ -131,6 +139,8 @@ struct GlobalOptimizationOptions { // Converts linalg named matmul ops to linalg generic ops. bool generalizeMatmul = false; + void applyOptimization(const OptionsBinder &binder, + const GlobalPipelineOptions &globalLevel); void bindOptions(OptionsBinder &binder); using FromFlags = OptionsFromFlags; }; diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.cpp b/compiler/src/iree/compiler/Utils/OptionUtils.cpp index 2fce6c90702d..f427e6d73b63 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.cpp +++ b/compiler/src/iree/compiler/Utils/OptionUtils.cpp @@ -6,15 +6,14 @@ #include "iree/compiler/Utils/OptionUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Passes/OptimizationLevel.h" #include "llvm/Support/ManagedStatic.h" namespace mlir::iree_compiler { -void OptionsBinder::addGlobalOption(std::unique_ptr option) { - static llvm::ManagedStatic>> - globalOptions; - globalOptions->push_back(std::move(option)); -} +llvm::ManagedStatic> + OptionsBinder::globalOptions; LogicalResult OptionsBinder::parseArguments(int argc, const char *const *argv, ErrorCallback onError) { @@ -76,10 +75,10 @@ LogicalResult OptionsBinder::parseArguments(int argc, const char *const *argv, llvm::SmallVector OptionsBinder::printArguments(bool nonDefaultOnly) { llvm::SmallVector values; - for (auto &info : localOptions) { + for (auto &[flag, info] : getOptionsStorage()) { if (!info.print) continue; - if (nonDefaultOnly && !info.isChanged()) + if (nonDefaultOnly && !info.isDefault()) continue; std::string s; @@ -91,6 +90,22 @@ OptionsBinder::printArguments(bool nonDefaultOnly) { return values; } +OptionsBinder::OptionsStorage &OptionsBinder::getOptionsStorage() { + if (!scope) { + return *globalOptions; + } else { + return localOptions; + } +} + +const OptionsBinder::OptionsStorage &OptionsBinder::getOptionsStorage() const { + if (!scope) { + return *globalOptions; + } else { + return localOptions; + } +} + } // namespace mlir::iree_compiler // // Examples: @@ -191,3 +206,44 @@ void llvm::cl::parser::printOptionDiff( } void llvm::cl::parser::anchor() {} + +bool llvm::cl::parser::parse( + Option &O, StringRef ArgName, StringRef Arg, llvm::OptimizationLevel &Val) { + auto val = StringSwitch>(Arg) + .Case("O0", OptimizationLevel::O0) + .Case("O1", OptimizationLevel::O1) + .Case("O2", OptimizationLevel::O2) + .Case("O3", OptimizationLevel::O3) + // .Case("Os", OptimizationLevel::Os) + // .Case("Oz", OptimizationLevel::Oz) + .Default(std::nullopt); + if (!val) { + return O.error("'" + Arg + + "' value not a valid optimization level, use " + "O0/O1/O2/O3"); + } + Val = *val; + return false; +} + +void llvm::cl::parser::printOptionDiff( + const Option &O, llvm::OptimizationLevel V, const OptVal &Default, + size_t GlobalWidth) const { + printOptionName(O, GlobalWidth); + std::string Str; + { + llvm::raw_string_ostream SS(Str); + SS << V.getSpeedupLevel() << "/" << V.getSizeLevel(); + } + outs() << "= " << Str; + outs().indent(2) << " (default: "; + if (Default.hasValue()) { + auto defaultVal = Default.getValue(); + outs() << defaultVal.getSpeedupLevel() << "/" << defaultVal.getSizeLevel(); + } else { + outs() << "*no default*"; + } + outs() << ")\n"; +} + +void llvm::cl::parser::anchor() {} diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.h b/compiler/src/iree/compiler/Utils/OptionUtils.h index d97228182410..8a9fe87f78e0 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.h +++ b/compiler/src/iree/compiler/Utils/OptionUtils.h @@ -9,12 +9,67 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Passes/OptimizationLevel.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Support/LogicalResult.h" +namespace llvm { +inline raw_ostream &operator<<(raw_ostream &os, + const llvm::OptimizationLevel &opt) { + return os << 'O' << opt.getSpeedupLevel(); +} +} // namespace llvm + namespace mlir::iree_compiler { +struct opt_initializer_base { + virtual ~opt_initializer_base() = default; +}; + +template +struct opt_initializer : opt_initializer_base { + Ty init; + llvm::OptimizationLevel optLevel; + opt_initializer(const llvm::OptimizationLevel opt, const Ty &val) + : init(val), optLevel(opt) {} + void apply(const llvm::OptimizationLevel inLevel, Ty &val) const { + assert(inLevel.getSizeLevel() == 0 && "size level not implemented"); + if (inLevel.getSpeedupLevel() >= optLevel.getSpeedupLevel()) + val = init; + } + + /// Append to the description string of the flag. + /// e.g. " at O2 default is true" + void appendToDesc(std::string &desc) { + llvm::raw_string_ostream os(desc); + os << "\nAt optimization level " << optLevel << " the default is "; + prettyPrint(os, init); + } + +private: + // TODO: merge this with the printing in `OptionsBinder`. + template + static void prettyPrint(llvm::raw_ostream &os, T &val) { + os << val; + } + + template <> + void prettyPrint(llvm::raw_ostream &os, bool &val) { + os << (val ? "true" : "false"); + } +}; + +/// Initialize the value of a variable if the optimization level is at least +/// the specified level. +template +opt_initializer init_at_opt(llvm::OptimizationLevel optLevel, + const Ty &val) { + return opt_initializer(optLevel, val); +} + // Base class that can bind named options to fields of structs. // // Typically use by adding the following to your struct: @@ -42,25 +97,99 @@ class OptionsBinder { return OptionsBinder(std::make_unique()); } - template + template < + typename T, typename V, typename... Mods, + std::enable_if_t< + !(std::is_same_v, opt_initializer> || ...), + int> = 0> void opt(llvm::StringRef name, V &value, Mods... Ms) { + auto [changedCallback, clCallback] = makeChangedCallback(); if (!scope) { // Bind global options. auto opt = std::make_unique>( - name, llvm::cl::location(value), llvm::cl::init(value), + name, llvm::cl::location(value), llvm::cl::init(value), clCallback, std::forward(Ms)...); - addGlobalOption(std::move(opt)); + auto defaultCallback = makeDefaultCallback(&value); + getOptionsStorage()[name] = OptionInfo{std::move(opt), /*print=*/nullptr, + /*isChanged=*/changedCallback, + /*isDefault*/ defaultCallback}; } else { // Bind local options. auto option = std::make_unique>( name, llvm::cl::sub(*scope), llvm::cl::location(value), - llvm::cl::init(value), std::forward(Ms)...); + llvm::cl::init(value), clCallback, std::forward(Ms)...); auto printCallback = makePrintCallback(option->ArgStr, option->getParser(), &value); - auto changedCallback = makeChangedCallback(&value); - localOptions.push_back( - LocalOptionInfo{std::move(option), printCallback, changedCallback}); + auto defaultCallback = makeDefaultCallback(&value); + getOptionsStorage()[name] = OptionInfo{ + std::move(option), /*print=*/printCallback, + /*isChanged=*/changedCallback, /*isDefault*/ defaultCallback}; + } + } + + // Bind a flag with a single `opt_initialier` that specifies defaults at a + // given optimization level. + template + void opt(llvm::StringRef name, V &value, + std::initializer_list> inits, Mods... Ms) { + llvm::SmallVector> initsSorted(inits.begin(), + inits.end()); + llvm::sort(initsSorted, [](opt_initializer &lhs, + opt_initializer &rhs) { + return lhs.optLevel.getSpeedupLevel() < rhs.optLevel.getSpeedupLevel(); + }); + + llvm::cl::desc &desc = filterDescription(Ms...); + auto descStr = std::make_unique(desc.Desc); + for (auto &init : initsSorted) { + init.appendToDesc(*descStr); + } + desc.Desc = descStr->c_str(); + + opt(name, value, Ms...); + OptionInfo &info = getOptionsStorage()[name]; + info.extendedDesc = std::move(descStr); + for (auto &init : initsSorted) { + info.optInits.emplace_back(std::make_unique>(init)); + } + } + + template + void optimizationLevel(llvm::StringRef name, llvm::OptimizationLevel &value, + Mods... Ms) { + opt(name, value, Ms...); + } + + template + void applyOptimization(llvm::StringRef name, T &value, + llvm::OptimizationLevel optLevel) const { + const auto infoIt = getOptionsStorage().find(name); + assert(infoIt != getOptionsStorage().end() && "Option not found"); + auto changedCallback = infoIt->getSecond().isChanged; + assert(changedCallback && "Expected changed callback"); + if (changedCallback()) { + return; + } + const auto &setOptLevels = infoIt->getSecond().optInits; + for (const auto &init : setOptLevels) { + reinterpret_cast *>(init.get()) + ->apply(optLevel, value); + } + } + + bool isFlagSet(llvm::StringRef name) const { + const auto infoIt = getOptionsStorage().find(name); + assert(infoIt != getOptionsStorage().end() && "Option not found"); + const auto &isChanged = infoIt->getSecond().isChanged; + assert(isChanged && "Expected changed callback"); + return isChanged(); + } + + template + void overrideDefault(llvm::StringRef name, T &val, const T &update) const { + if (!isFlagSet(name)) { + val = update; } } @@ -74,21 +203,25 @@ class OptionsBinder { // and use it to update. list->setCallback( [&value](const T &newElement) { value.push_back(newElement); }); - addGlobalOption(std::move(list)); + auto defaultCallback = makeListDefaultCallback(&value); + getOptionsStorage()[name] = + OptionInfo{std::move(list), /*print=*/nullptr, + /*isChanged=*/nullptr, /*isDefault*/ defaultCallback}; } else { // Bind local options. auto list = std::make_unique>( name, llvm::cl::sub(*scope), std::forward(Ms)...); auto printCallback = makeListPrintCallback(list->ArgStr, list->getParser(), &value); - auto changedCallback = makeListChangedCallback(&value); + auto defaultCallback = makeListDefaultCallback(&value); // Since list does not support external storage, hook the callback // and use it to update. list->setCallback( [&value](const T &newElement) { value.push_back(newElement); }); - localOptions.push_back( - LocalOptionInfo{std::move(list), printCallback, changedCallback}); + getOptionsStorage()[name] = + OptionInfo{std::move(list), /*print=*/printCallback, + /*isChanged=*/nullptr, /*isDefault=*/defaultCallback}; } } @@ -104,18 +237,27 @@ class OptionsBinder { llvm::SmallVector printArguments(bool nonDefaultOnly = false); private: - struct LocalOptionInfo { - using ChangedCallback = std::function; + struct OptionInfo { using PrintCallback = std::function; + using ChangedCallback = std::function; + using DefaultCallback = std::function; std::unique_ptr option; PrintCallback print; ChangedCallback isChanged; + DefaultCallback isDefault; + + // For options with optimization level defaults. + llvm::SmallVector> optInits; + std::unique_ptr extendedDesc; }; + using OptionsStorage = llvm::DenseMap; + + OptionsStorage &getOptionsStorage(); + const OptionsStorage &getOptionsStorage() const; OptionsBinder() = default; OptionsBinder(std::unique_ptr scope) : scope(std::move(scope)) {} - void addGlobalOption(std::unique_ptr option); // LLVM makes a half-hearted (i.e. "best effort" == "no effort") attempt to // handle non-enumerated generic value based options, but the generic @@ -127,7 +269,7 @@ class OptionsBinder { static auto makePrintCallback(llvm::StringRef optionName, ParserTy &parser, V *value) -> decltype(static_cast(parser), - static_cast(*value), LocalOptionInfo::PrintCallback()) { + static_cast(*value), OptionInfo::PrintCallback()) { return [optionName, &parser, value](llvm::raw_ostream &os) { llvm::StringRef valueName(""); for (unsigned i = 0; i < parser.getNumOptions(); ++i) { @@ -148,7 +290,7 @@ class OptionsBinder { static auto makePrintCallback(llvm::StringRef optionName, ParserTy &parser, V *value) -> decltype(static_cast &>(parser), - LocalOptionInfo::PrintCallback()) { + OptionInfo::PrintCallback()) { return [optionName, value](llvm::raw_ostream &os) { os << "--" << optionName << "=" << *value; }; @@ -159,7 +301,7 @@ class OptionsBinder { static auto makePrintCallback(llvm::StringRef optionName, ParserTy &parser, bool *value) -> decltype(static_cast &>(parser), - LocalOptionInfo::PrintCallback()) { + OptionInfo::PrintCallback()) { return [optionName, value](llvm::raw_ostream &os) { os << "--" << optionName << "="; if (*value) { @@ -170,9 +312,21 @@ class OptionsBinder { }; } - // Scalar changed specialization. + // Returns a pair of callbacks, the first returns if the option has been + // parsed and the second is passed to llvm::cl to track if the option has been + // parsed. template - static LocalOptionInfo::ChangedCallback makeChangedCallback(V *currentValue) { + static std::pair> + makeChangedCallback() { + std::shared_ptr changed = std::make_shared(false); + return std::pair{ + [changed]() -> bool { return *changed; }, + llvm::cl::cb([changed](const V &) { *changed = true; })}; + } + + // Scalar default specialization. + template + static OptionInfo::DefaultCallback makeDefaultCallback(V *currentValue) { // Capture the current value as the initial value. V initialValue = *currentValue; return [currentValue, initialValue]() -> bool { @@ -180,10 +334,9 @@ class OptionsBinder { }; } - // List changed specialization. + // List default specialization. template - static LocalOptionInfo::ChangedCallback - makeListChangedCallback(V *currentValue) { + static OptionInfo::DefaultCallback makeListDefaultCallback(V *currentValue) { return [currentValue]() -> bool { return !currentValue->empty(); }; } @@ -194,7 +347,7 @@ class OptionsBinder { static auto makeListPrintCallback(llvm::StringRef optionName, ParserTy &parser, ListTy *values) -> decltype(static_cast &>(parser), - LocalOptionInfo::PrintCallback()) { + OptionInfo::PrintCallback()) { return [optionName, values](llvm::raw_ostream &os) { os << "--" << optionName << "="; for (auto it : llvm::enumerate(*values)) { @@ -205,8 +358,26 @@ class OptionsBinder { }; } + // Finds the description in args + template + static llvm::cl::desc &filterDescription(Args &...args) { + llvm::cl::desc *result = nullptr; + ( + [&] { + if constexpr (std::is_same_v, llvm::cl::desc>) { + assert(!result && "Multiple llvm::cl::desc in args"); + if (!result) + result = &args; + } + }(), + ...); + assert(result && "Expected llvm::cl::desc in args"); + return *result; + } + std::unique_ptr scope; - llvm::SmallVector localOptions; + OptionsStorage localOptions; + static llvm::ManagedStatic globalOptions; }; // Generic class that is used for allocating an Options class that initializes @@ -275,6 +446,18 @@ class parser : public basic_parser { void anchor() override; }; +template <> +class parser : public basic_parser { +public: + parser(Option &O) : basic_parser(O) {} + bool parse(Option &O, StringRef ArgName, StringRef Arg, + OptimizationLevel &Val); + StringRef getValueName() const override { return "optimization level"; } + void printOptionDiff(const Option &O, OptimizationLevel V, + const OptVal &Default, size_t GlobalWidth) const; + void anchor() override; +}; + } // namespace llvm::cl #endif // IREE_COMPILER_UTILS_FLAG_UTILS_H diff --git a/compiler/src/iree/compiler/Utils/unittests/BUILD.bazel b/compiler/src/iree/compiler/Utils/unittests/BUILD.bazel index 3581fddea250..4c651e4e2644 100644 --- a/compiler/src/iree/compiler/Utils/unittests/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/unittests/BUILD.bazel @@ -19,6 +19,7 @@ iree_compiler_cc_test( "//compiler/src/iree/compiler/Utils", "//compiler/src/iree/testing:gtest_main", "@com_google_googletest//:gtest", + "@llvm-project//llvm:Passes", "@llvm-project//llvm:Support", ], ) diff --git a/compiler/src/iree/compiler/Utils/unittests/CMakeLists.txt b/compiler/src/iree/compiler/Utils/unittests/CMakeLists.txt index a850b2d7c83c..e57f1ea3c185 100644 --- a/compiler/src/iree/compiler/Utils/unittests/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/unittests/CMakeLists.txt @@ -16,6 +16,7 @@ iree_cc_test( SRCS "UtilsTest.cpp" DEPS + LLVMPasses LLVMSupport gmock gtest diff --git a/compiler/src/iree/compiler/Utils/unittests/UtilsTest.cpp b/compiler/src/iree/compiler/Utils/unittests/UtilsTest.cpp index db757a195332..1e5d2cdd086b 100644 --- a/compiler/src/iree/compiler/Utils/unittests/UtilsTest.cpp +++ b/compiler/src/iree/compiler/Utils/unittests/UtilsTest.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Utils/EmbeddedDataDirectory.h" #include "iree/compiler/Utils/Indexing.h" +#include "iree/compiler/Utils/OptionUtils.h" #include "iree/compiler/Utils/Permutation.h" #include "llvm/Support/FormatVariadic.h" @@ -106,3 +107,85 @@ TEST(BasisFromSizeStrides, OverlappingStrides) { EXPECT_FALSE( succeeded(basisFromSizesStrides({8, 4}, {6, 1}, basis, dimToResult))); } + +//=------------------------------------------------------------------------------=// +// OptionUtils tests +//=------------------------------------------------------------------------------=// + +namespace { +struct TestOptions { + llvm::OptimizationLevel parentOption = llvm::OptimizationLevel::O0; + bool childOption = false; + + void bindOptions(OptionsBinder &binder) { + binder.opt("parent-option", parentOption); + binder.opt("child-option", childOption); + } + + void applyOptimization(const OptionsBinder &binder, + llvm::OptimizationLevel globalOptLevel) { + binder.overrideDefault("parent-option", parentOption, globalOptLevel); + if (parentOption == llvm::OptimizationLevel::O3) { + binder.overrideDefault("child-option", childOption, true); + } + } +}; +} // namespace + +TEST(OptionUtils, DefaultTest) { + auto binder = OptionsBinder::local(); + TestOptions opts; + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O0); + EXPECT_EQ(opts.childOption, false); + + opts.bindOptions(binder); + LogicalResult parseResult = binder.parseArguments(0, nullptr); + + EXPECT_TRUE(succeeded(parseResult)); + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O0); + EXPECT_EQ(opts.childOption, false); +} + +TEST(OptionUtils, OverrideParent) { + auto binder = OptionsBinder::local(); + TestOptions opts; + opts.bindOptions(binder); + LogicalResult parseResult = binder.parseArguments(0, nullptr); + + opts.applyOptimization(binder, llvm::OptimizationLevel::O1); + EXPECT_TRUE(succeeded(parseResult)); + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O1); + EXPECT_EQ(opts.childOption, false); +} + +TEST(OptionUtils, NoOverrideParent) { + auto binder = OptionsBinder::local(); + TestOptions opts; + opts.bindOptions(binder); + + int argc = 1; + const char *argv[] = {"--parent-option=O2"}; + LogicalResult parseResult = binder.parseArguments(argc, argv); + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O2); + EXPECT_EQ(opts.childOption, false); + + opts.applyOptimization(binder, llvm::OptimizationLevel::O1); + EXPECT_TRUE(succeeded(parseResult)); + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O2); + EXPECT_EQ(opts.childOption, false); +} + +TEST(OptionUtils, OverrideParentAndChild) { + auto binder = OptionsBinder::local(); + TestOptions opts; + opts.bindOptions(binder); + + LogicalResult parseResult = binder.parseArguments(0, nullptr); + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O0); + EXPECT_EQ(opts.childOption, false); + + opts.applyOptimization(binder, llvm::OptimizationLevel::O3); + EXPECT_TRUE(succeeded(parseResult)); + EXPECT_EQ(opts.parentOption, llvm::OptimizationLevel::O3); + EXPECT_EQ(opts.childOption, true); +} diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index e60836207114..bc3cd611b420 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -258,14 +258,11 @@ def SDXL_PUNET_INT8_FP8_OUT( ROCM_COMPILE_FLAGS = [ "--iree-hal-target-backends=rocm", f"--iree-hip-target={rocm_chip}", + "--iree-opt-level=O3", "--iree-opt-const-eval=false", - "--iree-opt-strip-assertions=true", "--iree-global-opt-propagate-transposes=true", "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", "--iree-dispatch-creation-enable-aggressive-fusion=true", - "--iree-opt-aggressively-propagate-transposes=true", - "--iree-opt-outer-dim-concat=true", - "--iree-opt-generalize-matmul=true", "--iree-vm-target-truncate-unsupported-floats", "--iree-llvmgpu-enable-prefetch=true", "--iree-opt-data-tiling=false", diff --git a/tools/test/BUILD.bazel b/tools/test/BUILD.bazel index 23a64dbe83d7..74024b458230 100644 --- a/tools/test/BUILD.bazel +++ b/tools/test/BUILD.bazel @@ -21,12 +21,14 @@ iree_lit_test_suite( [ "benchmark_flags.txt", "compile_pipelines.mlir", + "compile_flags.mlir", "compile_to_continuation.mlir", "compile_to_phase.mlir", "executable_benchmarks.mlir", "executable_configurations.mlir", "executable_sources.mlir", "iree-benchmark-executable.mlir", + "iree-compile-help.txt", "iree-benchmark-module.mlir", "iree-convert-parameters.txt", "iree-dump-parameters.txt", diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt index 9d515d8d94ff..d0689e63cbed 100644 --- a/tools/test/CMakeLists.txt +++ b/tools/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "benchmark_flags.txt" + "compile_flags.mlir" "compile_pipelines.mlir" "compile_to_continuation.mlir" "compile_to_phase.mlir" @@ -23,6 +24,7 @@ iree_lit_test_suite( "executable_sources.mlir" "iree-benchmark-executable.mlir" "iree-benchmark-module.mlir" + "iree-compile-help.txt" "iree-convert-parameters.txt" "iree-dump-parameters.txt" "iree-run-mlir.mlir" diff --git a/tools/test/compile_flags.mlir b/tools/test/compile_flags.mlir new file mode 100644 index 000000000000..57836a78268e --- /dev/null +++ b/tools/test/compile_flags.mlir @@ -0,0 +1,34 @@ +// RUN: iree-compile --compile-to=global-optimization %s \ +// RUN: | FileCheck %s --check-prefix=NO-STRIP-CHECK +// RUN: iree-compile --compile-to=global-optimization \ +// RUN: --iree-global-optimization-opt-level=O2 \ +// RUN: --iree-opt-strip-assertions=false %s \ +// RUN: | FileCheck %s --check-prefix=NO-STRIP-CHECK +// RUN: iree-compile --compile-to=global-optimization \ +// RUN: --iree-opt-level=O2 \ +// RUN: --iree-opt-strip-assertions=false %s \ +// RUN: | FileCheck %s --check-prefix=NO-STRIP-CHECK +// RUN: iree-compile --compile-to=global-optimization \ +// RUN: --iree-opt-level=O2 \ +// RUN: --iree-global-optimization-opt-level=O0 %s \ +// RUN: | FileCheck %s --check-prefix=NO-STRIP-CHECK +// +// RUN: iree-compile --compile-to=global-optimization \ +// RUN: --iree-opt-strip-assertions %s \ +// RUN: | FileCheck %s --check-prefix=STRIP-CHECK +// RUN: iree-compile --compile-to=global-optimization \ +// RUN: --iree-opt-level=O2 %s \ +// RUN: | FileCheck %s --check-prefix=STRIP-CHECK +// RUN: iree-compile --compile-to=global-optimization \ +// RUN: --iree-global-optimization-opt-level=O2 %s \ +// RUN: | FileCheck %s --check-prefix=STRIP-CHECK + +util.func public @main(%0 : i1){ + cf.assert %0, "assert" + util.return +} + +// NO-STRIP-CHECK: util.func +// NO-STRIP-CHECK: cf.assert +// STRIP-CHECK: util.func +// STRIP-CHECK-NOT: cf.assert diff --git a/tools/test/iree-compile-help.txt b/tools/test/iree-compile-help.txt new file mode 100644 index 000000000000..c6f7a4eee0ed --- /dev/null +++ b/tools/test/iree-compile-help.txt @@ -0,0 +1,10 @@ +// RUN: iree-compile --help | \ +// RUN: FileCheck %s + +// Basic checks of optimization level defaults and printing. + +// CHECK: --iree-opt-level= +// CHECK: --iree-global-optimization-opt-level= +// CHECK: --iree-opt-strip-assertions +// CHECK-NEXT: At optimization level O0 the default is false +// CHECK-NEXT: At optimization level O1 the default is true