Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocm jaxlib v0.4.28 triton #7

Open
wants to merge 6 commits into
base: rocm-jaxlib-v0.4.28
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions third_party/triton/temporary/amd_pr7.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
==== triton/BUILD#46 - /google/src/cloud/csigg/triton_amd/triton/BUILD ====
# action=edit type=text
--- triton/BUILD 2024-04-11 02:00:21.000000000 -0700
+++ triton/BUILD 2024-04-21 23:52:01.000000000 -0700
@@ -725,12 +725,12 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:GPUToROCDLTransforms",
"@llvm-project//mlir:IR",
- "@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
- "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:ROCDLDialect",
"@llvm-project//mlir:Transforms",
],
)
==== triton/third_party/amd/BUILD#None - /google/src/cloud/csigg/triton_amd/triton/third_party/amd/BUILD ====
# action=add type=text
diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD
new file mode 100644
index 0000000..ee4bc37
--- /dev/null
+++ b/third_party/amd/BUILD
@@ -0,0 +1,128 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+
+
+package(
+ # copybara:uncomment_begin
+ # default_applicable_licenses = ["//:license"],
+ # default_visibility = [
+ # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__",
+ # "//:__subpackages__",
+ # ],
+ # copybara:uncomment_end_and_comment_begin
+ default_visibility = ["//visibility:public"],
+ # copybara:comment_end
+)
+
+# TODO(csigg): fix, enable error upstream, remove.
+_no_unused_variable = select({
+ "//:compiler_is_msvc": [],
+ "//conditions:default": ["-Wno-unused-variable"],
+})
+
+cc_library(
+ name = "TritonAMDGPUTransforms",
+ srcs = glob([
+ "lib/TritonAMDGPUTransforms/*.h",
+ "lib/TritonAMDGPUTransforms/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/TritonAMDGPU/*.h",
+ "include/TritonAMDGPUTransforms/*.h",
+ "lib/TritonAMDGPUTransforms/*.h",
+ ]),
+ copts = _no_unused_variable,
+ includes = [
+ "..",
+ "include",
+ "lib/TritonAMDGPUTransforms",
+ ],
+ deps = [
+ ":triton_conversion_amdgpu_to_llvm_passes_inc_gen",
+ "@llvm-project//mlir:ConvertToLLVM",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMCommonConversion",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ "//:TritonAnalysis",
+ "//:TritonDialects",
+ "//:TritonGPUToLLVM",
+ "//:TritonGPUTransforms",
+ ],
+)
+
+cc_library(
+ name = "TritonAMDGPUToLLVM",
+ srcs = glob([
+ "lib/TritonAMDGPUToLLVM/**/*.h",
+ "lib/TritonAMDGPUToLLVM/**/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/TritonAMDGPUToLLVM/**/*.h",
+ ]) + [
+ "lib/TritonAMDGPUToLLVM/Utility.h",
+ ],
+ copts = _no_unused_variable,
+ includes = [
+ "..",
+ "include",
+ "lib/TritonAMDGPUToLLVM",
+ ],
+ deps = [
+ ":triton_transforms_amdgpu_to_llvm_passes_inc_gen",
+ "@llvm-project//mlir:ConvertToLLVM",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMCommonConversion",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ "//:TritonAnalysis",
+ "//:TritonDialects",
+ "//:TritonGPUToLLVM",
+ ":TritonAMDGPUTransforms",
+ ],
+)
+
+td_library(
+ name = "td_files",
+ srcs = glob(["include/**/*.td"]),
+ includes = ["include"],
+ deps = ["//:td_files"],
+)
+
+gentbl_cc_library(
+ name = "triton_transforms_amdgpu_to_llvm_passes_inc_gen",
+ tbl_outs = [
+ (
+ [
+ "--gen-pass-decls",
+ "--name=TritonAMDGPUToLLVM",
+ ],
+ "include/TritonAMDGPUToLLVM/Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/TritonAMDGPUToLLVM/Passes.td",
+ deps = [":td_files"],
+)
+
+
+gentbl_cc_library(
+ name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen",
+ tbl_outs = [
+ (
+ [
+ "--gen-pass-decls",
+ "--name=TritonAMDGPU",
+ ],
+ "include/TritonAMDGPUTransforms/Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/TritonAMDGPUTransforms/Passes.td",
+ deps = [":td_files"],
+)
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
index f59efd6..cf601f0 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -1132,6 +1132,21 @@ struct FpToFpOpConversion
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
}
+
+ bool isSrcFP16 = srcElementType.isF16();
+ bool isSrcBF16 = srcElementType.isBF16();
+
+ if ((isSrcFP16 || isSrcBF16)
+ && isDstFP32) {
+ SmallVector<Value> outVals;
+ for (Value &v : inVals) {
+ if(isSrcFP16)
+ outVals.push_back(convertFp16ToFp32(loc, rewriter, v));
+ else
+ outVals.push_back(convertBf16ToFp32(loc, rewriter, v));
+ }
+ return outVals;
+ }
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16NZ(loc, rewriter, v);
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ internal patch during the next triton integration process.

temporary_patch_list = [
"//third_party/triton/temporary:pipelining.patch",
"//third_party/triton/temporary:amd_pr7.patch",
]
1 change: 1 addition & 0 deletions third_party/tsl/third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ def _create_local_rocm_repository(repository_ctx):
"-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_AMD__",
"-DEIGEN_USE_HIP",
"-DUSE_ROCM",
])

rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
Expand Down
10 changes: 7 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,19 @@ cc_library(
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//:TritonGPUTransforms",
"@triton//:TritonLLVMIR",
]) + if_rocm_is_configured([
"@triton//:TritonGPUTransforms",
"@triton//:TritonLLVMIR",
"@triton//third_party/amd:TritonAMDGPUToLLVM",
"@triton//third_party/amd:TritonAMDGPUTransforms",
]),
)

