diff --git a/csrc/adagrad/cpu_adagrad.cpp b/csrc/adagrad/cpu_adagrad.cpp index 563255176500a..7ef1037512119 100644 --- a/csrc/adagrad/cpu_adagrad.cpp +++ b/csrc/adagrad/cpu_adagrad.cpp @@ -21,6 +21,7 @@ static std::unordered_map> s_optimizers; // C++ interface +template void Adagrad_Optimizer::Step_1(float* _params, float* grads, float* _exp_avg_sq, @@ -30,7 +31,7 @@ void Adagrad_Optimizer::Step_1(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>( + Step_AVX<1, T>( &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); #endif if (_param_size > rounded_size) { @@ -97,6 +98,7 @@ void Adagrad_Optimizer::Step_1(float* _params, } } +template void Adagrad_Optimizer::Step_4(float* _params, float* grads, float* _exp_avg_sq, @@ -106,11 +108,11 @@ void Adagrad_Optimizer::Step_4(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>( + Step_AVX<4, T>( &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); #endif if (_param_size > rounded_size) - Step_1((_params + rounded_size), + Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), @@ -149,6 +151,7 @@ int create_adagrad_optimizer(int optimizer_id, return 0; } +template void Adagrad_Optimizer::Step_8(float* _params, float* grads, float* _exp_avg_sq, @@ -158,11 +161,11 @@ void Adagrad_Optimizer::Step_8(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>( + Step_AVX<8, T>( &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); #endif if (_param_size > rounded_size) - Step_4((_params + rounded_size), + Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), @@ -191,7 +194,12 @@ int ds_adagrad_step(int optimizer_id, std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step); opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel()); + if (params.options().dtype() == at::kHalf) + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true); + else if (params.options().dtype() == at::kBFloat16) + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true); + else + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, false); #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) opt->SynchronizeStreams(); @@ -224,7 +232,23 @@ int ds_adagrad_step_plus_copy(int optimizer_id, std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step); opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, + + if (params.options().dtype() == at::kHalf) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_sq_ptr, + params_c.numel(), + gpu_params_ptr, + true); + else if (params.options().dtype() == at::kBFloat16) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_sq_ptr, + params_c.numel(), + gpu_params_ptr, + true); + else + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 9a4a8d9565198..b3cf81192ee72 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -23,6 +23,7 @@ static std::unordered_map> s_optimizers; // C++ interface +template void Adam_Optimizer::Step_1(float* _params, float* grads, float* _exp_avg, @@ -33,7 +34,7 @@ void Adam_Optimizer::Step_1(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, + Step_AVX<1, T>(&rounded_size, _params, grads, _exp_avg, @@ -116,6 +117,7 @@ void Adam_Optimizer::Step_1(float* _params, } } +template void Adam_Optimizer::Step_4(float* _params, float* grads, float* _exp_avg, @@ -126,7 +128,7 @@ void Adam_Optimizer::Step_4(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, + Step_AVX<4, T>(&rounded_size, _params, grads, _exp_avg, @@ -136,7 +138,7 @@ void Adam_Optimizer::Step_4(float* _params, half_precision); #endif if (_param_size > rounded_size) - Step_1((_params + rounded_size), + Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), @@ -185,6 +187,7 @@ int create_adam_optimizer(int optimizer_id, return 0; } +template void Adam_Optimizer::Step_8(float* _params, float* grads, float* _exp_avg, @@ -195,7 +198,7 @@ void Adam_Optimizer::Step_8(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, + Step_AVX<8, T>(&rounded_size, _params, grads, _exp_avg, @@ -205,7 +208,7 @@ void Adam_Optimizer::Step_8(float* _params, half_precision); #endif if (_param_size > rounded_size) - Step_4((_params + rounded_size), + Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), @@ -244,13 +247,15 @@ int ds_adam_step(int optimizer_id, opt->IncrementStep(step, beta1, beta2); opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - nullptr, - (params.options().dtype() == at::kHalf)); + if (params.options().dtype() == at::kHalf) + opt->Step_8( + params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true); + else if (params.options().dtype() == at::kBFloat16) + opt->Step_8( + params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true); + else + opt->Step_8( + params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, false); #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) opt->SynchronizeStreams(); @@ -289,13 +294,30 @@ int ds_adam_step_plus_copy(int optimizer_id, std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step, beta1, beta2); opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - device_params_ptr, - (params.options().dtype() == at::kHalf)); + if (params.options().dtype() == at::kHalf) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + device_params_ptr, + true); + else if (params.options().dtype() == at::kBFloat16) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + device_params_ptr, + true); + else + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + device_params_ptr, + false); opt->SynchronizeStreams(); #else diff --git a/csrc/common/custom_cuda_kernel.cu b/csrc/common/custom_cuda_kernel.cu index f46bf303125c6..af353eac8f97d 100644 --- a/csrc/common/custom_cuda_kernel.cu +++ b/csrc/common/custom_cuda_kernel.cu @@ -42,3 +42,43 @@ void launch_param_update_half(const float* input, __half* output, int size, cuda param_update_kernel_half<<>>(input, output, size); } + +#ifdef BF16_AVAILABLE +__global__ void param_update_kernel(const float* input, __nv_bfloat16* output, int size) +{ + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < size) { output[id] = (__nv_bfloat16)input[id]; } +} + +void launch_param_update(const float* input, __nv_bfloat16* output, int size, cudaStream_t stream) +{ + int threads = 1024; + + dim3 grid_dim((size - 1) / threads + 1); + dim3 block_dim(threads); + + param_update_kernel<<>>(input, output, size); +} + +__global__ void param_update_kernel_half(const float* input, __nv_bfloat16* output, int size) +{ + int id = blockIdx.x * blockDim.x + threadIdx.x; + __nv_bfloat162* output_cast = reinterpret_cast<__nv_bfloat162*>(output); + if (id < size) { + float input_f = input[id]; + __nv_bfloat162* input_h = reinterpret_cast<__nv_bfloat162*>(&input_f); + output_cast[id] = *input_h; + } +} + +void launch_param_update_half(const float* input, __nv_bfloat16* output, int size, cudaStream_t stream) +{ + int threads = 1024; + size /= 2; + dim3 grid_dim((size - 1) / threads + 1); + dim3 block_dim(threads); + + param_update_kernel_half<<>>(input, output, size); +} +#endif diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index e60984d64b76e..784744732efd8 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -12,21 +12,29 @@ #include #include "simd.h" +#ifndef HALF_DTYPE + #error Must provide compiler option -DHALF_DTYPE= +#endif + #if defined(__ENABLE_CUDA__) #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include "cuda.h" #include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; #elif defined(__ENABLE_CANN__) #include "acl/acl.h" #include "torch_npu/csrc/core/npu/NPUStream.h" -typedef c10::Half ds_half_precision_t; #else typedef unsigned short ds_half_precision_t; #endif +typedef HALF_DTYPE ds_half_precision_t; + #define STEP(SPAN) \ + template \ void Step_##SPAN(float* _params, \ float* grads, \ float* _exp_avg_sq, \ @@ -64,7 +72,7 @@ class Adagrad_Optimizer { #endif } #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, float* _params, float* grads, @@ -121,7 +129,7 @@ class Adagrad_Optimizer { }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, float* _params, float* grads, @@ -130,6 +138,11 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, ds_half_precision_t* dev_params, bool half_precision) { +#if !defined(__AVX512__) + if (std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; AVX_Data eps_4; eps_4.data = SIMD_SET(_eps); @@ -153,16 +166,16 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); + simd_load(grad_4, grads + i); AVX_Data momentum_4[span]; - simd_load(momentum_4, grads + i, false); + simd_load(momentum_4, grads + i); AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); + simd_load(variance_4, _exp_avg_sq + i); AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); + simd_load(param_4, _params + i); if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } @@ -172,13 +185,13 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, simd_div(grad_4, momentum_4, grad_4); simd_fma(param_4, grad_4, step_size_4, param_4); - simd_store(_params + i, param_4, half_precision); + simd_store(_params + i, param_4); #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4); } #endif - simd_store(_exp_avg_sq + i, variance_4, false); + simd_store(_exp_avg_sq + i, variance_4); } #if defined(__ENABLE_CUDA__) if (dev_params) { diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index b1a104b2571dc..42f18d7a757fb 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -13,22 +13,29 @@ #include #include "simd.h" +#ifndef HALF_DTYPE + #error Must provide compiler option -DHALF_DTYPE= +#endif + #if defined(__ENABLE_CUDA__) #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include "cuda.h" #include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; #elif defined(__ENABLE_CANN__) #include "acl/acl.h" #include "torch_npu/csrc/core/npu/NPUStream.h" -typedef c10::Half ds_half_precision_t; #else #include -typedef unsigned short ds_half_precision_t; #endif +typedef HALF_DTYPE ds_half_precision_t; + #define STEP(SPAN) \ + template \ void Step_##SPAN(float* _params, \ float* grads, \ float* _exp_avg, \ @@ -81,7 +88,7 @@ class Adam_Optimizer { } #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, float* _params, float* grads, @@ -168,7 +175,7 @@ class Adam_Optimizer { }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Adam_Optimizer::Step_AVX(size_t* rounded_size, float* _params, float* grads, @@ -178,6 +185,11 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, ds_half_precision_t* dev_params, bool half_precision) { +#if !defined(__AVX512__) + if (std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; int rshft = half_precision ? 1 : 0; @@ -220,16 +232,16 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + (i >> rshft), half_precision); + simd_load(grad_4, grads + (i >> rshft)); AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); + simd_load(momentum_4, _exp_avg + i); AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); + simd_load(variance_4, _exp_avg_sq + i); AVX_Data param_4[span]; - simd_load(param_4, _params + (i >> rshft), half_precision); + simd_load(param_4, _params + (i >> rshft)); if (_weight_decay > 0 && !_adamw_mode) { simd_fma(grad_4, param_4, weight_decay4, grad_4); @@ -250,14 +262,14 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(param_4, grad_4, step_size_4, param_4); - simd_store(_params + (i >> rshft), param_4, half_precision); + simd_store(_params + (i >> rshft), param_4); #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4); } #endif - simd_store(_exp_avg + i, momentum_4, false); - simd_store(_exp_avg_sq + i, variance_4, false); + simd_store(_exp_avg + i, momentum_4); + simd_store(_exp_avg_sq + i, variance_4); } #if defined(__ENABLE_CUDA__) if (dev_params) { diff --git a/csrc/includes/cpu_lion.h b/csrc/includes/cpu_lion.h index 34c29eec47db2..e9ee3e6fc557e 100644 --- a/csrc/includes/cpu_lion.h +++ b/csrc/includes/cpu_lion.h @@ -13,22 +13,29 @@ #include #include "simd.h" +#ifndef HALF_DTYPE + #error Must provide compiler option -DHALF_DTYPE= +#endif + #if defined(__ENABLE_CUDA__) #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include "cuda.h" #include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; #elif defined(__ENABLE_CANN__) #include "acl/acl.h" #include "torch_npu/csrc/core/npu/NPUStream.h" -typedef c10::Half ds_half_precision_t; #else #include -typedef unsigned short ds_half_precision_t; #endif +typedef HALF_DTYPE ds_half_precision_t; + #define STEP(SPAN) \ + template \ void Step_##SPAN(float* _params, \ float* grads, \ float* _exp_avg, \ @@ -70,7 +77,7 @@ class Lion_Optimizer { } #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, float* _params, float* grads, @@ -128,7 +135,7 @@ class Lion_Optimizer { }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Lion_Optimizer::Step_AVX(size_t* rounded_size, float* _params, float* grads, @@ -137,6 +144,11 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, ds_half_precision_t* dev_params, bool half_precision) { +#if !defined(__AVX512__) + if (std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; int rshft = half_precision ? 1 : 0; @@ -177,13 +189,13 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + (i >> rshft), half_precision); + simd_load(grad_4, grads + (i >> rshft)); AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); + simd_load(momentum_4, _exp_avg + i); AVX_Data param_4[span]; - simd_load(param_4, _params + (i >> rshft), half_precision); + simd_load(param_4, _params + (i >> rshft)); AVX_Data tmp_4[span]; @@ -201,13 +213,13 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, simd_mul(momentum_4, momentum_4, betta2_4); simd_fma(momentum_4, grad_4, betta2_minus1_4, momentum_4); - simd_store(_params + (i >> rshft), param_4, half_precision); + simd_store(_params + (i >> rshft), param_4); #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4); } #endif - simd_store(_exp_avg + i, momentum_4, false); + simd_store(_exp_avg + i, momentum_4); } #if defined(__ENABLE_CUDA__) if (dev_params) { diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h index 265eb7b124440..8a8b9187fc028 100644 --- a/csrc/includes/custom_cuda_layers.h +++ b/csrc/includes/custom_cuda_layers.h @@ -9,6 +9,9 @@ #include #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include #include @@ -274,6 +277,10 @@ void launch_fuse_transpose_bias_kernel(const T* inp, void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream); void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream); +#ifdef BF16_AVAILABLE +void launch_param_update(const float* input, __nv_bfloat16* output, int size, cudaStream_t stream); +void launch_param_update_half(const float* input, __nv_bfloat16* output, int size, cudaStream_t stream); +#endif void launch_token_sort(int32_t* indices, int layers, diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index 59237b0261c18..3bf6a79076727 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -30,11 +30,52 @@ #define SIMD_XOR(x, y) _mm512_xor_ps(x, y) #define SIMD_WIDTH 16 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm512_storeu_ps(x, d)) +static __m512 load_16_bf16_as_f32(const void* data) +{ + __m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing + __m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32 + __m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by + // 16 bits (representing bf16->f32) + return readAs<__m512>(&c); // use memcpy to avoid aliasing +} + +static void store_16_f32_as_bf16_nearest(__m512 v, void* data) +{ + __m512i u32 = readAs<__m512i>(&v); + + // flow assuming non-nan: + + // uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + __m512i b = _mm512_srli_epi32(u32, 16); + __m512i lsb_mask = _mm512_set1_epi32(0x00000001); + __m512i c = _mm512_and_si512(b, lsb_mask); + __m512i bias_constant = _mm512_set1_epi32(0x00007fff); + __m512i rounding_bias = _mm512_add_epi32(c, bias_constant); + + // uint16_t res = static_cast((U32 + rounding_bias) >> 16); + __m512i d = _mm512_add_epi32(u32, rounding_bias); + __m512i e = _mm512_srli_epi32(d, 16); + __m256i non_nan_res = _mm512_cvtusepi32_epi16(e); + + // handle nan (exp is all 1s and mantissa != 0) + // if ((x & 0x7fffffffU) > 0x7f800000U) + __m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff); + __m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign); + __m512i nan_threshold = _mm512_set1_epi32(0x7f800000); + __mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT); + + // mix in results with nans as needed + __m256i nans = _mm256_set1_epi16(0x7fc0); + __m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans); + + writeAs(data, res); +} +#define SIMD_LOAD_BF16(x) load_16_bf16_as_f32(x) +#define SIMD_STORE_BF16(x, d) store_16_f32_as_bf16_nearest(d, x) + +#define SIMD_LOAD_FP16(x) _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) +#define SIMD_STORE_FP16(x, d) \ + _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #define INTV __m256i #elif defined(__AVX256__) @@ -52,11 +93,11 @@ #define SIMD_XOR(x, y) _mm256_xor_ps(x, y) #define SIMD_WIDTH 8 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) : _mm256_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm256_storeu_ps(x, d)) +#define SIMD_LOAD_BF16(x) static_assert(false && "AVX256 does not support BFloat16") +#define SIMD_STORE_BF16(x, d) static_assert(false && "AVX256 does not support BFloat16") +#define SIMD_LOAD_FP16(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) +#define SIMD_STORE_FP16(x, d) \ + _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #define INTV __m128i #endif @@ -70,20 +111,85 @@ union AVX_Data { // float data_f[16]; }; -template -inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +template +inline typename std::enable_if_t, void> simd_store( + float* dst, + AVX_Data* src) { - size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); + size_t width = SIMD_WIDTH / 2; #pragma unroll - for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); } + for (size_t i = 0; i < span; ++i) { SIMD_STORE_FP16(dst + width * i, src[i].data); } } -template -inline void simd_load(AVX_Data* dst, float* src, bool half_precision) + +template +inline typename std::enable_if_t, void> simd_store( + float* dst, + AVX_Data* src) +{ + #ifdef __AVX512__ + size_t width = SIMD_WIDTH / 2; + #pragma unroll + for (size_t i = 0; i < span; ++i) { SIMD_STORE_BF16(dst + width * i, src[i].data); } + #else + assert(false && "AVX512 required for BFloat16"); + #endif + +} + +template +inline typename std::enable_if_t, void> simd_store( + float* dst, + AVX_Data* src) { - size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); + size_t width = SIMD_WIDTH; #pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); } + for (size_t i = 0; i < span; ++i) { SIMD_STORE(dst + width * i, src[i].data); } } + +template +inline typename std::enable_if_t, void> simd_load( + AVX_Data* dst, + float* src) +{ + size_t width = SIMD_WIDTH / 2; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_FP16(src + width * i); } +} + +template +inline typename std::enable_if_t, void> simd_load( + AVX_Data* dst, + float* src) +{ + #ifdef __AVX512__ + size_t width = SIMD_WIDTH / 2; + #pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_BF16(src + width * i); } + #else + assert(false && "AVX512 required for BFloat16"); + #endif +} + +template +inline typename std::enable_if_t, void> simd_load( + AVX_Data* dst, + float* src) +{ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD(src + width * i); } +} + + + + + + + + + + + template inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) { diff --git a/csrc/lion/cpu_lion_impl.cpp b/csrc/lion/cpu_lion_impl.cpp index 28314cf5b6e1b..a7b50793ffbbd 100644 --- a/csrc/lion/cpu_lion_impl.cpp +++ b/csrc/lion/cpu_lion_impl.cpp @@ -24,6 +24,7 @@ static std::unordered_map> s_optimizers; // C++ interface +template void Lion_Optimizer::Step_1(float* _params, float* grads, float* _exp_avg, @@ -33,7 +34,7 @@ void Lion_Optimizer::Step_1(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); + Step_AVX<1, T>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - _betta1; @@ -106,6 +107,7 @@ void Lion_Optimizer::Step_1(float* _params, } } +template void Lion_Optimizer::Step_4(float* _params, float* grads, float* _exp_avg, @@ -115,10 +117,10 @@ void Lion_Optimizer::Step_4(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); + Step_AVX<4, T>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); #endif if (_param_size > rounded_size) - Step_1((_params + rounded_size), + Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_param_size - rounded_size), @@ -162,6 +164,7 @@ int create_lion_optimizer(int optimizer_id, return 0; } +template void Lion_Optimizer::Step_8(float* _params, float* grads, float* _exp_avg, @@ -171,10 +174,10 @@ void Lion_Optimizer::Step_8(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); + Step_AVX<8, T>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); #endif if (_param_size > rounded_size) - Step_4((_params + rounded_size), + Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_param_size - rounded_size), @@ -207,12 +210,29 @@ int ds_lion_step(int optimizer_id, opt->IncrementStep(step, beta1, beta2); opt->update_state(lr, weight_decay); - opt->Step_8(params_ptr, + + + if (params.options().dtype() == at::kHalf) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + params_c.numel(), + nullptr, + true); + else if (params.options().dtype() == at::kBFloat16) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + params_c.numel(), + nullptr, + true); + else + opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, params_c.numel(), nullptr, - (params.options().dtype() == at::kHalf)); + false); #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) opt->SynchronizeStreams(); @@ -246,12 +266,29 @@ int ds_lion_step_plus_copy(int optimizer_id, std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step, beta1, beta2); opt->update_state(lr, weight_decay); - opt->Step_8(params_ptr, + + + if (params.options().dtype() == at::kHalf) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + params_c.numel(), + gpu_params_ptr, + true); + else if (params.options().dtype() == at::kBFloat16) + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + params_c.numel(), + gpu_params_ptr, + true); + else + opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, params_c.numel(), gpu_params_ptr, - (params.options().dtype() == at::kHalf)); + false); opt->SynchronizeStreams(); #else diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index c356a52777f25..90d8c141b2442 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -19,7 +19,7 @@ def __init__(self, model_params, lr=1e-2, eps=1e-10, weight_decay=0, amsgrad=Fal self.opt_id = DeepSpeedCPUAdagrad.optimizer_id DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1 self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adagrad = CPUAdagradBuilder().load() + self.ds_opt_adagrad = CPUAdagradBuilder().set_dtype(self.param_groups[0]['params'][0].dtype).load() self.ds_opt_adagrad.create_adagrad(self.opt_id, lr, eps, weight_decay, should_log_le("info")) diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 10b8c15f970b8..218c9a1a1fd27 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -91,7 +91,7 @@ def __init__(self, DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 self.adam_w_mode = adamw_mode self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adam = CPUAdamBuilder().load() + self.ds_opt_adam = CPUAdamBuilder().set_dtype(self.param_groups[0]['params'][0].dtype).load() self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, should_log_le("info")) diff --git a/deepspeed/ops/adam/fused_adam.py b/deepspeed/ops/adam/fused_adam.py index 53f859e9cc87b..fb5a36f257f23 100644 --- a/deepspeed/ops/adam/fused_adam.py +++ b/deepspeed/ops/adam/fused_adam.py @@ -91,7 +91,7 @@ def __init__(self, self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none - fused_adam_cuda = FusedAdamBuilder().load() + fused_adam_cuda = FusedAdamBuilder().set_dtype(self.param_groups[0]['params'][0].dtype).load() # Skip buffer self._dummy_overflow_buf = get_accelerator().IntTensor([0]) self.multi_tensor_adam = fused_adam_cuda.multi_tensor_adam diff --git a/deepspeed/ops/lion/cpu_lion.py b/deepspeed/ops/lion/cpu_lion.py index a91a00643873d..f4df030e3c141 100755 --- a/deepspeed/ops/lion/cpu_lion.py +++ b/deepspeed/ops/lion/cpu_lion.py @@ -54,7 +54,7 @@ def __init__(self, model_params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0, fp self.opt_id = DeepSpeedCPULion.optimizer_id DeepSpeedCPULion.optimizer_id = DeepSpeedCPULion.optimizer_id + 1 self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_lion = CPULionBuilder().load() + self.ds_opt_lion = CPULionBuilder().set_dtype(self.param_groups[0]['params'][0].dtype).load() self.ds_opt_lion.create_lion(self.opt_id, lr, betas[0], betas[1], weight_decay, should_log_le("info")) diff --git a/op_builder/builder.py b/op_builder/builder.py index 8dc825c7926da..707dc8cf04d3a 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -517,6 +517,7 @@ def jit_load(self, verbose=True): nvcc_args.append("-DBF16_AVAILABLE") nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") if self.is_rocm_pytorch(): cxx_args.append("-D__HIP_PLATFORM_AMD__=1") diff --git a/op_builder/cpu/builder.py b/op_builder/cpu/builder.py index d2bc8eacfa254..d881842ad0b18 100644 --- a/op_builder/cpu/builder.py +++ b/op_builder/cpu/builder.py @@ -30,7 +30,11 @@ def builder(self): return cpp_ext def cxx_args(self): - return ['-O3', '-g', '-Wno-reorder'] + args = ['-O3', '-g', '-Wno-reorder'] + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH] + return args def libraries_args(self): return [] diff --git a/op_builder/cpu/cpu_adam.py b/op_builder/cpu/cpu_adam.py index 0c8438aea40d4..33a86c3d4cc68 100644 --- a/op_builder/cpu/cpu_adam.py +++ b/op_builder/cpu/cpu_adam.py @@ -12,6 +12,13 @@ class CPUAdamBuilder(CPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -23,5 +30,17 @@ def libraries_args(self): args = super().libraries_args() return args + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def include_paths(self): return ['csrc/includes'] diff --git a/op_builder/cpu/fused_adam.py b/op_builder/cpu/fused_adam.py index 34b43825b0902..5ee86ca3014d7 100644 --- a/op_builder/cpu/fused_adam.py +++ b/op_builder/cpu/fused_adam.py @@ -12,6 +12,13 @@ class FusedAdamBuilder(CPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -19,5 +26,17 @@ def absolute_name(self): def sources(self): return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def include_paths(self): return ['csrc/includes'] diff --git a/op_builder/cpu_adagrad.py b/op_builder/cpu_adagrad.py index d3f163f7464aa..a934d9535ac8c 100644 --- a/op_builder/cpu_adagrad.py +++ b/op_builder/cpu_adagrad.py @@ -13,6 +13,13 @@ class CPUAdagradBuilder(TorchCPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adagrad.{self.NAME}_op' @@ -41,3 +48,15 @@ def include_paths(self): else: CUDA_INCLUDE = [] return ['csrc/includes'] + CUDA_INCLUDE + + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 7c34c4ce43a16..43b29ba905336 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -13,6 +13,13 @@ class CPUAdamBuilder(TorchCPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -33,6 +40,18 @@ def libraries_args(self): return args + def cxx_args(self): + import torch + assert self.dtype is not None, "dype not set" + args = super().cxx_args() + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=__nv_bfloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=__half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def include_paths(self): import torch if self.build_for_cpu: diff --git a/op_builder/cpu_lion.py b/op_builder/cpu_lion.py index 5c16d10ebb445..d1956b47a750a 100644 --- a/op_builder/cpu_lion.py +++ b/op_builder/cpu_lion.py @@ -13,6 +13,13 @@ class CPULionBuilder(TorchCPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.lion.{self.NAME}_op' @@ -23,6 +30,18 @@ def sources(self): return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu'] + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def libraries_args(self): args = super().libraries_args() if self.build_for_cpu: diff --git a/op_builder/fused_adam.py b/op_builder/fused_adam.py index ac6e4eeaaea5d..fcb99fa505f83 100644 --- a/op_builder/fused_adam.py +++ b/op_builder/fused_adam.py @@ -14,6 +14,13 @@ class FusedAdamBuilder(CUDAOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -25,7 +32,16 @@ def include_paths(self): return ['csrc/includes', 'csrc/adam'] def cxx_args(self): + import torch args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + self.version_dependent_macros() def nvcc_args(self): diff --git a/op_builder/hpu/builder.py b/op_builder/hpu/builder.py index 3c86128fffd6e..c176a586ba495 100644 --- a/op_builder/hpu/builder.py +++ b/op_builder/hpu/builder.py @@ -31,7 +31,11 @@ def builder(self): return cpp_ext def cxx_args(self): - return ['-O3', '-g', '-Wno-reorder'] + args = ['-O3', '-g', '-Wno-reorder'] + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH] + return args def libraries_args(self): return [] diff --git a/op_builder/hpu/cpu_adam.py b/op_builder/hpu/cpu_adam.py index 2f3b7aefe7059..7abe83fae18da 100644 --- a/op_builder/hpu/cpu_adam.py +++ b/op_builder/hpu/cpu_adam.py @@ -13,6 +13,13 @@ class CPUAdamBuilder(CPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -21,8 +28,15 @@ def sources(self): return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] def cxx_args(self): + import torch args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] return args def libraries_args(self): diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py index d77228317ddb4..52bb87cb8ffcb 100644 --- a/op_builder/hpu/fused_adam.py +++ b/op_builder/hpu/fused_adam.py @@ -13,6 +13,13 @@ class FusedAdamBuilder(CPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -21,8 +28,15 @@ def sources(self): return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] def cxx_args(self): + import torch args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] return args def include_paths(self): diff --git a/op_builder/npu/cpu_adagrad.py b/op_builder/npu/cpu_adagrad.py index 161bc82efe1ca..cb9e87bd7bb9b 100644 --- a/op_builder/npu/cpu_adagrad.py +++ b/op_builder/npu/cpu_adagrad.py @@ -12,6 +12,13 @@ class CPUAdagradBuilder(NPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adagrad.{self.NAME}_op' @@ -23,3 +30,15 @@ def include_paths(self): args = super().include_paths() args += ['csrc/includes'] return args + + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args diff --git a/op_builder/npu/cpu_adam.py b/op_builder/npu/cpu_adam.py index a4e9569c0f336..e1b2580b2a166 100644 --- a/op_builder/npu/cpu_adam.py +++ b/op_builder/npu/cpu_adam.py @@ -12,10 +12,29 @@ class CPUAdamBuilder(NPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def sources(self): return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] diff --git a/op_builder/npu/cpu_lion.py b/op_builder/npu/cpu_lion.py index 6917e0fd03d08..32efb9bfe7f57 100644 --- a/op_builder/npu/cpu_lion.py +++ b/op_builder/npu/cpu_lion.py @@ -12,6 +12,13 @@ class CPULionBuilder(NPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.lion.{self.NAME}_op' @@ -19,6 +26,18 @@ def absolute_name(self): def sources(self): return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def include_paths(self): args = super().include_paths() args += ['csrc/includes'] diff --git a/op_builder/npu/fused_adam.py b/op_builder/npu/fused_adam.py index fc1bc83c7cc7c..7109985bdf983 100644 --- a/op_builder/npu/fused_adam.py +++ b/op_builder/npu/fused_adam.py @@ -60,6 +60,13 @@ class FusedAdamBuilder(NPUOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' diff --git a/op_builder/xpu/cpu_adagrad.py b/op_builder/xpu/cpu_adagrad.py index 18f80848e1b80..26b8b08fcb9bb 100644 --- a/op_builder/xpu/cpu_adagrad.py +++ b/op_builder/xpu/cpu_adagrad.py @@ -12,6 +12,13 @@ class CPUAdagradBuilder(SYCLOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adagrad.{self.NAME}_op' diff --git a/op_builder/xpu/cpu_adam.py b/op_builder/xpu/cpu_adam.py index 4c7d4d1198398..53be8c359cf90 100644 --- a/op_builder/xpu/cpu_adam.py +++ b/op_builder/xpu/cpu_adam.py @@ -12,10 +12,29 @@ class CPUAdamBuilder(SYCLOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' + def cxx_args(self): + import torch + args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=sycl::ext::oneapi::bfloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=sycl::half'] + else: + args += ['-DHALF_DTYPE=float'] + return args + def sources(self): if self.build_for_cpu: return ['csrc/xpu/adam/cpu_adam.cpp', 'csrc/xpu/adam/cpu_adam_impl.cpp'] diff --git a/op_builder/xpu/fused_adam.py b/op_builder/xpu/fused_adam.py index 0e0f1a66f8e64..81be8e22ebdb4 100644 --- a/op_builder/xpu/fused_adam.py +++ b/op_builder/xpu/fused_adam.py @@ -11,6 +11,13 @@ class FusedAdamBuilder(SYCLOpBuilder): def __init__(self): super().__init__(name=self.NAME) + self.dtype = None + + def set_dtype(self, dtype): + import torch + assert (dtype in [torch.bfloat16, torch.half, torch.float32]) + self.dtype = dtype + return self def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' @@ -22,5 +29,13 @@ def include_paths(self): return ['csrc/xpu/includes', 'csrc/xpu/adam'] def cxx_args(self): + import torch args = super().cxx_args() + assert self.dtype is not None, "dype not set" + if self.dtype == torch.bfloat16: + args += ['-DHALF_DTYPE=c10::BFloat16'] + elif self.dtype == torch.half: + args += ['-DHALF_DTYPE=c10::Half'] + else: + args += ['-DHALF_DTYPE=float'] return args + self.version_dependent_macros() diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 9a6ff6689446f..e238791e07e69 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -43,7 +43,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [ (64), @@ -63,6 +63,9 @@ class TestCPUAdam(DistributedTest): @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") def test_fused_adam_equal(self, dtype, model_size): + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"dtype {dtype} not supported in current accelerator") + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") @@ -97,20 +100,20 @@ def test_torch_adamw_equal(self, dtype, model_size): pytest.skip("torch.optim.AdamW with half precision only supported in CUDA environments.") ref_param_device = 'cpu' - from deepspeed.ops.adam import DeepSpeedCPUAdam + from deepspeed.ops.adam import DeepSpeedCPUAdam - cpu_data = torch.randn(model_size, device='cpu').to(dtype) - cpu_param = torch.nn.Parameter(cpu_data) - ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device)) + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + cpu_param = torch.nn.Parameter(cpu_data) + ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device)) - cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) - ref_optimizer = torch.optim.AdamW([ref_param]) + cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) + ref_optimizer = torch.optim.AdamW([ref_param]) - _compare_optimizers(model_size=model_size, - param1=cpu_param, - optimizer1=cpu_optimizer, - param2=ref_param, - optimizer2=ref_optimizer) + _compare_optimizers(model_size=model_size, + param1=cpu_param, + optimizer1=cpu_optimizer, + param2=ref_param, + optimizer2=ref_optimizer) class TestCPUAdamGPUError(DistributedTest): diff --git a/tests/unit/ops/adam/test_hybrid_adam.py b/tests/unit/ops/adam/test_hybrid_adam.py index c7ef4890b3220..d278fd6d6c1cd 100644 --- a/tests/unit/ops/adam/test_hybrid_adam.py +++ b/tests/unit/ops/adam/test_hybrid_adam.py @@ -32,7 +32,7 @@ def check_equal(first, second, atol=1e-2, verbose=False): np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [8, 16]) class TestHybridAdam(DistributedTest): world_size = 1 diff --git a/tests/unit/ops/lion/test_cpu_lion.py b/tests/unit/ops/lion/test_cpu_lion.py index 61a069af32576..0a8c2dca874ba 100644 --- a/tests/unit/ops/lion/test_cpu_lion.py +++ b/tests/unit/ops/lion/test_cpu_lion.py @@ -43,7 +43,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [ (64),