diff --git a/.github/workflows/cpu-inference.yml b/.github/workflows/cpu-inference.yml index 38dd9bd3efef..d91034270eec 100644 --- a/.github/workflows/cpu-inference.yml +++ b/.github/workflows/cpu-inference.yml @@ -97,5 +97,5 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests # LOCAL_SIZE=2 enforce CPU to report 2 devices, this helps run the test on github default runner - LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/ - LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/ + LOCAL_SIZE=2 COLUMNS=240 HF_HOME=~/tmp/hf_home/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/ + LOCAL_SIZE=2 COLUMNS=240 HF_HOME=~/tmp/hf_home/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/ diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 5727ff2e1cde..213421590ad6 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -50,5 +50,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3" - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3" + HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3" + HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3" diff --git a/.github/workflows/setup-venv/action.yml b/.github/workflows/setup-venv/action.yml index ce2c458b9e57..9a88e0651860 100644 --- a/.github/workflows/setup-venv/action.yml +++ b/.github/workflows/setup-venv/action.yml @@ -22,7 +22,7 @@ runs: - id: set-env-vars run: | echo TEST_DATA_DIR=/blob/ >> $GITHUB_ENV - echo TRANSFORMERS_CACHE=/blob/transformers_cache/ >> $GITHUB_ENV + echo HF_HOME=/blob/hf_home/ >> $GITHUB_ENV echo TORCH_EXTENSIONS_DIR=./torch-extensions/ >> $GITHUB_ENV echo TORCH_CACHE=/blob/torch_cache/ >> $GITHUB_ENV echo HF_DATASETS_CACHE=/blob/datasets_cache/ >> $GITHUB_ENV diff --git a/MANIFEST.in b/MANIFEST.in index ab79573ef96c..8d84aee0faf4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,8 +2,8 @@ include *.txt README.md include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so recursive-include requirements *.txt -recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json -recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc +recursive-include deepspeed *.cpp *.h *.hpp *.cu *.hip *.tr *.cuh *.cc *.json +recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc recursive-include op_builder *.py recursive-include benchmarks *.py recursive-include accelerator *.py diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 9c4a9c903f96..3fed89d7200f 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -7,6 +7,7 @@ from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore +import functools class XPU_Accelerator(DeepSpeedAccelerator): @@ -191,31 +192,31 @@ def supported_dtypes(self): @property def BFloat16Tensor(self): - return torch.xpu.BFloat16Tensor + return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name) @property def ByteTensor(self): - return torch.xpu.ByteTensor + return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name) @property def DoubleTensor(self): - return torch.xpu.DoubleTensor + return functools.partial(torch.tensor, dtype=torch.double, device=self._name) @property def FloatTensor(self): - return torch.xpu.FloatTensor + return functools.partial(torch.tensor, dtype=torch.float, device=self._name) @property def HalfTensor(self): - return torch.xpu.HalfTensor + return functools.partial(torch.tensor, dtype=torch.half, device=self._name) @property def IntTensor(self): - return torch.xpu.IntTensor + return functools.partial(torch.tensor, dtype=torch.int, device=self._name) @property def LongTensor(self): - return torch.xpu.LongTensor + return functools.partial(torch.tensor, dtype=torch.long, device=self._name) def pin_memory(self, tensor, align_bytes=1): if align_bytes == 1: diff --git a/csrc/adagrad/cpu_adagrad.cpp b/csrc/adagrad/cpu_adagrad.cpp index 563255176500..5790e79e2bc2 100644 --- a/csrc/adagrad/cpu_adagrad.cpp +++ b/csrc/adagrad/cpu_adagrad.cpp @@ -5,55 +5,38 @@ #include "cpu_adagrad.h" #include +#include #include +#include #include #include #include -#if defined(__ENABLE_CUDA__) -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" -#endif +using namespace std::string_literals; static std::unordered_map> s_optimizers; // C++ interface -void Adagrad_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) { float step_size = -1 * _alpha; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } for (size_t t = rounded_size; t < _param_size; t += TILE) { size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } -#endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float grad = (float)grads[k]; + float param = (float)_params[k]; float momentum = grads[k]; float variance = _exp_avg_sq[k]; if (_weight_decay > 0) { grad = param * _weight_decay + grad; } @@ -64,58 +47,30 @@ void Adagrad_Optimizer::Step_1(float* _params, grad += _eps; grad = momentum / grad; param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; -#endif - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; + _params[k] = param; // STORE UPDATE TERM TO GRAD'S MEMORY grads[k] = grad * step_size; _exp_avg_sq[k] = variance; } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - _buf_index = !_buf_index; - } -#elif defined(__ENABLE_CANN__) - if (dev_params) { - size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - aclrtMemcpy(dev_params + t, - memcpy_size, - _doubled_buffer[_buf_index], - memcpy_size, - aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); - - _buf_index = !_buf_index; - } -#endif } } } -void Adagrad_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adagrad_Optimizer::Step_4(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); } int create_adagrad_optimizer(int optimizer_id, @@ -149,25 +104,77 @@ int create_adagrad_optimizer(int optimizer_id, return 0; } -void Adagrad_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); +} + +template +void step_invoker(std::shared_ptr opt, + void* _params, + void* grads, + void* _exp_avg_sq, + size_t _param_size) +{ + opt->Step_8((ds_params_percision_t*)(_params), + (ds_params_percision_t*)(grads), + (ds_state_precision_t*)(_exp_avg_sq), + _param_size); +} + +std::map, + std::function, void*, void*, void*, size_t)>> + invokers; + +// Fill map with template functions for each type +template +void create_invoker() +{ + invokers[std::tuple(c10::CppTypeToScalarType(), + c10::CppTypeToScalarType())] = + step_invoker; +} +struct InvokerInitializer { + InvokerInitializer() + { + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + } +} _invoker_initializer; + +void invoke(std::shared_ptr opt, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq, + size_t param_size) +{ + c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); + c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg_sq.options().dtype()); + + auto it = invokers.find(std::tuple(params_type, state_type)); + if (it == invokers.end()) { + throw std::runtime_error("Adagrad optimizer with param type "s + + c10::toString(params_type) + " and state type "s + + c10::toString(state_type) + + " is not supported on current hardware"s); + } + + it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg_sq.data_ptr(), param_size); } int ds_adagrad_step(int optimizer_id, @@ -183,58 +190,13 @@ int ds_adagrad_step(int optimizer_id, auto grads_c = grads.contiguous(); auto exp_avg_sq_c = exp_avg_sq.contiguous(); - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step); opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel()); -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - opt->SynchronizeStreams(); -#endif - return 0; -} + invoke(opt, params_c, grads_c, exp_avg_sq_c, params_c.numel()); -int ds_adagrad_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float epsilon, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step); - opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_sq_ptr, - params_c.numel(), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); -#else - assert(false); -#endif return 0; } @@ -248,9 +210,6 @@ int destroy_adagrad_optimizer(int optimizer_id) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); - m.def("adagrad_update_copy", - &ds_adagrad_step_plus_copy, - "DeepSpeed CPU Adagrad update and param copy (C++)"); m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); } diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 96809827f3e1..263c443cb4d4 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -8,9 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); - m.def("adam_update_copy", - &ds_adam_step_plus_copy, - "DeepSpeed CPU Adam update and param copy (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); } diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 9a4a8d956519..15d4e74d69d5 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -5,42 +5,29 @@ #include #include +#include #include +#include #include #include #include #include "cpu_adam.h" -#if defined(__ENABLE_CUDA__) -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" -#endif - +using namespace std::string_literals; static std::unordered_map> s_optimizers; // C++ interface -void Adam_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adam_Optimizer::Step_1(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - _betta1; @@ -48,26 +35,15 @@ void Adam_Optimizer::Step_1(float* _params, float step_size = -1 * _alpha / _bias_correction1; float w_decay = -1 * _alpha * _weight_decay; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } for (size_t t = rounded_size; t < _param_size; t += TILE) { size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } -#endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float grad = (float)grads[k]; + float param = (float)_params[k]; float momentum = _exp_avg[k]; float variance = _exp_avg_sq[k]; if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } @@ -83,66 +59,31 @@ void Adam_Optimizer::Step_1(float* _params, grad = momentum / grad; if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; -#endif - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; + _params[k] = param; _exp_avg[k] = momentum; _exp_avg_sq[k] = variance; } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#elif defined(__ENABLE_CANN__) - if (dev_params) { - size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - aclrtMemcpy(dev_params + t, - memcpy_size, - _doubled_buffer[_buf_index], - memcpy_size, - aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); - - _buf_index = !_buf_index; - } -#endif } } } -void Adam_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adam_Optimizer::Step_4(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); } int create_adam_optimizer(int optimizer_id, @@ -185,33 +126,86 @@ int create_adam_optimizer(int optimizer_id, return 0; } -void Adam_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adam_Optimizer::Step_8(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); +} + +template +void step_invoker(std::shared_ptr opt, + void* _params, + void* grads, + void* _exp_avg, + void* _exp_avg_sq, + size_t _param_size) +{ + opt->Step_8((ds_params_percision_t*)(_params), + (ds_params_percision_t*)(grads), + (ds_state_precision_t*)(_exp_avg), + (ds_state_precision_t*)(_exp_avg_sq), + _param_size); +} + +std::map, + std::function, void*, void*, void*, void*, size_t)>> + invokers; + +// Fill map with template functions for each type +template +void create_invoker() +{ + invokers[std::tuple(c10::CppTypeToScalarType(), + c10::CppTypeToScalarType())] = + step_invoker; +} +struct InvokerInitializer { + InvokerInitializer() + { + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + } +} _invoker_initializer; + +void invoke(std::shared_ptr opt, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + size_t param_size) +{ + c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); + c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype()); + + auto it = invokers.find(std::tuple(params_type, state_type)); + if (it == invokers.end()) { + throw std::runtime_error("Adam optimizer with param type "s + c10::toString(params_type) + + " and state type "s + c10::toString(state_type) + + " is not supported on current hardware"s); + } + + it->second(opt, + params.data_ptr(), + grads.data_ptr(), + exp_avg.data_ptr(), + exp_avg_sq.data_ptr(), + param_size); } int ds_adam_step(int optimizer_id, @@ -232,75 +226,13 @@ int ds_adam_step(int optimizer_id, auto exp_avg_c = exp_avg.contiguous(); auto exp_avg_sq_c = exp_avg_sq.contiguous(); - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step, beta1, beta2); opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - nullptr, - (params.options().dtype() == at::kHalf)); + invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel()); -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - opt->SynchronizeStreams(); -#endif - return 0; -} - -int ds_adam_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - torch::Tensor& device_params) -{ -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - auto params_c = params.contiguous(); - auto device_params_c = device_params.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - device_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); -#else - assert(false); -#endif return 0; } diff --git a/csrc/common/custom_cuda_kernel.cu b/csrc/common/custom_cuda_kernel.cu deleted file mode 100644 index f46bf303125c..000000000000 --- a/csrc/common/custom_cuda_kernel.cu +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#include "custom_cuda_layers.h" - -__global__ void param_update_kernel(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - - if (id < size) { output[id] = (__half)input[id]; } -} - -void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel<<>>(input, output, size); -} - -__global__ void param_update_kernel_half(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - __half2* output_cast = reinterpret_cast<__half2*>(output); - if (id < size) { - float input_f = input[id]; - __half2* input_h = reinterpret_cast<__half2*>(&input_f); - output_cast[id] = *input_h; - } -} - -void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - size /= 2; - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel_half<<>>(input, output, size); -} diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index e60984d64b76..c06d3a6b35e9 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -9,84 +9,35 @@ // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c #include +#include #include #include "simd.h" -#if defined(__ENABLE_CUDA__) -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; -#elif defined(__ENABLE_CANN__) -#include "acl/acl.h" -#include "torch_npu/csrc/core/npu/NPUStream.h" -typedef c10::Half ds_half_precision_t; -#else -typedef unsigned short ds_half_precision_t; -#endif - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - ds_half_precision_t* dev_param = nullptr, \ - bool half_precision = false); +#define STEP(SPAN) \ + template \ + void Step_##SPAN(ds_params_percision_t* _params, \ + ds_params_percision_t* grads, \ + ds_state_precision_t* _exp_avg_sq, \ + size_t _param_size); class Adagrad_Optimizer { public: Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) : _alpha(alpha), _eps(eps), _weight_decay(weight_decay) { -#if defined(__ENABLE_CUDA__) - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = TrainingContext::Instance().GetCurrentStream(); - _streams[1] = TrainingContext::Instance().GetNewStream(); - _buf_index = false; -#elif defined(__ENABLE_CANN__) - aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _buf_index = false; -#endif - } - ~Adagrad_Optimizer() - { -#if defined(__ENABLE_CUDA__) - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); -#elif defined(__ENABLE_CANN__) - aclrtFreeHost(_doubled_buffer[0]); - aclrtFreeHost(_doubled_buffer[1]); -#endif } + ~Adagrad_Optimizer() {} #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t param_size, - ds_half_precision_t* dev_param = nullptr, - bool half_precision = false); + ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t param_size); #endif STEP(1) STEP(4) STEP(8) -#if defined(__ENABLE_CUDA__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } -#elif defined(__ENABLE_CANN__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream()); - } -#endif inline void IncrementStep(size_t step) { _step++; @@ -107,29 +58,22 @@ class Adagrad_Optimizer { float _betta1_t; float _betta2_t; size_t _step; - -#if defined(__ENABLE_CUDA__) - bool _buf_index; - float* _doubled_buffer[2]; - cudaStream_t _streams[2]; -#elif defined(__ENABLE_CANN__) - float* _doubled_buffer[2]; - c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(), - c10_npu::getNPUStreamFromPool()}; - bool _buf_index; -#endif }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) + ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { +#if !defined(__AVX512__) + if (std::is_same_v || + std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; AVX_Data eps_4; eps_4.data = SIMD_SET(_eps); @@ -145,24 +89,19 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } -#endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); + simd_load(grad_4, grads + i); AVX_Data momentum_4[span]; - simd_load(momentum_4, grads + i, false); + simd_load(momentum_4, grads + i); AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); + simd_load(variance_4, _exp_avg_sq + i); AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); + simd_load(param_4, _params + i); if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } @@ -172,38 +111,9 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, simd_div(grad_4, momentum_4, grad_4); simd_fma(param_4, grad_4, step_size_4, param_4); - simd_store(_params + i, param_4, half_precision); -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } -#endif - simd_store(_exp_avg_sq + i, variance_4, false); - } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; + simd_store(_params + i, param_4); + simd_store(_exp_avg_sq + i, variance_4); } -#elif defined(__ENABLE_CANN__) - if (dev_params) { - size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - if (half_precision) memcpy_size /= 2; - aclrtMemcpy(dev_params + t, - memcpy_size, - _doubled_buffer[_buf_index], - memcpy_size, - aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); - - _buf_index = !_buf_index; - } -#endif } *rounded_size = new_rounded_size; } diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index b1a104b2571d..faf99020aee5 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -13,29 +13,13 @@ #include #include "simd.h" -#if defined(__ENABLE_CUDA__) -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; -#elif defined(__ENABLE_CANN__) -#include "acl/acl.h" -#include "torch_npu/csrc/core/npu/NPUStream.h" -typedef c10::Half ds_half_precision_t; -#else -#include -typedef unsigned short ds_half_precision_t; -#endif - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - ds_half_precision_t* dev_param = nullptr, \ - bool half_precision = false); +#define STEP(SPAN) \ + template \ + void Step_##SPAN(ds_params_percision_t* _params, \ + ds_params_percision_t* grads, \ + ds_state_precision_t* _exp_avg, \ + ds_state_precision_t* _exp_avg_sq, \ + size_t _param_size); class Adam_Optimizer { public: @@ -55,56 +39,21 @@ class Adam_Optimizer { _step(0), _adamw_mode(adamw_mode) { -#if defined(__ENABLE_CUDA__) - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = TrainingContext::Instance().GetCurrentStream(); - _streams[1] = TrainingContext::Instance().GetNewStream(); - _buf_index = false; -#elif defined(__ENABLE_CANN__) - aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _buf_index = false; -#endif - } - ~Adam_Optimizer() - { -#if defined(__ENABLE_CUDA__) - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); -#elif defined(__ENABLE_CANN__) - aclrtFreeHost(_doubled_buffer[0]); - aclrtFreeHost(_doubled_buffer[1]); -#endif } + ~Adam_Optimizer() {} #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t param_size, - ds_half_precision_t* dev_param = nullptr, - bool half_precision = false); + ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t param_size); #endif STEP(1) STEP(4) STEP(8) -#if defined(__ENABLE_CUDA__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } -#elif defined(__ENABLE_CANN__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream()); - } -#endif inline void IncrementStep(size_t step, float beta1, float beta2) { if (beta1 != _betta1 || beta2 != _betta2) { @@ -154,32 +103,24 @@ class Adam_Optimizer { float _bias_correction2; bool _adamw_mode; - -#if defined(__ENABLE_CUDA__) - float* _doubled_buffer[2]; - cudaStream_t _streams[2]; - bool _buf_index; -#elif defined(__ENABLE_CANN__) - float* _doubled_buffer[2]; - c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(), - c10_npu::getNPUStreamFromPool()}; - bool _buf_index; -#endif }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Adam_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) + ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { +#if !defined(__AVX512__) + if (std::is_same_v || + std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; - int rshft = half_precision ? 1 : 0; AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -212,24 +153,19 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } -#endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + (i >> rshft), half_precision); + simd_load(grad_4, grads + i); AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); + simd_load(momentum_4, _exp_avg + i); AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); + simd_load(variance_4, _exp_avg_sq + i); AVX_Data param_4[span]; - simd_load(param_4, _params + (i >> rshft), half_precision); + simd_load(param_4, _params + i); if (_weight_decay > 0 && !_adamw_mode) { simd_fma(grad_4, param_4, weight_decay4, grad_4); @@ -250,39 +186,10 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(param_4, grad_4, step_size_4, param_4); - simd_store(_params + (i >> rshft), param_4, half_precision); -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } -#endif - simd_store(_exp_avg + i, momentum_4, false); - simd_store(_exp_avg_sq + i, variance_4, false); + simd_store(_params + i, param_4); + simd_store(_exp_avg + i, momentum_4); + simd_store(_exp_avg_sq + i, variance_4); } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#elif defined(__ENABLE_CANN__) - if (dev_params) { - size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - if (half_precision) memcpy_size /= 2; - aclrtMemcpy(dev_params + t, - memcpy_size, - _doubled_buffer[_buf_index], - memcpy_size, - aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); - - _buf_index = !_buf_index; - } -#endif } *rounded_size = new_rounded_size; } @@ -310,18 +217,4 @@ int ds_adam_step(int optimizer_id, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq); -int ds_adam_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params); - int destroy_adam_optimizer(int optimizer_id); diff --git a/csrc/includes/cpu_lion.h b/csrc/includes/cpu_lion.h index 34c29eec47db..62b304923222 100644 --- a/csrc/includes/cpu_lion.h +++ b/csrc/includes/cpu_lion.h @@ -13,28 +13,12 @@ #include #include "simd.h" -#if defined(__ENABLE_CUDA__) -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; -#elif defined(__ENABLE_CANN__) -#include "acl/acl.h" -#include "torch_npu/csrc/core/npu/NPUStream.h" -typedef c10::Half ds_half_precision_t; -#else -#include -typedef unsigned short ds_half_precision_t; -#endif - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - size_t _param_size, \ - ds_half_precision_t* dev_param = nullptr, \ - bool half_precision = false); +#define STEP(SPAN) \ + template \ + void Step_##SPAN(ds_params_percision_t* _params, \ + ds_params_percision_t* grads, \ + ds_state_precision_t* _exp_avg, \ + size_t _param_size); class Lion_Optimizer { public: @@ -44,55 +28,21 @@ class Lion_Optimizer { float weight_decay = 0) : _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0) { -#if defined(__ENABLE_CUDA__) - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = TrainingContext::Instance().GetCurrentStream(); - _streams[1] = TrainingContext::Instance().GetNewStream(); - _buf_index = false; -#elif defined(__ENABLE_CANN__) - aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _buf_index = false; -#endif - } - ~Lion_Optimizer() - { -#if defined(__ENABLE_CUDA__) - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); -#elif defined(__ENABLE_CANN__) - aclrtFreeHost(_doubled_buffer[0]); - aclrtFreeHost(_doubled_buffer[1]); -#endif } + ~Lion_Optimizer() {} #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - size_t param_size, - ds_half_precision_t* dev_param = nullptr, - bool half_precision = false); + ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + size_t param_size); #endif STEP(1) STEP(4) STEP(8) -#if defined(__ENABLE_CUDA__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } -#elif defined(__ENABLE_CANN__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream()); - } -#endif + inline void IncrementStep(size_t step, float beta1, float beta2) { _step++; @@ -114,31 +64,23 @@ class Lion_Optimizer { float _betta2; float _weight_decay; size_t _step; - -#if defined(__ENABLE_CUDA__) - float* _doubled_buffer[2]; - cudaStream_t _streams[2]; - bool _buf_index; -#elif defined(__ENABLE_CANN__) - float* _doubled_buffer[2]; - c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(), - c10_npu::getNPUStreamFromPool()}; - bool _buf_index; -#endif }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Lion_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) + ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) { +#if !defined(__AVX512__) + if (std::is_same_v || + std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; - int rshft = half_precision ? 1 : 0; constexpr float neg1 = -1.0f; AVX_Data neg1_4; @@ -169,21 +111,17 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } -#endif + #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + (i >> rshft), half_precision); + simd_load(grad_4, grads + i); AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); + simd_load(momentum_4, _exp_avg + i); AVX_Data param_4[span]; - simd_load(param_4, _params + (i >> rshft), half_precision); + simd_load(param_4, _params + i); AVX_Data tmp_4[span]; @@ -201,38 +139,9 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, simd_mul(momentum_4, momentum_4, betta2_4); simd_fma(momentum_4, grad_4, betta2_minus1_4, momentum_4); - simd_store(_params + (i >> rshft), param_4, half_precision); -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } -#endif - simd_store(_exp_avg + i, momentum_4, false); - } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; + simd_store(_params + i, param_4); + simd_store(_exp_avg + i, momentum_4); } -#elif defined(__ENABLE_CANN__) - if (dev_params) { - size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - if (half_precision) memcpy_size /= 2; - aclrtMemcpy(dev_params + t, - memcpy_size, - _doubled_buffer[_buf_index], - memcpy_size, - aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); - - _buf_index = !_buf_index; - } -#endif } *rounded_size = new_rounded_size; } @@ -255,15 +164,4 @@ int ds_lion_step(int optimizer_id, torch::Tensor& grads, torch::Tensor& exp_avg); -int ds_lion_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& gpu_params); - int destroy_lion_optimizer(int optimizer_id); diff --git a/csrc/includes/cublas_wrappers.h b/csrc/includes/cublas_wrappers.h index b57ff79923fc..2721fb990c7e 100644 --- a/csrc/includes/cublas_wrappers.h +++ b/csrc/includes/cublas_wrappers.h @@ -17,6 +17,7 @@ #include #endif #include +#include int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, @@ -29,7 +30,9 @@ int cublas_gemm_ex(cublasHandle_t handle, const float* A, const float* B, float* C, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); @@ -46,7 +49,8 @@ int cublas_gemm_ex(cublasHandle_t handle, const __half* A, const __half* B, __half* C, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -67,7 +71,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int stride_B, int stride_C, int batch, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); @@ -88,7 +93,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int stride_B, int stride_C, int batch, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h index 265eb7b12444..21f19749d4cf 100644 --- a/csrc/includes/custom_cuda_layers.h +++ b/csrc/includes/custom_cuda_layers.h @@ -272,9 +272,6 @@ void launch_fuse_transpose_bias_kernel(const T* inp, int cols, cudaStream_t stream); -void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream); -void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream); - void launch_token_sort(int32_t* indices, int layers, int batch_size, diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h index 8e4888109fcd..f8b16ee6a315 100644 --- a/csrc/includes/ds_kernel_utils.h +++ b/csrc/includes/ds_kernel_utils.h @@ -23,7 +23,7 @@ used throughout the codebase. #ifdef __HIP_PLATFORM_AMD__ // constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 #include #include diff --git a/csrc/includes/feed_forward.h b/csrc/includes/feed_forward.h index 46e3ba748d52..d2056403d265 100644 --- a/csrc/includes/feed_forward.h +++ b/csrc/includes/feed_forward.h @@ -48,7 +48,9 @@ class FeedForward { weights, input_ptr, out, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[0])); #else cublasGemmAlgo_t(config_.gemm_algos[0])); @@ -77,7 +79,8 @@ class FeedForward { input_ptr, out_grad, weights_grad, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[1])); #else cublasGemmAlgo_t(config_.gemm_algos[1])); @@ -94,7 +97,8 @@ class FeedForward { weights, out_grad, inp_grad_out, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[2])); #else cublasGemmAlgo_t(config_.gemm_algos[2])); diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h index 278515174523..de5b55cd3df1 100644 --- a/csrc/includes/gemm_test.h +++ b/csrc/includes/gemm_test.h @@ -67,7 +67,9 @@ class GemmTest { B, A, C, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -86,7 +88,8 @@ class GemmTest { A, C, B, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -105,7 +108,8 @@ class GemmTest { B, C, A, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -121,8 +125,11 @@ class GemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; +#elif defined(__HIP_PLATFORM_AMD__) + for (int algo = (int)HIPBLAS_GEMM_DEFAULT; algo <= (int)HIPBLAS_GEMM_DEFAULT; #else for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; @@ -211,7 +218,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -245,7 +253,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -276,7 +285,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -292,11 +302,17 @@ class StridedGemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; +#else +#ifdef __HIP_PLATFORM_AMD__ + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; #else for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; +#endif #endif algo++) { int warm_up = 5; diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index 59237b0261c1..73e41216a652 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -13,6 +13,19 @@ #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) +template +inline T readAs(const void* src) +{ + T res; + std::memcpy(&res, src, sizeof(T)); + return res; +} +template +inline void writeAs(void* dst, const T& val) +{ + std::memcpy(dst, &val, sizeof(T)); +} + #define ROUND_DOWN(size, step) ((size) & ~((step)-1)) #if defined(__AVX512__) @@ -30,11 +43,52 @@ #define SIMD_XOR(x, y) _mm512_xor_ps(x, y) #define SIMD_WIDTH 16 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm512_storeu_ps(x, d)) +static __m512 load_16_bf16_as_f32(const void* data) +{ + __m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing + __m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32 + __m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by + // 16 bits (representing bf16->f32) + return readAs<__m512>(&c); // use memcpy to avoid aliasing +} + +static void store_16_f32_as_bf16_nearest(__m512 v, void* data) +{ + __m512i u32 = readAs<__m512i>(&v); + + // flow assuming non-nan: + + // uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + __m512i b = _mm512_srli_epi32(u32, 16); + __m512i lsb_mask = _mm512_set1_epi32(0x00000001); + __m512i c = _mm512_and_si512(b, lsb_mask); + __m512i bias_constant = _mm512_set1_epi32(0x00007fff); + __m512i rounding_bias = _mm512_add_epi32(c, bias_constant); + + // uint16_t res = static_cast((U32 + rounding_bias) >> 16); + __m512i d = _mm512_add_epi32(u32, rounding_bias); + __m512i e = _mm512_srli_epi32(d, 16); + __m256i non_nan_res = _mm512_cvtusepi32_epi16(e); + + // handle nan (exp is all 1s and mantissa != 0) + // if ((x & 0x7fffffffU) > 0x7f800000U) + __m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff); + __m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign); + __m512i nan_threshold = _mm512_set1_epi32(0x7f800000); + __mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT); + + // mix in results with nans as needed + __m256i nans = _mm256_set1_epi16(0x7fc0); + __m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans); + + writeAs(data, res); +} +#define SIMD_LOAD_BF16(x) load_16_bf16_as_f32(x) +#define SIMD_STORE_BF16(x, d) store_16_f32_as_bf16_nearest(d, x) + +#define SIMD_LOAD_FP16(x) _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) +#define SIMD_STORE_FP16(x, d) \ + _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #define INTV __m256i #elif defined(__AVX256__) @@ -52,11 +106,11 @@ #define SIMD_XOR(x, y) _mm256_xor_ps(x, y) #define SIMD_WIDTH 8 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) : _mm256_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm256_storeu_ps(x, d)) +#define SIMD_LOAD_BF16(x) static_assert(false && "AVX256 does not support BFloat16") +#define SIMD_STORE_BF16(x, d) static_assert(false && "AVX256 does not support BFloat16") +#define SIMD_LOAD_FP16(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) +#define SIMD_STORE_FP16(x, d) \ + _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #define INTV __m128i #endif @@ -70,20 +124,66 @@ union AVX_Data { // float data_f[16]; }; -template -inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +template +inline typename std::enable_if_t, void> simd_store(T* dst, + AVX_Data* src) { - size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); + size_t width = SIMD_WIDTH; #pragma unroll - for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); } + for (size_t i = 0; i < span; ++i) { SIMD_STORE_FP16((float*)(dst + width * i), src[i].data); } } -template -inline void simd_load(AVX_Data* dst, float* src, bool half_precision) + +template +inline typename std::enable_if_t, void> simd_store(T* dst, + AVX_Data* src) { - size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); +#ifdef __AVX512__ + size_t width = SIMD_WIDTH; #pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); } + for (size_t i = 0; i < span; ++i) { SIMD_STORE_BF16((float*)(dst + width * i), src[i].data); } +#else + throw std::runtime_error("AVX512 required for BFloat16"); +#endif +} + +template +inline typename std::enable_if_t, void> simd_store(T* dst, AVX_Data* src) +{ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { SIMD_STORE(dst + width * i, src[i].data); } } + +template +inline typename std::enable_if_t, void> simd_load(AVX_Data* dst, + T* src) +{ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_FP16((float*)(src + width * i)); } +} + +template +inline typename std::enable_if_t, void> simd_load(AVX_Data* dst, + T* src) +{ +#ifdef __AVX512__ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_BF16((float*)(src + width * i)); } +#else + throw std::runtime_error("AVX512 required for BFloat16"); +#endif +} + +template +inline typename std::enable_if_t, void> simd_load(AVX_Data* dst, T* src) +{ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD(src + width * i); } +} + template inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) { diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h index 86d1e3dea11a..9767fcf589b8 100644 --- a/csrc/includes/strided_batch_gemm.h +++ b/csrc/includes/strided_batch_gemm.h @@ -77,7 +77,9 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[0])); #else cublasGemmAlgo_t(_config.gemm_algos[0])); @@ -105,7 +107,8 @@ class StridedBatchGemm { stride_b, stride_c, _config.batch_size, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[0])); #else cublasGemmAlgo_t(_config.gemm_algos[0])); @@ -149,7 +152,8 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[1])); #else cublasGemmAlgo_t(_config.gemm_algos[1])); @@ -178,7 +182,8 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[2])); #else cublasGemmAlgo_t(_config.gemm_algos[2])); diff --git a/csrc/lion/cpu_lion.cpp b/csrc/lion/cpu_lion.cpp index a0562eac9c4a..c5cf3e9e9235 100644 --- a/csrc/lion/cpu_lion.cpp +++ b/csrc/lion/cpu_lion.cpp @@ -8,9 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)"); - m.def("lion_update_copy", - &ds_lion_step_plus_copy, - "DeepSpeed CPU Lion update and param copy (C++)"); m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)"); m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)"); } diff --git a/csrc/lion/cpu_lion_impl.cpp b/csrc/lion/cpu_lion_impl.cpp index 28314cf5b6e1..85896ba86e19 100644 --- a/csrc/lion/cpu_lion_impl.cpp +++ b/csrc/lion/cpu_lion_impl.cpp @@ -6,34 +6,28 @@ #include #include #include +#include #include +#include #include #include #include #include "cpu_lion.h" -#if defined(__ENABLE_CUDA__) -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" -#endif - +using namespace std::string_literals; static std::unordered_map> s_optimizers; // C++ interface -void Lion_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Lion_Optimizer::Step_1(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - _betta1; @@ -41,26 +35,15 @@ void Lion_Optimizer::Step_1(float* _params, float alpha = _alpha; float after_decay = 1 - alpha * _weight_decay; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } for (size_t t = rounded_size; t < _param_size; t += TILE) { size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } -#endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float grad = (float)grads[k]; + float param = (float)_params[k]; float momentum = _exp_avg[k]; float tmp = momentum * _betta1; tmp = grad * betta1_minus1 + tmp; @@ -74,56 +57,28 @@ void Lion_Optimizer::Step_1(float* _params, } momentum = momentum * _betta2; momentum = grad * betta2_minus1 + momentum; -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; -#endif - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; + _params[k] = param; _exp_avg[k] = momentum; } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#elif defined(__ENABLE_CANN__) - if (dev_params) { - size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - aclrtMemcpy(dev_params + t, - memcpy_size, - _doubled_buffer[_buf_index], - memcpy_size, - aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); - - _buf_index = !_buf_index; - } -#endif } } } -void Lion_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Lion_Optimizer::Step_4(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); } int create_lion_optimizer(int optimizer_id, @@ -162,24 +117,76 @@ int create_lion_optimizer(int optimizer_id, return 0; } -void Lion_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Lion_Optimizer::Step_8(ds_params_percision_t* _params, + ds_params_percision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); +} + +template +void step_invoker(std::shared_ptr opt, + void* _params, + void* grads, + void* _exp_avg, + size_t _param_size) +{ + opt->Step_8((ds_params_percision_t*)(_params), + (ds_params_percision_t*)(grads), + (ds_state_precision_t*)(_exp_avg), + _param_size); +} + +std::map, + std::function, void*, void*, void*, size_t)>> + invokers; + +// Fill map with template functions for each type +template +void create_invoker() +{ + invokers[std::tuple(c10::CppTypeToScalarType(), + c10::CppTypeToScalarType())] = + step_invoker; +} +struct InvokerInitializer { + InvokerInitializer() + { + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + } +} _invoker_initializer; + +void invoke(std::shared_ptr opt, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + size_t param_size) +{ + c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); + c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype()); + + auto it = invokers.find(std::tuple(params_type, state_type)); + if (it == invokers.end()) { + throw std::runtime_error("Lion optimizer with param type "s + c10::toString(params_type) + + " and state type "s + c10::toString(state_type) + + " is not supported on current hardware"s); + } + + it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg.data_ptr(), param_size); } int ds_lion_step(int optimizer_id, @@ -196,67 +203,13 @@ int ds_lion_step(int optimizer_id, auto grads_c = grads.contiguous(); auto exp_avg_c = exp_avg.contiguous(); - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step, beta1, beta2); opt->update_state(lr, weight_decay); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - params_c.numel(), - nullptr, - (params.options().dtype() == at::kHalf)); + invoke(opt, params_c, grads_c, exp_avg_c, params_c.numel()); -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - opt->SynchronizeStreams(); -#endif - return 0; -} - -int ds_lion_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& gpu_params) -{ -#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, weight_decay); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - params_c.numel(), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); -#else - assert(false); -#endif return 0; } diff --git a/csrc/random_ltd/token_sort.cu b/csrc/random_ltd/token_sort.cu index 3049471cfe34..3c1dff49429f 100644 --- a/csrc/random_ltd/token_sort.cu +++ b/csrc/random_ltd/token_sort.cu @@ -16,7 +16,7 @@ constexpr int mem_vals = granularity / sizeof(int32_t); constexpr int max_buffer_size = (threads + 1) * mem_vals; #ifdef __HIP_PLATFORM_AMD__ -constexpr int warp_size = 64; +constexpr int warp_size = ROCM_WAVEFRONT_SIZE; #else constexpr int warp_size = 32; #endif diff --git a/csrc/transformer/cublas_wrappers.cu b/csrc/transformer/cublas_wrappers.cu index 7821a8759ab0..d982e65b8a81 100644 --- a/csrc/transformer/cublas_wrappers.cu +++ b/csrc/transformer/cublas_wrappers.cu @@ -5,7 +5,9 @@ #include "cublas_wrappers.h" -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -33,7 +35,8 @@ int cublas_gemm_ex(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -67,20 +70,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -96,7 +118,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -124,7 +147,8 @@ int cublas_gemm_ex(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -158,20 +182,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, (void*)C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -187,7 +230,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -223,7 +267,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -263,24 +308,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -297,7 +361,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -333,7 +398,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -373,24 +439,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index a06dbb48fd33..25a494111c54 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -99,17 +99,9 @@ __global__ void apply_rotary_pos_half(T* mixed_query, rope_theta, \ max_out_tokens); -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 4) { \ - LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ - } else if (threads_per_head == 8) { \ - LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ - } else if (threads_per_head == 16) { \ - LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ - } else if (threads_per_head == 32) { \ - LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ - } else if (threads_per_head == 64) { \ + if (threads_per_head == 64) { \ LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ } else { \ assert(false); \ diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index b7277d1e1678..1b9f91cd9c88 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -163,7 +163,9 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) (T*)W.data_ptr(), (T*)Q.data_ptr(), (T*)O.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -216,7 +218,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * k, seq_len * soft_len, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -253,7 +256,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * soft_len, seq_len * k, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -388,7 +392,8 @@ void attention_unfused(T* prev_key_cont, seq_len * k, seq_len * soft_len, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -421,7 +426,8 @@ void attention_unfused(T* prev_key_cont, seq_len * soft_len, seq_len * k, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -886,7 +892,8 @@ void quantized_gemm(void* output, weight16, (T*)input, (T*)output, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -931,7 +938,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, (T*)weight.data_ptr(), workspace, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1003,7 +1011,8 @@ std::vector ds_rms_qkv(at::Tensor& input, (T*)weight.data_ptr(), (T*)rms_norm.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1089,7 +1098,8 @@ void quantized_gemm(at::Tensor& output, (T*)weight16.data_ptr(), (T*)input.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1135,7 +1145,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, (T*)weight.data_ptr(), (T*)input_cont.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1353,7 +1364,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input, (T*)weight.data_ptr(), (T*)input.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1439,7 +1451,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, (T*)weight.data_ptr(), inp_norm, intermediate, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1483,7 +1496,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, (T*)weight1.data_ptr(), intermediate, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1617,7 +1631,8 @@ std::vector ds_rms_mlp_gemm(at::Tensor& input, (T*)weight_interm.data_ptr(), (T*)inp_norm.data_ptr(), intermediate_ptr, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1680,7 +1695,8 @@ std::vector ds_rms_mlp_gemm(at::Tensor& input, (T*)weight_out.data_ptr(), intermediate_ptr, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP, @@ -1742,7 +1758,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)weight.data_ptr(), (T*)input.data_ptr(), (T*)intermediate.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1776,7 +1793,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)weight_out.data_ptr(), (T*)intermediate.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/csrc/transformer/inference/includes/inference_cublas_wrappers.h b/csrc/transformer/inference/includes/inference_cublas_wrappers.h index 640751b12c8f..40c3e443941d 100644 --- a/csrc/transformer/inference/includes/inference_cublas_wrappers.h +++ b/csrc/transformer/inference/includes/inference_cublas_wrappers.h @@ -18,7 +18,9 @@ #endif #include -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -49,7 +51,8 @@ int cublas_gemm_ex(cublasHandle_t handle, #endif { const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -83,20 +86,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif ldb, (const void*)beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -113,7 +135,8 @@ int cublas_gemm_ex(cublasHandle_t handle, } template -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -144,7 +167,8 @@ int cublas_gemm_ex(cublasHandle_t handle, #endif { const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r : rocblas_datatype_bf16_r; rocblas_status status = rocblas_gemm_ex(handle, @@ -171,8 +195,12 @@ int cublas_gemm_ex(cublasHandle_t handle, algo, 0, 0); +#else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; #else constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif cublasStatus_t status = cublasGemmEx(handle, transa, transb, @@ -190,11 +218,18 @@ int cublas_gemm_ex(cublasHandle_t handle, (void*)C, cublas_dtype_16, m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -210,7 +245,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -246,7 +282,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -286,24 +323,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -321,7 +377,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, } template -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -357,7 +414,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r : rocblas_datatype_bf16_r; rocblas_status status = @@ -390,8 +448,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, algo, 0, 0); +#else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; #else constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif cublasStatus_t status = cublasGemmStridedBatchedEx(handle, op_A, op_B, @@ -413,11 +475,18 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 4f826edab3d6..85b7fab2c548 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -618,7 +618,7 @@ def init_distributed(dist_backend=None, auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI distributed_port: Optional (int). torch distributed backend port verbose: Optional (bool). verbose logging - timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes. + timeout: Optional (timedelta). Timeout for operations executed against the process group. The default value of 30 minutes can be overridden by the environment variable `DEEPSPEED_TIMEOUT`. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling) rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization) diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h index c02cc76905e0..294db7528699 100644 --- a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h @@ -55,7 +55,9 @@ class BlasContext { enum class BlasType { FP32, FP16, BF16 }; -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_operation get_trans_op(bool do_trans) { return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none; @@ -76,9 +78,15 @@ cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T cublasDataType_t get_datatype(BlasType type) { switch (type) { +#ifdef __HIP_PLATFORM_AMD__ + case BlasType::FP32: return HIPBLAS_R_32F; + case BlasType::FP16: return HIPBLAS_R_16F; + case BlasType::BF16: return HIPBLAS_R_16B; +#else case BlasType::FP32: return CUDA_R_32F; case BlasType::FP16: return CUDA_R_16F; case BlasType::BF16: return CUDA_R_16BF; +#endif default: throw std::runtime_error("Unsupported BlasType"); } } @@ -99,7 +107,8 @@ int blas_gemm_ex(void* C, const float* beta, BlasType type) { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_operation_t transa_op = get_trans_op(transa); rocblas_operation_t transb_op = get_trans_op(transb); @@ -151,11 +160,18 @@ int blas_gemm_ex(void* C, C, abc_type, ldc, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -190,7 +206,8 @@ int blas_strided_batched_gemm(void* C, int batch, BlasType type) { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_operation_t transa_op = get_trans_op(transa); rocblas_operation_t transb_op = get_trans_op(transb); @@ -253,11 +270,18 @@ int blas_strided_batched_gemm(void* C, ldc, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu index cfa62f94596a..fc14b1831361 100644 --- a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu @@ -17,7 +17,7 @@ constexpr int access_size = 16; constexpr int threads = 1024; template -float gated_act_fn(float x, float y); +DS_D_INLINE float gated_act_fn(float x, float y); template <> DS_D_INLINE float gated_act_fn(float x, float y) diff --git a/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h index 8e4888109fcd..f8b16ee6a315 100644 --- a/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h +++ b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h @@ -23,7 +23,7 @@ used throughout the codebase. #ifdef __HIP_PLATFORM_AMD__ // constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 #include #include diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py index 138bd493ffc7..e982785a8122 100644 --- a/deepspeed/linear/optimized_linear.py +++ b/deepspeed/linear/optimized_linear.py @@ -85,7 +85,7 @@ def __init__(self, self.bias = bias self.lora_config = lora_config self.quantization_config = quantization_config - device = get_accelerator().current_device() if device is None else device + device = get_accelerator().current_device_name() if device is None else device assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" self.zero_shards = self.lora_config.base_weight_sharding diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 88f7086518e8..3429ceb0a4ee 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -13,7 +13,7 @@ from deepspeed import comm as dist from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator -from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw +from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list @@ -133,7 +133,8 @@ def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", - "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm" + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", + "Phi3RMSNorm" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -306,6 +307,8 @@ def tp_parser(model): # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. elif 'w2' in layer and 'Mixtral' in str(type(module)): gem_list = gem_list + [layer] + elif 'self_attn.dense' in layer and 'Phi' in str(type(module)): + gem_list = gem_list + [layer] layer_list = [] if gem_list != []: @@ -328,6 +331,10 @@ def _replace(self, child, name, conv_linear_layer): # For mixtral-7x8b, need to skip MoE gate linear replace. if name == "block_sparse_moe.gate": return child + # for phi3. + if 'gate_up_proj' in name: + weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) + return LinearLayer(weight=weight, bias=bias) if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index cf087c16da8a..33d36fbfae54 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index): "FalconDecoderLayer": 'bloomtype', "GPTBigCodeBlock": 'bigcodetype', "DecoderLayer": 'glmtype', + "Phi3DecoderLayer": "phi3type" } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size): split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) return torch.cat((split_q[gpu_index], kv), dim=0) + def _phi3_type_transpose(input, mp_size): + num_kv_heads = get_num_kv_heads() + num_heads = get_num_attention_heads() + hidden_size = input.shape[1] + head_dim = hidden_size // num_heads + q_pos = input.shape[0] - 2 * num_kv_heads * head_dim + q = input[:q_pos] + k = input[q_pos:q_pos + num_kv_heads * head_dim] + v = input[q_pos + num_kv_heads * head_dim:] + split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) + split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) + split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0) + def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): return _qwen_type_transpose(src, mp_size, module) elif fused_qkv_type == 'bigcodetype': return _bigcode_type_transpose(src, mp_size) + elif fused_qkv_type == 'phi3type': + return _phi3_type_transpose(src, mp_size) raise ValueError("unknown fused_qkv_type") @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") return _bloom_type_transpose(src, mp_size) + + +# For phi3 with chunk mlp, adjust the weight order. +def shard_chunk_mlp( + weight, + bias, + rank, + world_size, +): + weight_gate, weight_states = weight.chunk(2, dim=0) + total_size = weight_gate.shape[0] + split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0) + if bias is not None: + bias_gate, bias_states = bias.chunk(2, dim=0) + split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0) + + return shard_weight, None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e1703562d180..3029a79698dc 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads from .load_checkpoint import load_model_with_checkpoint import time @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 4.2 set n_embd set_n_embd(n_embd) + # 4.3 set attention_heads + if hasattr(model_config, 'num_attention_heads'): + set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 79c19b5f1272..6758c7a657f6 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -12,6 +12,11 @@ def set_num_kv_heads(num): num_kv_heads = num +def set_num_attention_heads(num): + global num_attention_heads + num_attention_heads = num + + def set_n_embd(num): global n_embd n_embd = num @@ -22,6 +27,11 @@ def get_num_kv_heads(): return num_kv_heads +def get_num_attention_heads(): + global num_attention_heads + return num_attention_heads + + def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index bd2782279c01..96eab5e2ab17 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -220,7 +220,7 @@ def top1gating(logits: Tensor, tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) # Make sure the capacity value does not exceed the number of tokens. - capacity = min(new_capacity, torch.tensor(mask1.size(0))) + capacity = min(new_capacity, torch.tensor(mask1.size(0)).to(new_capacity.device)) # Compute l_aux me = torch.mean(gates, dim=0) diff --git a/deepspeed/monitor/comet.py b/deepspeed/monitor/comet.py new file mode 100644 index 000000000000..d8bc4017800f --- /dev/null +++ b/deepspeed/monitor/comet.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional + +from .utils import check_comet_availability +from .monitor import Monitor + +import deepspeed.comm as dist + +if TYPE_CHECKING: + import comet_ml + from .config import CometConfig + +Name = str +Value = Any +GlobalSamples = int +Event = Tuple[Name, Value, GlobalSamples] + + +class CometMonitor(Monitor): + + def __init__(self, comet_config: "CometConfig"): + super().__init__(comet_config) + check_comet_availability() + import comet_ml + + self.enabled = comet_config.enabled + self._samples_log_interval = comet_config.samples_log_interval + self._experiment: Optional["comet_ml.ExperimentBase"] = None + + if self.enabled and dist.get_rank() == 0: + self._experiment = comet_ml.start( + api_key=comet_config.api_key, + project=comet_config.project, + workspace=comet_config.workspace, + experiment_key=comet_config.experiment_key, + mode=comet_config.mode, + online=comet_config.online, + ) + + if comet_config.experiment_name is not None: + self._experiment.set_name(comet_config.experiment_name) + + self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval) + + @property + def experiment(self) -> Optional["comet_ml.ExperimentBase"]: + return self._experiment + + @property + def samples_log_interval(self) -> int: + return self._samples_log_interval + + def write_events(self, event_list: List[Event]) -> None: + if not self.enabled or dist.get_rank() != 0: + return None + + for event in event_list: + name = event[0] + value = event[1] + engine_global_samples = event[2] + + if self._events_log_scheduler.needs_logging(name, engine_global_samples): + self._experiment.__internal_api__log_metric__( + name=name, + value=value, + step=engine_global_samples, + ) + + +class EventsLogScheduler: + + def __init__(self, samples_log_interval: int): + self._samples_log_interval = samples_log_interval + self._last_logged_events_samples: Dict[str, int] = {} + + def needs_logging(self, name: str, current_sample: int) -> bool: + if name not in self._last_logged_events_samples: + self._last_logged_events_samples[name] = current_sample + return True + + last_logged_sample = self._last_logged_events_samples[name] + samples_delta = current_sample - last_logged_sample + + if samples_delta >= self._samples_log_interval: + self._last_logged_events_samples[name] = current_sample + return True + + return False diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py index 5a8ca6ecf5cd..d422d3b1b9bb 100644 --- a/deepspeed/monitor/config.py +++ b/deepspeed/monitor/config.py @@ -3,12 +3,14 @@ # DeepSpeed Team +from typing import Optional + from deepspeed.pydantic_v1 import root_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel def get_monitor_config(param_dict): - monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")} + monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor", "comet")} return DeepSpeedMonitorConfig(**monitor_dict) @@ -60,12 +62,75 @@ class CSVConfig(DeepSpeedConfigModel): """ Name for the current job. This will become a new directory inside `output_path`. """ +class CometConfig(DeepSpeedConfigModel): + """ + Sets parameters for Comet monitor. For logging data Comet uses + experiment object. + https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/ + """ + + enabled: bool = False + """ Whether logging to Comet is enabled. Requires `comet_ml` package is installed. """ + + samples_log_interval: int = 100 + """ Metrics will be submitted to Comet after processing every `samples_log_intervas` samples""" + + project: Optional[str] = None + """ + Comet project name. Can be set through .comet.config file or environment variable COMET_PROJECT_NAME + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + workspace: Optional[str] = None + """ + Comet workspace name. Can be set through .comet.config file or environment variable COMET_WORKSPACE + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + api_key: Optional[str] = None + """ + Comet API key. Can be set through .comet.config file or environment variable COMET_API_KEY + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + experiment_name: Optional[str] = None + """ + The name for comet experiment to be used for logging. + Can be set through .comet.config file or environment variable COMET_EXPERIMENT_NAME + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + experiment_key: Optional[str] = None + """ + The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters. + Can be set through .comet.config or environment variable COMET_EXPERIMENT_KEY + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + online: Optional[bool] = None + """ + If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment + Defaults to True. + """ + + mode: Optional[str] = None + """ + Control how the Comet experiment is started, 3 options are possible.: + - "get": Continue logging to an existing experiment identified by the `experiment_key` value. + - "create": Always creates of a new experiment, useful for HPO sweeps. + - "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one. + """ + + class DeepSpeedMonitorConfig(DeepSpeedConfigModel): """Sets parameters for various monitoring methods.""" tensorboard: TensorBoardConfig = {} """ TensorBoard monitor, requires `tensorboard` package is installed. """ + comet: CometConfig = {} + """ Comet monitor, requires `comet_ml` package is installed """ + wandb: WandbConfig = {} """ WandB monitor, requires `wandb` package is installed. """ @@ -75,5 +140,5 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel): @root_validator def check_enabled(cls, values): values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get( - "csv_monitor").enabled + "csv_monitor").enabled or values.get("comet") return values diff --git a/deepspeed/monitor/monitor.py b/deepspeed/monitor/monitor.py index 5a32b8bbcadd..e7e26dc483d9 100644 --- a/deepspeed/monitor/monitor.py +++ b/deepspeed/monitor/monitor.py @@ -24,6 +24,7 @@ def write_events(self, event_list): from .wandb import WandbMonitor from .tensorboard import TensorBoardMonitor from .csv_monitor import csvMonitor +from .comet import CometMonitor class MonitorMaster(Monitor): @@ -33,6 +34,7 @@ def __init__(self, monitor_config): self.tb_monitor = None self.wandb_monitor = None self.csv_monitor = None + self.comet_monitor = None self.enabled = monitor_config.enabled if dist.get_rank() == 0: @@ -42,6 +44,8 @@ def __init__(self, monitor_config): self.wandb_monitor = WandbMonitor(monitor_config.wandb) if monitor_config.csv_monitor.enabled: self.csv_monitor = csvMonitor(monitor_config.csv_monitor) + if monitor_config.comet.enabled: + self.comet_monitor = CometMonitor(monitor_config.comet) def write_events(self, event_list): if dist.get_rank() == 0: @@ -51,3 +55,5 @@ def write_events(self, event_list): self.wandb_monitor.write_events(event_list) if self.csv_monitor is not None: self.csv_monitor.write_events(event_list) + if self.comet_monitor is not None: + self.comet_monitor.write_events(event_list) diff --git a/deepspeed/monitor/utils.py b/deepspeed/monitor/utils.py index 265fc9811553..f5530e8532e1 100644 --- a/deepspeed/monitor/utils.py +++ b/deepspeed/monitor/utils.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from packaging import version as pkg_version + def check_tb_availability(): try: @@ -22,3 +24,14 @@ def check_wandb_availability(): 'If you want to use wandb logging, please `pip install wandb` and follow the instructions at https://docs.wandb.ai/quickstart' ) raise + + +def check_comet_availability(): + try: + import comet_ml + comet_version = pkg_version.parse(comet_ml.__version__) + if comet_version < pkg_version.Version("3.41.0"): + raise ImportError("`comet_ml` must have at least version 3.41.0") + except ImportError: + print('If you want to use comet logging, please `pip install "comet_ml>=3.41.0"`') + raise diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index c356a52777f2..dbde6d95f652 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -34,7 +34,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): + def step(self, closure=None): """Update the model parameters. .. note:: @@ -46,8 +46,6 @@ def step(self, closure=None, fp16_param_groups=None): Args: closure (callable, optional): closure to compute the loss. Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. Returns: loss: if ``closure`` is provided. Otherwise ``None``. @@ -94,16 +92,7 @@ def step(self, closure=None, fp16_param_groups=None): sparse_exp_avg_sq.values()) p[sparse_param.indices()] = sparse_param.values() state['exp_avg_sq'][sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values() - if fp16_param_groups is not None: - fp16_param_groups[group_id][param_id][sparse_param.indices()] = sparse_param.values() else: - if fp16_param_groups is not None: - self.ds_opt_adagrad.adagrad_update_copy(self.opt_id, state['step'], group['lr'], group['eps'], - group['weight_decay'], p.data, p.grad.data, - state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'], - group['weight_decay'], p.data, p.grad.data, - state['exp_avg_sq']) + self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'], + group['weight_decay'], p.data, p.grad.data, state['exp_avg_sq']) return loss diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 10b8c15f970b..e0a72a494257 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -107,7 +107,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): + def step(self, closure=None): """Update the model parameters. .. note:: @@ -119,8 +119,6 @@ def step(self, closure=None, fp16_param_groups=None): Args: closure (callable, optional): closure to compute the loss. Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. Returns: loss: if ``closure`` is provided. Otherwise ``None``. @@ -134,13 +132,6 @@ def step(self, closure=None, fp16_param_groups=None): # intended device for step device = torch.device('cpu') - # converting the fp16 params to a group of parameter - if type(fp16_param_groups) is list: - if type(fp16_param_groups[0]) is not list: - fp16_param_groups = [fp16_param_groups] - elif fp16_param_groups is not None: - fp16_param_groups = [[fp16_param_groups]] - for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group['params']): @@ -169,13 +160,7 @@ def step(self, closure=None, fp16_param_groups=None): state['step'] += 1 beta1, beta2 = group['betas'] - if fp16_param_groups is not None: - self.ds_opt_adam.adam_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2, - group['eps'], group['weight_decay'], group['bias_correction'], - p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], - group['weight_decay'], group['bias_correction'], p.data, p.grad.data, - state['exp_avg'], state['exp_avg_sq']) + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq']) return loss diff --git a/deepspeed/ops/lion/cpu_lion.py b/deepspeed/ops/lion/cpu_lion.py index a91a00643873..03342a3fcd34 100755 --- a/deepspeed/ops/lion/cpu_lion.py +++ b/deepspeed/ops/lion/cpu_lion.py @@ -69,7 +69,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): + def step(self, closure=None): """Update the model parameters. .. note:: @@ -81,8 +81,6 @@ def step(self, closure=None, fp16_param_groups=None): Args: closure (callable, optional): closure to compute the loss. Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. Returns: loss: if ``closure`` is provided. Otherwise ``None``. @@ -96,13 +94,6 @@ def step(self, closure=None, fp16_param_groups=None): # intended device for step device = torch.device('cpu') - # converting the fp16 params to a group of parameter - if type(fp16_param_groups) is list: - if type(fp16_param_groups[0]) is not list: - fp16_param_groups = [fp16_param_groups] - elif fp16_param_groups is not None: - fp16_param_groups = [[fp16_param_groups]] - for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group['params']): @@ -131,11 +122,6 @@ def step(self, closure=None, fp16_param_groups=None): state['step'] += 1 beta1, beta2 = group['betas'] - if fp16_param_groups is not None: - self.ds_opt_lion.lion_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2, - group['weight_decay'], p.data, p.grad.data, state['exp_avg'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2, - group['weight_decay'], p.data, p.grad.data, state['exp_avg']) + self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2, + group['weight_decay'], p.data, p.grad.data, state['exp_avg']) return loss diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index b5e4e33425d0..4f828d978613 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values): return field_value -class CompiledModuleWrapper(torch.nn.Module): - - def __init__(self, module, compile_config: Union[CompileConfig, None] = None): - super().__init__() - - assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." - - modules = self.__dict__.get('_modules') - modules['wrapped'] = module - self.__dict__['wrapped'] = module - self._is_compiled = False - self._backend = get_backend_fn(compile_config.backend) - self._compile_kwargs = compile_config.kwargs - self._compiler_fn = None - - def __getattr__(self, name): - return getattr(self.__dict__['wrapped'], name) - - def set_backend(self, backend: Union[str, Callable]): - """Set the backend for torch.compile. - - Args: - backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. - You can directly pass a function that works as a backend. - See also `backend` field in `CompileConfig` for more details. - """ - self._backend = get_backend_fn(backend) - - def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: - """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. - You can also pass a backend name with "backend" key to change the backend. - - Args: - kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. - """ - - if "backend" in kwargs: - raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") - self._compile_kwargs.update(kwargs) - - def set_compiler_fn(self, compiler_fn: Callable) -> None: - """Set a function to be used for compiling the module. - This function should take a torch.nn.Module as input and return a compiled module. - Note that other compile options are ignored when a compiler_fn is set. - - Example: - ```python - def my_compiler_fn(module: torch.nn.Module): - ... - return torch.compile(module, ...) - - engine.set_compiler_fn(my_compiler_fn) - ``` - """ - self._compiler_fn = compiler_fn - - def forward(self, *args, **kwargs) -> Any: - if not self.is_compiled: - if self._compiler_fn is None: - self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) - else: - self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) - self._is_compiled = True - - return self.__dict__['wrapped'](*args, **kwargs) - - @property - def is_compiled(self) -> bool: - return self._is_compiled - - @property - def backend(self) -> Union[str, Callable]: - return self._backend - - @property - def torch_compile_kwargs(self) -> Dict[str, Any]: - return self._compile_kwargs - - @property - def compiler_fn(self) -> Union[Callable, None]: - return self._compiler_fn +def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): + + class wrapper(mod.__class__): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + self.__dict__ = {k: module.__dict__[k] for k in module.__dict__ if not k in self.__class__.__dict__} + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, + backend=self._backend, + **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn + + return wrapper(mod, compile_config) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a2b943b0992..13f335cae6d5 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -469,13 +469,6 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) - elif isinstance(_module, CompiledModuleWrapper): - try: - return getattr(_module, name) - except AttributeError: - raise AttributeError( - f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" - ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") @@ -1270,7 +1263,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): else: self.optimizer = basic_optimizer - log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0]) + log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), ranks=[0]) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index bf1693307ea7..49093bb73c8f 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group): group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) grad_flat_st_idx = grad_flat_en_idx - return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name()) def step(self, closure=None): """ diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1dda7f1aad32..be8fe1a368c6 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine): def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) - assert isinstance(self.module, PipelineModule) \ - or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \ - "model must base PipelineModule" + assert isinstance(self.module, PipelineModule), "model must base PipelineModule" assert self.zero_optimization_stage( ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7744b2ee8b98..2c01c3475a70 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group): # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=group) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") @@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) # # for mask_idx in grad_norm_mask[idx]: # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True - cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) - mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = torch.zeros(p.shape[0] + 1, + device=get_accelerator().current_device_name(), + dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 76583c129cb9..2089d59dbce4 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -20,6 +20,7 @@ "stage": [0|1|2], "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, + "stage3_use_all_reduce_for_fetch_params": [true|false], "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -234,6 +235,12 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ + use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") + """ + Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing + the overhead of concatenation and slicing on the host. + """ + stage3_gather_fp16_weights_on_model_save: bool = Field(False, deprecated=True, new_param="gather_16bit_weights_on_model_save") diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c8099791f882..09e72a695db3 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -23,7 +23,7 @@ from deepspeed.utils import groups import deepspeed -from ..utils import see_memory_usage +from ..utils import see_memory_usage, get_only_unique_item from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -56,7 +56,7 @@ def __init__(self, param: Parameter) -> None: self.__param = param def wait(self) -> None: - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().synchronize() self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -82,7 +82,7 @@ def wait(self) -> None: if self.__complete: return - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().synchronize() for param in self.__params: assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" @@ -716,6 +716,31 @@ def wait(self) -> None: handle.wait() +class AllReduceCoalescedHandle: + + def __init__(self, handle, params: List[Parameter]) -> None: + self.handle = handle + self.params = params + self.complete = False + + for param in self.params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.complete: + return + + instrument_w_nvtx(self.handle.wait)() + + for param in self.params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.complete = True + + class QuantizationInfo: # a placeholder object to store all quant related vars used in handles def __init__(self) -> None: @@ -1003,6 +1028,11 @@ def __init__( if not self.use_all_gather_into_tensor: logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}") + self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig, + "use_all_reduce_for_fetch_params") + if _ds_config is not None: + self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params + def _update_persist_config(self, ds_config): Init.apply_param_persistence = True Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold @@ -1250,75 +1280,99 @@ def all_gather_coalesced(params: Iterable[Parameter], return AllGatherHandle(handle, param, quantization=quant_info) else: - if not quantize: - dtype_params = defaultdict(list) - for p in params: - dtype_params[p.ds_tensor.dtype].append(p) - handles = [] - for dtype, params in dtype_params.items(): - handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) + if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor: + # Use all_reduce instead of all_gather to fetch the module params + flat_buffer_size = sum(p.ds_numel_aligned for p in params) + flat_tensor = torch.zeros(flat_buffer_size, + dtype=get_only_unique_item(p.ds_tensor.dtype for p in params), + device=get_accelerator().current_device_name(), + requires_grad=False) + start_param = 0 + for param in params: + param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) + start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() + flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) - return MultipleAllGatherHandles(handles) + start_param += param.ds_numel + handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True) + + return AllReduceCoalescedHandle(handle=handle, params=params) else: - partition_sz = sum(p.ds_tensor.ds_numel for p in params) + if not quantize: + dtype_params = defaultdict(list) + for p in params: + dtype_params[p.ds_tensor.dtype].append(p) + handles = [] + for dtype, params in dtype_params.items(): + handles.append( + _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) - if use_secondary_tensor: - partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + return MultipleAllGatherHandles(handles) - flat_tensor = torch.empty(partition_sz * world_size, - dtype=torch.int8, - device=get_accelerator().current_device_name(), - requires_grad=False) - - if use_secondary_tensor: - if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): - quantized_param = instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params - ]) - scales = instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) - for p in params - ]) - else: - quantized_param, scales = self.quantizer_module.quantize( - instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params - ])) else: - if hasattr(params[0].ds_tensor, "ds_quant_scale"): - quantized_param = instrument_w_nvtx(torch.cat)( - [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) - scales = instrument_w_nvtx(torch.cat)([ - p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params - ]) + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups + for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False) + + if use_secondary_tensor: + if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) + for p in params + ]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.to(get_accelerator().current_device_name()) + for p in params + ])) else: - quantized_param, scales = self.quantizer_module.quantize( - instrument_w_nvtx(torch.cat)( - [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) - quant_scale_buffer = torch.empty( - scales.numel() * world_size, - dtype=torch.float32, - device=get_accelerator().current_device_name(), - requires_grad=False, - ) - handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) - quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) - quant_info = QuantizationInfo() - quant_info.quantized_param = flat_tensor - quant_info.backend = self.quantizer_module - quant_info.quant_handle = quant_handle - quant_info.scale_buffer = quant_scale_buffer - quant_info.partition_sz = partition_sz - quant_info.world_size = world_size - return AllGatherCoalescedHandle( - allgather_handle=handle, - params=params, - partitions=None, - world_size=world_size, - use_secondary_tensor=use_secondary_tensor, - quantization=quant_info, - ) + if hasattr(params[0].ds_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=torch.float32, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) + quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = flat_tensor + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + quant_info.partition_sz = partition_sz + quant_info.world_size = world_size + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=None, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + quantization=quant_info, + ) def partition(param_list=None, hierarchy=0, has_been_updated=False): cls = param @@ -1554,6 +1608,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): param.ds_tensor.ds_numel = partition_size param.ds_tensor.status = PartitionedParamStatus.AVAILABLE param.ds_tensor.final_location = final_location + param.ds_numel_aligned = tensor_size start = partition_size * self.get_partition_rank() end = start + partition_size @@ -1682,7 +1737,8 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ', force=False) - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() print_rank_0( f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}" @@ -1815,7 +1871,8 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data # guarantee the communication to be completed - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() return None diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c6ff216edfcb..3f43e865fa72 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1409,7 +1409,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm @@ -1575,7 +1575,7 @@ def set_none_gradients_to_zero(self, i, partition_id): for param_id in self.is_grad_computed[i][partition_id]: param = self.param_dict[param_id] if param.grad is None: - param.grad = torch.zero_like(param) + param.grad = torch.zeros_like(param) ######################Reduction Related Methods############################## diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 225c085f6f2b..3d5ff5e6b43e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1474,7 +1474,7 @@ def set_none_gradients_to_zero(self, i, partition_id): for param_id in self.is_grad_computed[i][partition_id]: param = self.param_dict[param_id] if param.grad is None: - param.grad = torch.zero_like(param) + param.grad = torch.zeros_like(param) ######################Reduction Related Methods############################## def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None): diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 217d56c14812..3bd3e451ab49 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -41,7 +41,7 @@ lnav: - title: 'Flops Profiler' url: /docs/config-json/#flops-profiler - title: 'Monitoring' - url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv + url: /docs/config-json/#monitoring-module - title: 'Communication Logging' url: /docs/config-json/#communication-logging - title: 'Model Compression' diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index abe314cbb1a6..adb2f1679ea0 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1139,15 +1139,16 @@ DeepSpeed Data Efficiency Library includes two techniques: curriculum learning a | ---------------------------------------------------------------------------------------------------------------------------- | ------- | | List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A | -### Monitoring Module (TensorBoard, WandB, CSV) +### Monitoring Module **Note:** Deepspeed logs to TensorBoard through PyTorch. Logging to TensorBoard requires that the `tensorboard` package is installed (read more in the [PyTorch documentation](https://pytorch.org/docs/1.8.0/tensorboard.html)). {: .notice--warning} **Note:** Logging to WandB requires that the `wandb` package is installed (read more in the [WandB documentation](https://docs.wandb.ai/quickstart)). {: .notice--warning} +**Note:** Logging to Comet requires that the `comet_ml` package is installed (read more in the [Comet documentation](https://www.comet.com/docs/v2/guides/quickstart/#1-install-and-configure-the-comet-ml-sdk)). +{: .notice--warning} - -Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. +Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), to [Comet](https://www.comet.com/site/?utm_source=deepseed&utm_medium=docs&utm_content=docs) or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. | Field | Description |Conditions | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | @@ -1201,6 +1202,36 @@ Example of **wandb** configuration: } ``` +**comet**: [dictionary] + +| Fields | Value | Default | +|--- |--- |--- | +| enabled | Whether logging to [Comet](https://www.comet.com/site/) is enabled. | `false` | +| workspace | Comet workspace name. | `None` | +| project | Comet project name. | `None` | +| samples_log_interval | Metrics will be submitted to Comet after processing every `samples_log_intervas` samples. | `100` | +| experiment_name | The name for comet experiment to be used for logging. | `None` | +| api_key | Comet API key. It's not recommended to save the Comet API Key in code. | `None` | +| experiment_key | The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters. | `None` | +| online | If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment. Default is `True`. | `None` | +| mode | Control how the Comet experiment is started. "get": Continue logging to an existing experiment identified by the `experiment_key` value. "create": Always creates of a new experiment, useful for HPO sweeps. "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one. | `None` | + + +Example of **comet** configuration: + +```json +"comet": { + "enabled": true, + "workspace": "my_workspace", + "project": "my_project", + "samples_log_interval": 50, + "experiment_name": "llama-fine-tuning", + "experiment_key": "0c4a1c4a90664f2a8084e600b19a9d7", + "online": false, + "mode": "get", +} +``` + **csv_monitor**: [dictionary] | Fields | Value |Default | diff --git a/docs/_tutorials/monitor.md b/docs/_tutorials/monitor.md index a9c111f8eeec..572e3f4558a7 100644 --- a/docs/_tutorials/monitor.md +++ b/docs/_tutorials/monitor.md @@ -11,7 +11,7 @@ In this tutorial, we introduce the DeepSpeed Monitor and provide examples of its ## Overview -Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), and simple CSV files. +Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), [Comet](https://www.comet.com/site/?utm_source=deepseed&utm_medium=docs&utm_content=tutorial) and simple CSV files. Below is a live monitoring view for TensorBoard: @@ -21,16 +21,20 @@ Below is a live monitoring view for WandB: ![WandB Example Output](/assets/images/wandb_monitor.PNG){: .align-center} +Below is a live monitoring view for Comet: + +![CometML Example Output](/assets/images/comet_monitor.png){: .align-center} + ## Usage -The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. +The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. - [Automatic Monitoring](#automatic-monitoring) - [Custom Monitoring](#custom-monitoring) ### Automatic Monitoring -When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module-tensorboard-wandb-csv) for details. +When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module) for details. ```json { @@ -45,6 +49,11 @@ When using DeepSpeed for model training, the Monitor can be configured in the De "group": "my_group", "project": "my_project" } + "comet": { + "enabled": true, + "project": "my_project", + "experiment_name": "my_experiment" + } "csv_monitor": { "enabled": true, "output_path": "output/ds_logs/", diff --git a/docs/assets/images/comet_monitor.png b/docs/assets/images/comet_monitor.png new file mode 100644 index 000000000000..83564cd5f1eb Binary files /dev/null and b/docs/assets/images/comet_monitor.png differ diff --git a/docs/code-docs/source/monitor.rst b/docs/code-docs/source/monitor.rst index d286af23f09e..694c72b9b870 100644 --- a/docs/code-docs/source/monitor.rst +++ b/docs/code-docs/source/monitor.rst @@ -29,6 +29,11 @@ WandB .. _WandbConfig: .. autopydantic_model:: deepspeed.monitor.config.WandbConfig +Comet +----- +.. _CometConfig: +.. autopydantic_model:: deepspeed.monitor.config.CometConfig + CSV Monitor ----------- .. _CSVConfig: diff --git a/op_builder/builder.py b/op_builder/builder.py index 8dc825c7926d..4c4978c29575 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -107,6 +107,8 @@ def assert_no_cuda_mismatch(name=""): class OpBuilder(ABC): _rocm_version = None + _rocm_gpu_arch = None + _rocm_wavefront_size = None _is_rocm_pytorch = None _is_sycl_enabled = None _loaded_ops = {} @@ -229,6 +231,32 @@ def installed_rocm_version(): OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) return OpBuilder._rocm_version + @staticmethod + def get_rocm_gpu_arch(): + if OpBuilder._rocm_gpu_arch: + return OpBuilder._rocm_gpu_arch + rocm_gpu_arch_cmd = "/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*'" + try: + result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True) + rocm_gpu_arch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_gpu_arch = "" + OpBuilder._rocm_gpu_arch = rocm_gpu_arch + return OpBuilder._rocm_gpu_arch + + @staticmethod + def get_rocm_wavefront_size(): + if OpBuilder._rocm_wavefront_size: + return OpBuilder._rocm_wavefront_size + rocm_wavefront_size_cmd = "/opt/rocm/bin/rocminfo | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'" + try: + result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True) + rocm_wavefront_size = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_wavefront_size = "32" + OpBuilder._rocm_wavefront_size = rocm_wavefront_size + return OpBuilder._rocm_wavefront_size + def include_paths(self): ''' Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) @@ -517,9 +545,12 @@ def jit_load(self, verbose=True): nvcc_args.append("-DBF16_AVAILABLE") nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") if self.is_rocm_pytorch(): cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) op_module = load(name=self.name, sources=self.strip_empty_entries(sources), @@ -650,6 +681,12 @@ def builder(self): if self.is_rocm_pytorch(): compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1") + #cxx compiler args are required to compile cpp files + compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + #nvcc compiler args are required to compile hip files + compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + if self.get_rocm_gpu_arch(): + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() cuda_ext = ExtensionBuilder(name=self.absolute_name(), sources=self.strip_empty_entries(self.sources()), diff --git a/op_builder/cpu/builder.py b/op_builder/cpu/builder.py index d2bc8eacfa25..d881842ad0b1 100644 --- a/op_builder/cpu/builder.py +++ b/op_builder/cpu/builder.py @@ -30,7 +30,11 @@ def builder(self): return cpp_ext def cxx_args(self): - return ['-O3', '-g', '-Wno-reorder'] + args = ['-O3', '-g', '-Wno-reorder'] + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH] + return args def libraries_args(self): return [] diff --git a/op_builder/cpu_adagrad.py b/op_builder/cpu_adagrad.py index d3f163f7464a..c05f71488950 100644 --- a/op_builder/cpu_adagrad.py +++ b/op_builder/cpu_adagrad.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os from .builder import TorchCPUOpBuilder @@ -18,26 +17,11 @@ def absolute_name(self): return f'deepspeed.ops.adagrad.{self.NAME}_op' def sources(self): - if self.build_for_cpu: - return ['csrc/adagrad/cpu_adagrad.cpp'] - - return ['csrc/adagrad/cpu_adagrad.cpp', 'csrc/common/custom_cuda_kernel.cu'] + return ['csrc/adagrad/cpu_adagrad.cpp'] def libraries_args(self): args = super().libraries_args() - if self.build_for_cpu: - return args - - if not self.is_rocm_pytorch(): - args += ['curand'] return args def include_paths(self): - import torch - if self.build_for_cpu: - CUDA_INCLUDE = [] - elif not self.is_rocm_pytorch(): - CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [] - return ['csrc/includes'] + CUDA_INCLUDE + return ['csrc/includes'] diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 7c34c4ce43a1..7f4c0847a8c4 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os from .builder import TorchCPUOpBuilder @@ -18,27 +17,11 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - if self.build_for_cpu: - return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - - return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp', 'csrc/common/custom_cuda_kernel.cu'] + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] def libraries_args(self): args = super().libraries_args() - if self.build_for_cpu: - return args - - if not self.is_rocm_pytorch(): - args += ['curand'] - return args def include_paths(self): - import torch - if self.build_for_cpu: - CUDA_INCLUDE = [] - elif not self.is_rocm_pytorch(): - CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [] - return ['csrc/includes'] + CUDA_INCLUDE + return ['csrc/includes'] diff --git a/op_builder/cpu_lion.py b/op_builder/cpu_lion.py index 5c16d10ebb44..9a60d99773b3 100644 --- a/op_builder/cpu_lion.py +++ b/op_builder/cpu_lion.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os from .builder import TorchCPUOpBuilder @@ -18,31 +17,11 @@ def absolute_name(self): return f'deepspeed.ops.lion.{self.NAME}_op' def sources(self): - if self.build_for_cpu: - return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] - - return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu'] + return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] def libraries_args(self): args = super().libraries_args() - if self.build_for_cpu: - return args - - if not self.is_rocm_pytorch(): - args += ['curand'] - return args def include_paths(self): - import torch - if self.build_for_cpu: - CUDA_INCLUDE = [] - elif not self.is_rocm_pytorch(): - CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [ - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), - ] - return ['csrc/includes'] + CUDA_INCLUDE + return ['csrc/includes'] diff --git a/op_builder/hpu/builder.py b/op_builder/hpu/builder.py index 3c86128fffd6..c176a586ba49 100644 --- a/op_builder/hpu/builder.py +++ b/op_builder/hpu/builder.py @@ -31,7 +31,11 @@ def builder(self): return cpp_ext def cxx_args(self): - return ['-O3', '-g', '-Wno-reorder'] + args = ['-O3', '-g', '-Wno-reorder'] + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH] + return args def libraries_args(self): return [] diff --git a/op_builder/hpu/cpu_adam.py b/op_builder/hpu/cpu_adam.py index 2f3b7aefe705..58eea2698ebb 100644 --- a/op_builder/hpu/cpu_adam.py +++ b/op_builder/hpu/cpu_adam.py @@ -20,11 +20,6 @@ def absolute_name(self): def sources(self): return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - def cxx_args(self): - args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] - return args - def libraries_args(self): args = super().libraries_args() return args diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py index d77228317ddb..5acb121668e3 100644 --- a/op_builder/hpu/fused_adam.py +++ b/op_builder/hpu/fused_adam.py @@ -4,10 +4,88 @@ # DeepSpeed Team -from .builder import CPUOpBuilder +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder +try: + import torch + import math +except ImportError as e: + pass -class FusedAdamBuilder(CPUOpBuilder): + +class HPUFusedAdam: + htcore = None + is_lazy_mode = None + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + + if HPUFusedAdam.htcore is None: + from habana_frameworks.torch import core as htcore + from habana_frameworks.torch.utils.internal import is_lazy + HPUFusedAdam.htcore = htcore + HPUFusedAdam.is_lazy_mode = is_lazy() + + htcore = HPUFusedAdam.htcore + + htcore.step_closure._mark_step_if_lazy() + step_size = lr + if bias_correction: + bias_correction1 = 1.0 - pow(beta1, step) + bias_correction2 = 1.0 - pow(beta2, step) + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + neg_step = -step_size + neg_step_t = (torch.tensor([neg_step], dtype=torch.float, + requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, + non_blocking=True)) + + weight_decay = weight_decay if adam_w_mode else 0 + + # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here + # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and + # perform weight decay unconditonally. + modified_wd = 1.0 - weight_decay * lr + + if HPUFusedAdam.is_lazy_mode: + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd, + ) + else: + modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to( + tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True)) + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd_t, + modified_wd != 1.0, + ) + + htcore.step_closure._mark_step_if_lazy() + + +class FusedAdamBuilder(OpBuilder): BUILD_VAR = "DS_BUILD_FUSED_ADAM" NAME = "fused_adam" @@ -18,12 +96,10 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - - def cxx_args(self): - args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] - return args + return [] def include_paths(self): - return ['csrc/includes'] + return [] + + def load(self, verbose=True): + return HPUFusedAdam diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index eb6bfc811e85..c0fc5dba9d33 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,5 +1,6 @@ accelerate clang-format==16.0.2 +comet_ml>=3.41.0 deepspeed-kernels ; sys_platform == 'linux' docutils<0.18 future diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 80c9f9b3287a..05f88337f3a9 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,10 +1,10 @@ hjson ninja numpy +nvidia-ml-py packaging>=20.0 psutil py-cpuinfo pydantic -pynvml torch tqdm diff --git a/tests/perf/adam_test1.py b/tests/perf/adam_test1.py index b35477afb4fe..bde1d53e5179 100755 --- a/tests/perf/adam_test1.py +++ b/tests/perf/adam_test1.py @@ -6,12 +6,10 @@ import torch from deepspeed.ops.adam import DeepSpeedCPUAdam import time -from deepspeed.accelerator import get_accelerator device = 'cpu' model_size = 1 * 1024**3 param = torch.nn.Parameter(torch.ones(model_size, device=device)) -param_fp16 = torch.nn.Parameter(torch.ones(model_size, dtype=torch.half, device=get_accelerator().device_name(0))) optimizer = DeepSpeedCPUAdam([param]) #torch.set_num_threads(128) @@ -19,7 +17,7 @@ avg = 0 for i in range(100): start = time.time() - optimizer.step(fp16_param_groups=[param_fp16]) + optimizer.step() stop = time.time() avg += (stop - start) param.grad = torch.ones(model_size, device=device) * 2 diff --git a/tests/unit/common.py b/tests/unit/common.py index a2593e703aef..58bb26ca18b4 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -82,8 +82,12 @@ def set_accelerator_visible(): if match: num_accelerators += 1 elif get_accelerator().device_name() == 'hpu': - hl_smi = subprocess.check_output(['hl-smi', "-L"]) - num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode()) + try: + hl_smi = subprocess.check_output(['hl-smi', "-L"]) + num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode()) + except FileNotFoundError: + sim_list = subprocess.check_output(['ls', '-1', '/dev/accel']) + num_accelerators = re.findall(r"accel(\d+)", sim_list.decode()) num_accelerators = sorted(num_accelerators, key=int) os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators) elif get_accelerator().device_name() == 'npu': diff --git a/tests/unit/elasticity/test_elastic.py b/tests/unit/elasticity/test_elastic.py index 63633a51914b..1f7cbbbca214 100644 --- a/tests/unit/elasticity/test_elastic.py +++ b/tests/unit/elasticity/test_elastic.py @@ -12,7 +12,7 @@ from deepspeed.ops.op_builder import FusedAdamBuilder, FusedLambBuilder if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]: - pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + pytest.skip("This op has not been implemented on this system.", allow_module_level=True) @pytest.fixture @@ -150,6 +150,8 @@ def test_proper_mbsz(ds_config): class TestNonElasticBatchParams(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): config_dict = { "train_batch_size": 2, @@ -182,9 +184,9 @@ def test(self): class TestNonElasticBatchParamsWithOverride(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): - if not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME]: - pytest.skip("This op had not been implemented on this system.", allow_module_level=True) config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -215,6 +217,8 @@ def test(self): class TestElasticConfigChanged(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): config_dict = { "train_batch_size": 2, diff --git a/tests/unit/inference/test_checkpoint_sharding.py b/tests/unit/inference/test_checkpoint_sharding.py index 564b3fab6bf4..5bae9a151a27 100644 --- a/tests/unit/inference/test_checkpoint_sharding.py +++ b/tests/unit/inference/test_checkpoint_sharding.py @@ -110,7 +110,7 @@ def write_checkpoints_json(model_name, class_tmpdir): cached_repo_dir = snapshot_download( model_name, local_files_only=is_offline_mode(), - cache_dir=os.getenv("TRANSFORMERS_CACHE", None), + cache_dir=os.getenv("HF_HOME", None), ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"], ) file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 4e203a71db60..36003319856c 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -84,7 +84,7 @@ class ModelInfo: def _hf_model_list() -> List[ModelInfo]: """ Caches HF model list to avoid repeated API calls """ - cache_dir = os.getenv("TRANSFORMERS_CACHE", "~/.cache/huggingface") + cache_dir = os.getenv("HF_HOME", "~/.cache/huggingface") cache_file_path = os.path.join(cache_dir, "DS_model_cache.pkl") cache_expiration_seconds = 60 * 60 * 24 # 1 day diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index d39f9fe3d651..fdff9430a4e6 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -177,7 +177,7 @@ class TestTopk(DistributedTest): world_size = 2 def test(self): - device = get_accelerator().current_device() + device = get_accelerator().current_device_name() if dist.get_rank() == 0: logits = torch.rand(2, 2, device=device) elif dist.get_rank() == 1: diff --git a/tests/unit/monitor/test_monitor.py b/tests/unit/monitor/test_monitor.py index 3e04bebfb6c1..d4b3cf43921d 100644 --- a/tests/unit/monitor/test_monitor.py +++ b/tests/unit/monitor/test_monitor.py @@ -7,10 +7,14 @@ from deepspeed.monitor.wandb import WandbMonitor from deepspeed.monitor.csv_monitor import csvMonitor from deepspeed.monitor.config import DeepSpeedMonitorConfig +from deepspeed.monitor.comet import CometMonitor from unit.common import DistributedTest +from unittest.mock import Mock, patch from deepspeed.runtime.config import DeepSpeedConfig +import deepspeed.comm as dist + class TestTensorBoard(DistributedTest): world_size = 2 @@ -97,3 +101,66 @@ def test_empty_csv_monitor(self): assert csv_monitor.enabled == defaults.enabled assert csv_monitor.output_path == defaults.output_path assert csv_monitor.job_name == defaults.job_name + + +class TestCometMonitor(DistributedTest): + world_size = 2 + + def test_comet_monitor(self): + import comet_ml + mock_experiment = Mock() + mock_start = Mock(return_value=mock_experiment) + + config_dict = { + "train_batch_size": 2, + "comet": { + "enabled": True, + "samples_log_interval": 42, + "workspace": "some-workspace", + "project": "some-project", + "api_key": "some-api-key", + "experiment_name": "some-experiment-name", + "experiment_key": "some-experiment-key", + "mode": "get_or_create", + "online": True + } + } + + ds_config = DeepSpeedConfig(config_dict) + + with patch.object(comet_ml, "start", mock_start): + comet_monitor = CometMonitor(ds_config.monitor_config.comet) + + assert comet_monitor.enabled is True + assert comet_monitor.samples_log_interval == 42 + + # experiment should be initialized via comet_ml.start only if rank == 0 + if dist.get_rank() == 0: + mock_start.assert_called_once_with( + api_key="some-api-key", + project="some-project", + workspace="some-workspace", + experiment_key="some-experiment-key", + mode="get_or_create", + online=True, + ) + + mock_experiment.set_name.assert_called_once_with("some-experiment-name") + assert comet_monitor.experiment is mock_experiment + else: + mock_start.assert_not_called() + + def test_empty_comet(self): + import comet_ml + mock_start = Mock() + + config_dict = {"train_batch_size": 2, "comet": {}} + ds_config = DeepSpeedConfig(config_dict) + + with patch.object(comet_ml, "start", mock_start): + comet_monitor = CometMonitor(ds_config.monitor_config.comet) + + defaults = DeepSpeedMonitorConfig().comet + assert comet_monitor.enabled == defaults.enabled + assert comet_monitor.samples_log_interval == defaults.samples_log_interval + mock_start.assert_not_called() diff --git a/tests/unit/ops/accelerators/test_accelerator_backward.py b/tests/unit/ops/accelerators/test_accelerator_backward.py index 48e5fbbe7475..4b1b392e933a 100644 --- a/tests/unit/ops/accelerators/test_accelerator_backward.py +++ b/tests/unit/ops/accelerators/test_accelerator_backward.py @@ -9,12 +9,14 @@ import random import copy import os +import deepspeed from torch import nn from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed.accelerator import get_accelerator from unit.modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln from unit.modelingpreln import BertEncoder as BertEncoderPreln from unit.common import DistributedTest, is_rocm_pytorch +from deepspeed.ops.op_builder import TransformerBuilder if torch.half not in get_accelerator().supported_dtypes(): pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) @@ -257,6 +259,8 @@ class TestCUDABackward(DistributedTest): #This is to flush denorms in forward pass. Please refer to https://github.com/pytorch/pytorch/blob/main/docs/source/notes/numerical_accuracy.rst#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices os.environ['ROCBLAS_INTERNAL_FP16_ALT_IMPL'] = '1' + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[TransformerBuilder.NAME], + reason="TransformerBuilder has not been implemented on this system.") def test_backward(self, is_preln, use_fp16, batch_size, hidden_size, seq_len, heads, num_layers, atol): # Only run fp16 test cases on devices with FP16 capability. if not get_accelerator().is_fp16_supported() and (use_fp16 is True or is_preln is False): diff --git a/tests/unit/ops/accelerators/test_accelerator_forward.py b/tests/unit/ops/accelerators/test_accelerator_forward.py index ee9464f63aa1..e2f4ac177f1b 100644 --- a/tests/unit/ops/accelerators/test_accelerator_forward.py +++ b/tests/unit/ops/accelerators/test_accelerator_forward.py @@ -8,12 +8,14 @@ import pytest import random import copy +import deepspeed from torch import nn from unit.modelingpreln import BertEncoder as BertEncoderPreln from unit.modeling import BertLayerNorm, BertConfig, BertEncoder as BertEncoderPostln from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest +from deepspeed.ops.op_builder import TransformerBuilder if torch.half not in get_accelerator().supported_dtypes(): pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) @@ -260,6 +262,8 @@ def test_forward(self, batch_size, hidden_size, seq_len, heads, num_layers, is_p class TestCUDAForwardSmallBatchSize(DistributedTest): world_size = 1 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[TransformerBuilder.NAME], + reason="TransformerBuilder has not been implemented on this system.") def test_forward_with_small_bsz(self, batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16): # Only run fp16 test cases on devices with FP16 capability. diff --git a/tests/unit/ops/adagrad/test_cpu_adagrad.py b/tests/unit/ops/adagrad/test_cpu_adagrad.py index 99e934e2efda..0c675ecd6a85 100644 --- a/tests/unit/ops/adagrad/test_cpu_adagrad.py +++ b/tests/unit/ops/adagrad/test_cpu_adagrad.py @@ -18,8 +18,8 @@ def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() + x = first.detach().float().numpy() + y = second.detach().float().numpy() if verbose: print("x = {}".format(x.flatten())) print("y = {}".format(y.flatten())) diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 9a6ff6689446..851485440428 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -11,7 +11,7 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.adam import FusedAdam -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder from unit.common import DistributedTest if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: @@ -21,8 +21,8 @@ def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() + x = first.detach().float().numpy() + y = second.detach().float().numpy() print("ATOL", atol) if verbose: print("x = {}".format(x.flatten())) @@ -43,7 +43,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [ (64), @@ -62,7 +62,12 @@ class TestCPUAdam(DistributedTest): set_dist_env = False @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test_fused_adam_equal(self, dtype, model_size): + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"dtype {dtype} not supported in current accelerator") + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") @@ -89,6 +94,8 @@ def test_fused_adam_equal(self, dtype, model_size): def test_torch_adamw_equal(self, dtype, model_size): if get_accelerator().is_available(): + if dtype == torch.half: + pytest.skip("torch.optim.AdamW with half precision inf/nan output.") if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") ref_param_device = get_accelerator().device_name() @@ -97,20 +104,20 @@ def test_torch_adamw_equal(self, dtype, model_size): pytest.skip("torch.optim.AdamW with half precision only supported in CUDA environments.") ref_param_device = 'cpu' - from deepspeed.ops.adam import DeepSpeedCPUAdam + from deepspeed.ops.adam import DeepSpeedCPUAdam - cpu_data = torch.randn(model_size, device='cpu').to(dtype) - cpu_param = torch.nn.Parameter(cpu_data) - ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device)) + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + cpu_param = torch.nn.Parameter(cpu_data) + ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device)) - cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) - ref_optimizer = torch.optim.AdamW([ref_param]) + cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) + ref_optimizer = torch.optim.AdamW([ref_param]) - _compare_optimizers(model_size=model_size, - param1=cpu_param, - optimizer1=cpu_optimizer, - param2=ref_param, - optimizer2=ref_optimizer) + _compare_optimizers(model_size=model_size, + param1=cpu_param, + optimizer1=cpu_optimizer, + param2=ref_param, + optimizer2=ref_optimizer) class TestCPUAdamGPUError(DistributedTest): diff --git a/tests/unit/ops/adam/test_hybrid_adam.py b/tests/unit/ops/adam/test_hybrid_adam.py index c7ef4890b322..652090d5b9d5 100644 --- a/tests/unit/ops/adam/test_hybrid_adam.py +++ b/tests/unit/ops/adam/test_hybrid_adam.py @@ -12,7 +12,7 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder from unit.common import DistributedTest if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: @@ -22,8 +22,8 @@ def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() + x = first.detach().float().numpy() + y = second.detach().float().numpy() print("ATOL", atol) if verbose: print("x = {}".format(x.flatten())) @@ -32,7 +32,7 @@ def check_equal(first, second, atol=1e-2, verbose=False): np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [8, 16]) class TestHybridAdam(DistributedTest): world_size = 1 @@ -43,6 +43,8 @@ class TestHybridAdam(DistributedTest): set_dist_env = False @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test_hybrid_adam_equal(self, dtype, model_size): if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") diff --git a/tests/unit/ops/lion/test_cpu_lion.py b/tests/unit/ops/lion/test_cpu_lion.py index 61a069af3257..dce027e286fb 100644 --- a/tests/unit/ops/lion/test_cpu_lion.py +++ b/tests/unit/ops/lion/test_cpu_lion.py @@ -14,15 +14,12 @@ from deepspeed.ops.op_builder import CPULionBuilder from unit.common import DistributedTest -if not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME]: - pytest.skip("cpu-lion is not compatible", allow_module_level=True) - pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() + x = first.detach().float().numpy() + y = second.detach().float().numpy() print("ATOL", atol) if verbose: print("x = {}".format(x.flatten())) @@ -43,7 +40,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [ (64), @@ -62,6 +59,8 @@ class TestCPULion(DistributedTest): set_dist_env = False @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME], + reason="CPULionBuilder has not been implemented on this system.") def test_fused_lion_equal(self, dtype, model_size): if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-lion with half precision not supported on AMD CPUs") @@ -84,6 +83,8 @@ def test_fused_lion_equal(self, dtype, model_size): class TestCPULionGPUError(DistributedTest): + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME], + reason="CPULionBuilder has not been implemented on this system.") def test_cpu_lion_gpu_error(self): model_size = 64 from deepspeed.ops.lion import DeepSpeedCPULion diff --git a/tests/unit/ops/lion/test_lion.py b/tests/unit/ops/lion/test_lion.py index b2c3ac2f52df..507ff72ea51a 100644 --- a/tests/unit/ops/lion/test_lion.py +++ b/tests/unit/ops/lion/test_lion.py @@ -12,6 +12,7 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import CPULionBuilder if torch.half not in get_accelerator().supported_dtypes(): pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) @@ -27,6 +28,7 @@ class TestLionConfigs(DistributedTest): world_size = 1 reuse_dist_env = True + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME], reason="CPULionBuilder has not been implemented on this system.") def test(self, optimizer, zero_offload, diff --git a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py index f350e08e68a7..4b263172261c 100644 --- a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py +++ b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py @@ -10,6 +10,7 @@ import numpy as np from unit.common import DistributedTest from unit.simple_model import SimpleModel +from deepspeed.ops.op_builder import FusedLambBuilder def run_model_step(model, gradient_list): @@ -152,6 +153,8 @@ def test_some_overflow(self): assert optim.cur_iter == expected_iteration +@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") class TestUnfused(DistributedTest): world_size = 1 diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index 5b300053d2a8..dba15a969459 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -12,7 +12,7 @@ from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, SimpleMoEModel, sequence_dataloader from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedLambBuilder from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer try: @@ -22,10 +22,15 @@ _amp_available = False amp_available = pytest.mark.skipif(not _amp_available, reason="apex/amp is not installed") +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) + class TestLambFP32GradClip(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") @@ -58,6 +63,8 @@ def test(self): class TestLambFP16(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test__basic(self): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") @@ -85,6 +92,8 @@ def test__basic(self): model.backward(loss) model.step() + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test_empty_grad(self): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") @@ -232,6 +241,8 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True) engine.step() @pytest.mark.parametrize("fused_lamb_legacy", [(False), (True)]) + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") @@ -495,6 +506,8 @@ def test_adam_basic(self): model.backward(loss) model.step() + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system") def test_lamb_basic(self): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 169096a6d4e5..9ff99f169f7a 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -20,6 +20,7 @@ from deepspeed.runtime.utils import see_memory_usage from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import FusedAdamBuilder @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -67,6 +68,9 @@ def test(self, optimizer_type): def _optimizer_callable(params) -> Optimizer: return AdamW(params=params) + if (optimizer_type is None) and (not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]): + pytest.skip("FusedAdam is not compatible") + hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -95,6 +99,8 @@ def _optimizer_callable(params) -> Optimizer: class TestConfigOptimizer(DistributedTest): world_size = 1 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test(self, client_parameters): ds_config = {"train_batch_size": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}}