xla_test(
name = "ir_emitter_triton_test",
srcs = if_cuda_is_configured(["ir_emitter_triton_test.cc"]),
srcs = if_gpu_is_configured(["ir_emitter_triton_test.cc"]),
backends = [
"gpu_a100",
"gpu_h100",
"gpu",
],
shard_count = 20,
tags = ["nomac"],
Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,11 @@ absl::StatusOr<bool> GemmFusion::Run(
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version_);
if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) {
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version_);

if ((!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere())
&& !rocm_compute_capability) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
}
Expand Down
9 changes: 7 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
gpu_target_config.device_description.gpu_compute_capability();
pipeline.AddPass<AlgorithmChecker>(gpu_version);
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
const auto* rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);

// Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8
// and may rewrite quantized FP8 GEMMs as higher-precision GEMMs.
Expand All @@ -1428,6 +1429,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
pipeline.AddPass<GemmFusion>(gpu_version);
}
if (debug_options.xla_gpu_enable_triton_gemm() && rocm_cc != nullptr) {
pipeline.AddPass<GemmFusion>(gpu_version);
}
// Rewrite non-FP8 GEMMs.
pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/false);

Expand All @@ -1449,8 +1453,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// ReductionDimensionGrouper, as that makes matching the softmax pattern
// harder.
if (debug_options.xla_gpu_enable_triton_softmax_fusion() &&
cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
((cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(simplifier_options);
pipeline.AddPass<SoftmaxRewriterTriton>(gpu_version);
}
Expand Down
37 changes: 22 additions & 15 deletions xla/service/gpu/ir_emitter_triton_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is
// included in build.
// #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
Expand All @@ -30,6 +30,8 @@ limitations under the License.
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h"

namespace xla {
namespace gpu {
Expand All @@ -54,9 +56,10 @@ absl::Status CreateTritonPipeline(
const int ccAsInt = 0;
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
const int threadsPerWarp = 32;
auto ccRocm = std::get<se::RocmComputeCapability>(cc);

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mt::createCombineOpsPass());
Expand All @@ -67,43 +70,47 @@ absl::Status CreateTritonPipeline(
pm.addPass(mlir::createSymbolDCEPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createCoalescePass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeThreadLocalityPass());
pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());
pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createPrefetchPass());

if(config.num_stages == 0 and ccRocm.has_mma_instr_support()) {
pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass());
pm.addPass(mlir::createCanonicalizerPass());
}
pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mlir::createTritonAMDGPUDecomposeConversionsPass());
if(config.num_stages == 0) {
pm.addPass(mt::gpu::createReorderInstructionsPass());
}
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
// pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
// pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertControlFlowToLLVMPass());

pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// There is no clusters in ROCm for now.
out_cluster_info.clusterDimX = 1;
out_cluster_info.clusterDimY = 1;
Expand Down
Loading
Loading