From 3197c810eb9fa50c97d2632e8d54f55615903b29 Mon Sep 17 00:00:00 2001 From: zl Date: Thu, 12 Dec 2024 19:26:52 +0800 Subject: [PATCH 1/2] FP8 groupwise scaling along M --- ...specialized_gemm_with_blockwise_scaling.cu | 2 +- ...specialized_gemm_with_groupwise_scaling.cu | 770 ++++++++++++++++++ .../CMakeLists.txt | 5 + .../host/gemm_with_groupwise_scaling.h | 507 ++++++++++++ .../collective/builders/sm90_gmma_builder.inl | 45 +- .../gemm/collective/fp8_accumulation.hpp | 24 +- ..._warpspecialized_fp8_blockwise_scaling.hpp | 105 ++- include/cutlass/gemm/dispatch_policy.hpp | 8 +- 8 files changed, 1418 insertions(+), 48 deletions(-) create mode 100644 examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu create mode 100644 examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index 0228f8b1df..b6dcc178f5 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -123,7 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu new file mode 100644 index 0000000000..74c85b82e6 --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -0,0 +1,770 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0. + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. + + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + + 4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for + A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor. + + 5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + + Examples: + + $ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \ + --m=2816 --n=3072 --k=16384 \ + --save_aux=false --save_amax=false \ + --device_scale=false --raster=h --swizzle=2 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +// Includes from examples directory +#include "helper.h" +#include "hopper_fp8_commandline.hpp" +#include "reference/host/gemm_with_groupwise_scaling.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Auxiliary matrix configuration and other fusion types +using ElementAux = ElementC; +using LayoutAux = LayoutC; +using ElementAmax = float; +using ElementBias = float; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementBlockScale = float; // Element type for blockscaling during accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster + +constexpr int ScaleMsPerTile = 2; +constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile; + +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + +using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopWithBlockWiseScaling, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ElementAmax = typename EpilogueOutputOp::ElementAmax; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; +using StrideAux = StrideD; + +constexpr bool IsDFp8 = + cute::is_same_v or + cute::is_same_v; + +constexpr bool IsAuxFp8 = + cute::is_same_v or + cute::is_same_v; + +static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, + "FP8 scaling granularity must evenly divide tile shape along M."); + +static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +StrideAux stride_aux; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +uint32_t mma_promotion_interval; +cutlass::HostTensor blockscale_tensor_A; +cutlass::HostTensor blockscale_tensor_B; +cutlass::HostTensor tensor_ref_D; +cutlass::HostTensor tensor_aux; +cutlass::HostTensor tensor_ref_aux; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; +cutlass::HostTensor scale_A; +cutlass::HostTensor scale_B; +cutlass::HostTensor scale_C; +cutlass::HostTensor scale_D; +cutlass::HostTensor scale_aux; +cutlass::HostTensor abs_max_D; +cutlass::HostTensor reference_abs_max_D; +cutlass::HostTensor abs_max_aux; +cutlass::HostTensor reference_abs_max_aux; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; + } + +/// Helper to initialize a block of device data (scale_tensors) + template + bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -1; + scope_max = 1; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; + } + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + // Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape + auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); + auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); + auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access. + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + stride_aux = stride_D; + + + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k); + auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l); + + tensor_A.resize(a_coord); + blockscale_tensor_A.resize(groupscale_a_coord); + tensor_B.resize(b_coord); + blockscale_tensor_B.resize(blockscale_b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + cutlass::Distribution::Kind dist_A = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind dist_B = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind dist_C = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind dist_scaleA = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind dist_scaleB = cutlass::Distribution::Uniform; + + initialize_tensor(tensor_A.host_view(), dist_A, seed + 2022); + initialize_tensor(tensor_B.host_view(), dist_B, seed + 2023); + initialize_tensor(tensor_C.host_view(), dist_C, seed + 2024); + initialize_scale_tensor(blockscale_tensor_A.host_view(), dist_scaleA, seed + 2025); + initialize_scale_tensor(blockscale_tensor_B.host_view(), dist_scaleB, seed + 2026); + +#if 0 // Dump blockscaled tensors + std::cout << "blockscale_tensor_A: " << groupscale_a_coord << std::endl; + std::cout << blockscale_tensor_A.host_view() << "\n"; + std::cout << "blockscale_tensor_B: " << blockscale_b_coord << std::endl; + std::cout << blockscale_tensor_B.host_view() << "\n"; +#endif + + // Print group scaling tensors on the host side. + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + blockscale_tensor_A.sync_device(); + blockscale_tensor_B.sync_device(); + + mma_promotion_interval = 4; + + if (options.save_aux) { + tensor_aux.resize(c_coord); + tensor_aux.sync_device(); + tensor_ref_aux.resize(c_coord); + } + + if (options.device_scale) { + scalar_alpha.resize(cutlass::make_Coord(1)); + scalar_beta.resize(cutlass::make_Coord(1)); + scale_A.resize(cutlass::make_Coord(1)); + scale_B.resize(cutlass::make_Coord(1)); + scale_C.resize(cutlass::make_Coord(1)); + scale_D.resize(cutlass::make_Coord(1)); + scale_aux.resize(cutlass::make_Coord(1)); + + cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha); + cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta); + cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a); + cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b); + cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c); + cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d); + cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux); + + scalar_alpha.sync_device(); + scalar_beta.sync_device(); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + scale_aux.sync_device(); + } + + if (IsDFp8 && options.save_amax) { + abs_max_D.resize(cutlass::make_Coord(1)); + abs_max_D.sync_device(); + reference_abs_max_D.resize(cutlass::make_Coord(1)); + } + + if (IsAuxFp8 && options.save_aux && options.save_amax) { + abs_max_aux.resize(cutlass::make_Coord(1)); + abs_max_aux.sync_device(); + reference_abs_max_aux.resize(cutlass::make_Coord(1)); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), + stride_A, + tensor_B.device_data(), + stride_B, + mma_promotion_interval, + blockscale_tensor_A.device_data(), + blockscale_tensor_B.device_data() + }, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + fusion_args.scale_a = options.scale_a; + fusion_args.scale_b = options.scale_b; + fusion_args.scale_c = options.scale_c; + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + + // ignored if tensor types are not fp8 + fusion_args.scale_d = options.scale_d; + fusion_args.scale_aux = options.scale_aux; + fusion_args.scale_d_ptr = scale_D.device_data(); + fusion_args.scale_aux_ptr = scale_aux.device_data(); + + // leaving/setting these as nullptr disables the fusion at runtime + fusion_args.bias_ptr = nullptr; + + if (options.save_aux) { + fusion_args.aux_ptr = tensor_aux.device_data(); + fusion_args.dAux = stride_aux; + if (options.save_amax) { + fusion_args.amax_aux_ptr = abs_max_aux.device_data(); + } + } + + if (options.save_amax) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape + auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); + auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); + auto blockscale_m = cute::get<0>(blockscale_shape); + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout( + cute::make_shape(options.m, options.k, options.l), + stride_A + ) + ); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout( + cute::make_shape(options.n, options.k, options.l), + stride_B + ) + ); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout( + cute::make_shape(options.m, options.n, options.l), + stride_C + ) + ); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout( + cute::make_shape(options.m, options.n, options.l), + stride_D + ) + ); + auto Aux = cute::make_tensor(tensor_ref_aux.host_data(), + cute::make_layout( + cute::make_shape(options.m, options.n, options.l), + stride_aux + ) + ); + + auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), + cute::make_layout( + cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l), + cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile) + ) + ); + auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), + cute::make_layout( + cute::make_shape(blockscale_n, blockscale_k, options.l), + cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k) + ) + ); + + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{ + A, B, // Operand Tensors + blockscale_A, blockscale_B // Groupwise scaling Tensors + }; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + decltype(Aux), + unused_t, // valpha + unused_t, // vbeta + ActivationFunctor + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.Aux = Aux; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + epilogue_params.scale_a = options.scale_a; + epilogue_params.scale_b = options.scale_b; + epilogue_params.scale_c = options.scale_c; + epilogue_params.scale_d = options.scale_d; + epilogue_params.scale_aux = options.scale_aux; + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data(); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + if (false) { + std::cout << "tensor_ref_D.host_view() {" << std::endl + << tensor_ref_D.host_view() << std::endl + << "}" << std::endl; + std::cout << "tensor_D.host_view() {" << std::endl + << tensor_D.host_view() << std::endl + << "}" << std::endl; + } + + if (IsDFp8 && options.save_amax) { + abs_max_D.sync_host(); + passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0)); + } + + if (options.save_aux) { + tensor_aux.sync_host(); + passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view()); + if (IsAuxFp8 && options.save_amax) { + abs_max_aux.sync_host(); + passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0)); + } + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + // if (!result.passed) { + // exit(-1); + // } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt index cdfd522c0f..3453ed409a 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt @@ -30,3 +30,8 @@ cutlass_example_add_executable( 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu ) + +cutlass_example_add_executable( + 67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling + 67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu + ) \ No newline at end of file diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h new file mode 100644 index 0000000000..efe4f8c46f --- /dev/null +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -0,0 +1,507 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" +#include +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorScaleA_, // (m, k, L) + class TensorScaleB_, // (n, k, L) + class TileShape_ +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + using TensorScaleA = TensorScaleA_; + using TensorScaleB = TensorScaleB_; + using TileShape = TileShape_; + using EngineScaleA = typename TensorScaleA::engine_type; + using EngineScaleB = typename TensorScaleB::engine_type; + + TensorA A{}; + TensorB B{}; + TensorScaleA ScaleA{}; + TensorScaleB ScaleB{}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = TensorD_, // (M, 1) + class TensorAux_ = TensorD_, // (M, N, L) + class VectorAlpha_ = TensorD_, // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); + static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); + // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); + // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + using ElementBlockScaleA = typename ElementTraits::type; + using ElementBlockScaleB = typename ElementTraits::type; + + using RingOp = multiply_add; + RingOp fma_op; + + multiplies scale_op; + + static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; + + // Tempo accumulators to seperate blockwise accumulation + typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + acc_temp[m_b][n_b] = ElementAccumulator(0); + } + } + + int64_t block_m = m / kBlockM; + int64_t block_n = n / kBlockN; + cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l); + cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l); + + const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape()); + assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape())); + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + + // Load Blockwise scaling factor from blockscale Tensors for B + int64_t block_k = k / kBlockK; + cute::Tensor scale_a = blockscale_A(_, block_k); + ElementBlockScaleB scale_b = blockscale_B[block_k]; + + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); + } + } + + // Apply Groupwise-scaling at kBlockK boundary + // (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary + // (b) Zero-out partial temporary (acc_temp), + // (c) Update permanent (accu) + if ((k+1) % kBlockK == 0) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a[m_b / ScaleGranularityM] * scale_b; + acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; + acc_temp[m_b][n_b] = ElementAccumulator(0); + } + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool IsClamp = + cute::is_same_v>; + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // per-row alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // per-row beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (IsClamp) { // Treat Clamp as ReLU + output = activation(output, {0, std::numeric_limits::max()}); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " + "with Batchmode are supported"); + // Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett). + Gett(mainloop_params, epilogue_params); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 68eba52e20..8b2452afd8 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -84,6 +84,28 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_co return (capacity_bytes - carveout_bytes) / stage_bytes; } +// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale. +template +constexpr int +compute_stage_count_with_blockwise_scale(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto scale_bits = cute::sizeof_bits_v; + constexpr int stage_bytes_ = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A + cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B + + constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) + + static_cast(mainloop_pipeline_bytes); + constexpr int carveout_bytes = cutlass::round_up(carveout_bytes_, alignment); + constexpr int capacity_bytes = capacity_bytes_ / alignment * alignment; + + return (capacity_bytes - carveout_bytes) / stage_bytes; +} + // Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. template constexpr int @@ -1009,7 +1031,7 @@ template < class TileShape_MNK, class ClusterShape_MNK, class StageCountType, - class KernelScheduleType + int ScaleGranularityM_ > struct CollectiveBuilder< arch::Sm90, @@ -1024,12 +1046,12 @@ struct CollectiveBuilder< TileShape_MNK, ClusterShape_MNK, StageCountType, - KernelScheduleType, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum, cute::enable_if_t< - (cute::is_any_of_v) && - not detail::is_use_rmem_A()> + not detail::is_use_rmem_A()> > { + using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED @@ -1048,6 +1070,7 @@ struct CollectiveBuilder< // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + using ElementBlockScale = ElementAccumulator; static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); @@ -1055,7 +1078,7 @@ struct CollectiveBuilder< static constexpr bool IsCooperative = cute::is_any_of_v; + KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>; using AtomLayoutMNK = cute::conditional_t>, Layout>>; @@ -1073,9 +1096,13 @@ struct CollectiveBuilder< static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = size<0>(TileShape_MNK{}) / ScaleGranularityM; + static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + + static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/include/cutlass/gemm/collective/fp8_accumulation.hpp b/include/cutlass/gemm/collective/fp8_accumulation.hpp index bca742c3c3..bd2a0cb280 100644 --- a/include/cutlass/gemm/collective/fp8_accumulation.hpp +++ b/include/cutlass/gemm/collective/fp8_accumulation.hpp @@ -75,12 +75,22 @@ struct GmmaFP8Accumulation { } // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + template < + class EngineScale, + class LayoutScale> CUTLASS_DEVICE - void scale_core(ElementAccumulator const& scale) { + void scale_core(const cute::Tensor &scale) { + using TensorScale = cute::Tensor; + + static_assert(is_static::value, "Scale Layout should be static"); + static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + warpgroup_wait<0>(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i) * scale; + accum_(i) += accum_temp_(i) * scale(i); } } @@ -142,8 +152,11 @@ struct GmmaFP8Accumulation { // /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> CUTLASS_DEVICE - void scale_if_needed(ElementAccumulator const& scale) { + void scale_if_needed(const cute::Tensor &scale) { mma_count_ += mma_count_per_mainloop_iteration_; reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { @@ -153,8 +166,11 @@ struct GmmaFP8Accumulation { } /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> CUTLASS_DEVICE - void scale_residue_if_needed(ElementAccumulator const& scale) { + void scale_residue_if_needed(const cute::Tensor &scale) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { scale_core(scale); } diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 8ba64b28cc..f95b4f9e17 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -57,6 +57,7 @@ template < int Stages, class ClusterShape, class KernelSchedule, + int ScaleGranularityM_, class TileShape_, class ElementA_, class StrideA_, @@ -72,7 +73,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, + MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, TileShape_, ElementA_, StrideA_, @@ -91,7 +92,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -118,6 +119,9 @@ struct CollectiveMma< // Two threads per CTA are producers (1 for operand tile and 1 for scales) static constexpr int NumProducerThreadEvents = 2; + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -126,6 +130,8 @@ struct CollectiveMma< static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, @@ -137,11 +143,13 @@ struct CollectiveMma< cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // Block scaling gmem-to-smem copy atom - using SmemBlockScalingCopyAtom = Copy_Atom, ElementBlockScale>; + using BlockScaleCopyTypeA = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>; + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; // Block scaling smem layout - using SmemLayoutScaleA = Layout>, Stride<_1>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); static_assert(cute::is_base_of::value && @@ -159,7 +167,7 @@ struct CollectiveMma< struct TensorStorage : cute::aligned_struct<128> { cute::array_aligned> smem_A; // mxk cute::array_aligned> smem_B; // nxk - cute::array_aligned> smem_scale_A; // 1xk + cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k cute::array_aligned> smem_scale_B; // 1xk } tensors; @@ -314,8 +322,8 @@ struct CollectiveMma< Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(get<2>(gA_mkl.shape()), get<3>(gA_mkl.shape()), get<4>(gA_mkl.shape())); // (m,k,l) - auto scale_dA = make_stride(get<3>(gA_mkl.shape()), Int<1>{}, get<2>(gA_mkl.shape()) * get<3>(gA_mkl.shape())); + auto scaleA_shape = make_shape(get<2>(gA_mkl.shape()), Int{}, get<3>(gA_mkl.shape()), get<4>(gA_mkl.shape())); // (m,ScaleMsPerTile,k,l) + auto scale_dA = make_stride(get<3>(gA_mkl.shape()) * Int{}, Int<1>{}, Int{}, get<2>(gA_mkl.shape()) * get<3>(gA_mkl.shape()) * Int{}); auto scaleA_layout = make_layout(scaleA_shape, scale_dA); auto scaleB_shape = make_shape(get<2>(gB_nkl.shape()), get<3>(gB_nkl.shape()), get<4>(gB_nkl.shape())); // (n,k,l) auto scale_dB = make_stride(get<3>(gB_nkl.shape()), Int<1>{}, get<2>(gB_nkl.shape()) * get<3>(gB_nkl.shape())); @@ -323,7 +331,7 @@ struct CollectiveMma< // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,k,l) + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l) Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); @@ -353,7 +361,7 @@ struct CollectiveMma< if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (k) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) // @@ -379,17 +387,19 @@ struct CollectiveMma< Tensor mScaleA_mkl = get<2>(load_inputs); Tensor mScaleB_nkl = get<3>(load_inputs); - Tensor gScaleA = mScaleA_mkl(m_coord,_,l_coord); // (1,k,1) + Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1) Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) - TiledCopy scale_copy = make_tiled_copy(SmemBlockScalingCopyAtom{}, Layout>{}, Layout>{}); // (1,1,1) - ThrCopy thr_scale_copy = scale_copy.get_slice(threadIdx.x); + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout>{}, Layout>>{}); // (1,ScaleMsPerTile,1) + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout>{}, Layout>{}); // (1,1,1) + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); - Tensor tAgA_ScaleA = thr_scale_copy.partition_S(gScaleA); - Tensor tAsA_ScaleA = thr_scale_copy.partition_D(sScaleA); + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); - Tensor tBgB_ScaleB = thr_scale_copy.partition_S(gScaleB); - Tensor tBsB_ScaleB = thr_scale_copy.partition_D(sScaleB); + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -435,8 +445,8 @@ struct CollectiveMma< copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); // Copy scale tensors from global memory to shared memory - copy(scale_copy, tAgA_ScaleA(_,*k_tile_iter), tAsA_ScaleA(_,write_stage)); - copy(scale_copy, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); + copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); ++k_tile_iter; @@ -493,7 +503,11 @@ struct CollectiveMma< Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // Block scaling - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (k) + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) // @@ -517,6 +531,8 @@ struct CollectiveMma< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -541,9 +557,12 @@ struct CollectiveMma< PipelineState smem_pipe_release = smem_pipe_read; // Per block scale values for operand A and B - ElementBlockScale scale_a; + + using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above + + Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) ElementBlockScale scale_b; - ElementBlockScale scale; // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); @@ -566,9 +585,20 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers. - scale_a = sScaleA[read_stage]; scale_b = sScaleB[read_stage]; - scale = __shfl_sync(0xffffffff, scale_a * scale_b, 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } warpgroup_arrive(); // Unroll the K mode manually to set scale D to 1 @@ -580,8 +610,8 @@ struct CollectiveMma< } warpgroup_commit_batch(); - // Block scale the accumulators with `scale` value - accumulation.scale_if_needed(scale); + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); ++smem_pipe_read; } @@ -603,10 +633,21 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); - // Load per block scale values from shared memory to registers (once per block) - scale_a = sScaleA[read_stage]; + // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) scale_b = sScaleB[read_stage]; - scale = __shfl_sync(0xffffffff, scale_a * scale_b, 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } if (accumulation.prepare_if_needed()) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; @@ -627,8 +668,8 @@ struct CollectiveMma< warpgroup_wait(); warpgroup_fence_operand(accumulation()); - // Block scale the accumulators with `scale` value - accumulation.scale_if_needed(scale); + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -637,7 +678,7 @@ struct CollectiveMma< ++smem_pipe_release; } - accumulation.scale_residue_if_needed(scale); + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); warpgroup_fence_operand(accumulation()); } diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 6c98624367..cbddfadd82 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -114,6 +114,9 @@ struct KernelPtrArrayTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedPingpong { }; // FP8 related policies (including Blocked Scaled Accumulation) +template< + int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M. +> struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { }; // Policies to opt into mixed type GEMMs @@ -296,12 +299,13 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8 template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecialized + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M. > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8 : MainloopSm90TmaGmmaWarpSpecialized { static_assert( - cute::is_same_v, + cute::is_same_v>, "KernelSchedule must be one of the warp specialized policies"); }; From b4abc3f0cd76c1962055d76a2d4be406aab7df72 Mon Sep 17 00:00:00 2001 From: Haicheng Wu Date: Fri, 31 Jan 2025 07:29:48 -0800 Subject: [PATCH 2/2] small updates --- ...hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu | 4 ++-- .../reference/host/gemm_with_groupwise_scaling.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index 74c85b82e6..d6de7f89b7 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -737,7 +737,7 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9) { std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture or " << "later (compute capability 90 or greater).\n"; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h index efe4f8c46f..652f72c1be 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without