Skip to content

Commit

Permalink
[AMD] Improve instruction scheduling hints for more targets (triton-l…
Browse files Browse the repository at this point in the history
…ang#5059)

This PR implements in-source machine description database;
thus, it adapts our custom instruction scheduling for several
AMD architectures (i.e., MI200 and MI300). If necessary, it can be
further extended for Navi2 and Navi3 cards.

Also this PR disables `load/store` optimization in the AMDGPU compiler
backend when custom instruction scheduling is applied (i.e., when
`buffer_loads` and `ck_v3` software pipelining are used). This prevents
overestimating `ds_read`/`ds_write` instruction counts at the MLIR
level. 

Additionally, the `default` instruction scheduling variant was renamed
to `none`.
  • Loading branch information
ravil-mobile authored Nov 7, 2024
1 parent e66850b commit 9378d8f
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 52 deletions.
7 changes: 4 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class HIPOptions:
backend_name: str = 'hip'

# The following option provides hints to the AMDGPU backend regarding instruction scheduling
# for all `tt.dot` operations in a kernel. The "default" variant preserves the default
# for all `tt.dot` operations in a kernel. The "none" variant preserves the default
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
# The option is experimental and may change at any time regarding its semantics and/or may
# be gone entirely anytime.
instruction_sched_variant: str = 'default'
instruction_sched_variant: str = 'none'

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
Expand Down Expand Up @@ -274,7 +274,8 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.num_stages, options.instruction_sched_variant)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages,
options.instruction_sched_variant)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ createConvertBuiltinFuncToLLVMPass(bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPUInsertInstructionSchedHintsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages,
std::string variant);
createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch,
int32_t numStages,
StringRef variant);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
6 changes: 4 additions & 2 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,18 @@ def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruc

def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*numStages=*/2, /*variant=*/\"\")";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2, /*variant=*/\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"arch", "arch", "std::string", /*default*/"\"\"",
"gfx target device architecture, e.g., gfx942">,
Option<"numStages", "num_stages", "int32_t", /*default*/"2",
"number of pipeline stages">,
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
Option<"variant", "variant", "std::string", /*default*/"\"none\"",
"instruction scheduling variant">,
];
}
Expand Down
199 changes: 157 additions & 42 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "SchedInstructions.h"
#include "TritonAMDGPUToLLVM/Passes.h"
#include "TritonAMDGPUToLLVM/TargetUtils.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "llvm/TargetParser/TargetParser.h"

namespace mlir::triton {
#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS
Expand Down Expand Up @@ -146,17 +148,81 @@ Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
return rewriter.create<ROCDL::IglpOpt>(loc, iglpValue);
}

// The following structs represent in-source database regarding a target
// machine. It provides instructions execution and issue cycles needed for
// scheduling.
struct MachineDescr {
virtual ~MachineDescr() = default;
virtual uint32_t getDsReadIssueCycle(uint32_t instrWidth) = 0;
virtual FailureOr<uint32_t> getMmaExecCycle(llvm::ArrayRef<int64_t> dims) = 0;
virtual uint32_t getMmaIssueCycle() = 0;
virtual uint32_t getNumLdsDataPaths() = 0;
static std::unique_ptr<MachineDescr> get(StringRef arch);
};

template <typename Derived> struct MachineDescrImpl : MachineDescr {
uint32_t getDsReadIssueCycle(uint32_t instrWidth) final {
return instrWidth == 16 ? 8 : 4;
}

FailureOr<uint32_t> getMmaExecCycle(llvm::ArrayRef<int64_t> dims) final {
if (dims.size() != 3)
return failure();
auto it =
Derived::mmaTable.find(std::make_tuple(dims[0], dims[1], dims[2]));
if (it != Derived::mmaTable.end())
return it->second;
return failure();
}

uint32_t getMmaIssueCycle() final { return Derived::mmaIssueCycle; };
uint32_t getNumLdsDataPaths() final { return Derived::numLdsDataPaths; }

using MmaTable =
llvm::DenseMap<std::tuple<int64_t, int64_t, int64_t>, uint32_t>;
};

struct CDNA2Kind : public MachineDescrImpl<CDNA2Kind> {
static const inline MmaTable mmaTable{{{32, 32, 8}, 64}, {{16, 16, 16}, 32}};
static const inline uint32_t mmaIssueCycle{4};
static const inline uint32_t numLdsDataPaths{2};
};

struct CDNA3Kind : public MachineDescrImpl<CDNA3Kind> {
static const inline MmaTable mmaTable{{{32, 32, 8}, 32}, {{16, 16, 16}, 16}};
static const inline uint32_t mmaIssueCycle{4};
static const inline uint32_t numLdsDataPaths{2};
};

std::unique_ptr<MachineDescr> MachineDescr::get(StringRef arch) {
AMD::ISAFamily family = AMD::deduceISAFamily(arch);
switch (family) {
case AMD::ISAFamily::CDNA3: {
return std::make_unique<MachineDescrImpl<CDNA3Kind>>();
}
case AMD::ISAFamily::CDNA2: {
return std::make_unique<MachineDescrImpl<CDNA2Kind>>();
}
default: {
return nullptr;
}
}
return nullptr;
}

