Skip to content

Commit

Permalink
feature: mlu (#70)
Browse files Browse the repository at this point in the history
* feature: support MLU

* fix

* fix

* fix: env

* fix

* fix: to_device

* move device

* fix

* fix: launch_runner
  • Loading branch information
cnstark authored Oct 25, 2022
1 parent 828a3b2 commit 983b725
Show file tree
Hide file tree
Showing 18 changed files with 271 additions and 114 deletions.
3 changes: 1 addition & 2 deletions easytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
from .version import __version__

__all__ = [
'Config', 'import_config', 'Runner', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner',
'launch_training', '__version__'
'Config', 'import_config', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner', 'launch_training', '__version__'
]
16 changes: 9 additions & 7 deletions easytorch/core/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch

from ..utils import get_logger, get_local_rank
from ..device import get_device_type


DEFAULT_LOGGER = get_logger('easytorch-checkpoint')

Expand All @@ -28,16 +30,14 @@ def get_last_ckpt_path(ckpt_save_dir: str, name_pattern: str = r'^.+_[\d]*.pt$')
return os.path.join(ckpt_save_dir, ckpt_list[-1])


def load_ckpt(ckpt_save_dir: str, ckpt_path: str = None, use_gpu: bool = True,
logger: Logger = DEFAULT_LOGGER) -> Dict:
def load_ckpt(ckpt_save_dir: str, ckpt_path: str = None, logger: Logger = DEFAULT_LOGGER) -> Dict:
"""Load checkpoint
if param `ckpt_path` is None, load the last checkpoint in `ckpt_save_dir`,
else load checkpoint from `ckpt_path`
Args:
ckpt_save_dir (str): checkpoint save directory
ckpt_path (str): checkpoint path, default is None
use_gpu (bool): set to ``True`` to load checkpoint to GPU
logger (Logger): logger, default is Logger('easytorch')
Returns:
Expand All @@ -46,10 +46,12 @@ def load_ckpt(ckpt_save_dir: str, ckpt_path: str = None, use_gpu: bool = True,

if ckpt_path is None:
ckpt_path = get_last_ckpt_path(ckpt_save_dir)
if use_gpu:
map_location = 'cuda:{}'.format(get_local_rank())
else:
map_location = 'cpu'
map_location = {
'gpu': 'cuda:{}'.format(get_local_rank()),
'mlu': None,
'cpu': 'cpu'
}[get_device_type()]

logger.info('Loading Checkpoint from \'{}\''.format(ckpt_path))
return torch.load(ckpt_path, map_location=map_location)

Expand Down
25 changes: 4 additions & 21 deletions easytorch/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .optimizer_builder import build_optim, build_lr_scheduler
from ..config import Config, get_ckpt_save_dir
from ..utils import TimePredictor, get_logger, get_local_rank, is_master, master_only, set_env
from ..device import to_device


class Runner(metaclass=ABCMeta):
Expand All @@ -32,7 +33,6 @@ def __init__(self, cfg: Config):
set_env(cfg.get('ENV', {}))

# param
self.use_gpu = cfg.get('GPU_NUM', 0) != 0
self.model_name = cfg['MODEL.NAME']
self.ckpt_save_dir = get_ckpt_save_dir(cfg)
self.logger.info('Set ckpt save dir: \'{}\''.format(self.ckpt_save_dir))
Expand Down Expand Up @@ -86,22 +86,6 @@ def init_logger(self, logger: logging.Logger = None, logger_name: str = None,
else:
raise TypeError('At least one of logger and logger_name is not None')

def to_running_device(self, src: Union[torch.Tensor, nn.Module]) -> Union[torch.Tensor, nn.Module]:
"""Move `src` to the running device. If `self.use_gpu` is ```True```,
the running device is GPU, else the running device is CPU.
Args:
src (Union[torch.Tensor, nn.Module]): source
Returns:
target (Union[torch.Tensor, nn.Module])
"""

if self.use_gpu:
return src.cuda()
else:
return src.cpu()

@staticmethod
@abstractmethod
def define_model(cfg: Config) -> nn.Module:
Expand Down Expand Up @@ -198,7 +182,7 @@ def build_model(self, cfg: Config) -> nn.Module:

self.logger.info('Building model.')
model = self.define_model(cfg)
model = self.to_running_device(model)
model = to_device(model)
if torch.distributed.is_initialized():
model = DDP(
model,
Expand Down Expand Up @@ -273,7 +257,7 @@ def load_model_resume(self, strict: bool = True):
"""

try:
checkpoint_dict = load_ckpt(self.ckpt_save_dir, use_gpu=self.use_gpu, logger=self.logger)
checkpoint_dict = load_ckpt(self.ckpt_save_dir, logger=self.logger)
if isinstance(self.model, DDP):
self.model.module.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict)
else:
Expand Down Expand Up @@ -301,8 +285,7 @@ def load_model(self, ckpt_path: str = None, strict: bool = True):
"""

try:
checkpoint_dict = load_ckpt(self.ckpt_save_dir, ckpt_path=ckpt_path, use_gpu=self.use_gpu,
logger=self.logger)
checkpoint_dict = load_ckpt(self.ckpt_save_dir, ckpt_path=ckpt_path, logger=self.logger)
if isinstance(self.model, DDP):
self.model.module.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict)
else:
Expand Down
74 changes: 74 additions & 0 deletions easytorch/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Union

import torch
from torch import nn

__all__ = [
'get_device_type', 'set_device_type', 'get_device_count', 'set_device', 'to_device', 'set_device_manual_seed'
]

_DEVICE_TYPE = 'gpu'


def get_device_type() -> str:
return _DEVICE_TYPE


def set_device_type(device_type: str):
global _DEVICE_TYPE
if device_type not in ['gpu', 'mlu', 'cpu']:
raise ValueError('Unknown device type!')
if device_type == 'mlu':
__import__('torch_mlu')
_DEVICE_TYPE = device_type


def get_device_count() -> int:
if _DEVICE_TYPE == 'gpu':
return torch.cuda.device_count()
elif _DEVICE_TYPE == 'mlu':
torch_mlu = __import__('torch_mlu')
return torch_mlu.mlu_model.device_count()
elif _DEVICE_TYPE == 'cpu':
return 0
else:
raise ValueError('Unknown device type!')


def set_device(device_id: int):
if _DEVICE_TYPE == 'gpu':
torch.cuda.set_device(device_id)
elif _DEVICE_TYPE == 'mlu':
torch_mlu = __import__('torch_mlu')
torch_mlu.mlu_model.set_device(device_id)
else:
raise ValueError('Unknown device type!')


def to_device(src: Union[torch.Tensor, nn.Module], device_id: int = None) -> Union[torch.Tensor, nn.Module]:
if _DEVICE_TYPE == 'gpu':
if device_id is None:
return src.cuda()
else:
return src.to('cuda:{:d}'.format(device_id))
elif _DEVICE_TYPE == 'mlu':
__import__('torch_mlu')
if device_id is None:
return src.mlu()
else:
return src.to('mlu:{:d}'.format(device_id))
elif _DEVICE_TYPE == 'cpu':
return src.cpu()
else:
raise ValueError('Unknown device type!')


def set_device_manual_seed(seed: int):
torch.manual_seed(seed)
if _DEVICE_TYPE == 'gpu':
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif _DEVICE_TYPE == 'mlu':
torch_mlu = __import__('torch_mlu')
torch_mlu.mlu_model.manual_seed(seed)
torch_mlu.mlu_model.manual_seed_all(seed)
4 changes: 2 additions & 2 deletions easytorch/entry_points/easytrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def parse_args():
parser = ArgumentParser(description='Welcome to EasyTorch!')
parser.add_argument('-c', '--cfg', help='training config', required=True)
parser.add_argument('--node-rank', default=0, type=int, help='node rank for distributed training')
parser.add_argument('--gpus', help='visible gpus', type=str)
parser.add_argument('--devices', help='visible devices', type=str)
return parser.parse_args()


Expand All @@ -22,4 +22,4 @@ def easytrain():
args = parse_args()

# train
launch_training(args.cfg, args.gpus, args.node_rank)
launch_training(args.cfg, args.devices, args.node_rank)
33 changes: 18 additions & 15 deletions easytorch/launcher/dist_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from ..utils import get_logger
from ..device import get_device_type, set_device_type, get_device_count, set_device


def dist_func(local_rank: int, dist_params: Dict[str, Any], func: Callable, *args):
Expand All @@ -18,7 +19,7 @@ def dist_func(local_rank: int, dist_params: Dict[str, Any], func: Callable, *arg

logger = get_logger('easytorch-launcher')

rank = dist_params['gpu_num'] * dist_params['node_rank'] + local_rank
rank = dist_params['device_num'] * dist_params['node_rank'] + local_rank
logger.info(
'Launching in distributed mode. Distributed parameters:'\
'word_size={:d}, node_rank={:d}, rank={:d}, local_rank={:d}, dist_backend={}, init_method={}'.format(
Expand All @@ -27,22 +28,24 @@ def dist_func(local_rank: int, dist_params: Dict[str, Any], func: Callable, *arg
)
)

set_device_type(dist_params['device_type'])

torch.distributed.init_process_group(
backend=dist_params['dist_backend'],
init_method=dist_params['init_method'],
rank=rank,
world_size=dist_params['word_size']
)

torch.cuda.set_device(local_rank)
set_device(local_rank)

args, kwargs = args
func(*args, **kwargs)


def dist_wrap(func: Callable,
node_num: int = 1,
gpu_num: int = 1,
device_num: int = 1,
node_rank: int = 0,
dist_backend: Optional[Union[str, torch.distributed.Backend]] = None,
init_method: Optional[str] = None) -> Callable:
Expand All @@ -55,7 +58,7 @@ def dist_wrap(func: Callable,
>>> function_dist = dist_wrap(
>>> function,
>>> node_num=node_num,
>>> gpu_num=gpu_num,
>>> device_num=device_num,
>>> node_rank=node_rank,
>>> dist_backend=dist_backend,
>>> init_method=init_method
Expand All @@ -65,7 +68,7 @@ def dist_wrap(func: Callable,
Args:
func (Callable): The function.
node_num (int, optional): Number of node. Defaults to 1.
gpu_num (int, optional): Number of gpus per node. Defaults to 1.
device_num (int, optional): Number of devices per node. Defaults to 1.
node_rank (int, optional): Rank of current node. Defaults to 0.
dist_backend (Optional[Union[str, distributed.Backend]], optional): The backend of DDP.
Defaults to None, means using `nccl` as the backend.
Expand All @@ -79,23 +82,22 @@ def dist_wrap(func: Callable,
if node_num < 1:
raise ValueError('The node_num must be greater than 1!')

if gpu_num < 0:
raise ValueError('The gpu_num must be greater than 0!')
if device_num < 0:
raise ValueError('The device_num must be greater than 0!')

word_size = node_num * gpu_num
word_size = node_num * device_num

if word_size == 0:
# CPU mode
return func
else:
# GPU mode
# DEVICE mode
if node_rank >= node_num:
raise ValueError('The node_rank must be less than dist_node_num!')

if gpu_num != torch.cuda.device_count():
raise RuntimeError('GPU num not match, cfg.GPU_NUM = {:d}, but torch.cuda.device_count() = {:d}'.format(
gpu_num, torch.cuda.device_count()
))
if device_num != get_device_count():
raise RuntimeError('Device num not match, cfg.DEVICE_NUM = {:d}, ' \
'but torch.cuda.device_count() = {:d}'.format(device_num, get_device_count()))

if word_size == 1:
return func
Expand All @@ -112,7 +114,8 @@ def dist_wrap(func: Callable,
@functools.wraps(func)
def wrapper(*args, **kwargs):
dist_params = {
'gpu_num': gpu_num,
'device_type': get_device_type(),
'device_num': device_num,
'node_rank': node_rank,
'word_size': word_size,
'dist_backend': dist_backend,
Expand All @@ -122,7 +125,7 @@ def wrapper(*args, **kwargs):
torch.multiprocessing.spawn(
dist_func,
args=(dist_params, func, args, kwargs),
nprocs=gpu_num,
nprocs=device_num,
join=True
)

Expand Down
Loading

0 comments on commit 983b725

Please sign in to comment.