Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a sketch of the instr. sched. group barriers #622

Draft
wants to merge 1 commit into
base: sjw-pipeline-infra
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,11 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
}];
}

def TTG_GroupSched : TTG_Op<"group_sched"> {
let summary = "A placeholder Op for the instruction group scheduling";
let description = [{
A placeholder Op for the instruction group scheduling.
}];
}

#endif
4 changes: 4 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ std::string translateLLVMIRToASM(llvm::Module &module,
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
opt.TrapUnreachable = true;

opt.MCOptions.AsmVerbose = true;
opt.MCOptions.PreserveAsmComments = true;

std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt,
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def make_llir(src, metadata, options):
passes.convert.add_index_to_llvmir(pm)

passes.ttgpuir.add_allocate_shared_memory(pm)

amd.passes.ttgpuir.insert_sched_group_barriers(pm)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
Expand All @@ -197,6 +199,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.add_sched_group_barriers(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createSchedGroupBarriersInsertionPass();
std::unique_ptr<OperationPass<ModuleOp>> createSchedGroupBarriersLoweringPass();

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
14 changes: 14 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,18 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

}

def SchedGroupBarriersInsertion : Pass<"insert-sched-group-barriers", "mlir::ModuleOp"> {
let summary = "Insert Scheduling Group Barriers";
let constructor = "mlir::triton::createSchedGroupBarriersInsertionPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

def SchedGroupBarriersLowering : Pass<"lower-sched-group-barriers", "mlir::ModuleOp"> {
let summary = "Lower Scheduling Group Barriers to LLVM intrinsics";
let constructor = "mlir::triton::createSchedGroupBarriersLoweringPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

#endif
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_triton_library(TritonAMDGPUToLLVM
OptimizeLDSUsage.cpp
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp
SchedInstructions.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen
Expand Down
207 changes: 207 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#include "TritonAMDGPUToLLVM/Passes.h"

#include "TritonAMDGPUTransforms/MfmaGroup.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_SCHEDGROUPBARRIERSINSERTION
#define GEN_PASS_DEF_SCHEDGROUPBARRIERSLOWERING
#include "TritonAMDGPUToLLVM/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;

namespace {
enum class InstructionMaskEnum : int64_t {
NONE = 0x0000000,
VALU = 0x00000002,
SALU = 0x00000004,
MFMA = 0x00000008,
ALL_VMEM = 0x00000010,
VMEM_READ = 0x00000020,
VMEM_WRITE = 0x00000040,
ALL_DS = 0x00000080,
DS_READ = 0x00000100,
DS_WRITE = 0x00000200
};

const bool modifyScheduling{false};
// const bool modifyScheduling{true};

void buildSchedGroupBarrier(PatternRewriter &builder,
InstructionMaskEnum maskValue, int sizeValue,
int groupIdValue) {
MLIRContext *ctx = builder.getContext();
Location loc = builder.getUnknownLoc();
auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.sched.group.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = builder.getI32Type();
auto mask = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(i32, static_cast<int64_t>(maskValue)));
auto size = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(i32, sizeValue));
auto groupId = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(i32, groupIdValue));
builder.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
}

Operation *generatedSchedBarrier(PatternRewriter &rewriter,
InstructionMaskEnum maskValue) {
MLIRContext *ctx = rewriter.getContext();
Location loc = rewriter.getUnknownLoc();
auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.sched.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = rewriter.getI32Type();
auto mask = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, static_cast<int64_t>(maskValue)));
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask}, defaultFlags);
}

