Skip to content

Commit

Permalink
[Feature] Support MLU Device on MMaction2 (#1608)
Browse files Browse the repository at this point in the history
  • Loading branch information
zgplvyou authored Apr 29, 2022
1 parent 1af84d0 commit 5e853b1
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 18 deletions.
22 changes: 13 additions & 9 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer, get_dist_info)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook,
OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import PreciseBNHook, get_root_logger
from ..utils import (PreciseBNHook, build_ddp, build_dp, default_device,
get_root_logger)
from .test import multi_gpu_test


def init_random_seed(seed=None, device='cuda', distributed=True):
def init_random_seed(seed=None, device=default_device, distributed=True):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
Expand Down Expand Up @@ -122,13 +122,17 @@ def train_model(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)

model = build_ddp(
model,
default_device,
default_args=dict(
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters))
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(
model, default_device, default_args=dict(device_ids=cfg.gpu_ids))

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down
4 changes: 3 additions & 1 deletion mmaction/core/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import torch.distributed as dist
from mmcv.runner import get_dist_info

from ..utils import default_device

def sync_random_seed(seed=None, device='cuda'):

def sync_random_seed(seed=None, device=default_device):
"""Make sure different ranks share the same seed. All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
Expand Down
3 changes: 2 additions & 1 deletion mmaction/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .distribution_env import build_ddp, build_dp, default_device
from .gradcam_utils import GradCAM
from .logger import get_root_logger
from .misc import get_random_string, get_shm_dir, get_thread_id
Expand All @@ -10,5 +11,5 @@
__all__ = [
'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id',
'get_shm_dir', 'GradCAM', 'PreciseBNHook', 'register_module_hooks',
'setup_multi_processes'
'setup_multi_processes', 'build_ddp', 'build_dp', 'default_device'
]
94 changes: 94 additions & 0 deletions mmaction/utils/distribution_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}

ddp_factory = {'cuda': MMDistributedDataParallel}


def build_dp(model, device='cuda', default_args=None):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel model; if device is mlu,
return a MLUDataParallel model.
Args:
model(nn.Module): model to be parallelized.
device(str): device type, cuda, cpu or mlu. Defaults to cuda.
default_args: dict type, include the following parameters.
device_ids(int): device ids of modules to be scattered to.
Defaults to None when GPU or MLU is not available.
Returns:
model(nn.Module): the model to be parallelized.
"""

if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
from mmcv.device.mlu import MLUDataParallel
dp_factory['mlu'] = MLUDataParallel
model = model.mlu()

return dp_factory[device](model, **default_args)


def build_ddp(model, device='cuda', default_args=None):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel model;
if device is mlu, return a MLUDistributedDataParallel model.
Args:
model(:class:`nn.Moudle`): module to be parallelized.
device(str): device type, mlu or cuda.
default_args: dict type, include the following parameters.
device_ids(int): which represents the only device where the input
module corresponding to this process resides. Defaults to None.
broadcast_buffers(bool): Flag that enables syncing (broadcasting)
buffers of the module at beginning of the forward function.
Defaults to True.
find_unused_parameters(bool): Traverse the autograd graph of all
tensors contained in the return value of the wrapped module's
``forward`` function.
Parameters that don't receive gradients as part of this graph
are preemptively marked as being ready to be reduced. Note that
all ``forward`` outputs that are derived from module parameters
must participate in calculating loss and later the gradient
computation. If they don't, this wrapper will hang waiting
for autograd to produce gradients for those parameters. Any
outputs derived from module parameters that are otherwise
unused can be detached from the autograd graph using
``torch.Tensor.detach``. Defaults to False.
Returns:
model(nn.Module): the module to be parallelized
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""

assert device in ['cuda', 'mlu'
], 'Only available for cuda or mlu devices currently.'
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
from mmcv.device.mlu import MLUDistributedDataParallel
ddp_factory['mlu'] = MLUDistributedDataParallel
model = model.mlu()

return ddp_factory[device](model, **default_args)


def is_mlu_available():
"""Returns a bool indicating if MLU is currently available."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()


def get_device():
"""Returns an available device, cpu, cuda or mlu."""
is_device_available = {
'cuda': torch.cuda.is_available(),
'mlu': is_mlu_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) == 1 else 'cpu'


default_device = get_device()
17 changes: 10 additions & 7 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.fileio.io import file_handlers
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcv.runner.fp16_utils import wrap_fp16_model

from mmaction.datasets import build_dataloader, build_dataset
from mmaction.models import build_model
from mmaction.utils import register_module_hooks, setup_multi_processes
from mmaction.utils import (build_ddp, build_dp, default_device,
register_module_hooks, setup_multi_processes)

# TODO import test functions from mmcv and delete them from mmaction2
try:
Expand Down Expand Up @@ -157,13 +157,16 @@ def inference_pytorch(args, cfg, distributed, data_loader):
model = fuse_conv_bn(model)

if not distributed:
model = MMDataParallel(model, device_ids=[0])
model = build_dp(
model, default_device, default_args=dict(device_ids=cfg.gpu_ids))
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
model = build_ddp(
model,
default_device,
default_args=dict(
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False))
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)

Expand Down

0 comments on commit 5e853b1

Please sign in to comment.