Skip to content

Commit

Permalink
adding an option to skip zeroing output tensor for f8f8bf16_rowwise_g…
Browse files Browse the repository at this point in the history
…rouped_dynamic (#3685)

Summary:

X-link: facebookresearch/FBGEMM#761

In certain uses cases, the user of this api does not need zeroing out the padded area, so add this option. Note that currently the actual skipping is only done for AMD.

Differential Revision: D69380351
  • Loading branch information
mxz297 authored and facebook-github-bot committed Feb 15, 2025
1 parent a4be13a commit e883746
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ __global__ void set_kernel_args_fixed_nk_kernel(
int M,
int N,
int K,
int group_count) {
int group_count,
bool zeroing_output_tensor) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each thread is responsible for setting up the arguments for one group.
if (thread_idx < group_count) {
Expand All @@ -227,6 +228,7 @@ __global__ void set_kernel_args_fixed_nk_kernel(
// Write kernel args to memory.
kernel_args[thread_idx] = kernel_group_args;
}
if (!zeroing_output_tensor) return;

// Figure out where in memory we are.
// Each thread sets one float 4 which corresponds to 8 bf16 values.
Expand All @@ -252,7 +254,8 @@ void set_dynamic_kernel_args(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size(0);
Expand Down Expand Up @@ -292,7 +295,8 @@ void set_dynamic_kernel_args(
M,
N,
K,
group_count);
group_count,
zeroing_output_tensor);
}

template <typename OutputType>
Expand Down Expand Up @@ -433,7 +437,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
int group_count = XQ.size(0);
Expand Down Expand Up @@ -473,7 +478,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
{static_cast<long>(group_count * sizeof(KernelArguments))},
XQ.options().dtype(at::kByte));
set_dynamic_kernel_args(
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M, zeroing_output_tensor);

RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,14 +682,20 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true) {
at::Tensor Y;
int group_count = XQ.size(0);
int M = XQ.size(1);
int N = WQ.size(1);
int K = XQ.size(0);
int total_output_size = group_count * M * N;
Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16));
if (zeroing_output_tensor) {
Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16));
} else {
Y = at::empty(total_output_size, XQ.options().dtype(at::kBFloat16));
}

// Return continuous view of output.
at::Tensor output = dispatch_fp8_grouped_kernel<at::Tensor>(
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
Expand Down Expand Up @@ -724,7 +730,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M);
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true);
at::Tensor f8f8bf16_blockwise(
at::Tensor XQ,
at::Tensor WQ,
Expand Down Expand Up @@ -221,7 +222,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"f8f8bf16_rowwise_grouped_stacked(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor(a!)? output=None) -> Tensor");
m.def(
"f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M) -> Tensor");
"f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M, bool zeroing_output_tensor=True) -> Tensor");
m.def(
"f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor");
Expand Down

0 comments on commit e883746

Please sign in to comment.