From 6c2f8f2fb877415685f06f53b9c6d82105a6c5e9 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Fri, 3 Sep 2021 10:26:15 -0700 Subject: [PATCH] CUTLASS 2.6.1 - functional and performance enhancements to strided DGRAD, fixes, and tuning * cutlass 2.6 update * remove debug prints * cutlass 2.6.1 (minor update) * Updated CHANGELOG. * Minor edit to readme to indicate patch version. * Minor edit to readme. Co-authored-by: Haicheng Wu , Andrew Kerr --- CHANGELOG.md | 9 +++- CMakeLists.txt | 8 ++++ README.md | 7 ++- .../13_two_tensor_op_fusion/device/b2b_gemm.h | 8 ---- .../device/b2b_implicit_gemm_convolution.h | 8 ---- .../threadblock/b2b_mma_multistage.h | 30 ++++--------- include/cutlass/conv/conv2d_problem_size.h | 22 ++++++++++ .../conv/device/implicit_gemm_convolution.h | 8 ---- .../implicit_gemm_convolution_strided_dgrad.h | 11 +++-- ...t_gradient_tile_access_iterator_analytic.h | 11 +++-- .../conv/threadblock/conv2d_tile_iterator.h | 29 ++++++++++++- include/cutlass/cutlass.h | 6 +++ .../epilogue/thread/linear_combination.h | 2 +- .../threadblock/default_epilogue_tensor_op.h | 5 ++- .../threadblock/predicated_tile_iterator.h | 18 +++++--- .../predicated_tile_iterator_strided_dgrad.h | 10 +++-- .../warp/tile_iterator_tensor_op_mixed.h | 15 +------ include/cutlass/gemm/device/gemm.h | 8 ---- include/cutlass/gemm/device/gemm_array.h | 8 ---- include/cutlass/gemm/device/gemm_batched.h | 8 ---- include/cutlass/gemm/device/gemm_complex.h | 8 ---- include/cutlass/gemm/device/gemm_sparse.h | 8 ---- .../gemm/device/gemm_splitk_parallel.h | 8 ---- .../cutlass/gemm/device/gemm_universal_base.h | 8 ---- .../kernel/default_gemm_with_k_reduction.h | 6 +-- include/cutlass/gemm/kernel/gemm.h | 1 + .../cutlass/gemm/threadblock/default_mma.h | 1 - .../default_mma_core_with_reduction.h | 4 +- .../threadblock/default_mma_with_reduction.h | 5 ++- .../cutlass/gemm/threadblock/mma_multistage.h | 18 +++----- .../cutlass/gemm/threadblock/mma_pipelined.h | 12 ++---- .../mma_planar_complex_multistage.h | 30 ++++++------- .../mma_planar_complex_pipelined.h | 22 ++++------ .../gemm/threadblock/mma_singlestage.h | 12 ++---- .../gemm/threadblock/mma_sparse_multistage.h | 24 ++++------- .../mma_with_reduction_multistage.h | 28 +++++------- .../warp/mma_tensor_op_fragment_iterator.h | 17 +++++--- include/cutlass/matrix.h | 6 +-- .../predicated_tile_access_iterator.h | 22 +++++----- ...icated_tile_access_iterator_2dthreadtile.h | 8 ++-- .../threadblock/predicated_tile_iterator.h | 24 +++++------ .../predicated_tile_iterator_2dthreadtile.h | 10 ++--- media/docs/profiler.md | 1 - media/docs/quickstart.md | 9 +++- test/unit/common/cutlass_unit_test.h | 4 +- test/unit/conv/device/conv2d_problems.h | 43 +++++++++++++++++++ ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu | 20 +++------ test/unit/nvrtc/cutlass/nvrtc/environment.h | 1 + tools/library/scripts/generator.py | 9 +++- tools/library/scripts/manifest.py | 3 +- .../src/reference/gemm_reference_operation.h | 20 +++------ tools/profiler/src/gemm_operation_profiler.cu | 2 +- tools/profiler/src/gpu_timer.h | 1 + tools/profiler/src/options.cu | 4 +- .../util/include/cutlass/util/command_line.h | 2 + 55 files changed, 317 insertions(+), 315 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ac752046..448b1134d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ # CUTLASS 2.x +## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03) + * Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators) + * Tuning for GEMMs fused with partial reductions + * Corrections and bug fixes reported by the CUTLASS community + * Thank you for filing these issues! + ## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22) * Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) * Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h) @@ -23,7 +29,8 @@ * Many improvements to the epilogue. * Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations * Performance improvement for FP16 tensor core kernels - * Bug fixes + * Bug fixes + * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. * Updated minimum CUDA Toolkit requirement to 10.2 * [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended * Corrections and bug fixes reported by the CUTLASS community diff --git a/CMakeLists.txt b/CMakeLists.txt index a27f32f1d..a16d77c0f 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -168,6 +168,11 @@ if (${CUTLASS_NVCC_VERBOSE}) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v) endif() +# +# CUTLASS NAMESPACE +# +set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS") + set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.") set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") @@ -383,6 +388,8 @@ function(cutlass_apply_standard_compile_options TARGET) set(_FLAGS_DEBUG ${__CUTLASS_CUDA_FLAGS_DEBUG} ${__CUTLASS_CUDA_NVCC_FLAGS_DEBUG}) endif() + target_link_libraries(${TARGET} PRIVATE CUTLASS) + target_compile_options( ${TARGET} PRIVATE @@ -425,6 +432,7 @@ set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/includ include_directories(${CUTLASS_INCLUDE_DIR}) target_compile_features(CUTLASS INTERFACE cxx_std_11) +target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE}) if (NOT DEFINED CUTLASS_REVISION) diff --git a/README.md b/README.md index 64f552c2a..0079012cf 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # CUTLASS 2.6 -_CUTLASS 2.6 - July 2021_ +_CUTLASS 2.6.1 - September 2021_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. @@ -34,6 +34,8 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](/media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. +See the [CHANGELOG](CHANGELOG.md) for descriptions of recent updates. + # What's New in CUTLASS 2.6 CUTLASS 2.6 is a minor update to CUTLASS adding: - Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution @@ -41,11 +43,12 @@ CUTLASS 2.6 is a minor update to CUTLASS adding: - [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation offers up to 4x performance improvements over previous strided Dgrad - 64-bit strides for large tensor allocations - [General affine layouts](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) fp64 tensor core and simt GEMM +- [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation - Enhanced functionality, boosted performance, and bug fixes in the epilogue. - Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) - Adopt new L2 prefetch feature in [ptx instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-isa-version-7-4). +- Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. - Numerous updates from the community (thanks!) -- See the [CHANGELOG](CHANGELOG.md) for more details # What's New in CUTLASS 2.5 CUTLASS 2.5 is a minor update to CUTLASS adding: diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h index 559808c7c..564a440a8 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -390,14 +390,6 @@ class B2bGemm { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } cutlass::Kernel<<>>(params_); diff --git a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h index 3ead9eb5f..1d539a0f1 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h @@ -197,14 +197,6 @@ class B2bImplicitGemmConvolution { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - cutlass::Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } return Status::kSuccess; diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h index 24af6c676..edb7fbe86 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -395,10 +395,8 @@ class B2bMmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations_0) { - if (gemm_k_iterations_0 == 0) { - iterator_A0.clear_mask(); - iterator_B0.clear_mask(); - } + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); iterator_A0.set_iteration_index(0); this->smem_iterator_A0_.set_iteration_index(0); @@ -490,10 +488,8 @@ class B2bMmaMultistage : ++this->warp_tile_iterator_A0_; ++this->warp_tile_iterator_B0_; - if (gemm_k_iterations_0 == 0) { - iterator_A0.clear_mask(); - iterator_B0.clear_mask(); - } + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -601,10 +597,8 @@ class B2bMmaMultistage : } --gemm_k_iterations_0; - if (gemm_k_iterations_0 == 0) { - iterator_A0.clear_mask(); - iterator_B0.clear_mask(); - } + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); } // Do any conversions feeding the first stage at the end of the loop so @@ -634,9 +628,7 @@ class B2bMmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations_1) { - if (gemm_k_iterations_1 == 0) { - iterator_B1.clear_mask(); - } + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); iterator_B1.set_iteration_index(0); this->smem_iterator_B1_.set_iteration_index(0); @@ -694,9 +686,7 @@ class B2bMmaMultistage : ++warp_tile_iterator_A1_; ++this->warp_tile_iterator_B1_; - if (gemm_k_iterations_1 == 0) { - iterator_B1.clear_mask(); - } + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); smem_write_stage_idx = Base::kStages - 1; smem_read_stage_idx = 0; @@ -793,9 +783,7 @@ class B2bMmaMultistage : ++smem_read_stage_idx; } - if (gemm_k_iterations_1 == 1) { - iterator_B1.clear_mask(); - } + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index e9ac10419..90c59ca99 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -47,6 +47,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" #include "cutlass/conv/convolution.h" +#include "cutlass/functional.h" namespace cutlass { namespace conv { @@ -485,6 +486,27 @@ int strided_dgrad_tile_m_per_filter( return tile_m_per_filter; } +// Computes starting Dx coord (h, w) for given starting filter postion +CUTLASS_HOST_DEVICE +void strided_dgrad_starting_coords( + Conv2dProblemSize const &problem_size, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int r, int s, + int &start_h, int &start_w) { + + // function locals for remainder by fast divmod + int pad_h_rem_, pad_w_rem_; + + // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; + stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); + int r_ = std::abs(problem_size.stride_h - (pad_h_rem_ - r)); + stride_h_divmod.divmod(start_h, r_); + + //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; + stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); + int s_ = std::abs(problem_size.stride_w - (pad_w_rem_ - s)); + stride_w_divmod.divmod(start_w, s_); +} } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 158ce178a..eed01b5d8 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -217,14 +217,6 @@ class ImplicitGemmConvolution { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - cutlass::Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } return Status::kSuccess; diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 8602a96f6..5ad1304fe 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -199,7 +199,8 @@ struct ImplicitGemmConvolutionStridedDgrad { struct Params { ConvProblemSize problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; - FastDivmod filter_s_divmod; + FastDivmod stride_h_divmod; + FastDivmod stride_w_divmod; int gemm_k_iterations; typename Mma::IteratorA::Params iterator_A; typename Mma::IteratorA::Element const *ptr_A; @@ -227,7 +228,8 @@ struct ImplicitGemmConvolutionStridedDgrad { int *semaphore = nullptr ): problem_size(args.problem_size), - filter_s_divmod(args.problem_size.stride_w), + stride_h_divmod(args.problem_size.stride_h), + stride_w_divmod(args.problem_size.stride_w), iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), @@ -297,7 +299,7 @@ struct ImplicitGemmConvolutionStridedDgrad { // int start_s = filter_tile_m % (params.problem_size.stride_w); int start_r, start_s; - params.filter_s_divmod(start_r, start_s, filter_tile_m); + params.stride_w_divmod(start_r, start_s, filter_tile_m); typename Mma::FragmentC accumulators; @@ -320,6 +322,7 @@ struct ImplicitGemmConvolutionStridedDgrad { params.problem_size, params.ptr_A, thread_idx, + params.stride_h_divmod, params.stride_w_divmod, start_r, start_s, MatrixCoord( threadblock_tile_idx.m() * Mma::Shape::kM, @@ -386,6 +389,7 @@ struct ImplicitGemmConvolutionStridedDgrad { params.ptr_D, ConvOutputIteratorParameter::extent(params.problem_size), thread_idx, + params.stride_h_divmod, params.stride_w_divmod, start_r, start_s, threadblock_offset ); @@ -396,6 +400,7 @@ struct ImplicitGemmConvolutionStridedDgrad { params.ptr_C, ConvOutputIteratorParameter::extent(params.problem_size), thread_idx, + params.stride_h_divmod, params.stride_w_divmod, start_r, start_s, threadblock_offset ); diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index dd5116385..f9e9083e2 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -130,7 +130,6 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < int offset_p_[ThreadMap::Iterations::kStrided]; int offset_q_[ThreadMap::Iterations::kStrided]; - public: CUTLASS_HOST_DEVICE @@ -139,6 +138,7 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < Conv2dProblemSize const &problem_size, Element const *ptr, int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, int start_r, int start_s, MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles ): @@ -164,9 +164,12 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < } // Starting h, w positions for filter position in gemm_k=0 - int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h); - int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w); - + int start_h, start_w; + strided_dgrad_starting_coords( + problem_size_, + stride_h_divmod, stride_w_divmod, + filter_r, filter_s, + start_h, start_w); // Effective P and Q for filter position required for remapping NHW rows int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; diff --git a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index 5f75645a7..7a9228c39 100644 --- a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -200,7 +200,27 @@ class TileIteratorStridedDgrad { public: - /// Constructor + /// Constructor (output gradient (Dy) OperandA ctor) + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_( + params, + problem_size, + ptr, + thread_idx, + stride_h_divmod, stride_w_divmod, + start_r, start_s, + threadblock_offset) { } + + /// Constructor (filter (w) OperandB ctor) CUTLASS_HOST_DEVICE TileIteratorStridedDgrad( Params const ¶ms, @@ -210,7 +230,12 @@ class TileIteratorStridedDgrad { int start_r, int start_s, MatrixCoord const &threadblock_offset = MatrixCoord() ): - tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { } + tile_access_iterator_(params, + problem_size, + ptr, + thread_idx, + start_r, start_s, + threadblock_offset) { } CUTLASS_HOST_DEVICE static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 8e28408ec..b7e447dd5 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -31,6 +31,12 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef CUTLASS_NAMESPACE +#define cutlass CUTLASS_NAMESPACE +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + #define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0) #if defined(_MSC_VER) diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 989d5af44..703de5893 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -174,12 +174,12 @@ class LinearCombination { // Convert to destination numeric type NumericArrayConverter destination_converter; + ComputeFragment converted_source = source_converter(source); ComputeFragment converted_accumulator = accumulator_converter(accumulator); if (Scale == ScaleType::Nothing) return destination_converter(converted_accumulator); - ComputeFragment converted_source = source_converter(source); // Perform binary operations ComputeFragment intermediate; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index dce467d7a..8b1d803f5 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -309,9 +309,12 @@ struct DefaultEpilogueTensorOp { kElementsPerAccess >::Type; + static bool const UseCUDAStore = platform::is_same::value; + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, - ElementOutput + ElementOutput, + UseCUDAStore >; using AccumulatorFragmentIterator = typename std::conditional::value, diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index f14b67886..294770d57 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -62,7 +62,8 @@ namespace threadblock { /// template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) - typename Element_ ///< Element data type + typename Element_, ///< Element data type + bool UseCUDAStore = false > class PredicatedTileIterator { public: @@ -341,10 +342,17 @@ class PredicatedTileIterator { bool guard = row_guard && mask_.predicates[column]; - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } } if (row + 1 < ThreadMap::Iterations::kRow) { diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h index 1699ce393..103fe2bec 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h @@ -222,6 +222,7 @@ class PredicatedTileIteratorStridedDgrad { Element *pointer, TensorCoord extent, int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, int start_r, int start_s, TensorCoord threadblock_offset = TensorCoord() ): @@ -238,9 +239,12 @@ class PredicatedTileIteratorStridedDgrad { s = (params_.problem_size.S - 1 - s); } - // check if start_h_ and start_w_ are always positive - start_h_ = std::abs((params_.problem_size.pad_h - r) % params_.problem_size.stride_h); - start_w_ = std::abs((params_.problem_size.pad_w - s) % params_.problem_size.stride_w); + // compute starting coordinates in Dx start_h_ and start_w_ + strided_dgrad_starting_coords( + params_.problem_size, + stride_h_divmod, stride_w_divmod, + r, s, + start_h_, start_w_); p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h; q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w; diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index 9de8fae79..a0fd92c88 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -256,20 +256,7 @@ class TileIteratorTensorOpMixed { int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess; -#if 0 - // Using inline PTX to avoid generic memory - AccessType *smem_ptr = pointers_[ptr_idx]; - smem_ptr[offset] = frag_ptr[n]; -#else - uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); - uint32_t const *data = reinterpret_cast(frag_ptr + n); - uint32_t offset_in_bytes = offset * sizeof(AccessType); - - asm volatile( - "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" - : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) - ); -#endif + ptr[offset] = frag_ptr[n]; } } diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index 6e1362a57..aeb4da4ad 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -455,14 +455,6 @@ class Gemm { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } cutlass::Kernel<<>>(params_); diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index 0baeb6e3b..e85f4591b 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -445,14 +445,6 @@ class GemmArray { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } cutlass::Kernel<<>>(params_); diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 63515e145..afa22fe94 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -423,14 +423,6 @@ class GemmBatched { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } cutlass::Kernel<<>>(params_); diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index e9de60fcb..be4e19b6c 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -437,14 +437,6 @@ class GemmComplex { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } cutlass::Kernel<<>>(params_); diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index a954ad5e2..581ca41dc 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -438,14 +438,6 @@ class SparseGemm { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } return Status::kSuccess; diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index 416ea9069..acf2cefae 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -352,14 +352,6 @@ class GemmSplitKParallel { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } Kernel<<>>(gemm_params_); diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 4b1b15862..aa85a5149 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -325,14 +325,6 @@ class GemmUniversalBase { if (result != cudaSuccess) { return Status::kErrorInternal; } - - result = cudaFuncSetAttribute( - Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } } return Status::kSuccess; diff --git a/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h index 6cb098416..114351d14 100644 --- a/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h @@ -103,8 +103,8 @@ template < int Stages, /// Operation performed by GEMM typename Operator, - /// Use zfill or predicate for SM80 out-of-bound cp.async - bool UseZfill = false, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// typename Enable = void> struct DefaultGemmWithKReduction { @@ -116,7 +116,7 @@ struct DefaultGemmWithKReduction { ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, UseZfill>::ThreadblockMma; + Operator, false, SharedMemoryClear>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index 96ed1e701..b781dbcc3 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -34,6 +34,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" #include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 4c429ac39..40b9b34bf 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -130,7 +130,6 @@ struct DefaultMma { - static_assert(platform::is_same::value || platform::is_same>::value, "simt epilogue must be row major"); diff --git a/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h b/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h index e3f798c8b..217786603 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h @@ -141,8 +141,8 @@ struct DefaultMmaWithReductionCore { using SmemLayoutB = typename Base::SmemLayoutB; using WarpCount = typename Base::WarpCount; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; // Define the warp-level tensor op using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp< diff --git a/include/cutlass/gemm/threadblock/default_mma_with_reduction.h b/include/cutlass/gemm/threadblock/default_mma_with_reduction.h index 7d42408db..e98475c7f 100644 --- a/include/cutlass/gemm/threadblock/default_mma_with_reduction.h +++ b/include/cutlass/gemm/threadblock/default_mma_with_reduction.h @@ -82,9 +82,10 @@ template < /// when output layout is interleaved. bool AccumulatorsInRowMajor = false, /// Use zfill or predicate for SM80 out-of-bound cp.async - bool UseZfill = false + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone > struct DefaultMmaWithReduction { + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global @@ -122,7 +123,7 @@ struct DefaultMmaWithReduction { typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, UseZfill>; + typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index 8d51b9b91..4c8e92cf1 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -303,10 +303,8 @@ class MmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -447,10 +445,8 @@ class MmaMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -558,10 +554,8 @@ class MmaMultistage : } --gemm_k_iterations; - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_pipelined.h b/include/cutlass/gemm/threadblock/mma_pipelined.h index ec294fee7..76513ee36 100644 --- a/include/cutlass/gemm/threadblock/mma_pipelined.h +++ b/include/cutlass/gemm/threadblock/mma_pipelined.h @@ -231,10 +231,8 @@ class MmaPipelined : public MmaBase { int smem_write_stage_idx = 1; // Avoid reading out of bounds - if (gemm_k_iterations <= 1) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing // shared memory loads (which have the tighest latency requirement). @@ -302,10 +300,8 @@ class MmaPipelined : public MmaBase { ++iterator_B; // Avoid reading out of bounds if this was the last loop iteration - if (gemm_k_iterations <= 2) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); } warp_mma(accum, warp_frag_A[warp_mma_k % 2], diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h index 94fcc4403..47cdf3b1a 100644 --- a/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h @@ -370,12 +370,10 @@ class MmaPlanarComplexMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - if (gemm_k_iterations == 0) { - iterator_A_real.clear_mask(); - iterator_A_imag.clear_mask(); - iterator_B_real.clear_mask(); - iterator_B_imag.clear_mask(); - } + iterator_A_real.clear_mask(gemm_k_iterations == 0); + iterator_A_imag.clear_mask(gemm_k_iterations == 0); + iterator_B_real.clear_mask(gemm_k_iterations == 0); + iterator_B_imag.clear_mask(gemm_k_iterations == 0); iterator_A_real.set_iteration_index(0); iterator_A_imag.set_iteration_index(0); @@ -501,12 +499,10 @@ class MmaPlanarComplexMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - if (gemm_k_iterations == 0) { - iterator_A_real.clear_mask(); - iterator_A_imag.clear_mask(); - iterator_B_real.clear_mask(); - iterator_B_imag.clear_mask(); - } + iterator_A_real.clear_mask(gemm_k_iterations == 0); + iterator_A_imag.clear_mask(gemm_k_iterations == 0); + iterator_B_real.clear_mask(gemm_k_iterations == 0); + iterator_B_imag.clear_mask(gemm_k_iterations == 0); // Start issuing the first group of the next stage outside of the mainloop copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag); @@ -611,12 +607,10 @@ class MmaPlanarComplexMultistage : } --gemm_k_iterations; - if (gemm_k_iterations == 0) { - iterator_A_real.clear_mask(); - iterator_A_imag.clear_mask(); - iterator_B_real.clear_mask(); - iterator_B_imag.clear_mask(); - } + iterator_A_real.clear_mask(gemm_k_iterations == 0); + iterator_A_imag.clear_mask(gemm_k_iterations == 0); + iterator_B_real.clear_mask(gemm_k_iterations == 0); + iterator_B_imag.clear_mask(gemm_k_iterations == 0); } warp_mma_planar_complex( diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h index d7eb08795..d2da8e6d3 100644 --- a/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h @@ -308,13 +308,11 @@ class MmaPlanarComplexPipelined : int smem_write_stage_idx = 1; // Avoid reading out of bounds - if (gemm_k_iterations <= 1) { - iterator_A_real.clear_mask(); - iterator_A_imag.clear_mask(); - - iterator_B_real.clear_mask(); - iterator_B_imag.clear_mask(); - } + iterator_A_real.clear_mask(gemm_k_iterations <= 1); + iterator_A_imag.clear_mask(gemm_k_iterations <= 1); + + iterator_B_real.clear_mask(gemm_k_iterations <= 1); + iterator_B_imag.clear_mask(gemm_k_iterations <= 1); // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing // shared memory loads (which have the tighest latency requirement). @@ -392,12 +390,10 @@ class MmaPlanarComplexPipelined : ++iterator_B_imag; // Avoid reading out of bounds if this was the last loop iteration - if (gemm_k_iterations <= 2) { - iterator_A_real.clear_mask(); - iterator_A_imag.clear_mask(); - iterator_B_real.clear_mask(); - iterator_B_imag.clear_mask(); - } + iterator_A_real.clear_mask(gemm_k_iterations <= 2); + iterator_A_imag.clear_mask(gemm_k_iterations <= 2); + iterator_B_real.clear_mask(gemm_k_iterations <= 2); + iterator_B_imag.clear_mask(gemm_k_iterations <= 2); } warp_mma_planar_complex( diff --git a/include/cutlass/gemm/threadblock/mma_singlestage.h b/include/cutlass/gemm/threadblock/mma_singlestage.h index f0d00525e..57aaf731b 100644 --- a/include/cutlass/gemm/threadblock/mma_singlestage.h +++ b/include/cutlass/gemm/threadblock/mma_singlestage.h @@ -196,10 +196,8 @@ class MmaSingleStage : public MmaBase { Operator warp_mma; // Avoid reading out of bounds - if (gemm_k_iterations <= 1) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); // // Mainloop @@ -247,10 +245,8 @@ class MmaSingleStage : public MmaBase { ++iterator_B; // Avoid reading out of bounds if this was the last loop iteration - if (gemm_k_iterations <= 2) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); } } diff --git a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h index 09d9345e9..a92c13cae 100644 --- a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h @@ -379,11 +379,9 @@ class SparseMmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - iterator_E.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_E.clear_mask(gemm_k_iterations == 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -500,11 +498,9 @@ class SparseMmaMultistage : ++this->warp_tile_iterator_B_; ++this->warp_tile_iterator_E_; - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - iterator_E.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_E.clear_mask(gemm_k_iterations == 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -637,11 +633,9 @@ class SparseMmaMultistage : } --gemm_k_iterations; - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - iterator_E.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_E.clear_mask(gemm_k_iterations == 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h index 75ed225ed..6159db44c 100644 --- a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h @@ -78,7 +78,7 @@ template < /// Number of stages, int Stages, /// Use zfill or predicate for out-of-bound cp.async - bool UseZfill = false, + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> class MmaWithReductionMultistage : @@ -234,7 +234,7 @@ class MmaWithReductionMultistage : for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_A.get(); - if (UseZfill) { + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { cutlass::arch::cp_async_zfill( dst_ptr + v, gmem_ptr, iterator_A.valid()); } else { @@ -269,7 +269,7 @@ class MmaWithReductionMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); - if (UseZfill) { + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { cutlass::arch::cp_async_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); } else { @@ -302,16 +302,14 @@ class MmaWithReductionMultistage : // // Prologue // - // Issue several complete stages + CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -403,10 +401,8 @@ class MmaWithReductionMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -515,10 +511,8 @@ class MmaWithReductionMultistage : } --gemm_k_iterations; - if (gemm_k_iterations == 0) { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } // Do any conversions feeding the first stage at the end of the loop so @@ -532,7 +526,7 @@ class MmaWithReductionMultistage : } - if (UseZfill) { + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); diff --git a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h index bbcedfcfb..aa0439f01 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -49,7 +49,6 @@ class MmaTensorOpFragmentIterator; // Partial specialization for col-major accumulator tile -// And Element type is the same as Accumulator Element type template < /// Shape of warp tile to load (concept: MatrixShape) @@ -58,13 +57,15 @@ template < typename AccumulatorShape_, /// KBlocks columns to compute residual int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, /// Element type typename Element_, /// Shape of one matrix product operation (concept: MatrixShape) typename InstructionShape_, /// Output operation on fragment typename OutputOp_> -class MmaTensorOpFragmentIterator { public: @@ -78,6 +79,9 @@ class MmaTensorOpFragmentIterator; /// Accumulator Fragment object - using AccumulatorFragment = Array; + using AccumulatorFragment = Array; private: /// Internal access type - using AccessType = Array; + using AccessType = Array; + using FragmentAccessType = Array; private: // @@ -203,10 +208,10 @@ class MmaTensorOpFragmentIterator(&frag); + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); int index = index_ * MmaIterations::kCount; diff --git a/include/cutlass/matrix.h b/include/cutlass/matrix.h index 40dcac6e5..0c4aa7b6a 100644 --- a/include/cutlass/matrix.h +++ b/include/cutlass/matrix.h @@ -14030,15 +14030,15 @@ struct Matrix { /// Returns a perspective projection matrix typical of OpenGL applications CUTLASS_HOST_DEVICE - static Matrix perspective(Element near, Element far, Element fovH, Element fovV) { + static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) { Element aspect = fovH / fovV; Element f = Element(cos(fovV)) / Element(fovH); - Element Q = near - far; + Element Q = near_plane - far_plane; return Matrix( f / aspect, 0, 0, 0, 0, f, 0, 0, - 0, 0, (near + far) / Q, Element(2) * far * near / Q, + 0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q, 0, 0, -1, 0 ); } diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index d622c0dfa..0049874fc 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -245,10 +245,10 @@ class PredicatedTileAccessIteratorPredicates { /// Clears the predicate set efficiently CUTLASS_HOST_DEVICE - void clear_mask() { + void clear_mask(bool enable = true) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0u; + predicates_[i] = enable ? 0u : predicates_[i]; } } @@ -551,8 +551,8 @@ class PredicatedTileAccessIterator, /// Clears the predicate set efficiently CUTLASS_HOST_DEVICE - void clear_mask() { the_predicates.clear_mask(); } + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } /// Clears the predicate set efficiently CUTLASS_HOST_DEVICE @@ -1401,7 +1401,7 @@ class PredicatedTileAccessIterator, AdvanceRa /// Clears the predicate set efficiently CUTLASS_HOST_DEVICE - void clear_mask() { address_iterator_.clear_mask(); } + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } /// Clears the predicate set efficiently CUTLASS_HOST_DEVICE @@ -1184,8 +1184,8 @@ class PredicatedTileIterator If true, profiling is actually conducted. - Verification: --verification-enabled= Whether to perform verification checks. diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 05b0a493f..3378f32f4 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -206,9 +206,12 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS="50;53" # compiles for NVIDIA Maxwell G ## Clang -For experimental purposes, CUTLASS may be compiled with -[clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the +For experimental purposes, CUTLASS has been verified to compile with the following versions of Clang and CUDA. + +* [clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the [CUDA 10.0 Toolkit](https://developer.nvidia.com/cuda-10.0-download-archive). +* [clang release/13.x](https://github.com/llvm/llvm-project/tree/release/13.x) using [CUDA 11.4](https://developer.nvidia.com/cuda-toolkit-archive) + At this time, compiling with clang enables the CUTLASS SIMT GEMM kernels (sgemm, dgemm, hgemm, igemm) but does not enable TensorCores. @@ -216,6 +219,8 @@ but does not enable TensorCores. $ mkdir build && cd build $ cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ .. +# Add -DCMAKE_CXX_FLAGS=-D__NV_NO_HOST_COMPILER_CHECK=1 -DCMAKE_CUDA_FLAGS=-D__NV_NO_HOST_COMPILER_CHECK=1 if compiler +# checks fail during CMake configuration. $ make test_unit -j ``` diff --git a/test/unit/common/cutlass_unit_test.h b/test/unit/common/cutlass_unit_test.h index 5d2ed7c2f..3259a3ba8 100644 --- a/test/unit/common/cutlass_unit_test.h +++ b/test/unit/common/cutlass_unit_test.h @@ -26,9 +26,9 @@ #pragma once #pragma warning (disable : 4068 ) /* disable unknown pragma warnings for vistual studio */ -#pragma diag_suppress boolean_controlling_expr_is_constant +#pragma nv_diag_suppress boolean_controlling_expr_is_constant #include -#pragma diag_warning boolean_controlling_expr_is_constant +#pragma nv_diag_warning boolean_controlling_expr_is_constant #pragma warning( disable : 4503) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_problems.h b/test/unit/conv/device/conv2d_problems.h index fb8d2cb8c..3503ee316 100644 --- a/test/unit/conv/device/conv2d_problems.h +++ b/test/unit/conv/device/conv2d_problems.h @@ -281,6 +281,22 @@ struct TestbedConv2dProblemSizes { {1, 1} // dilation (dilation_h, dilation_w) )); + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, 8}, // input size (NHWC) + {8, 3, 3, 8}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, 8}, // input size (NHWC) + {8, 3, 3, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + //////////////////////////////////////////////////////////////////////////////////// // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) //////////////////////////////////////////////////////////////////////////////////// @@ -389,6 +405,25 @@ struct TestbedConv2dProblemSizes { {1, 1} // dilation (dilation_h, dilation_w) )); + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size padding > stride, asymmetric filter, padding and striding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 31, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 4}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 35, 256}, // input size (NHWC) + {512, 7, 5, 256}, // filter size (KRSC) + {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 5}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + //////////////////////////////////////////////////////////////////////////////////// // Medium input size *mixed* stride (1, 2) and (2, 1), // filter (3, 3), default padding @@ -419,6 +454,14 @@ struct TestbedConv2dProblemSizes { {2, 2}, // stride (stride_h, stride_w) {1, 1} // dilation (dilation_h, dilation_w) )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 32, 32, 16}, // input size (NHWC) + {32, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {6, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( {32, 32, 32, 32}, // input size (NHWC) diff --git a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index c42a6b03f..3b36fd6c3 100644 --- a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -78,23 +78,15 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n test::conv::device::Conv2dProblemVector problem_size_list; -#if 0 // run specific problem size in the unit test first - problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( - {1, 56, 56, 8}, // input size (NHWC) - {8, 1, 1, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); +#if 0 // run specific problem size in the unit test first problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( - {1, 55, 55, 8}, // input size (NHWC) - {8, 1, 1, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) + {1, 4, 4, 8}, // input size (NHWC) + {8, 3, 3, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) )); - #endif /// Run all unit test sizes with device-level Conv2d instance diff --git a/test/unit/nvrtc/cutlass/nvrtc/environment.h b/test/unit/nvrtc/cutlass/nvrtc/environment.h index 281845091..bbd3ce00e 100644 --- a/test/unit/nvrtc/cutlass/nvrtc/environment.h +++ b/test/unit/nvrtc/cutlass/nvrtc/environment.h @@ -25,6 +25,7 @@ #pragma once #include +#include "cutlass/cutlass.h" namespace cutlass { namespace nvrtc { diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index b9bd1b41f..57d2a4ca4 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -1312,6 +1312,7 @@ def GenerateSM80_TensorOp_16816(manifest, args): TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), @@ -1392,7 +1393,7 @@ def GenerateSM80_SparseTensorOp_16832(manifest, args): max_cc = 1024 max_cc_smem_limited = 80 - alignment_constraints = [8, 4, 2] + alignment_constraints = [8] for math_inst in math_instructions: tile_descriptions = [ @@ -1967,6 +1968,8 @@ def GenerateSM80_TensorOp_1688(manifest, args): TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2051,6 +2054,8 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2100,7 +2105,7 @@ def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args): max_cc = 1024 max_cc_smem_limited = 80 - alignment_constraints = [4, 2, 1] + alignment_constraints = [4] for math_inst in math_instructions: tile_descriptions = [ diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 536f97bfa..f11ce06d9 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -20,7 +20,6 @@ def __init__(self, generated_path, kind, args): self.generated_path = generated_path self.kind = kind self.args = args - self.emitters = { OperationKind.Gemm: EmitGemmConfigurationLibrary , OperationKind.Conv2d: EmitConv2dConfigurationLibrary @@ -347,7 +346,7 @@ def emit(self, target = GeneratorTarget.Library): with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter: for operation_kind, configurations in self.operations.items(): - iface_emitter.emit(OperationKindNames[operation_kind]) + iface_emitter.emit(OperationKindNames[operation_kind]) source_files += iface_emitter.source_files diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h index 385aa8dca..422ac352e 100644 --- a/tools/library/src/reference/gemm_reference_operation.h +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -186,12 +186,6 @@ class GemmReferenceOperation : public Operation { GemmUniversalConfiguration const &config = *static_cast(host_workspace); GemmUniversalArguments const &args = *static_cast(arguments); - ElementCompute alpha; - ElementCompute beta; - - alpha = *static_cast(args.alpha); - beta = *static_cast(args.beta); - TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; @@ -212,16 +206,16 @@ class GemmReferenceOperation : public Operation { InnerProductOp >( config.problem_size, - alpha, + *static_cast(args.alpha), ref_A, kTransformA, ref_B, kTransformB, - beta, + *static_cast(args.beta), ref_C, ref_D, ElementAccumulator(), - config.batch_count, + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), args.batch_stride_A, args.batch_stride_B, args.batch_stride_C, @@ -245,16 +239,16 @@ class GemmReferenceOperation : public Operation { InnerProductOp >( config.problem_size, - alpha, + *static_cast(args.alpha), ref_A, kTransformA, ref_B, kTransformB, - beta, + *static_cast(args.beta), ref_C, ref_D, ElementAccumulator(), - config.batch_count, + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), args.batch_stride_A, args.batch_stride_B, args.batch_stride_C, @@ -263,7 +257,7 @@ class GemmReferenceOperation : public Operation { return Status::kSuccess; } - + return Status::kErrorNotSupported; } }; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 4d5afc011..fab38346e 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -791,7 +791,7 @@ bool GemmOperationProfiler::verify_with_reference_( handle.set_provider(provider); Status status = handle.gemm_universal( - library::GemmUniversalMode::kGemm, + problem_.mode, gemm_workspace_.configuration.problem_size.m(), gemm_workspace_.configuration.problem_size.n(), gemm_workspace_.configuration.problem_size.k(), diff --git a/tools/profiler/src/gpu_timer.h b/tools/profiler/src/gpu_timer.h index 0e9923dfa..935171988 100644 --- a/tools/profiler/src/gpu_timer.h +++ b/tools/profiler/src/gpu_timer.h @@ -29,6 +29,7 @@ #pragma once #include +#include "cutlass/cutlass.h" namespace cutlass { namespace profiler { diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 44c1ad489..a21ce50e4 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -425,7 +425,9 @@ void Options::Profiling::print_usage(std::ostream &out) const { << " Number of ms to sleep between profiling periods (ms).\n\n" << " --profiling-enabled= " - << " If true, profiling is actually conducted.\n\n"; + << " If true, profiling is actually conducted.\n\n" + + ; } void Options::Profiling::print_options(std::ostream &out, int indent) const { diff --git a/tools/util/include/cutlass/util/command_line.h b/tools/util/include/cutlass/util/command_line.h index 2c8182934..cf7b66f9d 100644 --- a/tools/util/include/cutlass/util/command_line.h +++ b/tools/util/include/cutlass/util/command_line.h @@ -32,6 +32,8 @@ #include +#include "cutlass/cutlass.h" + namespace cutlass { /******************************************************************************