Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix distributed sync in multigpu with compute_on_cpu=True #2510

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
input_dict[attr] = [dim_zero_cat(input_dict[attr])]

if dist_sync_fn == gather_all_tensors:
dist_sync_fn = functools.partial(gather_all_tensors, device=self.device)

output_dict = apply_to_collection(
input_dict,
Tensor,
Expand Down
9 changes: 8 additions & 1 deletion src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L
return gathered_result


def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
def gather_all_tensors(
result: Tensor, group: Optional[Any] = None, device: Optional[torch.device] = None
) -> List[Tensor]:
"""Gather all tensors from several ddp processes onto a list that is broadcasted to all processes.

Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
Expand All @@ -103,6 +105,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
device: optional device to move the result tensor to before gathering

Return:
list with size equal to the process group where element i corresponds to result tensor from process i
Expand All @@ -117,6 +120,10 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
world_size = torch.distributed.get_world_size(group)
torch.distributed.barrier(group=group)

# make sure this works with CPU tensors
if device is not None:
result = result.to(device)

# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)
Expand Down
16 changes: 16 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,19 @@ def _test_sync_with_empty_lists(rank):
def test_sync_with_empty_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES))


def _test_compute_on_cpu_distributed(rank):
dummy = DummyListMetric(compute_on_cpu=True).to(f"cuda:{rank}")
dummy.update(tensor(rank + 1))
val = dummy.compute()
assert val == [tensor(rank + 1)]


@pytest.mark.DDP()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Test requires at least 2 GPUs")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not hasattr(pytest, "pool"), reason="DDP pool not available.")
def test_compute_on_cpu_distributed_multi_gpu():
"""Check that compute_on_cpu works with DDP and multiple GPUs."""
pytest.pool.map(_test_compute_on_cpu_distributed, range(NUM_PROCESSES))
Loading