Skip to content

Commit

Permalink
Moves utility functions into a standalone file.
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#749

Moves functions to better modularize code.

facebookresearch/FBGEMM#749
#3671

Reviewed By: jianyuh

Differential Revision: D69377391
  • Loading branch information
levendlee authored and facebook-github-bot committed Feb 10, 2025
1 parent 3182ea5 commit 87bb8c5
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 116 deletions.
120 changes: 4 additions & 116 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-unsafe
import logging
import sys
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -18,6 +17,10 @@
early_config_prune,
estimate_matmul_time,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.utils import (
map_dtype_to_triton,
TmaAutoTuneHelper,
)
from torch._tensor import Tensor

from triton import Config # @manual
Expand Down Expand Up @@ -59,28 +62,6 @@ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper
return tl_reinterpret(tensor, dtype=dtype)


def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
"""
Maps torch dtype to triton dtype.
Args:
dtype (torch.dtype): input dtype.
Returns:
tl.dtype: triton dtype.
"""
if dtype == torch.float16:
return tl.float16
elif dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.int32:
return tl.int32
else:
raise ValueError(f"Unsupported dtype {dtype}")


def init_to_zero(name):
return lambda nargs: nargs[name].zero_()

Expand Down Expand Up @@ -1125,99 +1106,6 @@ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
file=sys.stderr,
)
else:
print(
"TMA benchmarks will be running without grid constant TMA descriptor.",
file=sys.stderr,
)


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc

def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()

TMA_SIZE = 128

def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}

# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)

# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]


@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
Expand Down
127 changes: 127 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys

import torch
import triton # @manual

import triton.language as tl # @manual


def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
"""
Maps torch dtype to triton dtype.
Args:
dtype (torch.dtype): input dtype.
Returns:
tl.dtype: triton dtype.
"""
if dtype == torch.float16:
return tl.float16
elif dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.int32:
return tl.int32
else:
raise ValueError(f"Unsupported dtype {dtype}")


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
file=sys.stderr,
)
else:
print(
"TMA benchmarks will be running without grid constant TMA descriptor.",
file=sys.stderr,
)


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc

def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()

TMA_SIZE = 128

def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}

# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)

# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]

0 comments on commit 87bb8c5

Please sign in to comment.