diff --git a/byteps/torch/parallel/distributed.py b/byteps/torch/parallel/distributed.py index c4f8e2945..8785ad5bb 100644 --- a/byteps/torch/parallel/distributed.py +++ b/byteps/torch/parallel/distributed.py @@ -10,6 +10,51 @@ from torch.cuda._utils import _get_device_index import os +def recursive_traverse_grad_fn(fn, seen_fns, seen_params): + ''' tranverse a grad fn recursively. ''' + if fn in seen_fns: + return + seen_fns.add(fn) + + # record tensors + if hasattr(fn, 'variable') and isinstance(fn.variable, torch.nn.Parameter): + seen_params.add(fn.variable) + + # recursively tranverse + if hasattr(fn, 'next_functions'): + for u in fn.next_functions: + if u[0] is not None: + recursive_traverse_grad_fn(u[0], seen_fns, seen_params) + if hasattr(fn, 'saved_tensors'): + for t in fn.saved_tensors: + recursive_traverse_grad_fn(t, seen_fns, seen_params) + +def find_parameters(tensors): + ''' find paramters in the autograd graph for this tensor. ''' + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + + grad_fns = set() + params = set() + for tensor in tensors: + if not isinstance(tensor, torch.Tensor): + continue + recursive_traverse_grad_fn(tensor.grad_fn, grad_fns, params) + return params, grad_fns + +def _find_tensors(obj): + r""" + Recursively find all tensors contained in the specified object. + """ + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(_find_tensors, obj.values())) + return [] + + class DistributedDataParallel(Module): r"""Implements distributed data parallelism that is based on byteps push-pull. @@ -125,10 +170,13 @@ def __init__(self, module, device_ids=None, ): super(DistributedDataParallel, self).__init__() - assert device_ids and len(device_ids) == 1, ( - "DistributedDataParallel device_ids contain exactlyone entry," - " but got {}.").format(device_ids) - self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + if device_ids is None: + self.device_ids = None + else: + assert device_ids and len(device_ids) == 1, ( + "DistributedDataParallel device_ids contain exactlyone entry," + " but got {}.").format(device_ids) + self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.module = module self.broadcast_buffers = broadcast_buffers self.require_forward_param_sync = broadcast_buffers @@ -181,6 +229,8 @@ def __init__(self, module, device_ids=None, if len(module_states) > 0: bps.torch.broadcast_parameters(self.module.state_dict(), root_rank=0) + print("Using the BytePS DistributedDataParallel Module") + self._step = -1 @contextmanager def no_sync(self): r""" @@ -207,6 +257,9 @@ def no_sync(self): self._require_backward_grad_sync = old_require_backward_grad_sync def forward(self, *inputs, **kwargs): + self._step += 1 + num_handles = len(self._handles) + assert num_handles == 0, f'step {self._step} num_handles is {num_handles}' if self.require_forward_param_sync: self._sync_params() return self.module(*inputs, **kwargs) @@ -269,9 +322,23 @@ def hook(*ignore): self.synchronize() return hook + def _model_post_fwd_hook(self, module, _inputs, outputs): + ''' post model forward hook. ''' + if not module.training: + return + + out_tensors = list(_find_tensors(outputs)) + params_in_graph, _ = find_parameters(out_tensors) + self._num_grads = len(set(params_in_graph) & set(self._trainable_params)) + byteps_torch_set_num_grads(self._num_grads) + def synchronize(self): missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: + if type(p.grad) == type(None): + continue + + assert False, "This should never be reached." handle, ctx, grad_count = self._push_pull_grad_group_sync(p, self._num_grads) self._handles[p] = (handle, ctx) @@ -280,8 +347,11 @@ def synchronize(self): if handle is None: handle, ctx, grad_count = self._push_pull_grad_group_sync(p) self._handles[p] = (handle, ctx) - for p, (handle, _) in self._handles.items(): + for p, (handle, ctx) in self._handles.items(): output = synchronize(handle) if not self._enable_async: + if type(p.grad) == type(None): + assert False, "This should never be reached." + continue p.grad.set_(self._compression.decompress(output, ctx)) self._handles.clear()