From c6c2be51e75a1955875097488b4ad1c2ff7db559 Mon Sep 17 00:00:00 2001 From: zgplvyou <18756963918@163.com> Date: Tue, 26 Apr 2022 19:23:54 +0800 Subject: [PATCH] feat(mlu): Support MLU Device on MMaction2 --- mmaction/apis/train.py | 22 ++++--- mmaction/core/dist_utils.py | 4 +- mmaction/utils/__init__.py | 3 +- mmaction/utils/distribution_env.py | 94 ++++++++++++++++++++++++++++++ tools/test.py | 17 +++--- 5 files changed, 122 insertions(+), 18 deletions(-) create mode 100644 mmaction/utils/distribution_env.py diff --git a/mmaction/apis/train.py b/mmaction/apis/train.py index 5498a286e9..f38b5efeab 100644 --- a/mmaction/apis/train.py +++ b/mmaction/apis/train.py @@ -7,7 +7,6 @@ import numpy as np import torch import torch.distributed as dist -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, build_optimizer, get_dist_info) from mmcv.runner.hooks import Fp16OptimizerHook @@ -15,11 +14,12 @@ from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook, OmniSourceRunner) from ..datasets import build_dataloader, build_dataset -from ..utils import PreciseBNHook, get_root_logger +from ..utils import (PreciseBNHook, build_ddp, build_dp, default_device, + get_root_logger) from .test import multi_gpu_test -def init_random_seed(seed=None, device='cuda', distributed=True): +def init_random_seed(seed=None, device=default_device, distributed=True): """Initialize random seed. If the seed is not set, the seed will be automatically randomized, @@ -122,13 +122,17 @@ def train_model(model, find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) + + model = build_ddp( + model, + default_device, + default_args=dict( + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters)) else: - model = MMDataParallel(model, device_ids=cfg.gpu_ids) + model = build_dp( + model, default_device, default_args=dict(device_ids=cfg.gpu_ids)) # build runner optimizer = build_optimizer(model, cfg.optimizer) diff --git a/mmaction/core/dist_utils.py b/mmaction/core/dist_utils.py index 32f57b6245..cae452d9bd 100644 --- a/mmaction/core/dist_utils.py +++ b/mmaction/core/dist_utils.py @@ -4,8 +4,10 @@ import torch.distributed as dist from mmcv.runner import get_dist_info +from ..utils import default_device -def sync_random_seed(seed=None, device='cuda'): + +def sync_random_seed(seed=None, device=default_device): """Make sure different ranks share the same seed. All workers must call this function, otherwise it will deadlock. This method is generally used in `DistributedSampler`, because the seed should be identical across all diff --git a/mmaction/utils/__init__.py b/mmaction/utils/__init__.py index 393a1d1325..a1bbbb761a 100644 --- a/mmaction/utils/__init__.py +++ b/mmaction/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env +from .distribution_env import build_ddp, build_dp, default_device from .gradcam_utils import GradCAM from .logger import get_root_logger from .misc import get_random_string, get_shm_dir, get_thread_id @@ -10,5 +11,5 @@ __all__ = [ 'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id', 'get_shm_dir', 'GradCAM', 'PreciseBNHook', 'register_module_hooks', - 'setup_multi_processes' + 'setup_multi_processes', 'build_ddp', 'build_dp', 'default_device' ] diff --git a/mmaction/utils/distribution_env.py b/mmaction/utils/distribution_env.py new file mode 100644 index 0000000000..6e241e032e --- /dev/null +++ b/mmaction/utils/distribution_env.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel + +dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} + +ddp_factory = {'cuda': MMDistributedDataParallel} + + +def build_dp(model, device='cuda', default_args=None): + """build DataParallel module by device type. + + if device is cuda, return a MMDataParallel model; if device is mlu, + return a MLUDataParallel model. + Args: + model(nn.Module): model to be parallelized. + device(str): device type, cuda, cpu or mlu. Defaults to cuda. + default_args: dict type, include the following parameters. + device_ids(int): device ids of modules to be scattered to. + Defaults to None when GPU or MLU is not available. + Returns: + model(nn.Module): the model to be parallelized. + """ + + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + from mmcv.device.mlu import MLUDataParallel + dp_factory['mlu'] = MLUDataParallel + model = model.mlu() + + return dp_factory[device](model, **default_args) + + +def build_ddp(model, device='cuda', default_args=None): + """Build DistributedDataParallel module by device type. + If device is cuda, return a MMDistributedDataParallel model; + if device is mlu, return a MLUDistributedDataParallel model. + Args: + model(:class:`nn.Moudle`): module to be parallelized. + device(str): device type, mlu or cuda. + default_args: dict type, include the following parameters. + device_ids(int): which represents the only device where the input + module corresponding to this process resides. Defaults to None. + broadcast_buffers(bool): Flag that enables syncing (broadcasting) + buffers of the module at beginning of the forward function. + Defaults to True. + find_unused_parameters(bool): Traverse the autograd graph of all + tensors contained in the return value of the wrapped module's + ``forward`` function. + Parameters that don't receive gradients as part of this graph + are preemptively marked as being ready to be reduced. Note that + all ``forward`` outputs that are derived from module parameters + must participate in calculating loss and later the gradient + computation. If they don't, this wrapper will hang waiting + for autograd to produce gradients for those parameters. Any + outputs derived from module parameters that are otherwise + unused can be detached from the autograd graph using + ``torch.Tensor.detach``. Defaults to False. + Returns: + model(nn.Module): the module to be parallelized + References: + .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. + DistributedDataParallel.html + """ + + assert device in ['cuda', 'mlu' + ], 'Only available for cuda or mlu devices currently.' + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + from mmcv.device.mlu import MLUDistributedDataParallel + ddp_factory['mlu'] = MLUDistributedDataParallel + model = model.mlu() + + return ddp_factory[device](model, **default_args) + + +def is_mlu_available(): + """Returns a bool indicating if MLU is currently available.""" + return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + + +def get_device(): + """Returns an available device, cpu, cuda or mlu.""" + is_device_available = { + 'cuda': torch.cuda.is_available(), + 'mlu': is_mlu_available() + } + device_list = [k for k, v in is_device_available.items() if v] + return device_list[0] if len(device_list) == 1 else 'cpu' + + +default_device = get_device() diff --git a/tools/test.py b/tools/test.py index 31514498cf..6b52e9fd1f 100644 --- a/tools/test.py +++ b/tools/test.py @@ -9,13 +9,13 @@ from mmcv import Config, DictAction from mmcv.cnn import fuse_conv_bn from mmcv.fileio.io import file_handlers -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint from mmcv.runner.fp16_utils import wrap_fp16_model from mmaction.datasets import build_dataloader, build_dataset from mmaction.models import build_model -from mmaction.utils import register_module_hooks, setup_multi_processes +from mmaction.utils import (build_ddp, build_dp, default_device, + register_module_hooks, setup_multi_processes) # TODO import test functions from mmcv and delete them from mmaction2 try: @@ -157,13 +157,16 @@ def inference_pytorch(args, cfg, distributed, data_loader): model = fuse_conv_bn(model) if not distributed: - model = MMDataParallel(model, device_ids=[0]) + model = build_dp( + model, default_device, default_args=dict(device_ids=cfg.gpu_ids)) outputs = single_gpu_test(model, data_loader) else: - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False) + model = build_ddp( + model, + default_device, + default_args=dict( + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False)) outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect)