Skip to content

Commit

Permalink
CUTLASS 2.6.1 - functional and performance enhancements to strided DG…
Browse files Browse the repository at this point in the history
…RAD, fixes, and tuning

* cutlass 2.6 update

* remove debug prints

* cutlass 2.6.1 (minor update)

* Updated CHANGELOG.

* Minor edit to readme to indicate patch version.

* Minor edit to readme.

Co-authored-by:  Haicheng Wu <[email protected]>, Andrew Kerr <[email protected]>
  • Loading branch information
Manish Gupta and kerrmudgeon authored Sep 3, 2021
1 parent a01feb9 commit 6c2f8f2
Show file tree
Hide file tree
Showing 55 changed files with 317 additions and 315 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

# CUTLASS 2.x

## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03)
* Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators)
* Tuning for GEMMs fused with partial reductions
* Corrections and bug fixes reported by the CUTLASS community
* Thank you for filing these issues!

## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
Expand All @@ -23,7 +29,8 @@
* Many improvements to the epilogue.
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
* Performance improvement for FP16 tensor core kernels
* Bug fixes
* Bug fixes
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
* Updated minimum CUDA Toolkit requirement to 10.2
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
* Corrections and bug fixes reported by the CUTLASS community
Expand Down
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ if (${CUTLASS_NVCC_VERBOSE})
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
endif()

#
# CUTLASS NAMESPACE
#
set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS")

set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
Expand Down Expand Up @@ -383,6 +388,8 @@ function(cutlass_apply_standard_compile_options TARGET)
set(_FLAGS_DEBUG ${__CUTLASS_CUDA_FLAGS_DEBUG} ${__CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
endif()

target_link_libraries(${TARGET} PRIVATE CUTLASS)

target_compile_options(
${TARGET}
PRIVATE
Expand Down Expand Up @@ -425,6 +432,7 @@ set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/includ
include_directories(${CUTLASS_INCLUDE_DIR})

target_compile_features(CUTLASS INTERFACE cxx_std_11)
target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE})

if (NOT DEFINED CUTLASS_REVISION)

Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# CUTLASS 2.6

_CUTLASS 2.6 - July 2021_
_CUTLASS 2.6.1 - September 2021_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
Expand Down Expand Up @@ -34,18 +34,21 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.

See the [CHANGELOG](CHANGELOG.md) for descriptions of recent updates.

# What's New in CUTLASS 2.6
CUTLASS 2.6 is a minor update to CUTLASS adding:
- Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution
- [Quaternion-valued GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu) in single-precision
- [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation offers up to 4x performance improvements over previous strided Dgrad
- 64-bit strides for large tensor allocations
- [General affine layouts](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) fp64 tensor core and simt GEMM
- [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
- Enhanced functionality, boosted performance, and bug fixes in the epilogue.
- Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
- Adopt new L2 prefetch feature in [ptx instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-isa-version-7-4).
- Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
- Numerous updates from the community (thanks!)
- See the [CHANGELOG](CHANGELOG.md) for more details

# What's New in CUTLASS 2.5
CUTLASS 2.5 is a minor update to CUTLASS adding:
Expand Down
8 changes: 0 additions & 8 deletions examples/13_two_tensor_op_fusion/device/b2b_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,6 @@ class B2bGemm {
if (result != cudaSuccess) {
return Status::kErrorInternal;
}

result = cudaFuncSetAttribute(
Kernel<B2bGemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);

if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}

cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,6 @@ class B2bImplicitGemmConvolution {
if (result != cudaSuccess) {
return Status::kErrorInternal;
}

result = cudaFuncSetAttribute(
cutlass::Kernel<B2bImplicitGemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);

if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}

return Status::kSuccess;
Expand Down
30 changes: 9 additions & 21 deletions examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,8 @@ class B2bMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations_0) {

if (gemm_k_iterations_0 == 0) {
iterator_A0.clear_mask();
iterator_B0.clear_mask();
}
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);

iterator_A0.set_iteration_index(0);
this->smem_iterator_A0_.set_iteration_index(0);
Expand Down Expand Up @@ -490,10 +488,8 @@ class B2bMmaMultistage :
++this->warp_tile_iterator_A0_;
++this->warp_tile_iterator_B0_;

if (gemm_k_iterations_0 == 0) {
iterator_A0.clear_mask();
iterator_B0.clear_mask();
}
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -601,10 +597,8 @@ class B2bMmaMultistage :
}

--gemm_k_iterations_0;
if (gemm_k_iterations_0 == 0) {
iterator_A0.clear_mask();
iterator_B0.clear_mask();
}
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down Expand Up @@ -634,9 +628,7 @@ class B2bMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations_1) {

if (gemm_k_iterations_1 == 0) {
iterator_B1.clear_mask();
}
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);

iterator_B1.set_iteration_index(0);
this->smem_iterator_B1_.set_iteration_index(0);
Expand Down Expand Up @@ -694,9 +686,7 @@ class B2bMmaMultistage :
++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_;

if (gemm_k_iterations_1 == 0) {
iterator_B1.clear_mask();
}
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);

smem_write_stage_idx = Base::kStages - 1;
smem_read_stage_idx = 0;
Expand Down Expand Up @@ -793,9 +783,7 @@ class B2bMmaMultistage :
++smem_read_stage_idx;
}

if (gemm_k_iterations_1 == 1) {
iterator_B1.clear_mask();
}
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
22 changes: 22 additions & 0 deletions include/cutlass/conv/conv2d_problem_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/functional.h"