struct SchedGroupBarriersRewriter
: public OpRewritePattern<triton::gpu::GroupSched> {
using OpRewritePattern<triton::gpu::GroupSched>::OpRewritePattern;
LogicalResult matchAndRewrite(triton::gpu::GroupSched schedBarrier,
PatternRewriter &rewriter) const override {

Block *block = schedBarrier->getBlock();

size_t numGlbLoads = 0;
block->walk([&numGlbLoads](LLVM::CallOp callOp) {
StringRef calleeName = callOp.getCallee().value();
if (calleeName.contains("__predicated_load_vector"))
++numGlbLoads;
});

size_t numDsReads = 0;
block->walk([&numDsReads](LLVM::LoadOp op) {
auto operandType = op.getOperand().getType();
if (auto ptr = llvm::dyn_cast<LLVM::LLVMPointerType>(operandType))
if (ptr.getAddressSpace() == 3)
++numDsReads;
});

size_t numDsWrites = 0;
block->walk([&numDsWrites](LLVM::StoreOp op) {
auto operandType = op.getOperand(1).getType();
if (auto ptr = llvm::dyn_cast<LLVM::LLVMPointerType>(operandType))
if (ptr.getAddressSpace() == 3)
++numDsWrites;
});

size_t numMfmas = 0;
block->walk([&numMfmas](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName.contains("mfma"))
++numMfmas;
});

llvm::dbgs() << "group scheduling info: ["
<< "numGlbLoads: " << numGlbLoads << ", "
<< "numDsReads: " << numDsReads << ", "
<< "numDsWrites: " << numDsWrites << ", "
<< "numMfmas: " << numMfmas << "]\n";

size_t barrierCounter{0};
block->walk([&barrierCounter, &rewriter](ROCDL::BarrierOp op) {
if (barrierCounter == 1) {
rewriter.setInsertionPointAfter(op);
return WalkResult::interrupt();
}
++barrierCounter;
return WalkResult::advance();
});

// rewriter.setInsertionPointToStart(block);
auto op = generatedSchedBarrier(rewriter, InstructionMaskEnum::NONE);

rewriter.setInsertionPointAfter(schedBarrier);
const size_t numIssues = numGlbLoads;
for (size_t i = 0; i < numIssues; ++i) {
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::DS_READ,
numDsReads / numIssues, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::DS_WRITE,
numDsWrites / numIssues, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA,
(numMfmas / numIssues) - 3, 0);
}
op = generatedSchedBarrier(rewriter, InstructionMaskEnum::NONE);
rewriter.eraseOp(schedBarrier);
return mlir::success();
}
};

struct SchedGroupBarriersLowering
: public triton::impl::SchedGroupBarriersLoweringBase<
SchedGroupBarriersLowering> {

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

if (!modifyScheduling)
return;

ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalOp<triton::gpu::GroupSched>();

RewritePatternSet patterns(ctx);
patterns.add<SchedGroupBarriersRewriter>(ctx);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
}
};

struct SchedGroupBarriersInsertion
: public triton::impl::SchedGroupBarriersInsertionBase<
SchedGroupBarriersInsertion> {

void insertPlaceholder(mlir::OpBuilder &builder, triton::DotOp dot) {
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(dot);
Location loc = builder.getUnknownLoc();
builder.create<triton::gpu::GroupSched>(loc);
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

if (!modifyScheduling)
return;

mlir::OpBuilder builder(ctx);
mod.walk(
[this, &builder](triton::DotOp op) { insertPlaceholder(builder, op); });
}
};
} // namespace

namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createSchedGroupBarriersLoweringPass() {
return std::make_unique<SchedGroupBarriersLowering>();
}

std::unique_ptr<OperationPass<ModuleOp>>
createSchedGroupBarriersInsertionPass() {
return std::make_unique<SchedGroupBarriersInsertion>();
}
} // namespace triton
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TritonLLVMConversionTarget : public ConversionTarget {
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
addLegalOp<triton::gpu::GroupSched>();
}
};

Expand Down
6 changes: 6 additions & 0 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ const char *const amdTargetTriple = "amdgcn-amd-amdhsa";

void init_triton_amd_passes_ttgpuir(py::module &&m) {
using namespace mlir::triton;
m.def("insert_sched_group_barriers", [](mlir::PassManager &pm) {
pm.addPass(createSchedGroupBarriersInsertionPass());
});
m.def("add_to_llvmir",
[](mlir::PassManager &pm, const std::string &arch, bool ftz) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
});
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(createConvertBuiltinFuncToLLVMPass());
});
m.def("add_sched_group_barriers", [](mlir::PassManager &pm) {
pm.addPass(createSchedGroupBarriersLoweringPass());
});
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm,
const std::string &arch) {
pm.addPass(
Expand Down