Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch: update ddp #432

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions byteps/torch/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,51 @@
from torch.cuda._utils import _get_device_index
import os

def recursive_traverse_grad_fn(fn, seen_fns, seen_params):
''' tranverse a grad fn recursively. '''
if fn in seen_fns:
return
seen_fns.add(fn)

# record tensors
if hasattr(fn, 'variable') and isinstance(fn.variable, torch.nn.Parameter):
seen_params.add(fn.variable)

# recursively tranverse
if hasattr(fn, 'next_functions'):
for u in fn.next_functions:
if u[0] is not None:
recursive_traverse_grad_fn(u[0], seen_fns, seen_params)
if hasattr(fn, 'saved_tensors'):
for t in fn.saved_tensors:
recursive_traverse_grad_fn(t, seen_fns, seen_params)

def find_parameters(tensors):
''' find paramters in the autograd graph for this tensor. '''
if isinstance(tensors, torch.Tensor):
tensors = [tensors]

grad_fns = set()
params = set()
for tensor in tensors:
if not isinstance(tensor, torch.Tensor):
continue
recursive_traverse_grad_fn(tensor.grad_fn, grad_fns, params)
return params, grad_fns

def _find_tensors(obj):
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []


class DistributedDataParallel(Module):
r"""Implements distributed data parallelism that is based on
byteps push-pull.
Expand Down Expand Up @@ -125,10 +170,13 @@ def __init__(self, module, device_ids=None,
):
super(DistributedDataParallel, self).__init__()

assert device_ids and len(device_ids) == 1, (
"DistributedDataParallel device_ids contain exactlyone entry,"
" but got {}.").format(device_ids)
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
if device_ids is None:
self.device_ids = None
else:
assert device_ids and len(device_ids) == 1, (
"DistributedDataParallel device_ids contain exactlyone entry,"
" but got {}.").format(device_ids)
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.module = module
self.broadcast_buffers = broadcast_buffers
self.require_forward_param_sync = broadcast_buffers
Expand Down Expand Up @@ -181,6 +229,8 @@ def __init__(self, module, device_ids=None,
if len(module_states) > 0:
bps.torch.broadcast_parameters(self.module.state_dict(), root_rank=0)

print("Using the BytePS DistributedDataParallel Module")
self._step = -1
@contextmanager
def no_sync(self):
r"""
Expand All @@ -207,6 +257,9 @@ def no_sync(self):
self._require_backward_grad_sync = old_require_backward_grad_sync

def forward(self, *inputs, **kwargs):
self._step += 1
num_handles = len(self._handles)
assert num_handles == 0, f'step {self._step} num_handles is {num_handles}'
if self.require_forward_param_sync:
self._sync_params()
return self.module(*inputs, **kwargs)
Expand Down Expand Up @@ -269,9 +322,23 @@ def hook(*ignore):
self.synchronize()
return hook

def _model_post_fwd_hook(self, module, _inputs, outputs):
''' post model forward hook. '''
if not module.training:
return

out_tensors = list(_find_tensors(outputs))
params_in_graph, _ = find_parameters(out_tensors)
self._num_grads = len(set(params_in_graph) & set(self._trainable_params))
byteps_torch_set_num_grads(self._num_grads)

def synchronize(self):
missing_p = self._requires_update - set(self._handles.keys())
for p in missing_p:
if type(p.grad) == type(None):
continue

assert False, "This should never be reached."
handle, ctx, grad_count = self._push_pull_grad_group_sync(p, self._num_grads)
self._handles[p] = (handle, ctx)

Expand All @@ -280,8 +347,11 @@ def synchronize(self):
if handle is None:
handle, ctx, grad_count = self._push_pull_grad_group_sync(p)
self._handles[p] = (handle, ctx)
for p, (handle, _) in self._handles.items():
for p, (handle, ctx) in self._handles.items():
output = synchronize(handle)
if not self._enable_async:
if type(p.grad) == type(None):
assert False, "This should never be reached."
continue
p.grad.set_(self._compression.decompress(output, ctx))
self._handles.clear()