struct InstructionSchedHintsRewriter
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> {

InstructionSchedHintsRewriter(MLIRContext *ctx, int32_t numStages,
std::string variant)
InstructionSchedHintsRewriter(MLIRContext *ctx, StringRef arch,
int32_t numStages, std::string variant)
: OpRewritePattern(ctx), numStages(numStages) {

this->machineDescr = MachineDescr::get(arch);
std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });

this->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
.Case("default", SchedulingType::NONE)
.Case("none", SchedulingType::NONE)
.Case("iglp0", SchedulingType::IGLP0)
.Case("iglp1", SchedulingType::IGLP1)
.Case("ck_v3", SchedulingType::CK_V3)
Expand Down Expand Up @@ -194,6 +260,11 @@ struct InstructionSchedHintsRewriter
return;
}

if (!machineDescr) {
schedHint.emitError("unknown target architecture detected");
return;
}

const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue();
const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue();

Expand All @@ -205,38 +276,53 @@ struct InstructionSchedHintsRewriter
const uint32_t numBufferLoadInstB =
schedHint.getNumGlobalLoadsB().getValue();

if (numBufferLoadInstA == 0)
if (numBufferLoadInstA == 0) {
schedHint.emitError("buffer load count for tile A must be initialized");
return;
}

if (numBufferLoadInstB == 0)
if (numBufferLoadInstB == 0) {
schedHint.emitError("buffer load count for tile B must be initialized");
return;
}

const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue();
const uint32_t numMmaInst = schedHint.getNumMMAs().getValue();

auto mfmaType = cast<RankedTensorType>(schedHint.getNumMMAs().getType());
const uint32_t nPerXDL = mfmaType.getShape()[1];
const uint32_t mfmaCycle = nPerXDL == 16 ? 16 : 32;
auto mmaType = cast<RankedTensorType>(schedHint.getNumMMAs().getType());
auto maybeMmaExecCycle = machineDescr->getMmaExecCycle(mmaType.getShape());
if (llvm::failed(maybeMmaExecCycle)) {
schedHint.emitError("unknown mma instruction type");
return;
}
const uint32_t mmaExecCycle = maybeMmaExecCycle.value();

auto dsReadsAType = cast<VectorType>(schedHint.getNumDsReadsA().getType());
auto dsReadsBType = cast<VectorType>(schedHint.getNumDsReadsB().getType());

const uint32_t dsReadAIssueCycle = dsReadsAType.getShape()[0] == 16 ? 8 : 4;
const uint32_t dsReadBIssueCycle = dsReadsBType.getShape()[0] == 16 ? 8 : 4;
const uint32_t dsReadAIssueCycle =
machineDescr->getDsReadIssueCycle(dsReadsAType.getShape()[0]);
const uint32_t dsReadBIssueCycle =
machineDescr->getDsReadIssueCycle(dsReadsBType.getShape()[0]);

const auto dsReadAMfmaRate =
(mfmaCycle - 4 + 2 * dsReadAIssueCycle - 1) / (2 * dsReadAIssueCycle);
const auto dsReadBMfmaRate =
(mfmaCycle - 4 + 2 * dsReadBIssueCycle - 1) / (2 * dsReadBIssueCycle);
const uint32_t mmaIssueCycle = this->machineDescr->getMmaIssueCycle();
const uint32_t numLdsDataPaths = this->machineDescr->getNumLdsDataPaths();

const auto numDsreadAMfma =
(numDsReadInstA + dsReadAMfmaRate - 1) / dsReadAMfmaRate;
const auto numDsreadBMfma =
(numDsReadInstB + dsReadBMfmaRate - 1) / dsReadBMfmaRate;
const auto dsReadAMmaRate = (mmaExecCycle - mmaIssueCycle +
numLdsDataPaths * dsReadAIssueCycle - 1) /
(numLdsDataPaths * dsReadAIssueCycle);
const auto dsReadBMmaRate = (mmaExecCycle - mmaIssueCycle +
numLdsDataPaths * dsReadBIssueCycle - 1) /
(numLdsDataPaths * dsReadBIssueCycle);

const auto numDsreadAMma =
(numDsReadInstA + dsReadAMmaRate - 1) / dsReadAMmaRate;
const auto numDsreadBMma =
(numDsReadInstB + dsReadBMmaRate - 1) / dsReadBMmaRate;

// stage 1
const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma);
const auto numMfmaPerIssue =
numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB);
const auto numMmaStage1 = numMmaInst - (numDsreadAMma + numDsreadBMma);
const auto numMmaPerIssue =
numMmaStage1 / (numBufferLoadInstA + numBufferLoadInstB);

const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA;
const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB;
Expand All @@ -254,7 +340,7 @@ struct InstructionSchedHintsRewriter
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0);
createSchedGroupBarrier(rewriter, loc,
mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma,
numMfmaPerIssue - numDswritePerIssueA, 0);
numMmaPerIssue - numDswritePerIssueA, 0);
}

