Skip to content

Commit

Permalink
Merge branch '1031-support-distribution-of-xarray' of github.com:helm…
Browse files Browse the repository at this point in the history
…holtz-analytics/heat into 1031-support-distribution-of-xarray
  • Loading branch information
Hoppe committed Jul 26, 2023
2 parents 316105b + 5d9179d commit 300f3a0
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name: Feature request
about: Suggest missing features or functionalities
title: ''
labels: ''
labels: 'enhancement'
assignees: ''

---
Expand Down
18 changes: 12 additions & 6 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## Due Diligence
<!--- Please address the following points before setting your PR "ready for review".
--->
- General:
- [ ] **base branch** must be `main` for new features, latest release branch (e.g. `release/1.3.x`) for bug fixes
- [ ] **title** of the PR is suitable to appear in the [Release Notes](https://github.com/helmholtz-analytics/heat/releases/latest)
- Implementation:
- [ ] unit tests: all split configurations tested
- [ ] unit tests: multiple dtypes tested
- [ ] documentation updated where needed


## Description

<!--- Include a summary of the change/s.
Expand Down Expand Up @@ -42,11 +54,5 @@ my be illegible. It may be easiest to save the output of each to a file.
--->


## Due Diligence
- [ ] All split configurations tested
- [ ] Multiple dtypes tested in relevant functions
- [ ] Documentation updated (if needed)
- [ ] Title of PR is suitable for corresponding CHANGELOG entry

#### Does this change modify the behaviour of other functions? If so, which?
yes / no
13 changes: 13 additions & 0 deletions .github/issue-branch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
mode: auto
silent: false
branchName: full
defaultBranch: 'main'

branches:
- label: bug
name: release/1.3.x
prefix: bugs/
- label: enhancement
prefix: features/
- label: documentation
prefix: docs/
16 changes: 16 additions & 0 deletions .github/workflows/create-branch-on-assignment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Create branch on assignment

on:
# The issues event below is only needed for the default (auto) mode,
# you can remove it otherwise
issues:
types: [ assigned ]

jobs:
create_issue_branch_job:
runs-on: ubuntu-latest
steps:
- name: Create Issue Branch
uses: robvanderleek/create-issue-branch@main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: check-added-large-files
- id: flake8
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle
Expand Down
25 changes: 20 additions & 5 deletions heat/core/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,15 @@ def _torch_data(dndarray, summarize) -> DNDarray:

# non-split dimension, can slice locally
if i != dndarray.split:
start_tensor = torch.index_select(data, i, torch.arange(edgeitems + 1))
start_tensor = torch.index_select(
data, i, torch.arange(edgeitems + 1, device=data.device)
)
end_tensor = torch.index_select(
data, i, torch.arange(dndarray.lshape[i] - edgeitems, dndarray.lshape[i])
data,
i,
torch.arange(
dndarray.lshape[i] - edgeitems, dndarray.lshape[i], device=data.device
),
)
data = torch.cat([start_tensor, end_tensor], dim=i)
# split-dimension , need to respect the global offset
Expand All @@ -249,18 +255,27 @@ def _torch_data(dndarray, summarize) -> DNDarray:

if offset < edgeitems + 1:
end = min(dndarray.lshape[i], edgeitems + 1 - offset)
data = torch.index_select(data, i, torch.arange(end))
data = torch.index_select(data, i, torch.arange(end, device=data.device))
elif dndarray.gshape[i] - edgeitems < offset - dndarray.lshape[i]:
global_start = dndarray.gshape[i] - edgeitems
data = torch.index_select(
data, i, torch.arange(max(0, global_start - offset), dndarray.lshape[i])
data,
i,
torch.arange(
max(0, global_start - offset),
dndarray.lshape[i],
device=data.device,
),
)
# exchange data
received = dndarray.comm.gather(data)
if dndarray.comm.rank == 0:
# concatenate data along the split axis
# problem: CUDA-aware MPI `gather`s all `data` in a list of tensors on MPI-process no. 0, but not necessarily on the same cuda device.
# Indeed, `received` may be a list of tensors on cuda device 0, cuda device 1, ... therefore, we need to move all entries of the list to cuda device 0 before applying `cat`.
device0 = received[0].device
received = [tens.to(device0) for tens in received]
data = torch.cat(received, dim=dndarray.split)

return data


Expand Down
8 changes: 8 additions & 0 deletions heat/core/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,11 @@ def test_split_2_above_threshold(self):

if dndarray.comm.rank == 0:
self.assertEqual(comparison, __str)


class TestPrintingGPU(TestCase):
def test_print_GPU(self):
# this test case also includes GPU now, checking the output is not done; only test whether the routine itself works...
a0 = ht.arange(2**20, dtype=ht.float32).reshape((2**10, 2**10)).resplit_(0)
a1 = ht.arange(2**20, dtype=ht.float32).reshape((2**10, 2**10)).resplit_(1)
print(a0, a1)

0 comments on commit 300f3a0

Please sign in to comment.