Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groupwise scaling along M for FP8 gemm #2037

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;

using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
)

cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
)

Large diffs are not rendered by default.

45 changes: 36 additions & 9 deletions include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
return (capacity_bytes - carveout_bytes) / stage_bytes;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int carveout_bytes_, int alignment = 128>
constexpr int
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
constexpr auto scale_bits = cute::sizeof_bits_v<ElementBlockScale>;
constexpr int stage_bytes_ =
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A
cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B

constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) +
static_cast<int>(mainloop_pipeline_bytes);
constexpr int carveout_bytes = cutlass::round_up(carveout_bytes_, alignment);
constexpr int capacity_bytes = capacity_bytes_ / alignment * alignment;

return (capacity_bytes - carveout_bytes) / stage_bytes;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
constexpr int
Expand Down Expand Up @@ -1009,7 +1031,7 @@ template <
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType
int ScaleGranularityM_
>
struct CollectiveBuilder<
arch::Sm90,
Expand All @@ -1024,12 +1046,12 @@ struct CollectiveBuilder<
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>,
cute::enable_if_t<
(cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>;

static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
Expand All @@ -1048,14 +1070,15 @@ struct CollectiveBuilder<
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
using ElementBlockScale = ElementAccumulator;

static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();

static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>;
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

Expand All @@ -1073,9 +1096,13 @@ struct CollectiveBuilder<
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);

static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_;
static constexpr int ScaleMsPerTile = size<0>(TileShape_MNK{}) / ScaleGranularityM;
static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");

static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_>;

using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
Expand Down
24 changes: 20 additions & 4 deletions include/cutlass/gemm/collective/fp8_accumulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,22 @@ struct GmmaFP8Accumulation {
}

// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_core(ElementAccumulator const& scale) {
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;

static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");

static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");

warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale;
accum_(i) += accum_temp_(i) * scale(i);
}
}

Expand Down Expand Up @@ -142,8 +152,11 @@ struct GmmaFP8Accumulation {
//

/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_if_needed(ElementAccumulator const& scale) {
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
mma_count_ += mma_count_per_mainloop_iteration_;
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
if (reset_accum_flag_) {
Expand All @@ -153,8 +166,11 @@ struct GmmaFP8Accumulation {
}

/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_residue_if_needed(ElementAccumulator const& scale) {
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
scale_core(scale);
}
Expand Down
Loading