namespace cutlass {
namespace conv {
Expand Down Expand Up @@ -485,6 +486,27 @@ int strided_dgrad_tile_m_per_filter(
return tile_m_per_filter;
}

// Computes starting Dx coord (h, w) for given starting filter postion
CUTLASS_HOST_DEVICE
void strided_dgrad_starting_coords(
Conv2dProblemSize const &problem_size,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int r, int s,
int &start_h, int &start_w) {

// function locals for remainder by fast divmod
int pad_h_rem_, pad_w_rem_;

// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
int r_ = std::abs(problem_size.stride_h - (pad_h_rem_ - r));
stride_h_divmod.divmod(start_h, r_);

//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
int s_ = std::abs(problem_size.stride_w - (pad_w_rem_ - s));
stride_w_divmod.divmod(start_w, s_);
}

} // namespace conv
} // namespace cutlass
Expand Down
8 changes: 0 additions & 8 deletions include/cutlass/conv/device/implicit_gemm_convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,6 @@ class ImplicitGemmConvolution {
if (result != cudaSuccess) {
return Status::kErrorInternal;
}

result = cudaFuncSetAttribute(
cutlass::Kernel<ImplicitGemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);

if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}

return Status::kSuccess;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
struct Params {
ConvProblemSize problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
FastDivmod filter_s_divmod;
FastDivmod stride_h_divmod;
FastDivmod stride_w_divmod;
int gemm_k_iterations;
typename Mma::IteratorA::Params iterator_A;
typename Mma::IteratorA::Element const *ptr_A;
Expand Down Expand Up @@ -227,7 +228,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
int *semaphore = nullptr
):
problem_size(args.problem_size),
filter_s_divmod(args.problem_size.stride_w),
stride_h_divmod(args.problem_size.stride_h),
stride_w_divmod(args.problem_size.stride_w),
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
ptr_A(args.ref_A.data()),
iterator_B(args.problem_size, args.ref_B.layout()),
Expand Down Expand Up @@ -297,7 +299,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
// int start_s = filter_tile_m % (params.problem_size.stride_w);

int start_r, start_s;
params.filter_s_divmod(start_r, start_s, filter_tile_m);
params.stride_w_divmod(start_r, start_s, filter_tile_m);

typename Mma::FragmentC accumulators;

Expand All @@ -320,6 +322,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
params.problem_size,
params.ptr_A,
thread_idx,
params.stride_h_divmod, params.stride_w_divmod,
start_r, start_s,
MatrixCoord(
threadblock_tile_idx.m() * Mma::Shape::kM,
Expand Down Expand Up @@ -386,6 +389,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
params.ptr_D,
ConvOutputIteratorParameter::extent(params.problem_size),
thread_idx,
params.stride_h_divmod, params.stride_w_divmod,
start_r, start_s,
threadblock_offset
);
Expand All @@ -396,6 +400,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
params.ptr_C,
ConvOutputIteratorParameter::extent(params.problem_size),
thread_idx,
params.stride_h_divmod, params.stride_w_divmod,
start_r, start_s,
threadblock_offset
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
int offset_p_[ThreadMap::Iterations::kStrided];
int offset_q_[ThreadMap::Iterations::kStrided];


public:

CUTLASS_HOST_DEVICE
Expand All @@ -139,6 +138,7 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
):
Expand All @@ -164,9 +164,12 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
}

// Starting h, w positions for filter position in gemm_k=0
int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h);
int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w);

int start_h, start_w;
strided_dgrad_starting_coords(
problem_size_,
stride_h_divmod, stride_w_divmod,
filter_r, filter_s,
start_h, start_w);

// Effective P and Q for filter position required for remapping NHW rows
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
Expand Down
29 changes: 27 additions & 2 deletions include/cutlass/conv/threadblock/conv2d_tile_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,27 @@ class TileIteratorStridedDgrad {

public:

/// Constructor
/// Constructor (output gradient (Dy) OperandA ctor)
CUTLASS_HOST_DEVICE
TileIteratorStridedDgrad(
Params const &params,
ConvProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord()
):
tile_access_iterator_(
params,
problem_size,
ptr,
thread_idx,
stride_h_divmod, stride_w_divmod,
start_r, start_s,
threadblock_offset) { }

/// Constructor (filter (w) OperandB ctor)
CUTLASS_HOST_DEVICE
TileIteratorStridedDgrad(
Params const &params,
Expand All @@ -210,7 +230,12 @@ class TileIteratorStridedDgrad {
int start_r, int start_s,
MatrixCoord const &threadblock_offset = MatrixCoord()
):
tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { }
tile_access_iterator_(params,
problem_size,
ptr,
thread_idx,
start_r, start_s,
threadblock_offset) { }

CUTLASS_HOST_DEVICE
static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
Expand Down
6 changes: 6 additions & 0 deletions include/cutlass/cutlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@

////////////////////////////////////////////////////////////////////////////////////////////////////

#ifdef CUTLASS_NAMESPACE
#define cutlass CUTLASS_NAMESPACE
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0)

#if defined(_MSC_VER)
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/epilogue/thread/linear_combination.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ class LinearCombination {
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;

ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);

if (Scale == ScaleType::Nothing)
return destination_converter(converted_accumulator);

ComputeFragment converted_source = source_converter(source);

// Perform binary operations
ComputeFragment intermediate;
Expand Down
Loading

0 comments on commit 6c2f8f2

Please sign in to comment.