Skip to content

Commit

Permalink
Merge branch 'master' into loadams/rocm-6
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored May 17, 2024
2 parents 631ebd0 + d3dd8e7 commit 98cfd73
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 80 deletions.
14 changes: 10 additions & 4 deletions csrc/includes/cublas_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <rocblas/rocblas.h>
#endif
#include <stdio.h>
#include <torch/version.h>

int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
Expand All @@ -29,7 +30,9 @@ int cublas_gemm_ex(cublasHandle_t handle,
const float* A,
const float* B,
float* C,
#ifdef __HIP_PLATFORM_AMD__
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
Expand All @@ -46,7 +49,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
const __half* A,
const __half* B,
__half* C,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
Expand All @@ -67,7 +71,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_B,
int stride_C,
int batch,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
Expand All @@ -88,7 +93,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_B,
int stride_C,
int batch,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
Expand Down
10 changes: 7 additions & 3 deletions csrc/includes/feed_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class FeedForward {
weights,
input_ptr,
out,
#ifdef __HIP_PLATFORM_AMD__
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(config_.gemm_algos[0]));
#else
cublasGemmAlgo_t(config_.gemm_algos[0]));
Expand Down Expand Up @@ -77,7 +79,8 @@ class FeedForward {
input_ptr,
out_grad,
weights_grad,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(config_.gemm_algos[1]));
#else
cublasGemmAlgo_t(config_.gemm_algos[1]));
Expand All @@ -94,7 +97,8 @@ class FeedForward {
weights,
out_grad,
inp_grad_out,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(config_.gemm_algos[2]));
#else
cublasGemmAlgo_t(config_.gemm_algos[2]));
Expand Down
32 changes: 24 additions & 8 deletions csrc/includes/gemm_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ class GemmTest {
B,
A,
C,
#ifdef __HIP_PLATFORM_AMD__
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
Expand All @@ -86,7 +88,8 @@ class GemmTest {
A,
C,
B,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
Expand All @@ -105,7 +108,8 @@ class GemmTest {
B,
C,
A,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
Expand All @@ -121,8 +125,11 @@ class GemmTest {
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;

#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
#elif defined(__HIP_PLATFORM_AMD__)
for (int algo = (int)HIPBLAS_GEMM_DEFAULT; algo <= (int)HIPBLAS_GEMM_DEFAULT;
#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
Expand Down Expand Up @@ -211,7 +218,8 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
Expand Down Expand Up @@ -245,7 +253,8 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
Expand Down Expand Up @@ -276,7 +285,8 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
Expand All @@ -292,11 +302,17 @@ class StridedGemmTest {
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;

#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
#else
#ifdef __HIP_PLATFORM_AMD__
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
#endif
#endif
algo++) {
int warm_up = 5;
Expand Down
13 changes: 9 additions & 4 deletions csrc/includes/strided_batch_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_AMD__
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
Expand Down Expand Up @@ -105,7 +107,8 @@ class StridedBatchGemm {
stride_b,
stride_c,
_config.batch_size,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
Expand Down Expand Up @@ -149,7 +152,8 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(_config.gemm_algos[1]));
#else
cublasGemmAlgo_t(_config.gemm_algos[1]));
Expand Down Expand Up @@ -178,7 +182,8 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(_config.gemm_algos[2]));
#else
cublasGemmAlgo_t(_config.gemm_algos[2]));
Expand Down
Loading

0 comments on commit 98cfd73

Please sign in to comment.