Skip to content

Commit

Permalink
Merge pull request #1170 from helmholtz-analytics/bug/1121-print-fail…
Browse files Browse the repository at this point in the history
…s-on-gpu

`ht.print` can now print arrays distributed over `n>1` GPUs
  • Loading branch information
mrfh92 authored Jul 24, 2023
2 parents 2b2ed0f + 7b4e466 commit 009a91a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
25 changes: 20 additions & 5 deletions heat/core/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,15 @@ def _torch_data(dndarray, summarize) -> DNDarray:

# non-split dimension, can slice locally
if i != dndarray.split:
start_tensor = torch.index_select(data, i, torch.arange(edgeitems + 1))
start_tensor = torch.index_select(
data, i, torch.arange(edgeitems + 1, device=data.device)
)
end_tensor = torch.index_select(
data, i, torch.arange(dndarray.lshape[i] - edgeitems, dndarray.lshape[i])
data,
i,
torch.arange(
dndarray.lshape[i] - edgeitems, dndarray.lshape[i], device=data.device
),
)
data = torch.cat([start_tensor, end_tensor], dim=i)
# split-dimension , need to respect the global offset
Expand All @@ -249,18 +255,27 @@ def _torch_data(dndarray, summarize) -> DNDarray:

if offset < edgeitems + 1:
end = min(dndarray.lshape[i], edgeitems + 1 - offset)
data = torch.index_select(data, i, torch.arange(end))
data = torch.index_select(data, i, torch.arange(end, device=data.device))
elif dndarray.gshape[i] - edgeitems < offset - dndarray.lshape[i]:
global_start = dndarray.gshape[i] - edgeitems
data = torch.index_select(
data, i, torch.arange(max(0, global_start - offset), dndarray.lshape[i])
data,
i,
torch.arange(
max(0, global_start - offset),
dndarray.lshape[i],
device=data.device,
),
)
# exchange data
received = dndarray.comm.gather(data)
if dndarray.comm.rank == 0:
# concatenate data along the split axis
# problem: CUDA-aware MPI `gather`s all `data` in a list of tensors on MPI-process no. 0, but not necessarily on the same cuda device.
# Indeed, `received` may be a list of tensors on cuda device 0, cuda device 1, ... therefore, we need to move all entries of the list to cuda device 0 before applying `cat`.
device0 = received[0].device
received = [tens.to(device0) for tens in received]
data = torch.cat(received, dim=dndarray.split)

return data


Expand Down
8 changes: 8 additions & 0 deletions heat/core/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,11 @@ def test_split_2_above_threshold(self):

if dndarray.comm.rank == 0:
self.assertEqual(comparison, __str)


class TestPrintingGPU(TestCase):
def test_print_GPU(self):
# this test case also includes GPU now, checking the output is not done; only test whether the routine itself works...
a0 = ht.arange(2**20, dtype=ht.float32).reshape((2**10, 2**10)).resplit_(0)
a1 = ht.arange(2**20, dtype=ht.float32).reshape((2**10, 2**10)).resplit_(1)
print(a0, a1)

0 comments on commit 009a91a

Please sign in to comment.