Skip to content

Commit

Permalink
Use HIPBLAS_COMPUTE_32F for if HIPBLAS_V2
Browse files Browse the repository at this point in the history
  • Loading branch information
rraminen committed Apr 18, 2024
1 parent 79c51b7 commit e74d5bf
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
18 changes: 13 additions & 5 deletions csrc/transformer/inference/includes/inference_cublas_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ int cublas_gemm_ex(cublasHandle_t handle,
CUDA_R_32F,
#endif
m,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && HIPBLAS_V2
HIPBLAS_COMPUTE_32F,
#elif defined(__HIP_PLATFORM_AMD__)
HIPBLAS_R_32F,
#else
CUDA_R_32F,
Expand Down Expand Up @@ -210,8 +212,10 @@ int cublas_gemm_ex(cublasHandle_t handle,
(void*)C,
cublas_dtype_16,
m,
#ifdef __HIP_PLATFORM_AMD__
HIPBLAS_R_32F,
#if defined(__HIP_PLATFORM_AMD__) && HIPBLAS_V2
HIPBLAS_COMPUTE_32F,
#elif defined(__HIP_PLATFORM_AMD__)
HIPBLAS_R_32F,
#else
CUDA_R_32F,
#endif
Expand Down Expand Up @@ -335,7 +339,9 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
m,
stride_C,
batch,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && HIPBLAS_V2
HIPBLAS_COMPUTE_32F,
#elif defined(__HIP_PLATFORM_AMD__)
HIPBLAS_R_32F,
#else
CUDA_R_32F,
Expand Down Expand Up @@ -457,7 +463,9 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
m,
stride_C,
batch,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && HIPBLAS_V2
HIPBLAS_COMPUTE_32F,
#elif defined(__HIP_PLATFORM_AMD__)
HIPBLAS_R_32F,
#else
CUDA_R_32F,
Expand Down
16 changes: 10 additions & 6 deletions deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ int blas_gemm_ex(void* C,
const float* beta,
BlasType type)
{
#ifdef __HIP_PLATFORM_AMD__ && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_operation_t transa_op = get_trans_op(transa);
rocblas_operation_t transb_op = get_trans_op(transb);

Expand Down Expand Up @@ -157,15 +157,17 @@ int blas_gemm_ex(void* C,
C,
abc_type,
ldc,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && HIPBLAS_V2
HIPBLAS_COMPUTE_32F,
#elif defined(__HIP_PLATFORM_AMD__)
HIPBLAS_R_32F,
#else
CUDA_R_32F,
#endif
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif

#ifdef __HIP_PLATFORM_AMD__ && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
if (status != rocblas_status_success) {
#else
if (status != CUBLAS_STATUS_SUCCESS) {
Expand Down Expand Up @@ -200,7 +202,7 @@ int blas_strided_batched_gemm(void* C,
int batch,
BlasType type)
{
#ifdef __HIP_PLATFORM_AMD__ && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_operation_t transa_op = get_trans_op(transa);
rocblas_operation_t transb_op = get_trans_op(transb);

Expand Down Expand Up @@ -263,15 +265,17 @@ int blas_strided_batched_gemm(void* C,
ldc,
stride_C,
batch,
#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) && HIPBLAS_V2
HIPBLAS_COMPUTE_32F,
#elif defined(__HIP_PLATFORM_AMD__)
HIPBLAS_R_32F,
#else
CUDA_R_32F,
#endif
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif

#ifdef __HIP_PLATFORM_AMD__ && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
if (status != rocblas_status_success) {
#else
if (status != CUBLAS_STATUS_SUCCESS) {
Expand Down

0 comments on commit e74d5bf

Please sign in to comment.