for (size_t i = 0; i < numBufferLoadInstB; ++i) {
Expand All @@ -270,37 +356,63 @@ struct InstructionSchedHintsRewriter
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0);
createSchedGroupBarrier(rewriter, loc,
mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma,
numMfmaPerIssue - numDswritePerIssueB, 0);
numMmaPerIssue - numDswritePerIssueB, 0);
}

// stage 2
for (size_t i = 0; i < numDsreadAMfma; ++i) {
if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) {
for (size_t i = 0; i < numDsreadAMma; ++i) {
if ((numDsReadInstA - (i + 1) * dsReadAMmaRate) >= dsReadAMmaRate) {
createSchedGroupBarrier(rewriter, loc,
mlir::amdgpu::sched_barrier_opt_enum::ds_read,
dsReadAMfmaRate, 0);
dsReadAMmaRate, 0);
} else {
createSchedGroupBarrier(
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read,
numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0);
numDsReadInstA - (numDsreadAMma - 1) * dsReadAMmaRate, 0);
}
createSchedGroupBarrier(
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0);
}

for (size_t i = 0; i < numDsreadBMfma; ++i) {
if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) {
for (size_t i = 0; i < numDsreadBMma; ++i) {
if ((numDsReadInstB - (i + 1) * dsReadBMmaRate) >= dsReadBMmaRate) {
createSchedGroupBarrier(rewriter, loc,
mlir::amdgpu::sched_barrier_opt_enum::ds_read,
dsReadBMfmaRate, 0);
dsReadBMmaRate, 0);
} else {
createSchedGroupBarrier(
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read,
numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0);
numDsReadInstB - (numDsreadBMma - 1) * dsReadBMmaRate, 0);
}
createSchedGroupBarrier(
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0);
}

// The AMDGPU compiler backend can fold consecutive `ds_read/ds_write`
// instructions into wider variants as a part of its load/store optimization
// during the instruction selection pass. If it happens, then it means that
// we are overestimated these types of instructions at the current level of
// the IR. In this scenario, the inserted `sched.group.barriers` will result
// in "fooling" the scheduling solver which can mess up the final assembly.
// To avoid this, we switch off the backend load/store folding optimization
// which is going to prevent instructions folding. In this case, the
// instruction widths of `ds_read/ds_write` instructions are going to match
// their LLVM representations. This is implemented as follows.

// TODO: The current implementation disables `ds_read/ds_write` folding for
// all basic blocks in the currently processed function. We should try to
// avoid it. The compiler backend team proposed to play we the load/store
// alignment values within the currently processed basic block as an
// alternative solution.
auto funcOp = schedHint->getParentOfType<LLVM::LLVMFuncOp>();
MLIRContext *ctx = schedHint->getContext();
llvm::SmallVector<StringAttr> targetFeatures;
if (auto attr = funcOp.getTargetFeatures()) {
llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures));
}
targetFeatures.push_back(str_attr("-load-store-opt"));
funcOp.setTargetFeaturesAttr(
::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures));
}

LogicalResult
Expand Down Expand Up @@ -364,16 +476,19 @@ struct InstructionSchedHintsRewriter
private:
int32_t numStages;
SchedulingType schedulingType;
std::unique_ptr<MachineDescr> machineDescr;
};

struct TritonAMDGPULowerInstructionSchedHints
: public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase<
TritonAMDGPULowerInstructionSchedHints> {

explicit TritonAMDGPULowerInstructionSchedHints(int32_t numStages,
std::string variant) {
explicit TritonAMDGPULowerInstructionSchedHints(StringRef arch,
int32_t numStages,
StringRef variant) {
this->arch = std::move(arch.str());
this->numStages = numStages;
this->variant = variant;
this->variant = std::move(variant.str());
}

void runOnOperation() override {
Expand All @@ -389,9 +504,8 @@ struct TritonAMDGPULowerInstructionSchedHints

RewritePatternSet patterns(ctx);

patterns.add<InstructionSchedHintsRewriter>(ctx, this->numStages,

this->variant);
patterns.add<InstructionSchedHintsRewriter>(ctx, this->arch,
this->numStages, this->variant);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand Down Expand Up @@ -425,10 +539,11 @@ struct TritonAMDGPUInsertInstructionSchedHints

namespace mlir::triton {
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages,
std::string variant) {
return std::make_unique<TritonAMDGPULowerInstructionSchedHints>(numStages,
variant);
createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch,
int32_t numStages,
StringRef variant) {
return std::make_unique<TritonAMDGPULowerInstructionSchedHints>(
arch, numStages, variant);
}

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
7 changes: 4 additions & 3 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass());
});
m.def("lower_instruction_sched_hints",
[](mlir::PassManager &pm, int32_t numStages, std::string variant) {
pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(numStages,
variant));
[](mlir::PassManager &pm, const std::string &arch, int32_t numStages,
const std::string &variant) {
pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(
arch, numStages, variant));
});
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm,
const std::string &arch) {
Expand Down

0 comments on commit 9378d8f

Please sign in to comment.