Skip to content

Commit

Permalink
PR openxla#6599: Fp8 Fast Accumulation support for cublasLt
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla#6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla#6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <[email protected]>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <[email protected]>:

Improve based on review #1

--
e906d76 by shuw <[email protected]>:

Improve based on review #2

Merging this change closes openxla#6599

COPYBARA_INTEGRATE_REVIEW=openxla#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578948593
  • Loading branch information
wenscarl authored and copybara-github committed Nov 2, 2023
1 parent 98fee3e commit b716639
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
32 changes: 32 additions & 0 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6940,6 +6940,38 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
#endif // CUDA_VERSION < 12000
const char* hlo_template = R"(
HloModule test
ENTRY test {
x = f8e4m3fn[1600,3200] parameter(0)
y = f8e4m3fn[3200,1600] parameter(1)
x_f32 = f32[1600,3200] convert(x)
y_f32 = f32[3200,1600] convert(y)
x_scale = f32[] parameter(2)
y_scale = f32[] parameter(3)
x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
}
)";

absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
replacements["<<precision>>"] = "default";
const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));

replacements["<<precision>>"] = "highest";
const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
Expand Down
13 changes: 11 additions & 2 deletions xla/stream_executor/cuda/cuda_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
/*static*/ tsl::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
blas::ComputationType compute_type, blas::DataType scale_type,
blas::Transpose trans_a, blas::Transpose trans_b,
gpu::BlasLt::Epilogue epilogue, PointerMode pointer_mode) {
gpu::BlasLt::Epilogue epilogue, bool enable_fast_accum,
PointerMode pointer_mode) {
VLOG(2) << "MatmulDesc::Create: compute_type: " << (int)compute_type
<< " scale:" << (int)scale_type << " trans a/b: " << (int)trans_a
<< "," << (int)trans_b << " epilogue:" << (int)epilogue
Expand All @@ -210,6 +211,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
AsCublasOperation(trans_b)));
TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
static_cast<int8_t>(enable_fast_accum)));
return std::move(desc);
}

Expand Down Expand Up @@ -315,11 +318,17 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg,
cfg.compute_precision));
}

// FP8 matmuls have a fast accumulation mode that is less precise than the
// default accumulation mode. Use the fast accumulation mode if the compute
// precision is DEFAULT.
bool enable_fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
xla::primitive_util::IsF8Type(rhs_layout.dtype)) &&
cfg.compute_precision == 0;
TF_ASSIGN_OR_RETURN(
auto op_desc,
MatmulDesc::Create(*compute_type,
gpu::GetScaleType(output_dtype, *compute_type),
trans_a, trans_b, epilogue));
trans_a, trans_b, epilogue, enable_fast_accum));

TF_ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout));
TF_ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout));
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BlasLt : public gpu::BlasLt {
blas::ComputationType compute_type, blas::DataType scale_type,
blas::Transpose trans_a = blas::Transpose::kNoTranspose,
blas::Transpose trans_b = blas::Transpose::kNoTranspose,
Epilogue epilogue = Epilogue::kDefault,
Epilogue epilogue = Epilogue::kDefault, bool enable_fast_accum = false,
PointerMode pointer_mode = PointerMode::kHost);

cublasComputeType_t compute_type() const;
Expand Down

0 comments on commit b716639

Please sign in to comment.