Skip to content

Commit

Permalink
[npu] add npu support for gemini and zero (hpcaitech#5067)
Browse files Browse the repository at this point in the history
* [npu] setup device utils (hpcaitech#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (hpcaitech#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (hpcaitech#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
  • Loading branch information
ver217 authored Nov 20, 2023
1 parent 8d56c9c commit e5ce4c8
Show file tree
Hide file tree
Showing 46 changed files with 989 additions and 228 deletions.
3 changes: 2 additions & 1 deletion colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor

from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device

__all__ = ["BaseGradScaler"]

Expand All @@ -22,7 +23,7 @@ class BaseGradScaler(ABC):

def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._verbose = verbose

if self._verbose:
Expand Down
8 changes: 5 additions & 3 deletions colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch

from colossalai.utils.device import get_current_device

from .base_grad_scaler import BaseGradScaler

__all__ = ["DynamicGradScaler"]
Expand Down Expand Up @@ -37,12 +39,12 @@ def __init__(
):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
else:
self._min_scale = None

if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
else:
self._max_scale = None

Expand Down Expand Up @@ -115,7 +117,7 @@ def state_dict(self):
return state_dict

def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._scale = state_dict["scale"].to(get_current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/offload/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.fx.node import Node

from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device

from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
Expand Down
10 changes: 8 additions & 2 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

Expand All @@ -37,6 +38,7 @@

ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2


def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
Expand All @@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
start_index += len(group["params"])

return param_info


class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -359,6 +363,8 @@ def __init__(
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
Expand Down Expand Up @@ -437,7 +443,7 @@ def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]

def configure(
self,
Expand Down Expand Up @@ -485,4 +491,4 @@ def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
raise NotImplementedError
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]

def configure(
self,
Expand Down
7 changes: 5 additions & 2 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from colossalai.context import Config
from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed


def launch(
Expand Down Expand Up @@ -47,12 +47,15 @@ def launch(
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")

if IS_NPU_AVAILABLE and backend == "nccl":
backend = "hccl"

# init default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)

# set cuda device
if torch.cuda.is_available():
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically
set_device(local_rank)

Expand Down
2 changes: 2 additions & 0 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Adam_Optimizer {
}
}

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
Expand All @@ -159,6 +160,7 @@ class Adam_Optimizer {
SIMD_STORE(ptr, data.data);
}
}
#endif

void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
Expand Down
Loading

0 comments on commit e5ce4c8

Please sign in to comment.