-
Notifications
You must be signed in to change notification settings - Fork 537
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moves utility functions into a standalone file.
Summary: X-link: facebookresearch/FBGEMM#749 Moves functions to better modularize code. facebookresearch/FBGEMM#749 #3671 Reviewed By: jianyuh Differential Revision: D69377391
- Loading branch information
1 parent
3182ea5
commit 87bb8c5
Showing
2 changed files
with
131 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
|
||
import torch | ||
import triton # @manual | ||
|
||
import triton.language as tl # @manual | ||
|
||
|
||
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: | ||
""" | ||
Maps torch dtype to triton dtype. | ||
Args: | ||
dtype (torch.dtype): input dtype. | ||
Returns: | ||
tl.dtype: triton dtype. | ||
""" | ||
if dtype == torch.float16: | ||
return tl.float16 | ||
elif dtype == torch.bfloat16: | ||
return tl.bfloat16 | ||
elif dtype == torch.float32: | ||
return tl.float32 | ||
elif dtype == torch.int32: | ||
return tl.int32 | ||
else: | ||
raise ValueError(f"Unsupported dtype {dtype}") | ||
|
||
|
||
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). | ||
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) | ||
|
||
if HAS_TMA_DESC: | ||
print( | ||
"TMA benchmarks will be running with experimental grid constant TMA descriptor.", | ||
file=sys.stderr, | ||
) | ||
else: | ||
print( | ||
"TMA benchmarks will be running without grid constant TMA descriptor.", | ||
file=sys.stderr, | ||
) | ||
|
||
|
||
class TmaAutoTuneHelper: | ||
|
||
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 | ||
class KernelParamWrapper: | ||
def __init__(self, desc): | ||
self.desc = desc | ||
|
||
def tma_desc_cpu_ptr(self): | ||
return self.desc.data_ptr() | ||
|
||
TMA_SIZE = 128 | ||
|
||
def __init__(self): | ||
self.fill_1d_tma_descriptor_inner = ( | ||
triton.runtime.driver.active.utils.fill_1d_tma_descriptor | ||
) | ||
self.fill_2d_tma_descriptor_inner = ( | ||
triton.runtime.driver.active.utils.fill_2d_tma_descriptor | ||
) | ||
if HAS_TMA_DESC: | ||
self.descriptors = {} | ||
else: | ||
self.cuda_descriptors = {} | ||
|
||
# Call this method outside of the lambda function for grid size | ||
def init_tma_descriptor(self, name): | ||
if HAS_TMA_DESC: | ||
self.descriptors[name] = torch.empty( | ||
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 | ||
) | ||
else: | ||
self.cuda_descriptors[name] = torch.empty( | ||
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 | ||
) | ||
|
||
# Call this method inside the lambda function for grid size | ||
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): | ||
if HAS_TMA_DESC: | ||
desc_x = self.descriptors[name] | ||
assert desc_x.data_ptr() % 64 == 0 | ||
self.fill_1d_tma_descriptor_inner( | ||
ptr, dim, block_dim, element_size, desc_x.data_ptr() | ||
) | ||
else: | ||
desc_x = self.cuda_descriptors[name] | ||
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) | ||
self.fill_1d_tma_descriptor_inner( | ||
ptr, dim, block_dim, element_size, buf_x.data_ptr() | ||
) | ||
desc_x.copy_(buf_x, non_blocking=True) | ||
|
||
# Call this method inside the lambda function for grid size | ||
def fill_2d_tma_descriptor( | ||
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size | ||
): | ||
if HAS_TMA_DESC: | ||
desc_x = self.descriptors[name] | ||
assert desc_x.data_ptr() % 64 == 0 | ||
self.fill_2d_tma_descriptor_inner( | ||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() | ||
) | ||
else: | ||
desc_x = self.cuda_descriptors[name] | ||
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) | ||
self.fill_2d_tma_descriptor_inner( | ||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() | ||
) | ||
desc_x.copy_(buf_x, non_blocking=True) | ||
|
||
def get_tma_descriptor_kernel_param(self, name): | ||
if HAS_TMA_DESC: | ||
assert self.descriptors[name] is not None | ||
return self.KernelParamWrapper(self.descriptors[name]) | ||
else: | ||
assert self.cuda_descriptors[name] is not None | ||
return self.cuda_descriptors[name] |