From 445fc9497fe5672e4e9c28277a8e4c2b7ccc5f20 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Feb 2022 13:40:33 +0100 Subject: [PATCH 001/132] Broken. __getitem__ refactoring in prep for distributed/non-ordered indexing --- heat/core/dndarray.py | 593 +++++++++++++++++++++++++++--------------- 1 file changed, 384 insertions(+), 209 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 539dc5e604..2786583ebe 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -683,233 +683,408 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (1/2) >>> tensor([0.]) (2/2) >>> tensor([0., 0.]) """ - key = getattr(key, "copy()", key) - l_dtype = self.dtype.torch_type() - advanced_ind = False - if isinstance(key, DNDarray) and key.ndim == self.ndim: - """ if the key is a DNDarray and it has as many dimensions as self, then each of the - entries in the 0th dim refer to a single element. To handle this, the key is split - into the torch tensors for each dimension. This signals that advanced indexing is - to be used. """ - # NOTE: this gathers the entire key on every process!! - # TODO: remove this resplit!! - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) - - if key.ndim > 1: - key = list(key.larray.split(1, dim=1)) - # key is now a list of tensors with dimensions (key.ndim, 1) - # squeeze singleton dimension: - key = list(key[i].squeeze_(1) for i in range(len(key))) - else: - key = [key] - advanced_ind = True - elif not isinstance(key, tuple): - """ this loop handles all other cases. DNDarrays which make it to here refer to - advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * max(self.ndim, 1) - if isinstance(key, DNDarray): - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - h[0] = torch.nonzero(key.larray).flatten() # .tolist() - else: - h[0] = key.larray.tolist() - elif isinstance(key, torch.Tensor): - if key.dtype in [torch.bool, torch.uint8]: - # (coquelin77) i am not certain why this works without being a list. but it works...for now - h[0] = torch.nonzero(key).flatten() # .tolist() - else: - h[0] = key.tolist() - else: - h[0] = key - - key = list(h) + # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof + self_proxy = self.__torch_proxy__() + self_proxy.names = [ + "split" if (self.split is not None and i == self.split) else "_{}".format(i) + for i in range(self_proxy.ndim) + ] - if isinstance(key, (list, tuple)): - key = list(key) - for i, k in enumerate(key): - # this might be a good place to check if the dtype is there + try: + indexed_proxy = self_proxy[key] + except IndexError as e: + # key might be a DNDarray or contain DNDarrays, torch returns IndexError + try: + # key might be a DNDarray + key_proxy = key.__torch_proxy__() + key_proxy.names = [ + "split" if (key.split is not None and i == key.split) else "_{}".format(i) + for i in range(key_proxy.ndim) + ] + indexed_proxy = self_proxy[key_proxy] + except AttributeError: + # key might be sequence of DNDarrays + key = list(key.copy()) + for i in len(key): + if isinstance(key[i], DNDarray): + if key[i].is_distributed: + raise NotImplementedError( + "Advanced indexing with distributed DNDarrays not supported yet" + ) + key[i] = key[i].larray try: - k = manipulations.resplit(k) - key[i] = k.larray - except AttributeError: - pass - - # ellipsis - key = list(key) - key_classes = [type(n) for n in key] - # if any(isinstance(n, ellipsis) for n in key): - n_elips = key_classes.count(type(...)) - if n_elips > 1: - raise ValueError("key can only contain 1 ellipsis") - elif n_elips == 1: - # get which item is the ellipsis - ell_ind = key_classes.index(type(...)) - kst = key[:ell_ind] - kend = key[ell_ind + 1 :] - slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - key = kst + slices + kend - else: - key = key + [slice(None)] * (self.ndim - len(key)) + indexed_proxy = self_proxy[tuple(key)] + except IndexError: + raise e + # TODO: catch torch exceptions, return reasonable error message - self_proxy = self.__torch_proxy__() - for i in range(len(key)): - if self.__key_adds_dimension(key, i, self_proxy): - key[i] = slice(None) - return self.expand_dims(i)[tuple(key)] + output_shape = tuple(indexed_proxy.shape) + try: + output_split = indexed_proxy.names.index("split") + except ValueError: + output_split = None - key = tuple(key) - # assess final global shape - gout_full = list(self_proxy[key].shape) - - # calculate new split axis - new_split = self.split - # when slicing, squeezed singleton dimensions may affect new split axis - if self.split is not None and len(gout_full) < self.ndim: - if advanced_ind: - new_split = 0 + try: + key_ndims = getattr(key, "ndim", len(key)) + except TypeError: + # key is a scalar or a slice + key = (key,) + key_ndims = 1 + + # expand key to match the number of dimensions of the DNDarray + if key_ndims < self.ndim: + expand_key = [slice(None)] * self.ndim + # account for ellipsis + if key.count(...): + ellipsis_index = key.index(...) + expand_key[:ellipsis_index] = key[:ellipsis_index] + expand_key[ellipsis_index + 2 :] = key[ellipsis_index + 1 :] else: - for i in range(len(key[: self.split + 1])): - if self.__key_is_singular(key, i, self_proxy): - new_split = None if i == self.split else new_split - 1 + expand_key[:key_ndims] = key + key = tuple(expand_key) - key = tuple(key) - if not self.is_distributed(): - arr = self.__array[key].reshape(gout_full) + # data are not distributed or split dimension is not affected by indexing + if not self.is_distributed or key[self.split] == slice(None): return DNDarray( - arr, tuple(gout_full), self.dtype, new_split, self.device, self.comm, self.balanced + self.larray[key], + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + balanced=self.balanced, + comm=self.comm, ) - # else: (DNDarray is distributed) - arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device) - rank = self.comm.rank - counts, chunk_starts = self.counts_displs() - counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts) - chunk_ends = chunk_starts + counts - chunk_start = chunk_starts[rank] - chunk_end = chunk_ends[rank] + # data are distributed and split dimension is affected by indexing + _, offsets = self.counts_displs() + split = self.split - if len(key) == 0: # handle empty list - # this will return an array of shape (0, ...) - arr = self.__array[key] - - """ At the end of the following if/elif/elif block the output array will be set. - each block handles the case where the element of the key along the split axis - is a different type and converts the key from global indices to local indices. """ - lout = gout_full.copy() - - if ( - isinstance(key[self.split], (list, torch.Tensor, DNDarray, np.ndarray)) - and len(key[self.split]) > 1 - ): - # advanced indexing, elements in the split dimension are adjusted to the local indices - lkey = list(key) - if isinstance(key[self.split], DNDarray): - lkey[self.split] = key[self.split].larray - - if not isinstance(lkey[self.split], torch.Tensor): - inds = torch.tensor( - lkey[self.split], dtype=torch.long, device=self.device.torch_device + # slice along the split axis + if isinstance(key[split], slice): + if key[split].start is None: + slice_start = 0 + else: + slice_start = ( + key[split].start + if key[split].start > 0 + else key[split].start + self.gshape[split] ) + if key[split].stop is None: + slice_stop = self.gshape[split] else: - if lkey[self.split].dtype in [torch.bool, torch.uint8]: # or torch.byte? - # need to convert the bools to indices - inds = torch.nonzero(lkey[self.split]) + slice_stop = ( + key[split].stop if key[split].stop > 0 else key[split].stop + self.gshape[split] + ) + slice_step = key[split].step + + # identify active ranks + offsets = torch.tensor(offsets, dtype=torch.int64, device=self.larray.device) + first_active = torch.where(offsets - slice_start <= 0)[0][-1].item() + last_active = torch.where(offsets - slice_stop <= 0)[0][-1].item() + active_ranks = range(first_active, last_active + 1) + + if self.comm.rank in active_ranks: + if slice_step is None: + slice_step = 1 + # calculate local slice + if ( + slice_start >= offsets[self.comm.rank] + and slice_start < self.lshape[split] + offsets[self.comm.rank] + ): + local_slice_start = slice_start - offsets[self.comm.rank] else: - inds = lkey[self.split] - # todo: remove where in favor of nonzero? might be a speed upgrade. testing required - loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end)) - # if there are no local indices on a process, then `arr` is empty - # if local indices exist: - if len(loc_inds[0]) != 0: - # select same local indices for other (non-split) dimensions if necessary - for i, k in enumerate(lkey): - if isinstance(k, (list, torch.Tensor, DNDarray)): - if i != self.split: - lkey[i] = k[loc_inds] - # correct local indices for offset - inds = inds[loc_inds] - chunk_start - lkey[self.split] = inds - lout[new_split] = len(inds) - arr = self.__array[tuple(lkey)].reshape(tuple(lout)) - elif len(loc_inds[0]) == 0: - if new_split is not None: - lout[new_split] = len(loc_inds[0]) + if slice_step != 1: + local_slice_start = torch.arange( + offsets[self.comm.rank], + offsets[self.comm.rank] + slice_step, + dtype=torch.int64, + device=self.larray.device, + ) + local_slice_start = ( + torch.where(local_slice_start % slice_step == 0)[0].item() + - offsets[self.comm.rank] + ) + else: + local_slice_start = 0 + if ( + slice_stop >= offsets[self.comm.rank] + and slice_stop < self.lshape[split] + offsets[self.comm.rank] + ): + local_slice_stop = slice_stop - offsets[self.comm.rank] else: - lout = [0] * len(gout_full) - arr = torch.tensor([], dtype=self.larray.dtype, device=self.larray.device).reshape( - tuple(lout) + if slice_step != 1: + local_slice_stop = torch.arange( + offsets[self.comm.rank] + 1 - slice_step, + offsets[self.comm.rank] + 1, + dtype=torch.int64, + device=self.larray.device, + ) + local_slice_stop = ( + torch.where(local_slice_stop % slice_step == 0)[0].item() + - offsets[self.comm.rank] + ) + else: + local_slice_stop = self.lshape[split] + # slice local tensor + local_slice = slice(local_slice_start, local_slice_stop, slice_step) + key = key[:split] + (local_slice,) + key[split + 1 :] + local_tensor = self.larray[key] + else: + # local tensor is empty + local_shape = list(output_shape) + local_shape[output_split] = 0 + local_tensor = torch.zeros( + tuple(local_shape), dtype=self.larray.dtype, device=self.larray.device ) - elif isinstance(key[self.split], slice): - # standard slicing along the split axis, - # adjust the slice start, stop, and step, then run it on the processes which have the requested data - key = list(key) - key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) - key_start, key_stop, key_step = ( - key[self.split].start, - key[self.split].stop, - key[self.split].step, + return DNDarray( + local_tensor, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + balanced=False, + comm=self.comm, ) - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) - if rank in actives: - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__xitem_get_key_start_stop( - rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - ) - key[self.split] = slice(key_start, key_stop, key_step) - lout[new_split] = ( - math.ceil((key_stop - key_start) / key_step) - if key_step is not None - else key_stop - key_start - ) - arr = self.__array[tuple(key)].reshape(lout) - else: - lout[new_split] = 0 - arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) - elif self.__key_is_singular(key, self.split, self_proxy): - # getting one item along split axis: - key = list(key) - if isinstance(key[self.split], list): - key[self.split] = key[self.split].pop() - elif isinstance(key[self.split], (torch.Tensor, DNDarray, np.ndarray)): - key[self.split] = key[self.split].item() - # translate negative index - if key[self.split] < 0: - key[self.split] += self.gshape[self.split] - - active_rank = torch.where(key[self.split] >= chunk_starts)[0][-1].item() - # slice `self` on `active_rank`, allocate `arr` on all other ranks in preparation for Bcast - if rank == active_rank: - key[self.split] -= chunk_start.item() - arr = self.__array[tuple(key)].reshape(tuple(lout)) - else: - arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) - # broadcast result - # TODO: Replace with `self.comm.Bcast(arr, root=active_rank)` after fixing #784 - arr = self.comm.bcast(arr, root=active_rank) - if arr.device != self.larray.device: - # todo: remove when unnecessary (also after #784) - arr = arr.to(device=self.larray.device) - - return DNDarray( - arr.type(l_dtype), - gout_full if isinstance(gout_full, tuple) else tuple(gout_full), - self.dtype, - new_split, - self.device, - self.comm, - balanced=True if new_split is None else None, - ) + # local indexing cases: + # self is not distributed, key is not distributed - DONE + # self is distributed, key along split is a slice - DONE + # self is distributed, key is boolean mask (what about distributed boolean mask?) + + # distributed indexing: + # key is distributed + # key calls for advanced indexing + # key is a non-sorted sequence + # key is a sorted sequence (descending) + + # key = getattr(key, "copy()", key) + # l_dtype = self.dtype.torch_type() + # advanced_ind = False + # if isinstance(key, DNDarray) and key.ndim == self.ndim: + # """ if the key is a DNDarray and it has as many dimensions as self, then each of the + # entries in the 0th dim refer to a single element. To handle this, the key is split + # into the torch tensors for each dimension. This signals that advanced indexing is + # to be used. """ + # # NOTE: this gathers the entire key on every process!! + # # TODO: remove this resplit!! + # key = manipulations.resplit(key) + # if key.larray.dtype in [torch.bool, torch.uint8]: + # key = indexing.nonzero(key) + + # if key.ndim > 1: + # key = list(key.larray.split(1, dim=1)) + # # key is now a list of tensors with dimensions (key.ndim, 1) + # # squeeze singleton dimension: + # key = list(key[i].squeeze_(1) for i in range(len(key))) + # else: + # key = [key] + # advanced_ind = True + # elif not isinstance(key, tuple): + # """ this loop handles all other cases. DNDarrays which make it to here refer to + # advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors + # are cast into lists here by PyTorch. lists mean advanced indexing will be used""" + # h = [slice(None, None, None)] * max(self.ndim, 1) + # if isinstance(key, DNDarray): + # key = manipulations.resplit(key) + # if key.larray.dtype in [torch.bool, torch.uint8]: + # h[0] = torch.nonzero(key.larray).flatten() # .tolist() + # else: + # h[0] = key.larray.tolist() + # elif isinstance(key, torch.Tensor): + # if key.dtype in [torch.bool, torch.uint8]: + # # (coquelin77) i am not certain why this works without being a list. but it works...for now + # h[0] = torch.nonzero(key).flatten() # .tolist() + # else: + # h[0] = key.tolist() + # else: + # h[0] = key + + # key = list(h) + + # if isinstance(key, (list, tuple)): + # key = list(key) + # for i, k in enumerate(key): + # # this might be a good place to check if the dtype is there + # try: + # k = manipulations.resplit(k) + # key[i] = k.larray + # except AttributeError: + # pass + + # # ellipsis + # key = list(key) + # key_classes = [type(n) for n in key] + # # if any(isinstance(n, ellipsis) for n in key): + # n_elips = key_classes.count(type(...)) + # if n_elips > 1: + # raise ValueError("key can only contain 1 ellipsis") + # elif n_elips == 1: + # # get which item is the ellipsis + # ell_ind = key_classes.index(type(...)) + # kst = key[:ell_ind] + # kend = key[ell_ind + 1 :] + # slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) + # key = kst + slices + kend + # else: + # key = key + [slice(None)] * (self.ndim - len(key)) + + # self_proxy = self.__torch_proxy__() + # for i in range(len(key)): + # if self.__key_adds_dimension(key, i, self_proxy): + # key[i] = slice(None) + # return self.expand_dims(i)[tuple(key)] + + # key = tuple(key) + # # assess final global shape + # gout_full = list(self_proxy[key].shape) + + # # calculate new split axis + # new_split = self.split + # # when slicing, squeezed singleton dimensions may affect new split axis + # if self.split is not None and len(gout_full) < self.ndim: + # if advanced_ind: + # new_split = 0 + # else: + # for i in range(len(key[: self.split + 1])): + # if self.__key_is_singular(key, i, self_proxy): + # new_split = None if i == self.split else new_split - 1 + + # key = tuple(key) + # if not self.is_distributed(): + # arr = self.__array[key].reshape(gout_full) + # return DNDarray( + # arr, tuple(gout_full), self.dtype, new_split, self.device, self.comm, self.balanced + # ) + + # # else: (DNDarray is distributed) + # arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device) + # rank = self.comm.rank + # counts, chunk_starts = self.counts_displs() + # counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts) + # chunk_ends = chunk_starts + counts + # chunk_start = chunk_starts[rank] + # chunk_end = chunk_ends[rank] + + # if len(key) == 0: # handle empty list + # # this will return an array of shape (0, ...) + # arr = self.__array[key] + + # """ At the end of the following if/elif/elif block the output array will be set. + # each block handles the case where the element of the key along the split axis + # is a different type and converts the key from global indices to local indices. """ + # lout = gout_full.copy() + + # if ( + # isinstance(key[self.split], (list, torch.Tensor, DNDarray, np.ndarray)) + # and len(key[self.split]) > 1 + # ): + # # advanced indexing, elements in the split dimension are adjusted to the local indices + # lkey = list(key) + # if isinstance(key[self.split], DNDarray): + # lkey[self.split] = key[self.split].larray + + # if not isinstance(lkey[self.split], torch.Tensor): + # inds = torch.tensor( + # lkey[self.split], dtype=torch.long, device=self.device.torch_device + # ) + # else: + # if lkey[self.split].dtype in [torch.bool, torch.uint8]: # or torch.byte? + # # need to convert the bools to indices + # inds = torch.nonzero(lkey[self.split]) + # else: + # inds = lkey[self.split] + # # todo: remove where in favor of nonzero? might be a speed upgrade. testing required + # loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end)) + # # if there are no local indices on a process, then `arr` is empty + # # if local indices exist: + # if len(loc_inds[0]) != 0: + # # select same local indices for other (non-split) dimensions if necessary + # for i, k in enumerate(lkey): + # if isinstance(k, (list, torch.Tensor, DNDarray)): + # if i != self.split: + # lkey[i] = k[loc_inds] + # # correct local indices for offset + # inds = inds[loc_inds] - chunk_start + # lkey[self.split] = inds + # lout[new_split] = len(inds) + # arr = self.__array[tuple(lkey)].reshape(tuple(lout)) + # elif len(loc_inds[0]) == 0: + # if new_split is not None: + # lout[new_split] = len(loc_inds[0]) + # else: + # lout = [0] * len(gout_full) + # arr = torch.tensor([], dtype=self.larray.dtype, device=self.larray.device).reshape( + # tuple(lout) + # ) + + # elif isinstance(key[self.split], slice): + # # standard slicing along the split axis, + # # adjust the slice start, stop, and step, then run it on the processes which have the requested data + # key = list(key) + # key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) + # key_start, key_stop, key_step = ( + # key[self.split].start, + # key[self.split].stop, + # key[self.split].step, + # ) + # og_key_start = key_start + # st_pr = torch.where(key_start < chunk_ends)[0] + # st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size + # sp_pr = torch.where(key_stop >= chunk_starts)[0] + # sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 + # actives = list(range(st_pr, sp_pr + 1)) + # if rank in actives: + # key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] + # key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] + # key_start, key_stop = self.__xitem_get_key_start_stop( + # rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + # ) + # key[self.split] = slice(key_start, key_stop, key_step) + # lout[new_split] = ( + # math.ceil((key_stop - key_start) / key_step) + # if key_step is not None + # else key_stop - key_start + # ) + # arr = self.__array[tuple(key)].reshape(lout) + # else: + # lout[new_split] = 0 + # arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) + + # elif self.__key_is_singular(key, self.split, self_proxy): + # # getting one item along split axis: + # key = list(key) + # if isinstance(key[self.split], list): + # key[self.split] = key[self.split].pop() + # elif isinstance(key[self.split], (torch.Tensor, DNDarray, np.ndarray)): + # key[self.split] = key[self.split].item() + # # translate negative index + # if key[self.split] < 0: + # key[self.split] += self.gshape[self.split] + + # active_rank = torch.where(key[self.split] >= chunk_starts)[0][-1].item() + # # slice `self` on `active_rank`, allocate `arr` on all other ranks in preparation for Bcast + # if rank == active_rank: + # key[self.split] -= chunk_start.item() + # arr = self.__array[tuple(key)].reshape(tuple(lout)) + # else: + # arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) + # # broadcast result + # # TODO: Replace with `self.comm.Bcast(arr, root=active_rank)` after fixing #784 + # arr = self.comm.bcast(arr, root=active_rank) + # if arr.device != self.larray.device: + # # todo: remove when unnecessary (also after #784) + # arr = arr.to(device=self.larray.device) + + # return DNDarray( + # arr.type(l_dtype), + # gout_full if isinstance(gout_full, tuple) else tuple(gout_full), + # self.dtype, + # new_split, + # self.device, + # self.comm, + # balanced=True if new_split is None else None, + # ) if torch.cuda.device_count() > 0: From 6641d1eb607d081aaa5ef8d2e44621a2b887c7eb Mon Sep 17 00:00:00 2001 From: Ben Bourgart Date: Tue, 22 Feb 2022 14:31:44 +0100 Subject: [PATCH 002/132] Preprocess key, workaround torch_proxy for advanced indexing, simplify slice-indexing. UNTESTED --- heat/core/dndarray.py | 215 ++++++++++++++++++------------------------ 1 file changed, 94 insertions(+), 121 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 2786583ebe..3c069dc93c 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -9,7 +9,7 @@ from inspect import stack from mpi4py import MPI from pathlib import Path -from typing import List, Union, Tuple, TypeVar, Optional +from typing import List, Union, Tuple, TypeVar, Optional, Iterable warnings.simplefilter("always", ResourceWarning) @@ -684,65 +684,93 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (2/2) >>> tensor([0., 0.]) """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof - self_proxy = self.__torch_proxy__() + # Trivial cases + if key is None: + return self.expand_dims(0) + if key == ... or key == slice(None): # latter doesnt work with torch for 0-dim tensors + return self + # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays + advanced_indexing = False + if isinstance( + key, DNDarray + ): # DNDARRAY CURRENTLY DOES NOT IMPLEMENT the Iterable interface, need to define __iter__() + advanced_indexing = True + # TODO: check for key.ndim = 0 and treat that as int + # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim + elif isinstance(key, Iterable) and not isinstance(key, tuple): + advanced_indexing = True + key = factories.array( + key + ) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for sequence of ndarrays though + # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim + elif isinstance(key, tuple): + key = list(key) + for i, k in enumerate(key): + if isinstance(k, Iterable) or isinstance(key, DNDarray): + advanced_indexing = True + key[i] = factories.array( + key[i] + ) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for seq of ndarrays though + # TODO: check for key.ndim = 0 and treat that as int + # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim + add_dims = sum(k is None for k in key) # (np.newaxis is None)===true + ellipsis = sum(isinstance(k, type(...)) for k in key) + if ellipsis > 1: + raise ValueError("key can only contain 1 ellipsis") + elif ellipsis == 1: + expand_key = [slice(None)] * (self.ndim + add_dims) + ellipsis_index = key.index(...) + expand_key[:ellipsis_index] = key[:ellipsis_index] + expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] + key = expand_key + if add_dims: + for i, k in reversed(enumerate(key)): + if k is None: + key[i] = slice(None) + self = self.expand_dims(i - add_dims + 1) # is the -1 correct? + add_dims -= 1 + # expand key to match the number of dimensions of the DNDarray + key = tuple(key + [slice(None)] * (self.ndim - len(key))) + else: # key is integer or slice + key = tuple([key] + [slice(None)] * (self.ndim - 1)) + + # To use torch_proxy with advanced indexing, add empty dimensions instead of + # advanced index. Later, replace the empty dimensions with the shape of the advanced index + proxy_key = key + proxy = self + if advanced_indexing: + proxy_key = [] + replace = {} + for i, k in reversed(enumerate(key)): + if isinstance(k, DNDarray): # all iterables have been made DNDarrays + # TODO Bool indexing (sometimes) is collapsed into one dimension + replace[i] = k.shape + proxy_key.extend([slice(None)] * k.ndim) + for _ in range(k.ndim - 1): + proxy = proxy.expand_dims(i) + else: + proxy_key.append(k) + proxy_key = tuple(reversed(proxy_key)) + + self_proxy = proxy.__torch_proxy__() self_proxy.names = [ - "split" if (self.split is not None and i == self.split) else "_{}".format(i) + "split" if (proxy.split is not None and i == proxy.split) else "_{}".format(i) for i in range(self_proxy.ndim) ] + indexed_proxy = self_proxy[proxy_key] - try: - indexed_proxy = self_proxy[key] - except IndexError as e: - # key might be a DNDarray or contain DNDarrays, torch returns IndexError - try: - # key might be a DNDarray - key_proxy = key.__torch_proxy__() - key_proxy.names = [ - "split" if (key.split is not None and i == key.split) else "_{}".format(i) - for i in range(key_proxy.ndim) - ] - indexed_proxy = self_proxy[key_proxy] - except AttributeError: - # key might be sequence of DNDarrays - key = list(key.copy()) - for i in len(key): - if isinstance(key[i], DNDarray): - if key[i].is_distributed: - raise NotImplementedError( - "Advanced indexing with distributed DNDarrays not supported yet" - ) - key[i] = key[i].larray - try: - indexed_proxy = self_proxy[tuple(key)] - except IndexError: - raise e - # TODO: catch torch exceptions, return reasonable error message + output_shape = list(indexed_proxy.shape) + if advanced_indexing: + for i, shape in replace.values(): + # TODO Bool indexing (sometimes) is collapsed into one dimension + output_shape[i : i + len(shape)] = shape + output_shape = tuple(output_shape) - output_shape = tuple(indexed_proxy.shape) try: output_split = indexed_proxy.names.index("split") except ValueError: output_split = None - try: - key_ndims = getattr(key, "ndim", len(key)) - except TypeError: - # key is a scalar or a slice - key = (key,) - key_ndims = 1 - - # expand key to match the number of dimensions of the DNDarray - if key_ndims < self.ndim: - expand_key = [slice(None)] * self.ndim - # account for ellipsis - if key.count(...): - ellipsis_index = key.index(...) - expand_key[:ellipsis_index] = key[:ellipsis_index] - expand_key[ellipsis_index + 2 :] = key[ellipsis_index + 1 :] - else: - expand_key[:key_ndims] = key - key = tuple(expand_key) - # data are not distributed or split dimension is not affected by indexing if not self.is_distributed or key[self.split] == slice(None): return DNDarray( @@ -758,79 +786,24 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # data are distributed and split dimension is affected by indexing _, offsets = self.counts_displs() split = self.split - # slice along the split axis if isinstance(key[split], slice): - if key[split].start is None: - slice_start = 0 - else: - slice_start = ( - key[split].start - if key[split].start > 0 - else key[split].start + self.gshape[split] - ) - if key[split].stop is None: - slice_stop = self.gshape[split] - else: - slice_stop = ( - key[split].stop if key[split].stop > 0 else key[split].stop + self.gshape[split] - ) - slice_step = key[split].step - - # identify active ranks - offsets = torch.tensor(offsets, dtype=torch.int64, device=self.larray.device) - first_active = torch.where(offsets - slice_start <= 0)[0][-1].item() - last_active = torch.where(offsets - slice_stop <= 0)[0][-1].item() - active_ranks = range(first_active, last_active + 1) - - if self.comm.rank in active_ranks: - if slice_step is None: - slice_step = 1 - # calculate local slice - if ( - slice_start >= offsets[self.comm.rank] - and slice_start < self.lshape[split] + offsets[self.comm.rank] - ): - local_slice_start = slice_start - offsets[self.comm.rank] - else: - if slice_step != 1: - local_slice_start = torch.arange( - offsets[self.comm.rank], - offsets[self.comm.rank] + slice_step, - dtype=torch.int64, - device=self.larray.device, - ) - local_slice_start = ( - torch.where(local_slice_start % slice_step == 0)[0].item() - - offsets[self.comm.rank] - ) - else: - local_slice_start = 0 - if ( - slice_stop >= offsets[self.comm.rank] - and slice_stop < self.lshape[split] + offsets[self.comm.rank] - ): - local_slice_stop = slice_stop - offsets[self.comm.rank] - else: - if slice_step != 1: - local_slice_stop = torch.arange( - offsets[self.comm.rank] + 1 - slice_step, - offsets[self.comm.rank] + 1, - dtype=torch.int64, - device=self.larray.device, - ) - local_slice_stop = ( - torch.where(local_slice_stop % slice_step == 0)[0].item() - - offsets[self.comm.rank] - ) - else: - local_slice_stop = self.lshape[split] - # slice local tensor - local_slice = slice(local_slice_start, local_slice_stop, slice_step) - key = key[:split] + (local_slice,) + key[split + 1 :] - local_tensor = self.larray[key] - else: - # local tensor is empty + key = list(key) + key[split] = stride_tricks.sanitize_slice(key[split], self.shape[split]) + start, stop, step = key[split].start, key[split].stop, key[split].step + if step < 0: # NOT supported by torch; TODO throw Exception + key[split] = slice(stop + 1, start + 1, abs(step)) + return self[tuple(key)].flip(axis=self.split) + + offset = offsets[self.comm.rank] + range_proxy = range(self.lshape[split]) + local_inds = range_proxy[start - offset : stop - offset] + local_inds = local_inds[(offset - start) % step :: step] + if len(local_inds): + local_slice = slice(local_inds.start, local_inds.stop, local_inds.step) + key[split] = local_slice + local_tensor = self.larray[tuple(key)] + else: # local tensor is empty local_shape = list(output_shape) local_shape[output_split] = 0 local_tensor = torch.zeros( From cd78ecbbe67fb2b6ece25ea0e7fad4f03abe2cb3 Mon Sep 17 00:00:00 2001 From: Ben Bourgart Date: Tue, 22 Feb 2022 15:46:41 +0100 Subject: [PATCH 003/132] put advanced index shape in the dimensions name to get the correct position in the index_proxy --- heat/core/dndarray.py | 54 +++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 3c069dc93c..3d32ead45f 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -691,26 +691,23 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar return self # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays advanced_indexing = False - if isinstance( - key, DNDarray - ): # DNDARRAY CURRENTLY DOES NOT IMPLEMENT the Iterable interface, need to define __iter__() + if isinstance(key, DNDarray): + # DNDARRAY CURRENTLY DOES NOT IMPLEMENT the Iterable interface, need to define __iter__() advanced_indexing = True # TODO: check for key.ndim = 0 and treat that as int # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim elif isinstance(key, Iterable) and not isinstance(key, tuple): advanced_indexing = True - key = factories.array( - key - ) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for sequence of ndarrays though + key = factories.array(key) + # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for sequence of ndarrays though # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim elif isinstance(key, tuple): key = list(key) for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(key, DNDarray): advanced_indexing = True - key[i] = factories.array( - key[i] - ) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for seq of ndarrays though + key[i] = factories.array(key[i]) + # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for seq of ndarrays though # TODO: check for key.ndim = 0 and treat that as int # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim add_dims = sum(k is None for k in key) # (np.newaxis is None)===true @@ -736,34 +733,37 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # To use torch_proxy with advanced indexing, add empty dimensions instead of # advanced index. Later, replace the empty dimensions with the shape of the advanced index - proxy_key = key proxy = self + names = [ + "split" if (proxy.split is not None and i == proxy.split) else "_{}".format(i) + for i in range(proxy.ndim) + ] + proxy_key = list(key) if advanced_indexing: - proxy_key = [] - replace = {} + proxy_key = list(key) for i, k in reversed(enumerate(key)): if isinstance(k, DNDarray): # all iterables have been made DNDarrays - # TODO Bool indexing (sometimes) is collapsed into one dimension - replace[i] = k.shape - proxy_key.extend([slice(None)] * k.ndim) + # TODO: Bool indexing (sometimes) is collapsed into one dimension + # TODO: What to do if advanced index is in split dimension?? + names[i] = "replace" + str(k.shape) # put shape into name + proxy_key[i] = slice(None) for _ in range(k.ndim - 1): proxy = proxy.expand_dims(i) - else: - proxy_key.append(k) - proxy_key = tuple(reversed(proxy_key)) + names.insert(i + 1, "_{}".format(len(names))) + proxy_key.insert(i + 1, slice(None)) + proxy_key = tuple(proxy_key) self_proxy = proxy.__torch_proxy__() - self_proxy.names = [ - "split" if (proxy.split is not None and i == proxy.split) else "_{}".format(i) - for i in range(self_proxy.ndim) - ] + self_proxy.names = names indexed_proxy = self_proxy[proxy_key] output_shape = list(indexed_proxy.shape) if advanced_indexing: - for i, shape in replace.values(): - # TODO Bool indexing (sometimes) is collapsed into one dimension - output_shape[i : i + len(shape)] = shape + for i, n in enumerate(indexed_proxy.names): + if "replace" in n: + shape = eval(n.split("replace")[1]) # extract shape from name + # TODO Bool indexing (sometimes) is collapsed into one dimension + output_shape[i : i + len(shape)] = shape output_shape = tuple(output_shape) try: @@ -791,14 +791,14 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar key = list(key) key[split] = stride_tricks.sanitize_slice(key[split], self.shape[split]) start, stop, step = key[split].start, key[split].stop, key[split].step - if step < 0: # NOT supported by torch; TODO throw Exception + if step < 0: # NOT supported by torch, should be filtered by torch_proxy key[split] = slice(stop + 1, start + 1, abs(step)) return self[tuple(key)].flip(axis=self.split) offset = offsets[self.comm.rank] range_proxy = range(self.lshape[split]) local_inds = range_proxy[start - offset : stop - offset] - local_inds = local_inds[(offset - start) % step :: step] + local_inds = local_inds[max(offset - start, 0) % step :: step] if len(local_inds): local_slice = slice(local_inds.start, local_inds.stop, local_inds.step) key[split] = local_slice From 7d97ea2cd765fb394819cf4dd98c23a2bd238d40 Mon Sep 17 00:00:00 2001 From: Ben Bourgart Date: Tue, 22 Feb 2022 23:04:11 +0100 Subject: [PATCH 004/132] first changes to setitem --- heat/core/dndarray.py | 159 +++++++++++++++++++++++++++++------------- 1 file changed, 111 insertions(+), 48 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 3d32ead45f..4743cc3b4b 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -653,43 +653,19 @@ def fill_diagonal(self, value: float) -> DNDarray: return self - def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDarray: + def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]]) -> Tuple: """ - Global getter function for DNDarrays. - Returns a new DNDarray composed of the elements of the original tensor selected by the indices - given. This does *NOT* redistribute or rebalance the resulting tensor. If the selection of values is - unbalanced then the resultant tensor is also unbalanced! - To redistributed the ``DNDarray`` use :func:`balance()` (issue #187) + Private method for processing keys for indexing. Returns wether advanced indexing is used as well as a processed key and self. + A processed key: + - doesn't cotain any ellipses or newaxis + - all Iterables are converted to ``DNDarrays`` + - has the same dimensionality as the ``DNDarray`` it indexes Parameters ---------- key : int, slice, Tuple[int,...], List[int,...] - Indices to get from the tensor. - - Examples - -------- - >>> a = ht.arange(10, split=0) - (1/2) >>> tensor([0, 1, 2, 3, 4], dtype=torch.int32) - (2/2) >>> tensor([5, 6, 7, 8, 9], dtype=torch.int32) - >>> a[1:6] - (1/2) >>> tensor([1, 2, 3, 4], dtype=torch.int32) - (2/2) >>> tensor([5], dtype=torch.int32) - >>> a = ht.zeros((4,5), split=0) - (1/2) >>> tensor([[0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.]]) - (2/2) >>> tensor([[0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.]]) - >>> a[1:4, 1] - (1/2) >>> tensor([0.]) - (2/2) >>> tensor([0., 0.]) + Indices for the tensor. """ - # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof - # Trivial cases - if key is None: - return self.expand_dims(0) - if key == ... or key == slice(None): # latter doesnt work with torch for 0-dim tensors - return self - # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays advanced_indexing = False if isinstance(key, DNDarray): # DNDARRAY CURRENTLY DOES NOT IMPLEMENT the Iterable interface, need to define __iter__() @@ -715,7 +691,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if ellipsis > 1: raise ValueError("key can only contain 1 ellipsis") elif ellipsis == 1: - expand_key = [slice(None)] * (self.ndim + add_dims) + expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) expand_key[:ellipsis_index] = key[:ellipsis_index] expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] @@ -724,12 +700,75 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar for i, k in reversed(enumerate(key)): if k is None: key[i] = slice(None) - self = self.expand_dims(i - add_dims + 1) # is the -1 correct? + arr = arr.expand_dims(i - add_dims + 1) # is the -1 correct? add_dims -= 1 # expand key to match the number of dimensions of the DNDarray - key = tuple(key + [slice(None)] * (self.ndim - len(key))) + key = tuple(key + [slice(None)] * (arr.ndim - len(key))) else: # key is integer or slice - key = tuple([key] + [slice(None)] * (self.ndim - 1)) + key = tuple([key] + [slice(None)] * (arr.ndim - 1)) + return advanced_indexing, arr, key + + def __get_local_slice(self, key: slice): + split = self.split + if split is None: + return key + key = stride_tricks.sanitize_slice(key, self.shape[split]) + start, stop, step = key.start, key.stop, key.step + if step < 0: # NOT supported by torch, should be filtered by torch_proxy + key = self.__get_local_slice(slice(stop + 1, start + 1, abs(step))) + if key is None: + return None + start, stop, step = key.start, key.stop, key.step + return slice(key.stop - 1, key.start - 1, -1 * key.step) + + _, offsets = self.counts_displs() + offset = offsets[self.comm.rank] + range_proxy = range(self.lshape[split]) + local_inds = range_proxy[start - offset : stop - offset] # only works if stop - offset > 0 + local_inds = local_inds[max(offset - start, 0) % step :: step] + if len(local_inds) and stop > offset: + # otherwise if (stop-offset) > -self.lshape[split] this can index into the local chunk despite ending before it + return slice(local_inds.start, local_inds.stop, local_inds.step) + return None + + def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDarray: + """ + Global getter function for DNDarrays. + Returns a new DNDarray composed of the elements of the original tensor selected by the indices + given. This does *NOT* redistribute or rebalance the resulting tensor. If the selection of values is + unbalanced then the resultant tensor is also unbalanced! + To redistributed the ``DNDarray`` use :func:`balance()` (issue #187) + + Parameters + ---------- + key : int, slice, Tuple[int,...], List[int,...] + Indices to get from the tensor. + + Examples + -------- + >>> a = ht.arange(10, split=0) + (1/2) >>> tensor([0, 1, 2, 3, 4], dtype=torch.int32) + (2/2) >>> tensor([5, 6, 7, 8, 9], dtype=torch.int32) + >>> a[1:6] + (1/2) >>> tensor([1, 2, 3, 4], dtype=torch.int32) + (2/2) >>> tensor([5], dtype=torch.int32) + >>> a = ht.zeros((4,5), split=0) + (1/2) >>> tensor([[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + (2/2) >>> tensor([[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + >>> a[1:4, 1] + (1/2) >>> tensor([0.]) + (2/2) >>> tensor([0., 0.]) + """ + # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof + # Trivial cases + if key is None: + return self.expand_dims(0) + if key == ... or key == slice(None): # latter doesnt work with torch for 0-dim tensors + return self + # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays + advanced_indexing, self, key = self.__process_key(key) # To use torch_proxy with advanced indexing, add empty dimensions instead of # advanced index. Later, replace the empty dimensions with the shape of the advanced index @@ -788,19 +827,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar split = self.split # slice along the split axis if isinstance(key[split], slice): - key = list(key) - key[split] = stride_tricks.sanitize_slice(key[split], self.shape[split]) - start, stop, step = key[split].start, key[split].stop, key[split].step - if step < 0: # NOT supported by torch, should be filtered by torch_proxy - key[split] = slice(stop + 1, start + 1, abs(step)) - return self[tuple(key)].flip(axis=self.split) - - offset = offsets[self.comm.rank] - range_proxy = range(self.lshape[split]) - local_inds = range_proxy[start - offset : stop - offset] - local_inds = local_inds[max(offset - start, 0) % step :: step] - if len(local_inds): - local_slice = slice(local_inds.start, local_inds.stop, local_inds.step) + local_slice = self.__get_local_slice(key[split]) + if local_slice is not None: + key = list(key) key[split] = local_slice local_tensor = self.larray[tuple(key)] else: # local tensor is empty @@ -1542,6 +1571,40 @@ def __setitem__( (2/2) >>> tensor([[0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.]]) """ + + def __set(arr: DNDarray, value: DNDarray): + """ + Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. + """ + if not isinstance(value, DNDarray): + value = factories.array(value, device=arr.device, comm=arr.comm) + while value.ndim < arr.ndim: # broadcasting + value = value.expand_dims(0) + sanitation.sanitize_out(arr, value.shape, value.split, value.device, value.comm) + value = sanitation.sanitize_distribution(value, target=arr) + arr.larray[None] = value.larray + return + + if key is None or key == ... or key == slice(None): + return __set(self, value) + + advanced_indexing, self, key = self.__process_key(key) + if advanced_indexing: + raise Exception("Advanced indexing is not supported yet") + + split = self.split + if not self.is_distributed or key[split] == slice(None): + return __set(self[key], value) + + if isinstance(key[split], slice): + return __set(self[key], value) + + if np.isscalar(key[split]): + key = list(key) + idx = int(key[split]) + key[split] = slice(idx, idx + 1) + return __set(self[tuple(key)], value) + key = getattr(key, "copy()", key) try: if value.split != self.split: From 0c37abfef0f1605711e455516e252a02400b5a28 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 24 Feb 2022 17:37:34 +0100 Subject: [PATCH 005/132] Expand `__process_key()` to address advanced indexing. --- heat/core/dndarray.py | 59 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 4743cc3b4b..4a11cbcf64 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -667,36 +667,86 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] Indices for the tensor. """ advanced_indexing = False + advanced_indexing_dims = [] + + output_shape = list(arr.gshape) + # output_split = arr.split + if isinstance(key, DNDarray): # DNDARRAY CURRENTLY DOES NOT IMPLEMENT the Iterable interface, need to define __iter__() advanced_indexing = True + advanced_indexing_dims.append(0) # TODO: check for key.ndim = 0 and treat that as int # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim - elif isinstance(key, Iterable) and not isinstance(key, tuple): + elif isinstance(key, Iterable) and not isinstance(key, (tuple, list)): advanced_indexing = True + advanced_indexing_dims.append(0) key = factories.array(key) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for sequence of ndarrays though # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim - elif isinstance(key, tuple): + elif isinstance(key, (tuple, list)): key = list(key) for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(key, DNDarray): advanced_indexing = True + advanced_indexing_dims.append(i) + # TODO: specify split axis key[i] = factories.array(key[i]) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for seq of ndarrays though # TODO: check for key.ndim = 0 and treat that as int # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim + if advanced_indexing: + # shapes of indexing arrays must be broadcastable + advanced_indexing_shapes = tuple(key[i].shape for i in advanced_indexing_dims) + try: + broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) + except RuntimeError: + raise IndexError( + "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( + advanced_indexing_shapes + ) + ) + add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) + if ( + len(advanced_indexing_dims) == 1 + or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) + == advanced_indexing_dims + ): + # dimensions affected by advanced indexing are consecutive: + output_shape[advanced_indexing_dims[0]] = broadcasted_shape + else: + # advanced-indexing dimensions are not consecutive: + # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions + non_adv_ind_dims = list( + i for i in range(arr.ndim) if i not in advanced_indexing_dims + ) + arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) + output_shape = list(arr.gshape) + output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + # modify key to match the new dimension order + key = [key[i] for i in advanced_indexing_dims] + [ + key[i] for i in non_adv_ind_dims + ] + # update advanced-indexing dims + advanced_indexing_dims = list(range(len(advanced_indexing_dims))) + # expand dimensions of input array, key to match output_shape + if add_dims > 0: + for i in range(add_dims): + arr = arr.expand_dims(advanced_indexing_dims[0]) + key.insert(advanced_indexing_dims[0], slice(None)) + + # now check for ellipsis, newaxis add_dims = sum(k is None for k in key) # (np.newaxis is None)===true ellipsis = sum(isinstance(k, type(...)) for k in key) if ellipsis > 1: raise ValueError("key can only contain 1 ellipsis") - elif ellipsis == 1: + if ellipsis == 1: expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) expand_key[:ellipsis_index] = key[:ellipsis_index] expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] key = expand_key - if add_dims: + while add_dims > 0: for i, k in reversed(enumerate(key)): if k is None: key[i] = slice(None) @@ -706,6 +756,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key = tuple(key + [slice(None)] * (arr.ndim - len(key))) else: # key is integer or slice key = tuple([key] + [slice(None)] * (arr.ndim - 1)) + return advanced_indexing, arr, key def __get_local_slice(self, key: slice): From b1508b9be5665b8a304e407e33e9287fe53416ad Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 26 Feb 2022 08:09:18 +0100 Subject: [PATCH 006/132] Address boolean indexing --- heat/core/dndarray.py | 112 +++++++++++++++++++++++------------------- 1 file changed, 62 insertions(+), 50 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 4a11cbcf64..9724da65dd 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -655,6 +655,7 @@ def fill_diagonal(self, value: float) -> DNDarray: def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]]) -> Tuple: """ + TODO: expand docs. This function processes key, manipulates `arr` if necessary, returns the final output shape Private method for processing keys for indexing. Returns wether advanced indexing is used as well as a processed key and self. A processed key: - doesn't cotain any ellipses or newaxis @@ -672,68 +673,79 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = list(arr.gshape) # output_split = arr.split + if isinstance(key, Iterable) and not isinstance(key, (tuple, list)): + # key is np.ndarray or torch.Tensor + key = factories.array(key) + if isinstance(key, DNDarray): - # DNDARRAY CURRENTLY DOES NOT IMPLEMENT the Iterable interface, need to define __iter__() - advanced_indexing = True - advanced_indexing_dims.append(0) + if key.dtype in (canonical_heat_type.bool, canonical_heat_type.uint8): + # boolean indexing + if not key.gshape == arr.gshape: + raise IndexError( + "IndexError: shape of boolean index {} did not match shape of indexed array {}".format( + key.gshape, arr.gshape + ) + ) + key = indexing.nonzero(key) + # TODO: fix indexing.nonzero to return a tuple of 1D dndarrays + else: + advanced_indexing = True + advanced_indexing_dims.append(0) + # TODO: check for dimensions of indexing array here? # TODO: check for key.ndim = 0 and treat that as int # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim - elif isinstance(key, Iterable) and not isinstance(key, (tuple, list)): - advanced_indexing = True - advanced_indexing_dims.append(0) - key = factories.array(key) - # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for sequence of ndarrays though - # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim - elif isinstance(key, (tuple, list)): + if isinstance(key, (tuple, list)): key = list(key) for i, k in enumerate(key): - if isinstance(k, Iterable) or isinstance(key, DNDarray): + if isinstance(k, Iterable) or isinstance(k, DNDarray): advanced_indexing = True advanced_indexing_dims.append(i) # TODO: specify split axis - key[i] = factories.array(key[i]) + if not isinstance(k, DNDarray): + key[i] = factories.array(k) # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for seq of ndarrays though # TODO: check for key.ndim = 0 and treat that as int # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim - if advanced_indexing: - # shapes of indexing arrays must be broadcastable - advanced_indexing_shapes = tuple(key[i].shape for i in advanced_indexing_dims) - try: - broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) - except RuntimeError: - raise IndexError( - "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( - advanced_indexing_shapes - ) - ) - add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) - if ( - len(advanced_indexing_dims) == 1 - or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) - == advanced_indexing_dims - ): - # dimensions affected by advanced indexing are consecutive: - output_shape[advanced_indexing_dims[0]] = broadcasted_shape - else: - # advanced-indexing dimensions are not consecutive: - # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions - non_adv_ind_dims = list( - i for i in range(arr.ndim) if i not in advanced_indexing_dims + if advanced_indexing: + # shapes of indexing arrays must be broadcastable + advanced_indexing_shapes = tuple(key[i].shape for i in advanced_indexing_dims) + try: + broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) + except RuntimeError: + raise IndexError( + "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( + advanced_indexing_shapes ) - arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) - output_shape = list(arr.gshape) - output_shape[: len(advanced_indexing_dims)] = broadcasted_shape - # modify key to match the new dimension order - key = [key[i] for i in advanced_indexing_dims] + [ - key[i] for i in non_adv_ind_dims - ] - # update advanced-indexing dims - advanced_indexing_dims = list(range(len(advanced_indexing_dims))) - # expand dimensions of input array, key to match output_shape - if add_dims > 0: - for i in range(add_dims): - arr = arr.expand_dims(advanced_indexing_dims[0]) - key.insert(advanced_indexing_dims[0], slice(None)) + ) + add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) + if ( + len(advanced_indexing_dims) == 1 + or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) + == advanced_indexing_dims + ): + # dimensions affected by advanced indexing are consecutive: + output_shape[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = broadcasted_shape + else: + # advanced-indexing dimensions are not consecutive: + # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions + non_adv_ind_dims = list( + i for i in range(arr.ndim) if i not in advanced_indexing_dims + ) + arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) + output_shape = list(arr.gshape) + output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + # modify key to match the new dimension order + key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] + # update advanced-indexing dims + advanced_indexing_dims = list(range(len(advanced_indexing_dims))) + # expand dimensions of input array, key, to match output_shape + while add_dims > 0: + arr = arr.expand_dims(advanced_indexing_dims[0]) + key.insert(advanced_indexing_dims[0], slice(None)) + add_dims -= 1 # now check for ellipsis, newaxis add_dims = sum(k is None for k in key) # (np.newaxis is None)===true From ae5af94ad6befed23f01eab33fcbaa787bf85415 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 28 Feb 2022 08:58:14 +0100 Subject: [PATCH 007/132] separate advanced indexing on dim 0 from adv ind across dimensions --- heat/core/dndarray.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 9724da65dd..5ffce7a777 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -667,9 +667,6 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key : int, slice, Tuple[int,...], List[int,...] Indices for the tensor. """ - advanced_indexing = False - advanced_indexing_dims = [] - output_shape = list(arr.gshape) # output_split = arr.split @@ -680,6 +677,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if isinstance(key, DNDarray): if key.dtype in (canonical_heat_type.bool, canonical_heat_type.uint8): # boolean indexing + # transform to sequence of indexing arrays if not key.gshape == arr.gshape: raise IndexError( "IndexError: shape of boolean index {} did not match shape of indexed array {}".format( @@ -689,23 +687,22 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key = indexing.nonzero(key) # TODO: fix indexing.nonzero to return a tuple of 1D dndarrays else: - advanced_indexing = True - advanced_indexing_dims.append(0) - # TODO: check for dimensions of indexing array here? + # advanced indexing on first dimension + output_shape = list(key.gshape) + output_shape[1:] # TODO: check for key.ndim = 0 and treat that as int - # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim + + advanced_indexing = False + advanced_indexing_dims = [] if isinstance(key, (tuple, list)): key = list(key) for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(k, DNDarray): + # advanced indexing across dimensions advanced_indexing = True advanced_indexing_dims.append(i) - # TODO: specify split axis if not isinstance(k, DNDarray): key[i] = factories.array(k) - # DOES NOT WORK FOR SEQUENCE OF TENSORS OR DNDARRAYS, works for seq of ndarrays though # TODO: check for key.ndim = 0 and treat that as int - # TODO: get outshape + outsplit; depends on wether key is bool or int and key.ndim if advanced_indexing: # shapes of indexing arrays must be broadcastable advanced_indexing_shapes = tuple(key[i].shape for i in advanced_indexing_dims) @@ -742,10 +739,11 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # update advanced-indexing dims advanced_indexing_dims = list(range(len(advanced_indexing_dims))) # expand dimensions of input array, key, to match output_shape - while add_dims > 0: - arr = arr.expand_dims(advanced_indexing_dims[0]) - key.insert(advanced_indexing_dims[0], slice(None)) - add_dims -= 1 + # while add_dims > 0: + # # TODO: check this out, I think this is wrong or only right if added dimension is of size (1,) + # arr = arr.expand_dims(advanced_indexing_dims[0]) + # key.insert(advanced_indexing_dims[0], slice(None)) + # add_dims -= 1 # now check for ellipsis, newaxis add_dims = sum(k is None for k in key) # (np.newaxis is None)===true From 0a8cb356387b81e0e53a71a8d15bbfa0268cbb7a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 11 Mar 2022 05:39:54 +0100 Subject: [PATCH 008/132] Replace `sanitize_in` with `try:...except:` construct --- heat/core/indexing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index ac7598b9b9..939041bc30 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -51,16 +51,19 @@ def nonzero(x: DNDarray) -> DNDarray: >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ - sanitation.sanitize_in(x) + try: + local_x = x.larray + except AttributeError: + raise TypeError("Input must be a DNDarray, is {}".format(type(x))) if x.split is None: # if there is no split then just return the values from torch - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) gout = list(lcl_nonzero.size()) is_split = None else: # a is split - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) _, _, slices = x.comm.chunk(x.shape, x.split) lcl_nonzero[..., x.split] += slices[x.split].start gout = list(lcl_nonzero.size()) From 6c7c10ae8294c6f8bf4a98b6dc4fd97af8c3bc16 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 11 Mar 2022 05:43:02 +0100 Subject: [PATCH 009/132] `nonzero()`: do not assume input DNDarray is load-balanced --- heat/core/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 939041bc30..a4bbf7b8c7 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -64,8 +64,8 @@ def nonzero(x: DNDarray) -> DNDarray: else: # a is split lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) - _, _, slices = x.comm.chunk(x.shape, x.split) - lcl_nonzero[..., x.split] += slices[x.split].start + _, displs = x.counts_displs() + lcl_nonzero[..., x.split] += displs[x.comm.rank] gout = list(lcl_nonzero.size()) gout[0] = x.comm.allreduce(gout[0], MPI.SUM) is_split = 0 From fb3524bbb2945f497d8f03c2f6e0e6533b36343f Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 14 Mar 2022 15:33:47 +0100 Subject: [PATCH 010/132] Memory management --- heat/core/indexing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index a4bbf7b8c7..6261a072c6 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -64,8 +64,11 @@ def nonzero(x: DNDarray) -> DNDarray: else: # a is split lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) + # adjust local indices along split dimension _, displs = x.counts_displs() lcl_nonzero[..., x.split] += displs[x.comm.rank] + del displs + # get global size of split dimension gout = list(lcl_nonzero.size()) gout[0] = x.comm.allreduce(gout[0], MPI.SUM) is_split = 0 From eb297fbf8570a46164d85b6681f70f1301c3d340 Mon Sep 17 00:00:00 2001 From: Ashwath V A <73862377+Mystic-Slice@users.noreply.github.com> Date: Fri, 8 Apr 2022 14:26:54 +0530 Subject: [PATCH 011/132] fix #925: ht.nonzero() returns tuple of 1-D arrays instead of n-D arrays (#937) * Create ci.yaml * Update ci.yaml * Update ci.yaml * Create CITATION.cff * Update CITATION.cff * Update ci.yaml different python and pytorch versions * Update ci.yaml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete pre-commit.yml * Update ci.yaml * Update CITATION.cff * Update tutorial.ipynb delete example with different split axis * Delete logo_heAT.pdf Removal of old logo * ht.nonzero() returns tuple of 1-D arrays instead of n-D arrays * Updated documentation and Unit-tests * replace x.larray with local_x * Code fixes * Fix return type of nonzero function and gout value * Made sure DNDarray meta-data is available to the tuple members * Transpose before if-branching + adjustments to accomodate it * Fixed global shape assignment * Updated changelog Co-authored-by: mtar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Coquelin Co-authored-by: Markus Goetz Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> --- .github/workflows/ci.yaml | 44 ++++++++++++++++++++ .github/workflows/pre-commit.yml | 14 ------- CHANGELOG.md | 1 + CITATION.cff | 68 +++++++++++++++++++++++++++++++ doc/images/logo_heAT.pdf | Bin 1690 -> 0 bytes heat/core/dndarray.py | 4 +- heat/core/indexing.py | 57 ++++++++++++++------------ heat/core/tests/test_indexing.py | 14 +++---- scripts/tutorial.ipynb | 32 --------------- 9 files changed, 152 insertions(+), 82 deletions(-) create mode 100644 .github/workflows/ci.yaml delete mode 100644 .github/workflows/pre-commit.yml create mode 100644 CITATION.cff delete mode 100644 doc/images/logo_heAT.pdf diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000000..33237d4424 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,44 @@ +name: ci + +on: + pull_request_review: + types: [submitted] + +jobs: + approved: + if: github.event.review.state == 'approved' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + py-version: + - 3.7 + - 3.8 + mpi: [ 'openmpi' ] + install-options: [ '.', '.[hdf5,netcdf]' ] + pytorch-version: + - 'torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2' + - 'torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1' + - 'torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0' + + + name: Python ${{ matrix.py-version }} with ${{ matrix.pytorch-version }}; options ${{ matrix.install-options }} + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Setup MPI + uses: mpi4py/setup-mpi@v1 + with: + mpi: ${{ matrix.mpi }} + - name: Use Python ${{ matrix.py-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.py-version }} + architecture: x64 + - name: Test + run: | + pip install pytest + pip install ${{ matrix.pytorch-version }} -f https://download.pytorch.org/whl/torch_stable.html + pip install ${{ matrix.install-options }} + mpirun -n 3 pytest heat/ + mpirun -n 4 pytest heat/ diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index b52d4afe5c..0000000000 --- a/.github/workflows/pre-commit.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: pre-commit - -on: - pull_request: - push: - branches: [main] - -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index ddb06676e4..3dca59403b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN) - [#894](https://github.com/helmholtz-analytics/heat/pull/894) Change inclusion of license file - [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use. +- [#937](https://github.com/helmholtz-analytics/heat/pull/937) Modified `ht.nonzero()` to return a tuple of 1-D arrays containing the non-zero indices in each dimension. ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000000..b655ef2fcc --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,68 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: +- family-names: "Götz" + given-names: "Markus" +- family-names: "Debus" + given-names: "Charlotte" +- family-names: "Coquelin" + given-names: "Daniel" +- family-names: "Krajsek" + given-names: "Kai" +- family-names: "Comito" + given-names: "Claudia" +- family-names: "Knechtges" + given-names: "Philipp" +- family-names: "Hagemeier" + given-names: "Björn" +- family-names: "Tarnawa" + given-names: "Michael" +- family-names: "Hanselmann" + given-names: "Simon" +- family-names: "Siggel" + given-names: "Martin" +- family-names: "Basermann" + given-names: "Achim" +- family-names: "Streit" + given-names: "Achim" +title: "Heat - Helmholtz Analytics Toolkit" +version: 1.1.0 +date-released: 2021-09-21 +url: "https://github.com/helmholtz-analytics/heat" +preferred-citation: + type: conference-paper + authors: + - family-names: "Götz" + given-names: "Markus" + - family-names: "Debus" + given-names: "Charlotte" + - family-names: "Coquelin" + given-names: "Daniel" + - family-names: "Krajsek" + given-names: "Kai" + - family-names: "Comito" + given-names: "Claudia" + - family-names: "Knechtges" + given-names: "Philipp" + - family-names: "Hagemeier" + given-names: "Björn" + - family-names: "Tarnawa" + given-names: "Michael" + - family-names: "Hanselmann" + given-names: "Simon" + - family-names: "Siggel" + given-names: "Martin" + - family-names: "Basermann" + given-names: "Achim" + - family-names: "Streit" + given-names: "Achim" + title: "HeAT -- a Distributed and GPU-accelerated Tensor Framework for Data Analytics" + year: 2020 + collection-title: "2020 IEEE International Conference on Big Data (IEEE Big Data 2020)" + collection-doi: 10.1109/BigData50022.2020.9378050 + conference: + name: 2020 IEEE International Conference on Big Data (IEEE Big Data 2020) + date-start: 2020-12-10 + date-end: 2020-12-13 + start: 276 + end: 287 diff --git a/doc/images/logo_heAT.pdf b/doc/images/logo_heAT.pdf deleted file mode 100644 index d839eade2b7bed5a6a288bde60136fc4475df68d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1690 zcmZuy2T&AO7!E`ST@jdAG6=>;If0yTyT?@|6z{m>EJ#r-5M{kv?oRQR?CnXPU=WO& z!4e%xtdwA*h!Sv=&I#BDL$FLx(I^(ipeQ1!d3%%)CvSH4?fc&Q`}_W71xsb_mK8T*C2z(kGYdZreHe?2fYW-#(e{i7?L__g%DjFqM7mk;IDIQ19S`l|Xo6(`aQ)pq&u zM3jX#e3q@Q^T)2Ah~ALt>f*^m_0Q7>1Mk#u<8sCG?oIj3`zV=eXMepn?y|g&o~_!A zyj)!^s?a}$R*h4V_OI^fG?TA7GY6O*Y52n$i}Nsa<260cGRt?B+9lBQt_*b-ZhL{#KLcOWA!Wl;ATzjn>5#q{Ao+K_V-(6 z{hBt>JqVl5zIsDCA6LCSm}G5X8NKE2d6m^KtDx1~%CN!INw**S-16-{?{F=Xm+WMW zdsyR zX>I(j$hzfp=lN;xCO_%1@0wyPykB8RZiy`g*x>B#=Un1? zn@9U8f-84L$|x@PV(M#K&R%6sSI*g=1uuLK+49*>tLm1W>9;xarSwS0&7d-=)8cXs zd;5^_aF28S92e|tpg@!7_<2{}s(mg(Oq$u=+$YnQBv1>*9bXuCwq?t<0rOK6X05nq zmVRuO!|zjB^~Mj`L6*v+>A^|fVVwP?w;o`J+T)i;RDAV~Rm6&&gA#Vx1;^8^wPp^j zQGVvr+@tm!(@co(JL>A-Y&Ey(r03iYo9gbF7j3u$>(;(-30y=*&2BXolqpoUtC}ut zUw}KV^Nov`C_G-Du{!hmtGf#MZQr(!dCA985JIOK;R*a=7{<_wnIW-+Vj;87tUSl~ zXJTAHJS`8SA=`krHv_=I!BMyX9Em@`07r?#H{>APfN%JW=;4m(0i0zCu>{}*<7xu1 z0A|t~j8lY;hN7d?UP7p}_yRH>L_i1yA|d9%5IfGrVT&eo)ax+l2ZihOv5aM9!YHf&G-V)0R}y$iN^H_9iBS0h1{{uz z6H4ew1EnPNfXPqjxHy>zM*G#jaq1aa&LXW!5947{5jy6(feCw@0>L;1!4#=7C}D?l zRpMHT1egcL_rOr#s-fvvFAhvLMAZ}?tI;a;9weo9b2Ax|!2U;TNu87_l&jQ>i*iwv z$MwYE;ELHO*9Ar0#@Q5(vpMXv1gG&BQf<=46iPo*ntux#PZ7}wEDVB<4Itq2J^3IS zj9?fSgO`dKKsSn^Vi8Q)rx*s)_6dgmkE}=}{`4Hi6~X)QUs=pk46`_j(G;$YAt=)y uButw~K$n?fgpwo;n81f`j6xSp0w(vSV(Mv}qD>KEJe~+)u>zLLkbeM} DNDar output_split = None # data are not distributed or split dimension is not affected by indexing - if not self.is_distributed or key[self.split] == slice(None): + if not self.is_distributed() or key[self.split] == slice(None): return DNDarray( self.larray[key], gshape=output_shape, @@ -1654,7 +1654,7 @@ def __set(arr: DNDarray, value: DNDarray): raise Exception("Advanced indexing is not supported yet") split = self.split - if not self.is_distributed or key[split] == slice(None): + if not self.is_distributed() or key[split] == slice(None): return __set(self[key], value) if isinstance(key[split], slice): diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 6261a072c6..0452000c2f 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -13,12 +13,12 @@ __all__ = ["nonzero", "where"] -def nonzero(x: DNDarray) -> DNDarray: +def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: """ - Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero.. (using ``torch.nonzero``) - If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` + Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, + containing the indices of the non-zero elements in that dimension. If ``x`` is split then + the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. - Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. The values in ``x`` are always tested and returned in row-major, C-style order. The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. @@ -32,10 +32,8 @@ def nonzero(x: DNDarray) -> DNDarray: >>> import heat as ht >>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0) >>> ht.nonzero(x) - DNDarray([[0, 0], - [1, 1], - [1, 2], - [2, 1]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None)) >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0) >>> y > 3 DNDarray([[False, False, False], @@ -48,6 +46,8 @@ def nonzero(x: DNDarray) -> DNDarray: [2, 0], [2, 1], [2, 2]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None)) >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ @@ -56,39 +56,42 @@ def nonzero(x: DNDarray) -> DNDarray: except AttributeError: raise TypeError("Input must be a DNDarray, is {}".format(type(x))) + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False).transpose(0, 1) + if x.split is None: - # if there is no split then just return the values from torch - lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) + # if there is no split then just return the transpose of values from torch + gout = list(lcl_nonzero.size()) is_split = None else: # a is split - lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) # adjust local indices along split dimension _, displs = x.counts_displs() - lcl_nonzero[..., x.split] += displs[x.comm.rank] + lcl_nonzero[x.split] += displs[x.comm.rank] del displs + # get global size of split dimension gout = list(lcl_nonzero.size()) - gout[0] = x.comm.allreduce(gout[0], MPI.SUM) + gout[1] = x.comm.allreduce(gout[1], MPI.SUM) is_split = 0 - if x.ndim == 1: - lcl_nonzero = lcl_nonzero.squeeze(dim=1) - for g in range(len(gout) - 1, -1, -1): - if gout[g] == 1: - del gout[g] - - return DNDarray( - lcl_nonzero, - gshape=tuple(gout), - dtype=types.canonical_heat_type(lcl_nonzero.dtype), - split=is_split, - device=x.device, - comm=x.comm, - balanced=False, + non_zero_indices = list( + [ + DNDarray( + dim_indices, + gshape=tuple(gout), + dtype=types.canonical_heat_type(lcl_nonzero.dtype), + split=is_split, + device=x.device, + comm=x.comm, + balanced=False, + ) + for dim_indices in lcl_nonzero + ] ) + return tuple(non_zero_indices) + DNDarray.nonzero = lambda self: nonzero(self) DNDarray.nonzero.__doc__ = nonzero.__doc__ diff --git a/heat/core/tests/test_indexing.py b/heat/core/tests/test_indexing.py index 4707aa28ab..58c7410456 100644 --- a/heat/core/tests/test_indexing.py +++ b/heat/core/tests/test_indexing.py @@ -9,18 +9,18 @@ def test_nonzero(self): a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) cond = a > 3 nz = ht.nonzero(cond) - self.assertEqual(nz.gshape, (5, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, None) + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 5) + self.assertEqual(nz[0].dtype, ht.int64) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 nz = cond.nonzero() - self.assertEqual(nz.gshape, (6, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, 0) - a[nz] = 10.0 + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 6) + self.assertEqual(nz[0].dtype, ht.int64) + a[nz] = 10 self.assertEqual(ht.all(a[nz] == 10), 1) def test_where(self): diff --git a/scripts/tutorial.ipynb b/scripts/tutorial.ipynb index f2ce191bd2..95cc6e3465 100644 --- a/scripts/tutorial.ipynb +++ b/scripts/tutorial.ipynb @@ -1044,38 +1044,6 @@ "a + b" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The example below will show that it is also possible to use operations on tensors with different split and the proper result calculated. However, this should be used seldomly and with small data amounts only, as it entails sending large amounts of data over the network." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(0/2) tensor([[9., 9., 9., 9., 9., 9.],\n", - "(0/2) [9., 9., 9., 9., 9., 9.]])\n", - "(1/2) tensor([[9., 9., 9., 9., 9., 9.],\n", - "(1/2) [9., 9., 9., 9., 9., 9.]])" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = ht.full((4, 6,), 8, split=0)\n", - "b = ht.ones((4, 6,), split=1)\n", - "a + b" - ] - }, { "cell_type": "markdown", "metadata": {}, From a52e518dc53f52718093661ae44a934ee187c837 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 7 Jun 2022 15:44:04 +0200 Subject: [PATCH 012/132] calculate output_shape, split axis bookkeeping for advanced indexing --- heat/core/dndarray.py | 90 +++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 33 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 7d60c261d4..e9372ca065 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -654,8 +654,8 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] TODO: expand docs. This function processes key, manipulates `arr` if necessary, returns the final output shape Private method for processing keys for indexing. Returns wether advanced indexing is used as well as a processed key and self. A processed key: - - doesn't cotain any ellipses or newaxis - - all Iterables are converted to ``DNDarrays`` + - doesn't contain any ellipses or newaxis + - all Iterables are converted to ``DNDarrays`` TODO: NO, change this. - has the same dimensionality as the ``DNDarray`` it indexes Parameters @@ -664,44 +664,41 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] Indices for the tensor. """ output_shape = list(arr.gshape) - # output_split = arr.split - - if isinstance(key, Iterable) and not isinstance(key, (tuple, list)): - # key is np.ndarray or torch.Tensor - key = factories.array(key) - - if isinstance(key, DNDarray): - if key.dtype in (canonical_heat_type.bool, canonical_heat_type.uint8): - # boolean indexing - # transform to sequence of indexing arrays - if not key.gshape == arr.gshape: - raise IndexError( - "IndexError: shape of boolean index {} did not match shape of indexed array {}".format( - key.gshape, arr.gshape - ) - ) - key = indexing.nonzero(key) - # TODO: fix indexing.nonzero to return a tuple of 1D dndarrays - else: - # advanced indexing on first dimension - output_shape = list(key.gshape) + output_shape[1:] - # TODO: check for key.ndim = 0 and treat that as int + split_bookkeeping = [None] * arr.ndim + if arr.is_distributed(): + split_bookkeeping[arr.split] = "split" advanced_indexing = False - advanced_indexing_dims = [] + + if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): + if key.dtype in (types.bool, types.uint8, torch.bool, torch.uint8, np.bool, np.uint8): + # boolean indexing: transform to sequence of indexing (1-D) arrays + try: + # torch.Tensor key + key = key.nonzero(as_tuple=True) + except AttributeError: + # np.array or DNDarray key + key = key.nonzero() + else: + # advanced indexing on first dimension: first dim expands to shape of key + output_shape = list(key.shape) + output_shape[1:] + # adjust split axis accordingly + split_bookkeeping = [None] * (len(key.shape) - 1) + split_bookkeeping[1:] + if isinstance(key, (tuple, list)): key = list(key) + advanced_indexing_dims = [] for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(k, DNDarray): # advanced indexing across dimensions advanced_indexing = True advanced_indexing_dims.append(i) if not isinstance(k, DNDarray): - key[i] = factories.array(k) - # TODO: check for key.ndim = 0 and treat that as int + key[i] = torch.tensor(k) + if advanced_indexing: + advanced_indexing_shapes = tuple(tuple(key[i].shape) for i in advanced_indexing_dims) # shapes of indexing arrays must be broadcastable - advanced_indexing_shapes = tuple(key[i].shape for i in advanced_indexing_dims) try: broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) except RuntimeError: @@ -721,6 +718,11 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] advanced_indexing_dims[0] : advanced_indexing_dims[0] + len(advanced_indexing_dims) ] = broadcasted_shape + split_bookkeeping = ( + split_bookkeeping[: advanced_indexing_dims[0]] + + [None] * add_dims + + split_bookkeeping[advanced_indexing_dims[0] :] + ) else: # advanced-indexing dimensions are not consecutive: # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions @@ -730,6 +732,10 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) output_shape = list(arr.gshape) output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + split_bookkeeping = [None] * arr.ndim + if arr.is_distributed: + split_bookkeeping[arr.split] = "split" + split_bookkeeping = [None] * add_dims + split_bookkeeping # modify key to match the new dimension order key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] # update advanced-indexing dims @@ -820,13 +826,16 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases + print("DEBUGGING: RAW KEY = ", key) if key is None: return self.expand_dims(0) - if key == ... or key == slice(None): # latter doesnt work with torch for 0-dim tensors + if ( + key is ... or isinstance(key, slice) and key == slice(None) + ): # latter doesnt work with torch for 0-dim tensors return self # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays advanced_indexing, self, key = self.__process_key(key) - + print("DEBUGGING: AFTER PROCESSING KEY = ", key, type(key)) # To use torch_proxy with advanced indexing, add empty dimensions instead of # advanced index. Later, replace the empty dimensions with the shape of the advanced index proxy = self @@ -834,9 +843,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar "split" if (proxy.split is not None and i == proxy.split) else "_{}".format(i) for i in range(proxy.ndim) ] - proxy_key = list(key) + proxy_key = list(key) # copy OR IS THIS REALLY NEEDED?? + print("DEBUGGING: proxy_key, ADVANCED_INDEXING", proxy_key, advanced_indexing) if advanced_indexing: - proxy_key = list(key) for i, k in reversed(enumerate(key)): if isinstance(k, DNDarray): # all iterables have been made DNDarrays # TODO: Bool indexing (sometimes) is collapsed into one dimension @@ -851,9 +860,16 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar self_proxy = proxy.__torch_proxy__() self_proxy.names = names + print("DEBUGGING: self_proxy = ", self_proxy) + print("debugging: proxy_key", proxy_key) + print("DEBUGGING: self_proxy.shape", self_proxy.shape) + print("DEBUGGING: type(self_proxy)", type(self_proxy)) + indexed_proxy = self_proxy[proxy_key] + print("DEBUGGING: indexed_proxy = ", indexed_proxy) output_shape = list(indexed_proxy.shape) + print("DEBUGGING: output_shape = ", output_shape) if advanced_indexing: for i, n in enumerate(indexed_proxy.names): if "replace" in n: @@ -867,8 +883,11 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar except ValueError: output_split = None + print("DEBUGGING: output_split = ", output_split) # data are not distributed or split dimension is not affected by indexing if not self.is_distributed() or key[self.split] == slice(None): + print("DEBUGGING: NOT DISTRIBUTED OR SPLIT DIMENSION NOT AFFECTED BY INDEXING") + print("DEBUGGING: output_shape = ", output_shape) return DNDarray( self.larray[key], gshape=output_shape, @@ -1224,7 +1243,11 @@ def __len__(self) -> int: """ The length of the ``DNDarray``, i.e. the number of items in the first dimension. """ - return self.shape[0] + try: + len = self.shape[0] + return len + except IndexError: + raise TypeError("len() of unsized DNDarray") def numpy(self) -> np.array: """ @@ -2032,4 +2055,5 @@ def __xitem_get_key_start_stop( from .devices import Device from .stride_tricks import sanitize_axis +import types from .types import datatype, canonical_heat_type From 59956398d91ffaad3e4d755dcf0b44d5fbbb8254 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 12 Jul 2022 14:22:45 +0200 Subject: [PATCH 013/132] `__process_key()` to return expanded array, expanded key, output gshape and new split axis --- heat/core/dndarray.py | 155 ++++++++++++++++++++++++------------------ 1 file changed, 88 insertions(+), 67 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index e9372ca065..75a50a9027 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -651,11 +651,12 @@ def fill_diagonal(self, value: float) -> DNDarray: def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]]) -> Tuple: """ - TODO: expand docs. This function processes key, manipulates `arr` if necessary, returns the final output shape + TODO: expand docs!! + This function processes key, manipulates `arr` if necessary, returns the final output shape Private method for processing keys for indexing. Returns wether advanced indexing is used as well as a processed key and self. A processed key: - doesn't contain any ellipses or newaxis - - all Iterables are converted to ``DNDarrays`` TODO: NO, change this. + - all Iterables are converted to torch tensors - has the same dimensionality as the ``DNDarray`` it indexes Parameters @@ -687,89 +688,109 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if isinstance(key, (tuple, list)): key = list(key) - advanced_indexing_dims = [] - for i, k in enumerate(key): - if isinstance(k, Iterable) or isinstance(k, DNDarray): - # advanced indexing across dimensions - advanced_indexing = True - advanced_indexing_dims.append(i) - if not isinstance(k, DNDarray): - key[i] = torch.tensor(k) - if advanced_indexing: - advanced_indexing_shapes = tuple(tuple(key[i].shape) for i in advanced_indexing_dims) - # shapes of indexing arrays must be broadcastable - try: - broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) - except RuntimeError: - raise IndexError( - "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( - advanced_indexing_shapes - ) - ) - add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) - if ( - len(advanced_indexing_dims) == 1 - or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) - == advanced_indexing_dims - ): - # dimensions affected by advanced indexing are consecutive: - output_shape[ - advanced_indexing_dims[0] : advanced_indexing_dims[0] - + len(advanced_indexing_dims) - ] = broadcasted_shape - split_bookkeeping = ( - split_bookkeeping[: advanced_indexing_dims[0]] - + [None] * add_dims - + split_bookkeeping[advanced_indexing_dims[0] :] - ) - else: - # advanced-indexing dimensions are not consecutive: - # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions - non_adv_ind_dims = list( - i for i in range(arr.ndim) if i not in advanced_indexing_dims - ) - arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) - output_shape = list(arr.gshape) - output_shape[: len(advanced_indexing_dims)] = broadcasted_shape - split_bookkeeping = [None] * arr.ndim - if arr.is_distributed: - split_bookkeeping[arr.split] = "split" - split_bookkeeping = [None] * add_dims + split_bookkeeping - # modify key to match the new dimension order - key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] - # update advanced-indexing dims - advanced_indexing_dims = list(range(len(advanced_indexing_dims))) - # expand dimensions of input array, key, to match output_shape - # while add_dims > 0: - # # TODO: check this out, I think this is wrong or only right if added dimension is of size (1,) - # arr = arr.expand_dims(advanced_indexing_dims[0]) - # key.insert(advanced_indexing_dims[0], slice(None)) - # add_dims -= 1 - - # now check for ellipsis, newaxis + # check for ellipsis, newaxis add_dims = sum(k is None for k in key) # (np.newaxis is None)===true ellipsis = sum(isinstance(k, type(...)) for k in key) if ellipsis > 1: raise ValueError("key can only contain 1 ellipsis") + # replace with explicit `slice(None)` for interested dimensions if ellipsis == 1: + # output_shape, split_bookkeeping not affected expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) expand_key[:ellipsis_index] = key[:ellipsis_index] expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] key = expand_key while add_dims > 0: - for i, k in reversed(enumerate(key)): + # expand array dims, output_shape, split_bookkeeping to reflect newaxis + # replace newaxis with slice(None) in key + for i, k in reversed(list(enumerate(key))): if k is None: key[i] = slice(None) - arr = arr.expand_dims(i - add_dims + 1) # is the -1 correct? + arr = arr.expand_dims(i - add_dims + 1) + output_shape = ( + output_shape[: i - add_dims + 1] + + [1] + + output_shape[i - add_dims + 1 :] + ) + split_bookkeeping = ( + split_bookkeeping[: i - add_dims + 1] + + [None] + + split_bookkeeping[i - add_dims + 1 :] + ) add_dims -= 1 + + # check for advanced indexing + advanced_indexing_dims = [] + for i, k in enumerate(key): + if isinstance(k, Iterable) or isinstance(k, DNDarray): + # advanced indexing across dimensions + advanced_indexing = True + advanced_indexing_dims.append(i) + if not isinstance(k, DNDarray): + key[i] = torch.tensor(k) + + if advanced_indexing: + advanced_indexing_shapes = tuple( + tuple(key[i].shape) for i in advanced_indexing_dims + ) + # shapes of indexing arrays must be broadcastable + try: + broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) + except RuntimeError: + raise IndexError( + "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( + advanced_indexing_shapes + ) + ) + add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) + if ( + len(advanced_indexing_dims) == 1 + or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) + == advanced_indexing_dims + ): + # dimensions affected by advanced indexing are consecutive: + output_shape[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = broadcasted_shape + split_bookkeeping = ( + split_bookkeeping[: advanced_indexing_dims[0]] + + [None] * add_dims + + split_bookkeeping[advanced_indexing_dims[0] :] + ) + else: + # advanced-indexing dimensions are not consecutive: + # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions + non_adv_ind_dims = list( + i for i in range(arr.ndim) if i not in advanced_indexing_dims + ) + arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) + output_shape = list(arr.gshape) + output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + split_bookkeeping = [None] * arr.ndim + if arr.is_distributed: + split_bookkeeping[arr.split] = "split" + split_bookkeeping = [None] * add_dims + split_bookkeeping + # modify key to match the new dimension order + key = [key[i] for i in advanced_indexing_dims] + [ + key[i] for i in non_adv_ind_dims + ] + # update advanced-indexing dims + advanced_indexing_dims = list(range(len(advanced_indexing_dims))) + # expand key to match the number of dimensions of the DNDarray - key = tuple(key + [slice(None)] * (arr.ndim - len(key))) + if arr.ndim > len(key): + key += [slice(None)] * (arr.ndim - len(key)) else: # key is integer or slice - key = tuple([key] + [slice(None)] * (arr.ndim - 1)) + key = [key] + [slice(None)] * (arr.ndim - 1) + + key = tuple(key) + output_shape = tuple(output_shape) + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - return advanced_indexing, arr, key + return advanced_indexing, arr, key, output_shape, new_split def __get_local_slice(self, key: slice): split = self.split From 3830e62fc328b19737d25046fb1f200c2da0e315 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 25 Aug 2022 05:00:46 +0200 Subject: [PATCH 014/132] in , copy before manipulations --- heat/core/dndarray.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 75a50a9027..a90ddfe507 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -670,6 +670,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_bookkeeping[arr.split] = "split" advanced_indexing = False + arr_is_copy = False if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): if key.dtype in (types.bool, types.uint8, torch.bool, torch.uint8, np.bool, np.uint8): @@ -708,6 +709,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] for i, k in reversed(list(enumerate(key))): if k is None: key[i] = slice(None) + if not arr_is_copy: + arr = arr.copy() + arr_is_copy = True arr = arr.expand_dims(i - add_dims + 1) output_shape = ( output_shape[: i - add_dims + 1] @@ -766,6 +770,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] non_adv_ind_dims = list( i for i in range(arr.ndim) if i not in advanced_indexing_dims ) + if not arr_is_copy: + arr = arr.copy() + arr_is_copy = True arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) output_shape = list(arr.gshape) output_shape[: len(advanced_indexing_dims)] = broadcasted_shape From 82b25086cd774cb9ef8b57de9da9de6b907087fe Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 27 Aug 2022 07:28:17 +0200 Subject: [PATCH 015/132] nonzero() to return tuple of 1D arrays, stable distributed results --- heat/core/indexing.py | 87 +++++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 32 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 0452000c2f..9946049185 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -9,12 +9,14 @@ from .dndarray import DNDarray from . import sanitation from . import types +from . import manipulations __all__ = ["nonzero", "where"] def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: """ + TODO: UPDATE DOCS! Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` @@ -56,41 +58,62 @@ def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: except AttributeError: raise TypeError("Input must be a DNDarray, is {}".format(type(x))) - lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False).transpose(0, 1) - - if x.split is None: - # if there is no split then just return the transpose of values from torch - - gout = list(lcl_nonzero.size()) - is_split = None + if not x.is_distributed(): + # nonzero indices as tuple + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=True) + # bookkeeping for final DNDarray construct + output_shape = (lcl_nonzero[0].shape,) + output_split = None else: - # a is split - # adjust local indices along split dimension + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) + nonzero_size = torch.tensor( + lcl_nonzero.shape[0], dtype=torch.int64, device=lcl_nonzero.device + ) + # construct global DNDarray of nz indices: + # global shape and split + x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) + output_shape = (nonzero_size.item(), x.ndim) + output_split = 0 + # correct indices along split axis _, displs = x.counts_displs() - lcl_nonzero[x.split] += displs[x.comm.rank] - del displs - - # get global size of split dimension - gout = list(lcl_nonzero.size()) - gout[1] = x.comm.allreduce(gout[1], MPI.SUM) - is_split = 0 - - non_zero_indices = list( - [ - DNDarray( - dim_indices, - gshape=tuple(gout), - dtype=types.canonical_heat_type(lcl_nonzero.dtype), - split=is_split, - device=x.device, - comm=x.comm, - balanced=False, - ) - for dim_indices in lcl_nonzero - ] - ) + lcl_nonzero[:, x.split] += displs[x.comm.rank] + global_nonzero = DNDarray( + lcl_nonzero, + gshape=output_shape, + dtype=types.int64, + split=output_split, + device=x.device, + comm=x.comm, + balanced=False, + ) + # stabilize distributed result: vectorize sorting of nz indices along axis 0 + global_nonzero.balance_() + global_nonzero = manipulations.unique(global_nonzero, axis=0) + # return indices as tuple of columns + lcl_nonzero = global_nonzero.larray.split(1, dim=1) + # bookkeeping for final DNDarray construct + output_shape = (global_nonzero.shape[0],) + output_split = 0 + + # return global_nonzero as tuple of DNDarrays + global_nonzero = list(lcl_nonzero) + for i, nz_tensor in enumerate(global_nonzero): + if nz_tensor.ndim > 1: + # extra dimension in distributed case from usage of torch.split() + nz_tensor = nz_tensor.squeeze() + nz_array = DNDarray( + nz_tensor, + gshape=output_shape, + dtype=types.int64, + split=output_split, + device=x.device, + comm=x.comm, + balanced=True, + ) + global_nonzero[i] = nz_array + global_nonzero = tuple(global_nonzero) - return tuple(non_zero_indices) + return tuple(global_nonzero) DNDarray.nonzero = lambda self: nonzero(self) From aafaf99be06f1c74fafb8859365ffb5e4eefebb5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 27 Aug 2022 07:39:01 +0200 Subject: [PATCH 016/132] update __process_key(), get rid of recursive calls, __getitem__ broken --- heat/core/dndarray.py | 132 +++++++++++++++--------------------------- 1 file changed, 48 insertions(+), 84 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a90ddfe507..24c4796a60 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -673,19 +673,24 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] arr_is_copy = False if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): - if key.dtype in (types.bool, types.uint8, torch.bool, torch.uint8, np.bool, np.uint8): + if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool, np.uint8): # boolean indexing: transform to sequence of indexing (1-D) arrays try: # torch.Tensor key key = key.nonzero(as_tuple=True) - except AttributeError: + except TypeError: # np.array or DNDarray key key = key.nonzero() else: - # advanced indexing on first dimension: first dim expands to shape of key - output_shape = list(key.shape) + output_shape[1:] + # advanced indexing on first dimension: first dim will expand to shape of key + advanced_indexing = True + output_shape = tuple(list(key.shape) + output_shape[1:]) # adjust split axis accordingly split_bookkeeping = [None] * (len(key.shape) - 1) + split_bookkeeping[1:] + new_split = ( + split_bookkeeping.index("split") if "split" in split_bookkeeping else None + ) + return arr, key, output_shape, new_split, advanced_indexing if isinstance(key, (tuple, list)): key = list(key) @@ -733,12 +738,13 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] advanced_indexing = True advanced_indexing_dims.append(i) if not isinstance(k, DNDarray): - key[i] = torch.tensor(k) + key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) if advanced_indexing: advanced_indexing_shapes = tuple( tuple(key[i].shape) for i in advanced_indexing_dims ) + print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) # shapes of indexing arrays must be broadcastable try: broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) @@ -797,7 +803,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = tuple(output_shape) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - return advanced_indexing, arr, key, output_shape, new_split + return arr, key, output_shape, new_split, advanced_indexing def __get_local_slice(self, key: slice): split = self.split @@ -862,62 +868,20 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ): # latter doesnt work with torch for 0-dim tensors return self # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays - advanced_indexing, self, key = self.__process_key(key) - print("DEBUGGING: AFTER PROCESSING KEY = ", key, type(key)) - # To use torch_proxy with advanced indexing, add empty dimensions instead of - # advanced index. Later, replace the empty dimensions with the shape of the advanced index - proxy = self - names = [ - "split" if (proxy.split is not None and i == proxy.split) else "_{}".format(i) - for i in range(proxy.ndim) - ] - proxy_key = list(key) # copy OR IS THIS REALLY NEEDED?? - print("DEBUGGING: proxy_key, ADVANCED_INDEXING", proxy_key, advanced_indexing) - if advanced_indexing: - for i, k in reversed(enumerate(key)): - if isinstance(k, DNDarray): # all iterables have been made DNDarrays - # TODO: Bool indexing (sometimes) is collapsed into one dimension - # TODO: What to do if advanced index is in split dimension?? - names[i] = "replace" + str(k.shape) # put shape into name - proxy_key[i] = slice(None) - for _ in range(k.ndim - 1): - proxy = proxy.expand_dims(i) - names.insert(i + 1, "_{}".format(len(names))) - proxy_key.insert(i + 1, slice(None)) - proxy_key = tuple(proxy_key) - - self_proxy = proxy.__torch_proxy__() - self_proxy.names = names - print("DEBUGGING: self_proxy = ", self_proxy) - print("debugging: proxy_key", proxy_key) - print("DEBUGGING: self_proxy.shape", self_proxy.shape) - print("DEBUGGING: type(self_proxy)", type(self_proxy)) - - indexed_proxy = self_proxy[proxy_key] - print("DEBUGGING: indexed_proxy = ", indexed_proxy) - - output_shape = list(indexed_proxy.shape) - print("DEBUGGING: output_shape = ", output_shape) - if advanced_indexing: - for i, n in enumerate(indexed_proxy.names): - if "replace" in n: - shape = eval(n.split("replace")[1]) # extract shape from name - # TODO Bool indexing (sometimes) is collapsed into one dimension - output_shape[i : i + len(shape)] = shape - output_shape = tuple(output_shape) + self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) - try: - output_split = indexed_proxy.names.index("split") - except ValueError: - output_split = None + # TODO: test that key for not affected dims is always slice(None) + # including match between self.split and key after self manipulation - print("DEBUGGING: output_split = ", output_split) # data are not distributed or split dimension is not affected by indexing if not self.is_distributed() or key[self.split] == slice(None): - print("DEBUGGING: NOT DISTRIBUTED OR SPLIT DIMENSION NOT AFFECTED BY INDEXING") - print("DEBUGGING: output_shape = ", output_shape) + try: + indexed_arr = self.larray[key.larray.long()] + except AttributeError: + # key is an ndarray + indexed_arr = self.larray[key] return DNDarray( - self.larray[key], + indexed_arr, gshape=output_shape, dtype=self.dtype, split=output_split, @@ -926,32 +890,32 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # data are distributed and split dimension is affected by indexing - _, offsets = self.counts_displs() - split = self.split - # slice along the split axis - if isinstance(key[split], slice): - local_slice = self.__get_local_slice(key[split]) - if local_slice is not None: - key = list(key) - key[split] = local_slice - local_tensor = self.larray[tuple(key)] - else: # local tensor is empty - local_shape = list(output_shape) - local_shape[output_split] = 0 - local_tensor = torch.zeros( - tuple(local_shape), dtype=self.larray.dtype, device=self.larray.device - ) + # # data are distributed and split dimension is affected by indexing + # _, offsets = self.counts_displs() + # split = self.split + # # slice along the split axis + # if isinstance(key[split], slice): + # local_slice = self.__get_local_slice(key[split]) + # if local_slice is not None: + # key = list(key) + # key[split] = local_slice + # local_tensor = self.larray[tuple(key)] + # else: # local tensor is empty + # local_shape = list(output_shape) + # local_shape[output_split] = 0 + # local_tensor = torch.zeros( + # tuple(local_shape), dtype=self.larray.dtype, device=self.larray.device + # ) - return DNDarray( - local_tensor, - gshape=output_shape, - dtype=self.dtype, - split=output_split, - device=self.device, - balanced=False, - comm=self.comm, - ) + # return DNDarray( + # local_tensor, + # gshape=output_shape, + # dtype=self.dtype, + # split=output_split, + # device=self.device, + # balanced=False, + # comm=self.comm, + # ) # local indexing cases: # self is not distributed, key is not distributed - DONE @@ -1696,7 +1660,7 @@ def __set(arr: DNDarray, value: DNDarray): if key is None or key == ... or key == slice(None): return __set(self, value) - advanced_indexing, self, key = self.__process_key(key) + self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) if advanced_indexing: raise Exception("Advanced indexing is not supported yet") @@ -2084,4 +2048,4 @@ def __xitem_get_key_start_stop( from .devices import Device from .stride_tricks import sanitize_axis import types -from .types import datatype, canonical_heat_type +from .types import datatype, canonical_heat_type, bool, uint8 From b7468723010860eefbd2c0f6eb5c23702031935c Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 30 Aug 2022 10:59:48 +0200 Subject: [PATCH 017/132] deal with scalar key, local and distributed cases --- heat/core/dndarray.py | 62 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 24c4796a60..1c5f3e737f 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -861,15 +861,71 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases print("DEBUGGING: RAW KEY = ", key) + # early out: key is a scalar + scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 + if scalar: + output_shape = self.gshape[1:] + try: + # is key an ndarray, DNDarray or torch tensor? + key = key.copy().item() + except AttributeError: + # key is already an integer, do nothing + pass + if not self.is_distributed() or self.split != 0: + indexed_arr = self.larray[key] + output_split = None if self.split is None else self.split - 1 + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=self.balanced, + ) + return indexed_arr + # check for negative key + key = key + self.shape[0] if key < 0 else key + # identify root process + _, displs = self.counts_displs() + if key in displs: + root = displs.index(key) + else: + displs = torch.cat((torch.tensor(displs), torch.tensor(key).reshape(-1)), dim=0) + _, sorted_indices = displs.unique(sorted=True, return_inverse=True) + root = sorted_indices[-1] - 1 + # correct key for relevant displacement + key -= displs[root] + # allocate buffer on all processes + if self.comm.rank == root: + indexed_arr = self.larray[key] + else: + indexed_arr = torch.zeros( + output_shape, dtype=self.larray.dtype, device=self.larray.device + ) + # broadcast result to all processes + self.comm.Bcast(indexed_arr, root=root) + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=None, + device=self.device, + comm=self.comm, + balanced=True, + ) + return indexed_arr + if key is None: return self.expand_dims(0) if ( key is ... or isinstance(key, slice) and key == slice(None) ): # latter doesnt work with torch for 0-dim tensors return self + # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) - + print("DEBUGGING: processed key = ", key) # TODO: test that key for not affected dims is always slice(None) # including match between self.split and key after self manipulation @@ -1661,8 +1717,8 @@ def __set(arr: DNDarray, value: DNDarray): return __set(self, value) self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) - if advanced_indexing: - raise Exception("Advanced indexing is not supported yet") + # if advanced_indexing: + # raise Exception("Advanced indexing is not supported yet") split = self.split if not self.is_distributed() or key[split] == slice(None): From 00fe5380c8cdc65e22e797539822b48ff5de7fe6 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 30 Aug 2022 11:02:21 +0200 Subject: [PATCH 018/132] test getitem separately, follow numpy Indexing on ndarray examples --- heat/core/tests/test_dndarray.py | 929 ++++++++++++++++--------------- 1 file changed, 482 insertions(+), 447 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e42c5a9a14..5dd5ced775 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -6,15 +6,15 @@ class TestDNDarray(TestCase): - @classmethod - def setUpClass(cls): - super(TestDNDarray, cls).setUpClass() - N = ht.MPI_WORLD.size - cls.reference_tensor = ht.zeros((N, N + 1, 2 * N)) + # @classmethod + # def setUpClass(cls): + # super(TestDNDarray, cls).setUpClass() + # N = ht.MPI_WORLD.size + # cls.reference_tensor = ht.zeros((N, N + 1, 2 * N)) - for n in range(N): - for m in range(N + 1): - cls.reference_tensor[n, m, :] = ht.arange(0, 2 * N) + m * 10 + n * 100 + # for n in range(N): + # for m in range(N + 1): + # cls.reference_tensor[n, m, :] = ht.arange(0, 2 * N) + m * 10 + n * 100 def test_and(self): int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16) @@ -516,6 +516,41 @@ def test_float_cast(self): with self.assertRaises(TypeError): float(ht.full((ht.MPI_WORLD.size,), 2, split=0)) + def test_getitem(self): + # following https://numpy.org/doc/stable/user/basics.indexing.html + + # Single element indexing + # 1D, local + x = ht.arange(10) + self.assertTrue(x[2].item() == 2) + self.assertTrue(x[-2].item() == 8) + self.assertTrue(x[2].dtype == ht.int32) + # 1D, distributed + x = ht.arange(10, split=0, dtype=ht.float64) + self.assertTrue(x[2].item() == 2.0) + self.assertTrue(x[-2].item() == 8.0) + self.assertTrue(x[2].dtype == ht.float64) + self.assertTrue(x[2].split is None) + # 3D, local + x = ht.arange(27).reshape(3, 3, 3) + key = -2 + indexed = x[key] + self.assertTrue((indexed.larray == x.larray[key]).all()) + self.assertTrue(indexed.dtype == ht.int32) + self.assertTrue(indexed.split is None) + # 3D, distributed, split = 0 + x_split0 = ht.array(x, dtype=ht.float32, split=0) + indexed_split0 = x_split0[key] + self.assertTrue((indexed_split0.larray == x.larray[key]).all()) + self.assertTrue(indexed_split0.dtype == ht.float32) + self.assertTrue(indexed_split0.split is None) + # 3D, distributed split, != 0 + x_split2 = ht.array(x, dtype=ht.int64, split=2) + indexed_split2 = x_split2[key] + self.assertTrue((indexed_split2.numpy() == x.numpy()[key]).all()) + self.assertTrue(indexed_split2.dtype == ht.int64) + self.assertTrue(indexed_split2.split == 1) + def test_int_cast(self): # simple scalar tensor a = ht.ones(1) @@ -1053,445 +1088,445 @@ def test_rshift(self): res = ht.right_shift(ht.array([True]), 2) self.assertTrue(res == 0) - def test_setitem_getitem(self): - # tests for bug #825 - a = ht.ones((102, 102), split=0) - setting = ht.zeros((100, 100), split=0) - a[1:-1, 1:-1] = setting - self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) - - a = ht.ones((102, 102), split=1) - setting = ht.zeros((30, 100), split=1) - a[-30:, 1:-1] = setting - self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) - - a = ht.ones((102, 102), split=1) - setting = ht.zeros((100, 100), split=1) - a[1:-1, 1:-1] = setting - self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) - - a = ht.ones((102, 102), split=1) - setting = ht.zeros((100, 20), split=1) - a[1:-1, :20] = setting - self.assertTrue(ht.all(a[1:-1, :20] == 0)) - - # tests for bug 730: - a = ht.ones((10, 25, 30), split=1) - if a.comm.size > 1: - self.assertEqual(a[0].split, 0) - self.assertEqual(a[:, 0, :].split, None) - self.assertEqual(a[:, :, 0].split, 1) - - # set and get single value - a = ht.zeros((13, 5), split=0) - # set value on one node - a[10, np.array(0)] = 1 - self.assertEqual(a[10, 0], 1) - self.assertEqual(a[10, 0].dtype, ht.float32) - - a = ht.zeros((13, 5), split=0) - a[10] = 1 - b = a[torch.tensor(10)] - self.assertTrue((b == 1).all()) - self.assertEqual(b.dtype, ht.float32) - self.assertEqual(b.gshape, (5,)) - - a = ht.zeros((13, 5), split=0) - a[-1] = 1 - b = a[-1] - self.assertTrue((b == 1).all()) - self.assertEqual(b.dtype, ht.float32) - self.assertEqual(b.gshape, (5,)) - - # slice in 1st dim only on 1 node - a = ht.zeros((13, 5), split=0) - a[1:4] = 1 - self.assertTrue((a[1:4] == 1).all()) - self.assertEqual(a[1:4].gshape, (3, 5)) - self.assertEqual(a[1:4].split, 0) - self.assertEqual(a[1:4].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4].lshape, (3, 5)) - else: - self.assertEqual(a[1:4].lshape, (0, 5)) - - a = ht.zeros((13, 5), split=0) - a[1:2] = 1 - self.assertTrue((a[1:2] == 1).all()) - self.assertEqual(a[1:2].gshape, (1, 5)) - self.assertEqual(a[1:2].split, 0) - self.assertEqual(a[1:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:2].lshape, (1, 5)) - else: - self.assertEqual(a[1:2].lshape, (0, 5)) - - # slice in 1st dim only on 1 node w/ singular second dim - a = ht.zeros((13, 5), split=0) - a[1:4, 1] = 1 - b = a[1:4, np.int64(1)] - self.assertTrue((b == 1).all()) - self.assertEqual(b.gshape, (3,)) - self.assertEqual(b.split, 0) - self.assertEqual(b.dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(b.lshape, (3,)) - else: - self.assertEqual(b.lshape, (0,)) - - # slice in 1st dim across both nodes (2 node case) w/ singular second dim - a = ht.zeros((13, 5), split=0) - a[1:11, 1] = 1 - self.assertTrue((a[1:11, 1] == 1).all()) - self.assertEqual(a[1:11, 1].gshape, (10,)) - self.assertEqual(a[1:11, torch.tensor(1)].split, 0) - self.assertEqual(a[1:11, 1].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[1:11, 1].lshape, (4,)) - if a.comm.rank == 0: - self.assertEqual(a[1:11, 1].lshape, (6,)) - - # slice in 1st dim across 1 node (2nd) w/ singular second dim - c = ht.zeros((13, 5), split=0) - c[8:12, ht.array(1)] = 1 - b = c[8:12, np.int64(1)] - self.assertTrue((b == 1).all()) - self.assertEqual(b.gshape, (4,)) - self.assertEqual(b.split, 0) - self.assertEqual(b.dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(b.lshape, (4,)) - if a.comm.rank == 0: - self.assertEqual(b.lshape, (0,)) - - # slice in both directions - a = ht.zeros((13, 5), split=0) - a[3:13, 2:5:2] = 1 - self.assertTrue((a[3:13, 2:5:2] == 1).all()) - self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) - self.assertEqual(a[3:13, 2:5:2].split, 0) - self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[3:13, 2:5:2].lshape, (6, 2)) - if a.comm.rank == 0: - self.assertEqual(a[3:13, 2:5:2].lshape, (4, 2)) - - # setting with heat tensor - a = ht.zeros((4, 5), split=0) - a[1, 0:4] = ht.arange(4) - # if a.comm.size == 2: - for c, i in enumerate(range(4)): - self.assertEqual(a[1, c], i) - - # setting with torch tensor - a = ht.zeros((4, 5), split=0) - a[1, 0:4] = torch.arange(4, device=self.device.torch_device) - # if a.comm.size == 2: - for c, i in enumerate(range(4)): - self.assertEqual(a[1, c], i) - - ################################################### - a = ht.zeros((13, 5), split=1) - # # set value on one node - a[10] = 1 - self.assertEqual(a[10].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10].lshape, (3,)) - if a.comm.rank == 1: - self.assertEqual(a[10].lshape, (2,)) - - a = ht.zeros((13, 5), split=1) - # # set value on one node - a[10, 0] = 1 - self.assertEqual(a[10, 0], 1) - self.assertEqual(a[10, 0].dtype, ht.float32) - - # slice in 1st dim only on 1 node - a = ht.zeros((13, 5), split=1) - a[1:4] = 1 - self.assertTrue((a[1:4] == 1).all()) - self.assertEqual(a[1:4].gshape, (3, 5)) - self.assertEqual(a[1:4].split, 1) - self.assertEqual(a[1:4].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4].lshape, (3, 3)) - if a.comm.rank == 1: - self.assertEqual(a[1:4].lshape, (3, 2)) - - # slice in 1st dim only on 1 node w/ singular second dim - a = ht.zeros((13, 5), split=1) - a[1:4, 1] = 1 - self.assertTrue((a[1:4, 1] == 1).all()) - self.assertEqual(a[1:4, 1].gshape, (3,)) - self.assertEqual(a[1:4, 1].split, None) - self.assertEqual(a[1:4, 1].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4, 1].lshape, (3,)) - if a.comm.rank == 1: - self.assertEqual(a[1:4, 1].lshape, (3,)) - - # slice in 2st dim across both nodes (2 node case) w/ singular fist dim - a = ht.zeros((13, 5), split=1) - a[11, 1:5] = 1 - self.assertTrue((a[11, 1:5] == 1).all()) - self.assertEqual(a[11, 1:5].gshape, (4,)) - self.assertEqual(a[11, 1:5].split, 0) - self.assertEqual(a[11, 1:5].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[11, 1:5].lshape, (2,)) - if a.comm.rank == 0: - self.assertEqual(a[11, 1:5].lshape, (2,)) - - # slice in 1st dim across 1 node (2nd) w/ singular second dim - a = ht.zeros((13, 5), split=1) - a[8:12, 1] = 1 - self.assertTrue((a[8:12, 1] == 1).all()) - self.assertEqual(a[8:12, 1].gshape, (4,)) - self.assertEqual(a[8:12, 1].split, None) - self.assertEqual(a[8:12, 1].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[8:12, 1].lshape, (4,)) - if a.comm.rank == 1: - self.assertEqual(a[8:12, 1].lshape, (4,)) - - # slice in both directions - a = ht.zeros((13, 5), split=1) - a[3:13, 2::2] = 1 - self.assertTrue((a[3:13, 2:5:2] == 1).all()) - self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) - self.assertEqual(a[3:13, 2:5:2].split, 1) - self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) - if a.comm.rank == 0: - self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) - - a = ht.zeros((13, 5), split=1) - a[..., 2::2] = 1 - self.assertTrue((a[:, 2:5:2] == 1).all()) - self.assertEqual(a[..., 2:5:2].gshape, (13, 2)) - self.assertEqual(a[..., 2:5:2].split, 1) - self.assertEqual(a[..., 2:5:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[..., 2:5:2].lshape, (13, 1)) - if a.comm.rank == 0: - self.assertEqual(a[:, 2:5:2].lshape, (13, 1)) - - # setting with heat tensor - a = ht.zeros((4, 5), split=1) - a[1, 0:4] = ht.arange(4) - for c, i in enumerate(range(4)): - b = a[1, c] - if b.larray.numel() > 0: - self.assertEqual(b.item(), i) - - # setting with torch tensor - a = ht.zeros((4, 5), split=1) - a[1, 0:4] = torch.arange(4, device=self.device.torch_device) - for c, i in enumerate(range(4)): - self.assertEqual(a[1, c], i) - - #################################################### - a = ht.zeros((13, 5, 7), split=2) - # # set value on one node - a[10, :, :] = 1 - self.assertEqual(a[10, :, :].dtype, ht.float32) - self.assertEqual(a[10, :, :].gshape, (5, 7)) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10, :, :].lshape, (5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[10, :, :].lshape, (5, 3)) - - a = ht.zeros((13, 5, 7), split=2) - # # set value on one node - a[10, ...] = 1 - self.assertEqual(a[10, ...].dtype, ht.float32) - self.assertEqual(a[10, ...].gshape, (5, 7)) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10, ...].lshape, (5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[10, ...].lshape, (5, 3)) - - a = ht.zeros((13, 5, 8), split=2) - # # set value on one node - a[10, 0, 0] = 1 - self.assertEqual(a[10, 0, 0], 1) - self.assertEqual(a[10, 0, 0].dtype, ht.float32) - - # # slice in 1st dim only on 1 node - a = ht.zeros((13, 5, 7), split=2) - a[1:4] = 1 - self.assertTrue((a[1:4] == 1).all()) - self.assertEqual(a[1:4].gshape, (3, 5, 7)) - self.assertEqual(a[1:4].split, 2) - self.assertEqual(a[1:4].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4].lshape, (3, 5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[1:4].lshape, (3, 5, 3)) - - # slice in 1st dim only on 1 node w/ singular second dim - a = ht.zeros((13, 5, 7), split=2) - a[1:4, 1, :] = 1 - self.assertTrue((a[1:4, 1, :] == 1).all()) - self.assertEqual(a[1:4, 1, :].gshape, (3, 7)) - if a.comm.size == 2: - self.assertEqual(a[1:4, 1, :].split, 1) - self.assertEqual(a[1:4, 1, :].dtype, ht.float32) - if a.comm.rank == 0: - self.assertEqual(a[1:4, 1, :].lshape, (3, 4)) - if a.comm.rank == 1: - self.assertEqual(a[1:4, 1, :].lshape, (3, 3)) - - # slice in both directions - a = ht.zeros((13, 5, 7), split=2) - a[3:13, 2:5:2, 1:7:3] = 1 - self.assertTrue((a[3:13, 2:5:2, 1:7:3] == 1).all()) - self.assertEqual(a[3:13, 2:5:2, 1:7:3].split, 2) - self.assertEqual(a[3:13, 2:5:2, 1:7:3].dtype, ht.float32) - self.assertEqual(a[3:13, 2:5:2, 1:7:3].gshape, (10, 2, 2)) - if a.comm.size == 2: - out = ht.ones((4, 5, 5), split=1) - self.assertEqual(out[0].gshape, (5, 5)) - if a.comm.rank == 1: - self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) - self.assertEqual(out[0].lshape, (2, 5)) - if a.comm.rank == 0: - self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) - self.assertEqual(out[0].lshape, (3, 5)) - - a = ht.ones((4, 5), split=0).tril() - a[0] = [6, 6, 6, 6, 6] - self.assertTrue((a[0] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = (6, 6, 6, 6, 6) - self.assertTrue((a[0] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = np.array([6, 6, 6, 6, 6]) - self.assertTrue((a[0] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = ht.array([6, 6, 6, 6, 6]) - self.assertTrue((a[ht.array((0,))] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = ht.array([6, 6, 6, 6, 6]) - self.assertTrue((a[ht.array((0,))] == 6).all()) - - # ======================= indexing with bools ================================= - split = None - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = np_arr < 0.5 - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - # key -> tuple(ht.bool, int) - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - ht_key = ht.array(np_key, split=split) - arr[ht_key, 4] = 10.0 - np_arr[np_key, 4] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key, 4] == 10.0)) - - # key -> tuple(torch.bool, int) - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - t_key = torch.tensor(np_key, device=arr.larray.device) - arr[t_key, 4] = 10.0 - np_arr[np_key, 4] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[t_key, 4] == 10.0)) - - # key -> torch.bool - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - t_key = torch.tensor(np_key, device=arr.larray.device) - arr[t_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[t_key] == 10.0)) - - split = 1 - arr = ht.random.random((20, 20, 10)).resplit(split) - np_arr = arr.numpy() - np_key = np_arr < 0.5 - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - split = 2 - arr = ht.random.random((15, 20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = np_arr < 0.5 - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - with self.assertRaises(ValueError): - a[..., ...] - with self.assertRaises(ValueError): - a[..., ...] = 1 - if a.comm.size > 1: - with self.assertRaises(ValueError): - x = ht.ones((10, 10), split=0) - setting = ht.zeros((8, 8), split=1) - x[1:-1, 1:-1] = setting - - for split in [None, 0, 1, 2]: - for new_dim in [0, 1, 2]: - for add in [np.newaxis, None]: - arr = ht.ones((4, 3, 2), split=split, dtype=ht.int32) - check = torch.ones((4, 3, 2), dtype=torch.int32) - idx = [slice(None), slice(None), slice(None)] - idx[new_dim] = add - idx = tuple(idx) - arr = arr[idx] - check = check[idx] - self.assertTrue(arr.shape == check.shape) - self.assertTrue(arr.lshape[new_dim] == 1) + # def test_setitem_getitem(self): + # # tests for bug #825 + # a = ht.ones((102, 102), split=0) + # setting = ht.zeros((100, 100), split=0) + # a[1:-1, 1:-1] = setting + # self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + + # a = ht.ones((102, 102), split=1) + # setting = ht.zeros((30, 100), split=1) + # a[-30:, 1:-1] = setting + # self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) + + # a = ht.ones((102, 102), split=1) + # setting = ht.zeros((100, 100), split=1) + # a[1:-1, 1:-1] = setting + # self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + + # a = ht.ones((102, 102), split=1) + # setting = ht.zeros((100, 20), split=1) + # a[1:-1, :20] = setting + # self.assertTrue(ht.all(a[1:-1, :20] == 0)) + + # # tests for bug 730: + # a = ht.ones((10, 25, 30), split=1) + # if a.comm.size > 1: + # self.assertEqual(a[0].split, 0) + # self.assertEqual(a[:, 0, :].split, None) + # self.assertEqual(a[:, :, 0].split, 1) + + # # set and get single value + # a = ht.zeros((13, 5), split=0) + # # set value on one node + # a[10, np.array(0)] = 1 + # self.assertEqual(a[10, 0], 1) + # self.assertEqual(a[10, 0].dtype, ht.float32) + + # a = ht.zeros((13, 5), split=0) + # a[10] = 1 + # b = a[torch.tensor(10)] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.dtype, ht.float32) + # self.assertEqual(b.gshape, (5,)) + + # a = ht.zeros((13, 5), split=0) + # a[-1] = 1 + # b = a[-1] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.dtype, ht.float32) + # self.assertEqual(b.gshape, (5,)) + + # # slice in 1st dim only on 1 node + # a = ht.zeros((13, 5), split=0) + # a[1:4] = 1 + # self.assertTrue((a[1:4] == 1).all()) + # self.assertEqual(a[1:4].gshape, (3, 5)) + # self.assertEqual(a[1:4].split, 0) + # self.assertEqual(a[1:4].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4].lshape, (3, 5)) + # else: + # self.assertEqual(a[1:4].lshape, (0, 5)) + + # a = ht.zeros((13, 5), split=0) + # a[1:2] = 1 + # self.assertTrue((a[1:2] == 1).all()) + # self.assertEqual(a[1:2].gshape, (1, 5)) + # self.assertEqual(a[1:2].split, 0) + # self.assertEqual(a[1:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:2].lshape, (1, 5)) + # else: + # self.assertEqual(a[1:2].lshape, (0, 5)) + + # # slice in 1st dim only on 1 node w/ singular second dim + # a = ht.zeros((13, 5), split=0) + # a[1:4, 1] = 1 + # b = a[1:4, np.int64(1)] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.gshape, (3,)) + # self.assertEqual(b.split, 0) + # self.assertEqual(b.dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(b.lshape, (3,)) + # else: + # self.assertEqual(b.lshape, (0,)) + + # # slice in 1st dim across both nodes (2 node case) w/ singular second dim + # a = ht.zeros((13, 5), split=0) + # a[1:11, 1] = 1 + # self.assertTrue((a[1:11, 1] == 1).all()) + # self.assertEqual(a[1:11, 1].gshape, (10,)) + # self.assertEqual(a[1:11, torch.tensor(1)].split, 0) + # self.assertEqual(a[1:11, 1].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[1:11, 1].lshape, (4,)) + # if a.comm.rank == 0: + # self.assertEqual(a[1:11, 1].lshape, (6,)) + + # # slice in 1st dim across 1 node (2nd) w/ singular second dim + # c = ht.zeros((13, 5), split=0) + # c[8:12, ht.array(1)] = 1 + # b = c[8:12, np.int64(1)] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.gshape, (4,)) + # self.assertEqual(b.split, 0) + # self.assertEqual(b.dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(b.lshape, (4,)) + # if a.comm.rank == 0: + # self.assertEqual(b.lshape, (0,)) + + # # slice in both directions + # a = ht.zeros((13, 5), split=0) + # a[3:13, 2:5:2] = 1 + # self.assertTrue((a[3:13, 2:5:2] == 1).all()) + # self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) + # self.assertEqual(a[3:13, 2:5:2].split, 0) + # self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[3:13, 2:5:2].lshape, (6, 2)) + # if a.comm.rank == 0: + # self.assertEqual(a[3:13, 2:5:2].lshape, (4, 2)) + + # # setting with heat tensor + # a = ht.zeros((4, 5), split=0) + # a[1, 0:4] = ht.arange(4) + # # if a.comm.size == 2: + # for c, i in enumerate(range(4)): + # self.assertEqual(a[1, c], i) + + # # setting with torch tensor + # a = ht.zeros((4, 5), split=0) + # a[1, 0:4] = torch.arange(4, device=self.device.torch_device) + # # if a.comm.size == 2: + # for c, i in enumerate(range(4)): + # self.assertEqual(a[1, c], i) + + # ################################################### + # a = ht.zeros((13, 5), split=1) + # # # set value on one node + # a[10] = 1 + # self.assertEqual(a[10].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[10].lshape, (3,)) + # if a.comm.rank == 1: + # self.assertEqual(a[10].lshape, (2,)) + + # a = ht.zeros((13, 5), split=1) + # # # set value on one node + # a[10, 0] = 1 + # self.assertEqual(a[10, 0], 1) + # self.assertEqual(a[10, 0].dtype, ht.float32) + + # # slice in 1st dim only on 1 node + # a = ht.zeros((13, 5), split=1) + # a[1:4] = 1 + # self.assertTrue((a[1:4] == 1).all()) + # self.assertEqual(a[1:4].gshape, (3, 5)) + # self.assertEqual(a[1:4].split, 1) + # self.assertEqual(a[1:4].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4].lshape, (3, 3)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4].lshape, (3, 2)) + + # # slice in 1st dim only on 1 node w/ singular second dim + # a = ht.zeros((13, 5), split=1) + # a[1:4, 1] = 1 + # self.assertTrue((a[1:4, 1] == 1).all()) + # self.assertEqual(a[1:4, 1].gshape, (3,)) + # self.assertEqual(a[1:4, 1].split, None) + # self.assertEqual(a[1:4, 1].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4, 1].lshape, (3,)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4, 1].lshape, (3,)) + + # # slice in 2st dim across both nodes (2 node case) w/ singular fist dim + # a = ht.zeros((13, 5), split=1) + # a[11, 1:5] = 1 + # self.assertTrue((a[11, 1:5] == 1).all()) + # self.assertEqual(a[11, 1:5].gshape, (4,)) + # self.assertEqual(a[11, 1:5].split, 0) + # self.assertEqual(a[11, 1:5].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[11, 1:5].lshape, (2,)) + # if a.comm.rank == 0: + # self.assertEqual(a[11, 1:5].lshape, (2,)) + + # # slice in 1st dim across 1 node (2nd) w/ singular second dim + # a = ht.zeros((13, 5), split=1) + # a[8:12, 1] = 1 + # self.assertTrue((a[8:12, 1] == 1).all()) + # self.assertEqual(a[8:12, 1].gshape, (4,)) + # self.assertEqual(a[8:12, 1].split, None) + # self.assertEqual(a[8:12, 1].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[8:12, 1].lshape, (4,)) + # if a.comm.rank == 1: + # self.assertEqual(a[8:12, 1].lshape, (4,)) + + # # slice in both directions + # a = ht.zeros((13, 5), split=1) + # a[3:13, 2::2] = 1 + # self.assertTrue((a[3:13, 2:5:2] == 1).all()) + # self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) + # self.assertEqual(a[3:13, 2:5:2].split, 1) + # self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) + # if a.comm.rank == 0: + # self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) + + # a = ht.zeros((13, 5), split=1) + # a[..., 2::2] = 1 + # self.assertTrue((a[:, 2:5:2] == 1).all()) + # self.assertEqual(a[..., 2:5:2].gshape, (13, 2)) + # self.assertEqual(a[..., 2:5:2].split, 1) + # self.assertEqual(a[..., 2:5:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[..., 2:5:2].lshape, (13, 1)) + # if a.comm.rank == 0: + # self.assertEqual(a[:, 2:5:2].lshape, (13, 1)) + + # # setting with heat tensor + # a = ht.zeros((4, 5), split=1) + # a[1, 0:4] = ht.arange(4) + # for c, i in enumerate(range(4)): + # b = a[1, c] + # if b.larray.numel() > 0: + # self.assertEqual(b.item(), i) + + # # setting with torch tensor + # a = ht.zeros((4, 5), split=1) + # a[1, 0:4] = torch.arange(4, device=self.device.torch_device) + # for c, i in enumerate(range(4)): + # self.assertEqual(a[1, c], i) + + # #################################################### + # a = ht.zeros((13, 5, 7), split=2) + # # # set value on one node + # a[10, :, :] = 1 + # self.assertEqual(a[10, :, :].dtype, ht.float32) + # self.assertEqual(a[10, :, :].gshape, (5, 7)) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[10, :, :].lshape, (5, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[10, :, :].lshape, (5, 3)) + + # a = ht.zeros((13, 5, 7), split=2) + # # # set value on one node + # a[10, ...] = 1 + # self.assertEqual(a[10, ...].dtype, ht.float32) + # self.assertEqual(a[10, ...].gshape, (5, 7)) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[10, ...].lshape, (5, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[10, ...].lshape, (5, 3)) + + # a = ht.zeros((13, 5, 8), split=2) + # # # set value on one node + # a[10, 0, 0] = 1 + # self.assertEqual(a[10, 0, 0], 1) + # self.assertEqual(a[10, 0, 0].dtype, ht.float32) + + # # # slice in 1st dim only on 1 node + # a = ht.zeros((13, 5, 7), split=2) + # a[1:4] = 1 + # self.assertTrue((a[1:4] == 1).all()) + # self.assertEqual(a[1:4].gshape, (3, 5, 7)) + # self.assertEqual(a[1:4].split, 2) + # self.assertEqual(a[1:4].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4].lshape, (3, 5, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4].lshape, (3, 5, 3)) + + # # slice in 1st dim only on 1 node w/ singular second dim + # a = ht.zeros((13, 5, 7), split=2) + # a[1:4, 1, :] = 1 + # self.assertTrue((a[1:4, 1, :] == 1).all()) + # self.assertEqual(a[1:4, 1, :].gshape, (3, 7)) + # if a.comm.size == 2: + # self.assertEqual(a[1:4, 1, :].split, 1) + # self.assertEqual(a[1:4, 1, :].dtype, ht.float32) + # if a.comm.rank == 0: + # self.assertEqual(a[1:4, 1, :].lshape, (3, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4, 1, :].lshape, (3, 3)) + + # # slice in both directions + # a = ht.zeros((13, 5, 7), split=2) + # a[3:13, 2:5:2, 1:7:3] = 1 + # self.assertTrue((a[3:13, 2:5:2, 1:7:3] == 1).all()) + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].split, 2) + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].dtype, ht.float32) + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].gshape, (10, 2, 2)) + # if a.comm.size == 2: + # out = ht.ones((4, 5, 5), split=1) + # self.assertEqual(out[0].gshape, (5, 5)) + # if a.comm.rank == 1: + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) + # self.assertEqual(out[0].lshape, (2, 5)) + # if a.comm.rank == 0: + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) + # self.assertEqual(out[0].lshape, (3, 5)) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = [6, 6, 6, 6, 6] + # self.assertTrue((a[0] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = (6, 6, 6, 6, 6) + # self.assertTrue((a[0] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = np.array([6, 6, 6, 6, 6]) + # self.assertTrue((a[0] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = ht.array([6, 6, 6, 6, 6]) + # self.assertTrue((a[ht.array((0,))] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = ht.array([6, 6, 6, 6, 6]) + # self.assertTrue((a[ht.array((0,))] == 6).all()) + + # # ======================= indexing with bools ================================= + # split = None + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = np_arr < 0.5 + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # # key -> tuple(ht.bool, int) + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # ht_key = ht.array(np_key, split=split) + # arr[ht_key, 4] = 10.0 + # np_arr[np_key, 4] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key, 4] == 10.0)) + + # # key -> tuple(torch.bool, int) + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # t_key = torch.tensor(np_key, device=arr.larray.device) + # arr[t_key, 4] = 10.0 + # np_arr[np_key, 4] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[t_key, 4] == 10.0)) + + # # key -> torch.bool + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # t_key = torch.tensor(np_key, device=arr.larray.device) + # arr[t_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[t_key] == 10.0)) + + # split = 1 + # arr = ht.random.random((20, 20, 10)).resplit(split) + # np_arr = arr.numpy() + # np_key = np_arr < 0.5 + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # split = 2 + # arr = ht.random.random((15, 20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = np_arr < 0.5 + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # with self.assertRaises(ValueError): + # a[..., ...] + # with self.assertRaises(ValueError): + # a[..., ...] = 1 + # if a.comm.size > 1: + # with self.assertRaises(ValueError): + # x = ht.ones((10, 10), split=0) + # setting = ht.zeros((8, 8), split=1) + # x[1:-1, 1:-1] = setting + + # for split in [None, 0, 1, 2]: + # for new_dim in [0, 1, 2]: + # for add in [np.newaxis, None]: + # arr = ht.ones((4, 3, 2), split=split, dtype=ht.int32) + # check = torch.ones((4, 3, 2), dtype=torch.int32) + # idx = [slice(None), slice(None), slice(None)] + # idx[new_dim] = add + # idx = tuple(idx) + # arr = arr[idx] + # check = check[idx] + # self.assertTrue(arr.shape == check.shape) + # self.assertTrue(arr.lshape[new_dim] == 1) def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None) From 4360bd1d58eb70c95982d175f49a032fbd75c5d8 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 30 Aug 2022 11:29:34 +0200 Subject: [PATCH 019/132] test for 0-dim DNDarray key --- heat/core/tests/test_dndarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 5dd5ced775..67f1f4425e 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -546,8 +546,9 @@ def test_getitem(self): self.assertTrue(indexed_split0.split is None) # 3D, distributed split, != 0 x_split2 = ht.array(x, dtype=ht.int64, split=2) + key = ht.array(2) indexed_split2 = x_split2[key] - self.assertTrue((indexed_split2.numpy() == x.numpy()[key]).all()) + self.assertTrue((indexed_split2.numpy() == x.numpy()[key.item()]).all()) self.assertTrue(indexed_split2.dtype == ht.int64) self.assertTrue(indexed_split2.split == 1) From 231c1dec0739eace6ca12b736f670555a7fa85b6 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 31 Aug 2022 09:31:27 +0200 Subject: [PATCH 020/132] Expand __process_key() to deal with distributed boolean mask --- heat/core/dndarray.py | 53 ++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 1c5f3e737f..b73f9301ea 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -668,6 +668,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_bookkeeping = [None] * arr.ndim if arr.is_distributed(): split_bookkeeping[arr.split] = "split" + counts, displs = arr.counts_displs() advanced_indexing = False arr_is_copy = False @@ -681,17 +682,39 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] except TypeError: # np.array or DNDarray key key = key.nonzero() - else: - # advanced indexing on first dimension: first dim will expand to shape of key - advanced_indexing = True - output_shape = tuple(list(key.shape) + output_shape[1:]) - # adjust split axis accordingly - split_bookkeeping = [None] * (len(key.shape) - 1) + split_bookkeeping[1:] - new_split = ( - split_bookkeeping.index("split") if "split" in split_bookkeeping else None - ) + key = list(key).copy() + # if key is sequence of DNDarrays, extract local tensors + try: + for i, k in enumerate(key): + key[i] = k.larray + except AttributeError: + pass + if arr.is_distributed(): + # return locally relevant key only + key[arr.split] -= displs[arr.comm.rank] + cond1 = key[arr.split] >= 0 + cond2 = key[arr.split] < counts[arr.comm.rank] + for i, k in enumerate(key): + key[i] = k[cond1 & cond2] + # calculate output_shape + total_nonzero = torch.tensor(key[arr.comm.split].shape[0]) + arr.comm.Allreduce(MPI.IN_PLACE, total_nonzero, MPI.SUM) + output_shape = (total_nonzero,) + new_split = 0 + else: + output_shape = (key[0].shape[0],) + new_split = None if arr.split is None else 0 + key = tuple(key) return arr, key, output_shape, new_split, advanced_indexing + # advanced indexing on first dimension: first dim will expand to shape of key + advanced_indexing = True + output_shape = tuple(list(key.shape) + output_shape[1:]) + # adjust split axis accordingly + split_bookkeeping = [None] * (len(key.shape) - 1) + split_bookkeeping[1:] + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + return arr, key, output_shape, new_split, advanced_indexing + if isinstance(key, (tuple, list)): key = list(key) @@ -861,7 +884,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases print("DEBUGGING: RAW KEY = ", key) - # early out: key is a scalar + + # Single element indexing scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: output_shape = self.gshape[1:] @@ -894,10 +918,10 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar displs = torch.cat((torch.tensor(displs), torch.tensor(key).reshape(-1)), dim=0) _, sorted_indices = displs.unique(sorted=True, return_inverse=True) root = sorted_indices[-1] - 1 - # correct key for relevant displacement - key -= displs[root] # allocate buffer on all processes if self.comm.rank == root: + # correct key for rank-specific displacement + key -= displs[root] indexed_arr = self.larray[key] else: indexed_arr = torch.zeros( @@ -923,6 +947,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ): # latter doesnt work with torch for 0-dim tensors return self + # Many-elements indexing: incl. slicing and striding, ordered advanced indexing + # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) print("DEBUGGING: processed key = ", key) @@ -946,7 +972,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # # data are distributed and split dimension is affected by indexing + # data are distributed and split dimension is affected by indexing + # _, offsets = self.counts_displs() # split = self.split # # slice along the split axis From f19f90247e437cf2db432256acb56f4b016a5153 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 31 Aug 2022 09:32:36 +0200 Subject: [PATCH 021/132] Expand test_getitem for distributed single-element indexing, non-distr boolean mask --- heat/core/tests/test_dndarray.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 67f1f4425e..f3933b8d2d 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -531,6 +531,15 @@ def test_getitem(self): self.assertTrue(x[-2].item() == 8.0) self.assertTrue(x[2].dtype == ht.float64) self.assertTrue(x[2].split is None) + # 2D, local + x = ht.arange(10).reshape(2, 5) + self.assertTrue((x[0] == ht.arange(5)).all().item()) + self.assertTrue(x[0].dtype == ht.int32) + # 2D, distributed + x_split0 = ht.array(x, split=0) + self.assertTrue((x_split0[0] == ht.arange(5, split=None)).all().item()) + x_split1 = ht.array(x, split=1) + self.assertTrue((x_split1[-2] == ht.arange(5, split=0)).all().item()) # 3D, local x = ht.arange(27).reshape(3, 3, 3) key = -2 @@ -552,6 +561,13 @@ def test_getitem(self): self.assertTrue(indexed_split2.dtype == ht.int64) self.assertTrue(indexed_split2.split == 1) + # boolean mask, local + arr = ht.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6) + mask = np.random.randint(0, 2, (3, 4, 5, 6), dtype=bool) + self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) + + # boolean mask, distributed + def test_int_cast(self): # simple scalar tensor a = ht.ones(1) From 7ed435f724c6b222834d7315cbbf00c5a89c41ee Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 31 Aug 2022 09:41:19 +0200 Subject: [PATCH 022/132] Add check for matching boolean index / indexed array shapes --- heat/core/dndarray.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index b73f9301ea..df08816e38 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -675,7 +675,14 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool, np.uint8): - # boolean indexing: transform to sequence of indexing (1-D) arrays + # boolean indexing: shape must match arr.shape + if not tuple(key.shape) == arr.shape: + raise IndexError( + "Boolean index of shape {} does not match indexed array of shape {}".format( + tuple(key.shape), arr.shape + ) + ) + # transform key to sequence of indexing (1-D) arrays try: # torch.Tensor key key = key.nonzero(as_tuple=True) From 0da7f5663543d08f7a6165b7fc68fb9546b43c70 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 3 Sep 2022 08:01:35 +0200 Subject: [PATCH 023/132] Only sort result if input.split != 0 --- heat/core/indexing.py | 52 ++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 9946049185..ece379f0fd 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -63,40 +63,46 @@ def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: lcl_nonzero = torch.nonzero(input=local_x, as_tuple=True) # bookkeeping for final DNDarray construct output_shape = (lcl_nonzero[0].shape,) - output_split = None + output_split = None if x.split is None else 0 + output_balanced = True else: lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) nonzero_size = torch.tensor( lcl_nonzero.shape[0], dtype=torch.int64, device=lcl_nonzero.device ) - # construct global DNDarray of nz indices: - # global shape and split + # global nonzero_size x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) - output_shape = (nonzero_size.item(), x.ndim) - output_split = 0 # correct indices along split axis _, displs = x.counts_displs() lcl_nonzero[:, x.split] += displs[x.comm.rank] - global_nonzero = DNDarray( - lcl_nonzero, - gshape=output_shape, - dtype=types.int64, - split=output_split, - device=x.device, - comm=x.comm, - balanced=False, - ) - # stabilize distributed result: vectorize sorting of nz indices along axis 0 - global_nonzero.balance_() - global_nonzero = manipulations.unique(global_nonzero, axis=0) - # return indices as tuple of columns - lcl_nonzero = global_nonzero.larray.split(1, dim=1) - # bookkeeping for final DNDarray construct - output_shape = (global_nonzero.shape[0],) - output_split = 0 + + if x.split != 0: + # construct global 2D DNDarray of nz indices: + shape_2d = (nonzero_size.item(), x.ndim) + global_nonzero = DNDarray( + lcl_nonzero, + gshape=shape_2d, + dtype=types.int64, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # stabilize distributed result: vectorized sorting of nz indices along axis 0 + global_nonzero.balance_() + global_nonzero = manipulations.unique(global_nonzero, axis=0) + # return indices as tuple of columns + lcl_nonzero = global_nonzero.larray.split(1, dim=1) + output_balanced = True + else: + # return indices as tuple of columns + lcl_nonzero = lcl_nonzero.split(1, dim=1) + output_balanced = False # return global_nonzero as tuple of DNDarrays global_nonzero = list(lcl_nonzero) + output_shape = (nonzero_size.item(),) + output_split = 0 for i, nz_tensor in enumerate(global_nonzero): if nz_tensor.ndim > 1: # extra dimension in distributed case from usage of torch.split() @@ -108,7 +114,7 @@ def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: split=output_split, device=x.device, comm=x.comm, - balanced=True, + balanced=output_balanced, ) global_nonzero[i] = nz_array global_nonzero = tuple(global_nonzero) From e55c7f98da178974d4706e1e0d5a1edcf9f617f5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 3 Sep 2022 08:03:53 +0200 Subject: [PATCH 024/132] BROKEN: distributed boolean indexing to return stable result for all splits --- heat/core/dndarray.py | 207 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 168 insertions(+), 39 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index df08816e38..824941b351 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -672,6 +672,8 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] advanced_indexing = False arr_is_copy = False + split_key_is_sorted = True + out_is_balanced = False if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool, np.uint8): @@ -682,45 +684,167 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] tuple(key.shape), arr.shape ) ) - # transform key to sequence of indexing (1-D) arrays - try: - # torch.Tensor key - key = key.nonzero(as_tuple=True) - except TypeError: - # np.array or DNDarray key - key = key.nonzero() - key = list(key).copy() - # if key is sequence of DNDarrays, extract local tensors try: + # key is DNDarray or ndarray + key = key.copy() + except AttributeError: + # key is torch tensor + key = key.clone() + if not arr.is_distributed(): + try: + # key is DNDarray, extract torch tensor + key = key.larray + except AttributeError: + pass + try: + # key is torch tensor + key = key.nonzero(as_tuple=True) + except TypeError: + # key is np.ndarray + key = key.nonzero() + output_shape = tuple(key[0].shape) + new_split = None if arr.split is None else 0 + out_is_balanced = True + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced + + # arr is distributed + if not isinstance(key, DNDarray) or not key.is_distributed(): + key = factories.array(key, split=arr.split, device=arr.device) + else: + if key.split != arr.split: + raise IndexError( + "Boolean index does not match distribution scheme of indexed array. index.split is {}, array.split is {}".format( + key.split, arr.split + ) + ) + if arr.split == 0: + # ensure arr and key are aligned + key.redistribute_(target_map=arr.lshape_map) + # transform key to sequence of indexing (1-D) arrays + key = list(key.nonzero()) + output_shape = key[0].shape + new_split = 0 + # all local indexing + out_is_balanced = False for i, k in enumerate(key): key[i] = k.larray - except AttributeError: - pass - if arr.is_distributed(): - # return locally relevant key only key[arr.split] -= displs[arr.comm.rank] - cond1 = key[arr.split] >= 0 - cond2 = key[arr.split] < counts[arr.comm.rank] - for i, k in enumerate(key): - key[i] = k[cond1 & cond2] - # calculate output_shape - total_nonzero = torch.tensor(key[arr.comm.split].shape[0]) - arr.comm.Allreduce(MPI.IN_PLACE, total_nonzero, MPI.SUM) - output_shape = (total_nonzero,) - new_split = 0 + key = tuple(key) else: - output_shape = (key[0].shape[0],) - new_split = None if arr.split is None else 0 - key = tuple(key) - return arr, key, output_shape, new_split, advanced_indexing + # key to distributed 2D matrix of nonzero indices + key = key.larray.nonzero(as_tuple=False) + # swap columns so that indices along split axis are in the first column + col_swap = list(range(key.shape[1])) + col_swap[0], col_swap[arr.split] = arr.split, 0 + key = key.index_select(1, torch.LongTensor(col_swap)) + # construct global key array + nz_size = torch.tensor(key.shape[0], device=key.device, dtype=key.dtype) + arr.comm.Allreduce(MPI.IN_PLACE, nz_size, MPI.SUM) + key_gshape = (nz_size.item(), arr.ndim) + key[:, 0] += displs[arr.comm.rank] + key = DNDarray( + key, + gshape=key_gshape, + dtype=canonical_heat_type(key.dtype), + split=0, + device=arr.device, + comm=arr.comm, + balanced=False, + ) + # vectorized sorting along axis 0 + key.balance_() + key = manipulations.unique(key, axis=0, return_inverse=False) + # redistribute key so that local nonzero indices match local array indices along split axis + first_local_item = key.larray[0, 0] + if first_local_item.item() in displs: + first_send_rank = displs.index(first_local_item.item()) + else: + _, sort_indices = torch.cat( + (torch.tensor(displs), torch.tensor(first_local_item).reshape(-1)), + dim=0, + ).unique(sorted=True, return_inverse=True) + first_send_rank = sort_indices[-1] - 1 + key_counts, _ = key.counts_displs() + sending_counts = torch.zeros( + (1, arr.comm.size), dtype=torch.int64, device=key.larray.device + ) + for i in range(first_send_rank, arr.comm.size): + cond1 = key.larray[:, 0] >= displs[i] + if i != arr.comm.size - 1: + cond2 = key.larray[:, 0] < displs[i + 1] + sending_counts[:, i] = key.larray[:, 0][cond1 & cond2].shape[0] + else: + sending_counts[:, i] = key.larray[:, 0][cond1].shape[0] + # if sending_counts[:,first_send_rank:].sum() == key_counts[arr.comm.rank]: + # # all local counts accounted for + # break + # dispatch sending counts information + sending_counts_buf = torch.zeros( + (arr.comm.size, arr.comm.size), + dtype=sending_counts.dtype, + device=sending_counts.device, + ) + arr.comm.Allgather(sending_counts, sending_counts_buf) + target_counts = sending_counts_buf.sum(dim=0) + target_displs = torch.cat( + ( + torch.tensor( + [0], dtype=target_counts.dtype, device=target_counts.device + ), + target_counts, + ), + dim=0, + ).cumsum(dim=0)[:-1] + target_key_lshape_map = key.lshape_map + target_key_lshape_map[:, 0] = target_displs + key.redistribute_(target_map=target_key_lshape_map) + # finally swap split axis column back into original position + key.larray = key.larray.index_select(1, torch.LongTensor(col_swap)) + # return local key as tuple of 1D tensors + key.larray[:, arr.split] -= displs[arr.comm.rank] + key = key.larray.split(1, dim=1) + output_shape = (nz_size.item(),) + new_split = 0 + split_key_is_sorted = True + out_is_balanced = False + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced # advanced indexing on first dimension: first dim will expand to shape of key - advanced_indexing = True output_shape = tuple(list(key.shape) + output_shape[1:]) # adjust split axis accordingly - split_bookkeeping = [None] * (len(key.shape) - 1) + split_bookkeeping[1:] - new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - return arr, key, output_shape, new_split, advanced_indexing + if arr.is_distributed(): + if arr.split != 0: + # split axis is not affected + split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] + new_split = ( + split_bookkeeping.index("split") if "split" in split_bookkeeping else None + ) + out_is_balanced = arr.balanced + else: + # split axis is affected + if key.ndim > 1: + try: + key_numel = key.numel() + except AttributeError: + key_numel = key.size + if key_numel == arr.shape[0]: + new_split = tuple(key.shape).index(arr.shape[0]) + else: + new_split = key.ndim - 1 + else: + new_split = 0 + # assess if key is sorted along split axis + try: + key_split = key[new_split].larray + sorted, _ = key_split.sort() + except AttributeError: + key_split = key[new_split] + sorted = key_split.sort() + split_key_is_sorted = torch.tensor( + (key_split == sorted).all(), dtype=torch.uint8 + ) + + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced if isinstance(key, (tuple, list)): key = list(key) @@ -892,7 +1016,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # Trivial cases print("DEBUGGING: RAW KEY = ", key) - # Single element indexing + # Single-element indexing scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: output_shape = self.gshape[1:] @@ -957,29 +1081,34 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # Many-elements indexing: incl. slicing and striding, ordered advanced indexing # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays - self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) + ( + self, + key, + output_shape, + output_split, + split_key_is_sorted, + out_is_balanced, + ) = self.__process_key(key) print("DEBUGGING: processed key = ", key) # TODO: test that key for not affected dims is always slice(None) # including match between self.split and key after self manipulation # data are not distributed or split dimension is not affected by indexing - if not self.is_distributed() or key[self.split] == slice(None): - try: - indexed_arr = self.larray[key.larray.long()] - except AttributeError: - # key is an ndarray - indexed_arr = self.larray[key] + # if not self.is_distributed() or key[self.split] == slice(None): + if split_key_is_sorted: + indexed_arr = self.larray[key] return DNDarray( indexed_arr, gshape=output_shape, dtype=self.dtype, split=output_split, device=self.device, - balanced=self.balanced, + balanced=out_is_balanced, comm=self.comm, ) # data are distributed and split dimension is affected by indexing + # __process_key() returns the local key already # _, offsets = self.counts_displs() # split = self.split From 75d931468f50d6c79eb621e1e29c6eb3d074e5ee Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 3 Sep 2022 08:04:35 +0200 Subject: [PATCH 025/132] Add tests for distributed boolean indexing --- heat/core/tests/test_dndarray.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index f3933b8d2d..e4f320994a 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -567,6 +567,10 @@ def test_getitem(self): self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) # boolean mask, distributed + arr_split0 = arr.resplit(axis=1) + mask_split0 = ht.array(mask, split=1) + print("DEBUGGING: mask_split0.dtype = ", mask_split0.dtype) + self.assertTrue((arr_split0[mask_split0].numpy() == arr.numpy()[mask]).all()) def test_int_cast(self): # simple scalar tensor From 15a8a28a646684f9194dc6e53b76b242c80c6c67 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 4 Sep 2022 07:55:12 +0200 Subject: [PATCH 026/132] BROKEN: Fixed key redistribution for input.split != 0. --- heat/core/dndarray.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 824941b351..c5a3eb084d 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -775,9 +775,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] sending_counts[:, i] = key.larray[:, 0][cond1 & cond2].shape[0] else: sending_counts[:, i] = key.larray[:, 0][cond1].shape[0] - # if sending_counts[:,first_send_rank:].sum() == key_counts[arr.comm.rank]: - # # all local counts accounted for - # break + if sending_counts[:, first_send_rank:].sum() == key_counts[arr.comm.rank]: + # all local counts accounted for + break # dispatch sending counts information sending_counts_buf = torch.zeros( (arr.comm.size, arr.comm.size), @@ -786,26 +786,23 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] ) arr.comm.Allgather(sending_counts, sending_counts_buf) target_counts = sending_counts_buf.sum(dim=0) - target_displs = torch.cat( - ( - torch.tensor( - [0], dtype=target_counts.dtype, device=target_counts.device - ), - target_counts, - ), - dim=0, - ).cumsum(dim=0)[:-1] target_key_lshape_map = key.lshape_map - target_key_lshape_map[:, 0] = target_displs + target_key_lshape_map[:, 0] = target_counts key.redistribute_(target_map=target_key_lshape_map) # finally swap split axis column back into original position key.larray = key.larray.index_select(1, torch.LongTensor(col_swap)) + # sort local key again after swapping columns + key.larray = key.larray.unique(dim=0, sorted=True, return_inverse=False) # return local key as tuple of 1D tensors key.larray[:, arr.split] -= displs[arr.comm.rank] - key = key.larray.split(1, dim=1) + key = list(key.larray.split(1, dim=1)) + for i, k in enumerate(key): + key[i] = k.squeeze(1) + key = tuple(key) output_shape = (nz_size.item(),) new_split = 0 - split_key_is_sorted = True + # key is local but not sorted in the new_split dimension, needs Alltoallv communication after local indexing + split_key_is_sorted = False out_is_balanced = False return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced @@ -1014,7 +1011,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases - print("DEBUGGING: RAW KEY = ", key) + # print("DEBUGGING: RAW KEY = ", key) # Single-element indexing scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 @@ -1107,6 +1104,10 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) + # TODO: boolean indexing with data.split != 0 + # __process_key() returns locally correct key + # after local indexing, Alltoallv for correct order of output + # data are distributed and split dimension is affected by indexing # __process_key() returns the local key already From 8db0511b678812e992285c4bb6883565a8fd6d3b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 4 Sep 2022 07:56:46 +0200 Subject: [PATCH 027/132] Expanded boolean indexing tests --- heat/core/tests/test_dndarray.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e4f320994a..60c60dd138 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -562,16 +562,20 @@ def test_getitem(self): self.assertTrue(indexed_split2.split == 1) # boolean mask, local - arr = ht.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6) - mask = np.random.randint(0, 2, (3, 4, 5, 6), dtype=bool) + arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + np.random.seed(42) + mask = np.random.randint(0, 2, arr.shape, dtype=bool) self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) # boolean mask, distributed - arr_split0 = arr.resplit(axis=1) - mask_split0 = ht.array(mask, split=1) - print("DEBUGGING: mask_split0.dtype = ", mask_split0.dtype) + arr_split0 = ht.array(arr, split=0) + mask_split0 = ht.array(mask, split=0) self.assertTrue((arr_split0[mask_split0].numpy() == arr.numpy()[mask]).all()) + arr_split1 = ht.array(arr, split=1) + mask_split1 = ht.array(mask, split=1) + self.assertTrue((arr_split1[mask_split1].numpy() == arr.numpy()[mask]).all()) + def test_int_cast(self): # simple scalar tensor a = ht.ones(1) From 291329e7ab86c89ef6e471f612c170181f0158e7 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 8 Sep 2022 10:34:53 +0200 Subject: [PATCH 028/132] Set up communication matrix for boolean indexing along non-zero split --- heat/core/dndarray.py | 70 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c5a3eb084d..d51ecba09b 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1104,6 +1104,76 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) + # key along new_split is not sorted + # apply local key then reorder global indexed array + indexed_arr = self.larray[key] + # prepare for Alltoallv: allocate buffer + non_ordered = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + balanced=out_is_balanced, + comm=self.comm, + ) + _, non_ordered_displs = non_ordered.counts_displs() + ordered = non_ordered.balance() + ordered_counts, _ = ordered.counts_displs() + # for every dimension of self, how many elements on what process + ndim_counts_on_proc = torch.zeros( + (self.ndim, 1), dtype=torch.int64, device=self.larray.device + ) + ndim_displs_on_proc = torch.zeros( + (self.ndim,), dtype=torch.int64, device=self.larray.device + ) + for i in range(0, self.ndim): + where_dim = torch.where(key[output_split] == i)[0] + ndim_counts_on_proc[i, :] = where_dim.shape[0] + ndim_displs_on_proc[i] = where_dim[0].item() + ndim_displs_on_proc += non_ordered_displs[self.comm.rank] + # share info to all processes + global_ndim_counts = torch.empty( + (self.ndim, self.comm.size), + dtype=ndim_counts_on_proc.dtype, + device=ndim_counts_on_proc.device, + ) + self.comm.Allgather(ndim_counts_on_proc, global_ndim_counts) + # construct communication matrix: what process sends how many elements to whom + comm_on_rank = torch.zeros( + (1, self.comm.size), dtype=torch.int64, device=global_ndim_counts.device + ) + counts_bookkeeping = global_ndim_counts.flatten() + ordered_counts = torch.tensor(ordered_counts, device=counts_bookkeeping.device) + _, indices = torch.cat( + (counts_bookkeeping.cumsum(0), ordered_counts.cumsum(0)), dim=0 + ).unique(sorted=True, return_inverse=True) + for i in range(-self.comm.size, 0): + send_r = self.comm.size + i + end = indices[i] + if send_r == 0: + start = 0 + comm_on_rank[:, send_r] = counts_bookkeeping[ + slice(start + self.comm.rank % self.comm.size, end, self.comm.size) + ].sum() + else: + start = indices[i - 1] + slice_start = start + (self.comm.rank - start) % self.comm.size + comm_on_rank[:, send_r] = counts_bookkeeping[ + slice(slice_start, end, self.comm.size) + ].sum() + leftover_counts = ordered_counts[send_r] - counts_bookkeeping[start:end].sum() + if leftover_counts > 0: + counts_bookkeeping[indices[i]] -= leftover_counts + if self.comm.rank == indices[i] % self.comm.size: + comm_on_rank[:, send_r] += leftover_counts + # share info + comm_matrix = torch.zeros( + (self.comm.size, self.comm.size), dtype=torch.int64, device=global_ndim_counts.device + ) + self.comm.Allgather(comm_on_rank, comm_matrix) + # example: comm_matrix[0, 1] returns the counts that rank 0 is about to send to rank 1 + # TODO: boolean indexing with data.split != 0 # __process_key() returns locally correct key # after local indexing, Alltoallv for correct order of output From 6d986dd7f0d80144710818619231d575a835f6fc Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 4 Nov 2022 05:14:11 +0100 Subject: [PATCH 029/132] Implement getitem for non-ordered key along split axis --- heat/core/dndarray.py | 213 ++++++++++++++----------------- heat/core/tests/test_dndarray.py | 6 +- 2 files changed, 104 insertions(+), 115 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index d51ecba09b..0ec9efa624 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -731,22 +731,18 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key[arr.split] -= displs[arr.comm.rank] key = tuple(key) else: - # key to distributed 2D matrix of nonzero indices key = key.larray.nonzero(as_tuple=False) - # swap columns so that indices along split axis are in the first column - col_swap = list(range(key.shape[1])) - col_swap[0], col_swap[arr.split] = arr.split, 0 - key = key.index_select(1, torch.LongTensor(col_swap)) # construct global key array nz_size = torch.tensor(key.shape[0], device=key.device, dtype=key.dtype) arr.comm.Allreduce(MPI.IN_PLACE, nz_size, MPI.SUM) key_gshape = (nz_size.item(), arr.ndim) - key[:, 0] += displs[arr.comm.rank] + key[:, arr.split] += displs[arr.comm.rank] + key_split = 0 key = DNDarray( key, gshape=key_gshape, dtype=canonical_heat_type(key.dtype), - split=0, + split=key_split, device=arr.device, comm=arr.comm, balanced=False, @@ -754,56 +750,16 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # vectorized sorting along axis 0 key.balance_() key = manipulations.unique(key, axis=0, return_inverse=False) - # redistribute key so that local nonzero indices match local array indices along split axis - first_local_item = key.larray[0, 0] - if first_local_item.item() in displs: - first_send_rank = displs.index(first_local_item.item()) - else: - _, sort_indices = torch.cat( - (torch.tensor(displs), torch.tensor(first_local_item).reshape(-1)), - dim=0, - ).unique(sorted=True, return_inverse=True) - first_send_rank = sort_indices[-1] - 1 - key_counts, _ = key.counts_displs() - sending_counts = torch.zeros( - (1, arr.comm.size), dtype=torch.int64, device=key.larray.device - ) - for i in range(first_send_rank, arr.comm.size): - cond1 = key.larray[:, 0] >= displs[i] - if i != arr.comm.size - 1: - cond2 = key.larray[:, 0] < displs[i + 1] - sending_counts[:, i] = key.larray[:, 0][cond1 & cond2].shape[0] - else: - sending_counts[:, i] = key.larray[:, 0][cond1].shape[0] - if sending_counts[:, first_send_rank:].sum() == key_counts[arr.comm.rank]: - # all local counts accounted for - break - # dispatch sending counts information - sending_counts_buf = torch.zeros( - (arr.comm.size, arr.comm.size), - dtype=sending_counts.dtype, - device=sending_counts.device, - ) - arr.comm.Allgather(sending_counts, sending_counts_buf) - target_counts = sending_counts_buf.sum(dim=0) - target_key_lshape_map = key.lshape_map - target_key_lshape_map[:, 0] = target_counts - key.redistribute_(target_map=target_key_lshape_map) - # finally swap split axis column back into original position - key.larray = key.larray.index_select(1, torch.LongTensor(col_swap)) - # sort local key again after swapping columns - key.larray = key.larray.unique(dim=0, sorted=True, return_inverse=False) - # return local key as tuple of 1D tensors - key.larray[:, arr.split] -= displs[arr.comm.rank] + # return tuple key key = list(key.larray.split(1, dim=1)) for i, k in enumerate(key): key[i] = k.squeeze(1) key = tuple(key) - output_shape = (nz_size.item(),) + + output_shape = (key[0].shape[0],) new_split = 0 - # key is local but not sorted in the new_split dimension, needs Alltoallv communication after local indexing split_key_is_sorted = False - out_is_balanced = False + out_is_balanced = True return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced # advanced indexing on first dimension: first dim will expand to shape of key @@ -1104,75 +1060,104 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # key along new_split is not sorted - # apply local key then reorder global indexed array - indexed_arr = self.larray[key] - # prepare for Alltoallv: allocate buffer - non_ordered = DNDarray( - indexed_arr, - gshape=output_shape, - dtype=self.dtype, - split=output_split, - device=self.device, - balanced=out_is_balanced, - comm=self.comm, + # key is sorted along dim 0 but not along self.split + # key is tuple of torch.Tensor + _, displs = self.counts_displs() + original_split = self.split + + # send and receive "request key" info on what data element to shup where + recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) + request_key_shape = (0, self.ndim) + outgoing_request_key = torch.empty( + tuple(request_key_shape), dtype=torch.int64, device=self.larray.device ) - _, non_ordered_displs = non_ordered.counts_displs() - ordered = non_ordered.balance() - ordered_counts, _ = ordered.counts_displs() - # for every dimension of self, how many elements on what process - ndim_counts_on_proc = torch.zeros( - (self.ndim, 1), dtype=torch.int64, device=self.larray.device + outgoing_request_key_counts = torch.zeros( + (self.comm.size,), dtype=torch.int64, device=self.larray.device ) - ndim_displs_on_proc = torch.zeros( - (self.ndim,), dtype=torch.int64, device=self.larray.device + for i in range(self.comm.size): + cond1 = key[original_split] >= displs[i] + if i != self.comm.size - 1: + cond2 = key[original_split] < displs[i + 1] + else: + # cond2 is always true + cond2 = torch.ones( + (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device + ) + selection = list(k[cond1 & cond2] for k in key) + recv_counts[i, :] = selection[0].shape[0] + selection = torch.stack(selection, dim=1) + outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) + + # share recv_counts among all processes + comm_matrix = torch.empty( + (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device ) - for i in range(0, self.ndim): - where_dim = torch.where(key[output_split] == i)[0] - ndim_counts_on_proc[i, :] = where_dim.shape[0] - ndim_displs_on_proc[i] = where_dim[0].item() - ndim_displs_on_proc += non_ordered_displs[self.comm.rank] - # share info to all processes - global_ndim_counts = torch.empty( - (self.ndim, self.comm.size), - dtype=ndim_counts_on_proc.dtype, - device=ndim_counts_on_proc.device, + self.comm.Allgather(recv_counts, comm_matrix) + + outgoing_request_key_counts = comm_matrix[self.comm.rank] + outgoing_request_key_displs = torch.cat( + ( + torch.zeros( + (1,), + dtype=outgoing_request_key_counts.dtype, + device=outgoing_request_key_counts.device, + ), + outgoing_request_key_counts, + ), + dim=0, + ).cumsum(dim=0)[:-1] + incoming_request_key_counts = comm_matrix[:, self.comm.rank] + incoming_request_key_displs = torch.cat( + ( + torch.zeros( + (1,), + dtype=outgoing_request_key_counts.dtype, + device=outgoing_request_key_counts.device, + ), + incoming_request_key_counts, + ), + dim=0, + ).cumsum(dim=0)[:-1] + incoming_request_key = torch.empty( + (incoming_request_key_counts.sum(), self.ndim), + dtype=outgoing_request_key_counts.dtype, + device=outgoing_request_key_counts.device, + ) + # send and receive request keys + self.comm.Alltoallv( + ( + outgoing_request_key, + outgoing_request_key_counts.tolist(), + outgoing_request_key_displs.tolist(), + ), + ( + incoming_request_key, + incoming_request_key_counts.tolist(), + incoming_request_key_displs.tolist(), + ), ) - self.comm.Allgather(ndim_counts_on_proc, global_ndim_counts) - # construct communication matrix: what process sends how many elements to whom - comm_on_rank = torch.zeros( - (1, self.comm.size), dtype=torch.int64, device=global_ndim_counts.device + + incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) + incoming_request_key[original_split] -= displs[self.comm.rank] + send_buf = self.larray[incoming_request_key] + output_lshape = list(output_shape) + output_lshape[output_split] = key[0].shape[0] + recv_buf = torch.empty( + tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) - counts_bookkeeping = global_ndim_counts.flatten() - ordered_counts = torch.tensor(ordered_counts, device=counts_bookkeeping.device) - _, indices = torch.cat( - (counts_bookkeeping.cumsum(0), ordered_counts.cumsum(0)), dim=0 - ).unique(sorted=True, return_inverse=True) - for i in range(-self.comm.size, 0): - send_r = self.comm.size + i - end = indices[i] - if send_r == 0: - start = 0 - comm_on_rank[:, send_r] = counts_bookkeeping[ - slice(start + self.comm.rank % self.comm.size, end, self.comm.size) - ].sum() - else: - start = indices[i - 1] - slice_start = start + (self.comm.rank - start) % self.comm.size - comm_on_rank[:, send_r] = counts_bookkeeping[ - slice(slice_start, end, self.comm.size) - ].sum() - leftover_counts = ordered_counts[send_r] - counts_bookkeeping[start:end].sum() - if leftover_counts > 0: - counts_bookkeeping[indices[i]] -= leftover_counts - if self.comm.rank == indices[i] % self.comm.size: - comm_on_rank[:, send_r] += leftover_counts - # share info - comm_matrix = torch.zeros( - (self.comm.size, self.comm.size), dtype=torch.int64, device=global_ndim_counts.device + recv_displs = outgoing_request_key_displs + send_counts = incoming_request_key_counts + send_displs = incoming_request_key_displs + self.comm.Alltoallv( + (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) ) - self.comm.Allgather(comm_on_rank, comm_matrix) - # example: comm_matrix[0, 1] returns the counts that rank 0 is about to send to rank 1 + + # reorganize incoming counts according to original key order + key = torch.stack(key, dim=1).tolist() + outgoing_request_key = outgoing_request_key.tolist() + map = [outgoing_request_key.index(k) for k in key] + indexed_arr = recv_buf[map] + return factories.array(indexed_arr, is_split=0) # TODO: boolean indexing with data.split != 0 # __process_key() returns locally correct key diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 60c60dd138..564cfd63d9 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -574,7 +574,11 @@ def test_getitem(self): arr_split1 = ht.array(arr, split=1) mask_split1 = ht.array(mask, split=1) - self.assertTrue((arr_split1[mask_split1].numpy() == arr.numpy()[mask]).all()) + self.assert_array_equal(arr_split1[mask_split1], arr.numpy()[mask]) + + arr_split2 = ht.array(arr, split=2) + mask_split2 = ht.array(mask, split=2) + self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) def test_int_cast(self): # simple scalar tensor From f46ae672578b27aa705f1515ebcebf2c0c53a09d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 12 Dec 2022 11:17:31 +0100 Subject: [PATCH 030/132] Fix edge-case contiguity mismatch for Allgatherv --- heat/core/communication.py | 43 +++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/heat/core/communication.py b/heat/core/communication.py index ad58dae964..23d633c30f 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -240,7 +240,11 @@ def counts_displs_shape( @classmethod def mpi_type_and_elements_of( - cls, obj: Union[DNDarray, torch.Tensor], counts: Tuple[int], displs: Tuple[int] + cls, + obj: Union[DNDarray, torch.Tensor], + counts: Tuple[int], + displs: Tuple[int], + is_contiguous: bool, ) -> Tuple[MPI.Datatype, Tuple[int, ...]]: """ Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray` @@ -255,12 +259,18 @@ def mpi_type_and_elements_of( Optional counts arguments for variable MPI-calls (e.g. Alltoallv) displs : Tuple[ints,...], optional Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) + is_contiguous: bool, optional + Optional information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``. # ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation """ mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj) - # simple case, continuous memory can be transmitted as is - if obj.is_contiguous(): + # simple case, contiguous memory can be transmitted as is + if is_contiguous is None: + # determine local contiguity + is_contiguous = obj.is_contiguous() + + if is_contiguous: if counts is None: return mpi_type, elements else: @@ -273,7 +283,7 @@ def mpi_type_and_elements_of( ), ) - # non-continuous memory, e.g. after a transpose, has to be packed in derived MPI types + # non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types elements = obj.shape[0] shape = obj.shape[1:] strides = [1] * len(shape) @@ -305,7 +315,11 @@ def as_mpi_memory(cls, obj) -> MPI.memory: @classmethod def as_buffer( - cls, obj: torch.Tensor, counts: Tuple[int] = None, displs: Tuple[int] = None + cls, + obj: torch.Tensor, + counts: Tuple[int] = None, + displs: Tuple[int] = None, + is_contiguous: bool = None, ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: """ Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. @@ -318,14 +332,15 @@ def as_buffer( Optional counts arguments for variable MPI-calls (e.g. Alltoallv) displs : Tuple[int,...], optional Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) + is_contiguous: bool, optional + Optional information on global contiguity of the memory-distributed object. """ squ = False if not obj.is_contiguous() and obj.ndim == 1: # this makes the math work below this function. obj.unsqueeze_(-1) squ = True - mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs) - + mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous) mpi_mem = cls.as_mpi_memory(obj) if squ: # the squeeze happens in the mpi_type_and_elements_of function in the case of a @@ -1037,7 +1052,6 @@ def __allgather_like( type(sendbuf) ) ) - # unpack the receive buffer if isinstance(recvbuf, tuple): recvbuf, recv_counts, recv_displs = recvbuf @@ -1053,17 +1067,18 @@ def __allgather_like( # keep a reference to the original buffer object original_recvbuf = recvbuf - + sbuf_is_contiguous, rbuf_is_contiguous = True, True # permute the send_axis order so that the split send_axis is the first to be transmitted if axis != 0: send_axis_permutation = list(range(sendbuf.ndimension())) send_axis_permutation[0], send_axis_permutation[axis] = axis, 0 sendbuf = sendbuf.permute(*send_axis_permutation) + sbuf_is_contiguous = False - if axis != 0: recv_axis_permutation = list(range(recvbuf.ndimension())) recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0 recvbuf = recvbuf.permute(*recv_axis_permutation) + rbuf_is_contiguous = False else: recv_axis_permutation = None @@ -1074,20 +1089,18 @@ def __allgather_like( if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor): mpi_sendbuf = sbuf else: - mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) + mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous) if send_counts is not None: mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank] if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor): mpi_recvbuf = rbuf else: - mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) + mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous) if recv_counts is None: mpi_recvbuf[1] //= self.size - # perform the scatter operation exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) - return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation def Allgather( @@ -1260,7 +1273,7 @@ def __alltoall_like( # keep a reference to the original buffer object original_recvbuf = recvbuf - # Simple case, continuous buffers can be transmitted as is + # Simple case, contiguous buffers can be transmitted as is if send_axis < 2 and recv_axis < 2: send_axis_permutation = list(range(recvbuf.ndimension())) recv_axis_permutation = list(range(recvbuf.ndimension())) From 27ea911b98c660d29c7ca8033d37b9290d5db95a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 12 Dec 2022 12:17:47 +0100 Subject: [PATCH 031/132] Update ubuntu --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 822a501a9a..9cd92c30a9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,5 +1,5 @@ test: - image: nvidia/cuda:11.6.2-runtime-ubuntu20.04 + image: nvidia/cuda:11.6.2-runtime-ubuntu22.04 tags: - cuda - x86_64 From d0fb6c8213b119708addcbdc25c3ec1518cd10d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Dec 2022 11:18:26 +0000 Subject: [PATCH 032/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .github/release-drafter.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index c1abd3124d..7fef410249 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -34,7 +34,7 @@ categories: label: 'chore' - title: '🧪 Testing' label: 'testing' - + change-template: '- #$NUMBER $TITLE (by @$AUTHOR)' categorie-template: '### $TITLE' exclude-labels: From 0e704d43e8c8fb5d74a23c0fb4895ea7ce796b13 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 12 Dec 2022 14:02:24 +0100 Subject: [PATCH 033/132] switch back to ubuntu 20.04 --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9cd92c30a9..822a501a9a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,5 +1,5 @@ test: - image: nvidia/cuda:11.6.2-runtime-ubuntu22.04 + image: nvidia/cuda:11.6.2-runtime-ubuntu20.04 tags: - cuda - x86_64 From acfe9bdd2dc8ade78d03003eea1d88fdb456758a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 12 Dec 2022 15:22:05 +0100 Subject: [PATCH 034/132] Upgrade CI to ubuntu 22.04 and cuda 11.7.1 --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 822a501a9a..51e8b292ee 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,5 +1,5 @@ test: - image: nvidia/cuda:11.6.2-runtime-ubuntu20.04 + image: nvidia/cuda:11.7.1-runtime-ubuntu22.04 tags: - cuda - x86_64 From 0fd3d87bf37ee30a31bfe20160e1fd7a3ba0f851 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:06:41 +0100 Subject: [PATCH 035/132] avoid unnecessary gathering of test DNDarrays --- heat/core/tests/test_suites/basic_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index f094668bc8..65dcea4e96 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -136,8 +136,8 @@ def assert_array_equal(self, heat_array, expected_array): "Local shapes do not match. " "Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape), ) - local_heat_numpy = heat_array.numpy() - self.assertTrue(np.allclose(local_heat_numpy, expected_array)) + # compare local tensors to corresponding slice of expected_array + self.assertTrue(np.allclose(heat_array.larray.numpy(), expected_array[slices])) def assert_func_equal( self, From 3c4c07cf450973965f60525644726656e713a1ff Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:14:52 +0100 Subject: [PATCH 036/132] early out for resplit of non-distributed DNDarrays --- heat/core/manipulations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 33ebf4d365..00a8241bc0 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3372,6 +3372,9 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: # early out for unchanged content if axis == arr.split: return arr.copy() + if not arr.is_distributed(): + return factories.array(arr.larray, split=axis, device=arr.device, copy=True) + if axis is None: # new_arr = arr.copy() gathered = torch.empty( From 989e0f4e358e8324d37a0a3ac0ddc1946d54d26b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:17:37 +0100 Subject: [PATCH 037/132] match split of comparison array to expected output --- heat/core/linalg/tests/test_basics.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index a3cb827b84..45d4e34d82 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -238,6 +238,7 @@ def test_inv(self): self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) # distributed + # ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=0) a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -245,6 +246,7 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) + ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=1) a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=1) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -281,7 +283,7 @@ def test_inv(self): self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) # pivoting row change - ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double) / 3.0 + ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=0) / 3.0 a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=0) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -289,6 +291,7 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) + ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=1) / 3.0 a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=1) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -365,7 +368,8 @@ def test_matmul(self): self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) - self.assertEqual(a.split, 0) + if a.comm.size > 1: + self.assertEqual(a.split, 0) self.assertEqual(b.split, None) if a.comm.size > 1: From 6d66fad4222c6d13f2ac5a339387c1cc207a76a6 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:22:50 +0100 Subject: [PATCH 038/132] avoid MPI calls in non-distributed cases --- heat/core/linalg/basics.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index bc5d3e9e65..7a2776386b 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -510,6 +510,13 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: if b.dtype != c_type: b = c_type(b, device=b.device) + # early out for single-process setup, torch matmul + if a.comm.size == 1: + ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device) + if gpu_int_flag: + ret = og_type(ret, device=a.device) + return ret + if a.split is None and b.split is None: # matmul from torch if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit: # if either of A or B is a vector @@ -517,17 +524,17 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: if gpu_int_flag: ret = og_type(ret, device=a.device) return ret - else: - a.resplit_(0) - slice_0 = a.comm.chunk(a.shape, a.split)[2][0] - hold = a.larray @ b.larray - c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device) - c.larray[slice_0.start : slice_0.stop, :] += hold - c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) - if gpu_int_flag: - c = og_type(c, device=a.device) - return c + a.resplit_(0) + slice_0 = a.comm.chunk(a.shape, a.split)[2][0] + hold = a.larray @ b.larray + + c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device) + c.larray[slice_0.start : slice_0.stop, :] += hold + c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) + if gpu_int_flag: + c = og_type(c, device=a.device) + return c # if they are vectors they need to be expanded to be the proper dimensions vector_flag = False # flag to run squeeze at the end of the function From a37b4d3c35c7e81fd5a3528c00aee74c7e80ce8a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:27:10 +0100 Subject: [PATCH 039/132] avoid MPI calls in non-distributed resplit --- heat/core/dndarray.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 9ec0ea89e1..6e9d2c56ef 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1268,8 +1268,11 @@ def resplit_(self, axis: int = None): axis = sanitize_axis(self.shape, axis) # early out for unchanged content + if self.comm.size == 1: + self.__split = axis if axis == self.split: return self + if axis is None: gathered = torch.empty( self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device From 8eebe10b4359f14ac84b90eb067199446bc4caf5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:28:24 +0100 Subject: [PATCH 040/132] set default to None --- heat/core/communication.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/communication.py b/heat/core/communication.py index 23d633c30f..fd800185ca 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -340,6 +340,7 @@ def as_buffer( # this makes the math work below this function. obj.unsqueeze_(-1) squ = True + mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous) mpi_mem = cls.as_mpi_memory(obj) if squ: @@ -1067,7 +1068,7 @@ def __allgather_like( # keep a reference to the original buffer object original_recvbuf = recvbuf - sbuf_is_contiguous, rbuf_is_contiguous = True, True + sbuf_is_contiguous, rbuf_is_contiguous = None, None # permute the send_axis order so that the split send_axis is the first to be transmitted if axis != 0: send_axis_permutation = list(range(sendbuf.ndimension())) From 22c5c68ffdb5f70ea2c559787f45ce28722dc0ea Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:30:07 +0100 Subject: [PATCH 041/132] remove print statement --- heat/core/tests/test_dndarray.py | 1 - 1 file changed, 1 deletion(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e42c5a9a14..726a85e77a 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -126,7 +126,6 @@ def test_gethalo(self): # test no data on process data_np = np.arange(2 * 12).reshape(2, 12) data = ht.array(data_np, split=0) - print("DEBUGGING: data.lshape_map = ", data.lshape_map) data.get_halo(1) data_with_halos = data.array_with_halos From c692bff3bde5279d6ff3e497bad6fc766fb5ff19 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:42:17 +0100 Subject: [PATCH 042/132] upgrade torch version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2210ceaf97..0e8f00b0de 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ install_requires=[ "mpi4py>=3.0.0", "numpy>=1.13.0", - "torch>=1.7.0, <1.13.1", + "torch>=1.7.0, <1.13.2", "scipy>=0.14.0", "pillow>=6.0.0", "torchvision>=0.8.0", From df6a4e567419d7548f926cf1a55a75bb7fb05f9d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 20 Dec 2022 05:58:21 +0100 Subject: [PATCH 043/132] copy to cpu before comparing --- heat/core/tests/test_suites/basic_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index 65dcea4e96..2ef0c1d96c 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -137,7 +137,7 @@ def assert_array_equal(self, heat_array, expected_array): "Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape), ) # compare local tensors to corresponding slice of expected_array - self.assertTrue(np.allclose(heat_array.larray.numpy(), expected_array[slices])) + self.assertTrue(np.allclose(heat_array.larray.cpu().numpy(), expected_array[slices])) def assert_func_equal( self, From af0e721d3654d6e0e02f4ea0ff799001408b1b28 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 23 Dec 2022 12:19:41 +0100 Subject: [PATCH 044/132] use ht.allclose instead of np.allclose --- heat/core/tests/test_suites/basic_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index 2ef0c1d96c..b15103a1c5 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -137,7 +137,7 @@ def assert_array_equal(self, heat_array, expected_array): "Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape), ) # compare local tensors to corresponding slice of expected_array - self.assertTrue(np.allclose(heat_array.larray.cpu().numpy(), expected_array[slices])) + self.assertTrue(ht.allclose(heat_array, ht.array(expected_array))) def assert_func_equal( self, From bac6d4e524d2754f59a2fe0986bb34c4ff36983b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 23 Dec 2022 12:21:35 +0100 Subject: [PATCH 045/132] cast different dtype operands to promoted dtype within torch call --- heat/core/logical.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/heat/core/logical.py b/heat/core/logical.py index a6be081ea7..8106a556ee 100644 --- a/heat/core/logical.py +++ b/heat/core/logical.py @@ -140,7 +140,19 @@ def allclose( t1, t2 = __sanitize_close_input(x, y) # no sanitation for shapes of x and y needed, torch.allclose raises relevant errors - _local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan)) + try: + _local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan)) + except RuntimeError: + promoted_dtype = torch.promote_types(t1.larray.dtype, t2.larray.dtype) + _local_allclose = torch.tensor( + torch.allclose( + t1.larray.type(promoted_dtype), + t2.larray.type(promoted_dtype), + rtol, + atol, + equal_nan, + ) + ) # If x is distributed, then y is also distributed along the same axis if t1.comm.is_distributed(): From c0c63629a45a20eeef75a00d8d871933b2eb5e48 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 23 Dec 2022 12:52:49 +0100 Subject: [PATCH 046/132] compare local tensors to corresponding slice of expected_array only --- heat/core/tests/test_suites/basic_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index b15103a1c5..39f6a5f063 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -137,7 +137,11 @@ def assert_array_equal(self, heat_array, expected_array): "Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape), ) # compare local tensors to corresponding slice of expected_array - self.assertTrue(ht.allclose(heat_array, ht.array(expected_array))) + is_allclose = np.allclose(heat_array.larray.cpu(), expected_array[slices]) + ht_is_allclose = ht.array( + [is_allclose], dtype=ht.bool, is_split=0, device=heat_array.device + ) + self.assertTrue(ht.all(ht_is_allclose)) def assert_func_equal( self, From 587bc054782ddea297ea27c0979c5aa484b8a517 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 23 Dec 2022 13:38:38 +0100 Subject: [PATCH 047/132] expand tests --- heat/core/linalg/tests/test_basics.py | 18 ++++++++++++++++++ heat/core/tests/test_logical.py | 2 ++ heat/core/tests/test_manipulations.py | 10 ++++++++++ 3 files changed, 30 insertions(+) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index 45d4e34d82..c379904b18 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -372,6 +372,24 @@ def test_matmul(self): self.assertEqual(a.split, 0) self.assertEqual(b.split, None) + # splits 0 None on 1 process + if a.comm.size == 1: + a = ht.ones((n, m), split=0) + b = ht.ones((j, k), split=None) + a[0] = ht.arange(1, m + 1) + a[:, -1] = ht.arange(1, n + 1) + b[0] = ht.arange(1, k + 1) + b[:, 0] = ht.arange(1, j + 1) + ret00 = ht.matmul(a, b, allow_resplit=True) + + self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1) + self.assertIsInstance(ret00, ht.DNDarray) + self.assertEqual(ret00.shape, (n, k)) + self.assertEqual(ret00.dtype, ht.float) + self.assertEqual(ret00.split, None) + self.assertEqual(a.split, 0) + self.assertEqual(b.split, None) + if a.comm.size > 1: # splits 00 a = ht.ones((n, m), split=0, dtype=ht.float64) diff --git a/heat/core/tests/test_logical.py b/heat/core/tests/test_logical.py index 691df7ec62..c2e3d1a786 100644 --- a/heat/core/tests/test_logical.py +++ b/heat/core/tests/test_logical.py @@ -182,6 +182,7 @@ def test_allclose(self): c = ht.zeros((4, 6), split=0) d = ht.zeros((4, 6), split=1) e = ht.zeros((4, 6)) + f = ht.float64([[2.000005, 2.000005], [2.000005, 2.000005]]) self.assertFalse(ht.allclose(a, b)) self.assertTrue(ht.allclose(a, b, atol=1e-04)) @@ -189,6 +190,7 @@ def test_allclose(self): self.assertTrue(ht.allclose(a, 2)) self.assertTrue(ht.allclose(a, 2.0)) self.assertTrue(ht.allclose(2, a)) + self.assertTrue(ht.allclose(f, a)) self.assertTrue(ht.allclose(c, d)) self.assertTrue(ht.allclose(c, e)) self.assertTrue(e.allclose(c)) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 9a41bceab8..4464053fd3 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2992,6 +2992,16 @@ def test_resplit(self): self.assertEqual(data2.lshape, (data.comm.size, 1)) self.assertEqual(data2.split, 1) + # resplitting a non-distributed DNDarray with split not None + if ht.MPI_WORLD.size == 1: + data = ht.zeros(10, 10, split=0) + data2 = ht.resplit(data, 1) + data3 = ht.resplit(data, None) + self.assertTrue((data == data2).all()) + self.assertTrue((data == data3).all()) + self.assertEqual(data2.split, 1) + self.assertTrue(data3.split is None) + # splitting an unsplit tensor should result in slicing the tensor locally shape = (ht.MPI_WORLD.size, ht.MPI_WORLD.size) data = ht.zeros(shape) From 24239a11e22067ec21c8a7a8eb2c4f895459b1a0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 23 Dec 2022 13:39:15 +0100 Subject: [PATCH 048/132] remove redundant code --- heat/core/manipulations.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 00a8241bc0..7cf02ab016 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3384,11 +3384,6 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) return new_arr - # tensor needs be split/sliced locally - if arr.split is None: - temp = arr.larray[arr.comm.chunk(arr.shape, axis)[2]] - new_arr = factories.array(temp, is_split=axis, device=arr.device, dtype=arr.dtype) - return new_arr arr_tiles = tiling.SplitTiles(arr) new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device) From cd65b370a5dd72ee62cd5ebf2c76464f222e8ac1 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 26 Dec 2022 07:53:58 +0100 Subject: [PATCH 049/132] Implement slicing with negative step --- heat/core/dndarray.py | 332 ++++++++++++++++++++++++++---------------- 1 file changed, 210 insertions(+), 122 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 0ec9efa624..5eb3a10418 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -669,6 +669,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if arr.is_distributed(): split_bookkeeping[arr.split] = "split" counts, displs = arr.counts_displs() + new_split = arr.split advanced_indexing = False arr_is_copy = False @@ -799,118 +800,154 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced - if isinstance(key, (tuple, list)): - key = list(key) + key = list(key) if isinstance(key, Iterable) else [key] - # check for ellipsis, newaxis - add_dims = sum(k is None for k in key) # (np.newaxis is None)===true - ellipsis = sum(isinstance(k, type(...)) for k in key) - if ellipsis > 1: - raise ValueError("key can only contain 1 ellipsis") - # replace with explicit `slice(None)` for interested dimensions - if ellipsis == 1: - # output_shape, split_bookkeeping not affected - expand_key = [slice(None)] * (arr.ndim + add_dims) - ellipsis_index = key.index(...) - expand_key[:ellipsis_index] = key[:ellipsis_index] - expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] - key = expand_key - while add_dims > 0: - # expand array dims, output_shape, split_bookkeeping to reflect newaxis - # replace newaxis with slice(None) in key - for i, k in reversed(list(enumerate(key))): - if k is None: - key[i] = slice(None) - if not arr_is_copy: - arr = arr.copy() - arr_is_copy = True - arr = arr.expand_dims(i - add_dims + 1) - output_shape = ( - output_shape[: i - add_dims + 1] - + [1] - + output_shape[i - add_dims + 1 :] - ) - split_bookkeeping = ( - split_bookkeeping[: i - add_dims + 1] - + [None] - + split_bookkeeping[i - add_dims + 1 :] - ) - add_dims -= 1 - - # check for advanced indexing - advanced_indexing_dims = [] - for i, k in enumerate(key): - if isinstance(k, Iterable) or isinstance(k, DNDarray): - # advanced indexing across dimensions - advanced_indexing = True - advanced_indexing_dims.append(i) - if not isinstance(k, DNDarray): - key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) - - if advanced_indexing: - advanced_indexing_shapes = tuple( - tuple(key[i].shape) for i in advanced_indexing_dims - ) - print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) - # shapes of indexing arrays must be broadcastable - try: - broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) - except RuntimeError: - raise IndexError( - "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( - advanced_indexing_shapes - ) + # check for ellipsis, newaxis. NB: (np.newaxis is None)===true + add_dims = sum(k is None for k in key) + ellipsis = sum(isinstance(k, type(...)) for k in key) + if ellipsis > 1: + raise ValueError("key can only contain 1 ellipsis") + # replace with explicit `slice(None)` for interested dimensions + if ellipsis == 1: + # output_shape, split_bookkeeping not affected + expand_key = [slice(None)] * (arr.ndim + add_dims) + ellipsis_index = key.index(...) + expand_key[:ellipsis_index] = key[:ellipsis_index] + expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] + key = expand_key + while add_dims > 0: + # expand array dims, output_shape, split_bookkeeping to reflect newaxis + # replace newaxis with slice(None) in key + for i, k in reversed(list(enumerate(key))): + if k is None: + key[i] = slice(None) + if not arr_is_copy: + arr = arr.copy() + arr_is_copy = True + arr = arr.expand_dims(i - add_dims + 1) + output_shape = ( + output_shape[: i - add_dims + 1] + [1] + output_shape[i - add_dims + 1 :] ) - add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) - if ( - len(advanced_indexing_dims) == 1 - or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) - == advanced_indexing_dims - ): - # dimensions affected by advanced indexing are consecutive: - output_shape[ - advanced_indexing_dims[0] : advanced_indexing_dims[0] - + len(advanced_indexing_dims) - ] = broadcasted_shape split_bookkeeping = ( - split_bookkeeping[: advanced_indexing_dims[0]] - + [None] * add_dims - + split_bookkeeping[advanced_indexing_dims[0] :] + split_bookkeeping[: i - add_dims + 1] + + [None] + + split_bookkeeping[i - add_dims + 1 :] ) + add_dims -= 1 + + # check for advanced indexing and slices + print("DEBUGGING: key = ", key) + advanced_indexing_dims = [] + for i, k in enumerate(key): + if isinstance(k, Iterable) or isinstance(k, DNDarray): + # advanced indexing across dimensions + print("DEBUGGING: k = ", k) + advanced_indexing = True + advanced_indexing_dims.append(i) + if not isinstance(k, DNDarray): + key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) + elif isinstance(k, slice) and k != slice(None): + start, stop, step = k.start, k.stop, k.step + if start is None: + start = 0 + elif start < 0: + start += arr.gshape[i] + if stop is None: + stop = arr.gshape[i] + elif stop < 0: + stop += arr.gshape[i] + if step is None: + step = 1 + if step < 0 and start > stop: + # PyTorch doesn't support negative step as of 1.13 + # Lazy solution, potentially large memory footprint + # TODO: implement ht.fromiter (implemented in ASSET_ht) + key[i] = list(range(start, stop, step)) + output_shape[i] = len(key[i]) + if arr.is_distributed() and new_split == i: + # distribute key and proceed with non-ordered indexing + key[i] = factories.array(key[i], split=0, device=arr.device).larray + split_key_is_sorted = False + out_is_balanced = True + elif step > 0 and start < stop: + output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) + if arr.is_distributed() and new_split == i: + split_key_is_sorted = True + out_is_balanced = False + if ( + stop >= displs[arr.comm.rank] + and start < displs[arr.comm.rank] + counts[arr.comm.rank] + ): + index_in_cycle = (displs[arr.comm.rank] - start) % step + local_start = 0 if index_in_cycle == 0 else step - index_in_cycle + local_stop = stop - displs[arr.comm.rank] + key[i] = slice(local_start, local_stop, step) + else: + key[i] = slice(0, 0) + elif step == 0: + raise ValueError("Slice step cannot be zero") else: - # advanced-indexing dimensions are not consecutive: - # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions - non_adv_ind_dims = list( - i for i in range(arr.ndim) if i not in advanced_indexing_dims + key[i] = slice(0, 0) + output_shape[i] = 0 + + if advanced_indexing: + advanced_indexing_shapes = tuple(tuple(key[i].shape) for i in advanced_indexing_dims) + print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) + # shapes of indexing arrays must be broadcastable + try: + broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) + except RuntimeError: + raise IndexError( + "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( + advanced_indexing_shapes ) - if not arr_is_copy: - arr = arr.copy() - arr_is_copy = True - arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) - output_shape = list(arr.gshape) - output_shape[: len(advanced_indexing_dims)] = broadcasted_shape - split_bookkeeping = [None] * arr.ndim - if arr.is_distributed: - split_bookkeeping[arr.split] = "split" - split_bookkeeping = [None] * add_dims + split_bookkeeping - # modify key to match the new dimension order - key = [key[i] for i in advanced_indexing_dims] + [ - key[i] for i in non_adv_ind_dims - ] - # update advanced-indexing dims - advanced_indexing_dims = list(range(len(advanced_indexing_dims))) - - # expand key to match the number of dimensions of the DNDarray - if arr.ndim > len(key): - key += [slice(None)] * (arr.ndim - len(key)) - else: # key is integer or slice - key = [key] + [slice(None)] * (arr.ndim - 1) + ) + add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) + if ( + len(advanced_indexing_dims) == 1 + or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) + == advanced_indexing_dims + ): + # dimensions affected by advanced indexing are consecutive: + output_shape[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = broadcasted_shape + split_bookkeeping = ( + split_bookkeeping[: advanced_indexing_dims[0]] + + [None] * add_dims + + split_bookkeeping[advanced_indexing_dims[0] :] + ) + else: + # advanced-indexing dimensions are not consecutive: + # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions + non_adv_ind_dims = list( + i for i in range(arr.ndim) if i not in advanced_indexing_dims + ) + if not arr_is_copy: + arr = arr.copy() + arr_is_copy = True + arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) + output_shape = list(arr.gshape) + output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + split_bookkeeping = [None] * arr.ndim + if arr.is_distributed: + split_bookkeeping[arr.split] = "split" + split_bookkeeping = [None] * add_dims + split_bookkeeping + # modify key to match the new dimension order + key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] + # update advanced-indexing dims + advanced_indexing_dims = list(range(len(advanced_indexing_dims))) + + # expand key to match the number of dimensions of the DNDarray + if arr.ndim > len(key): + key += [slice(None)] * (arr.ndim - len(key)) key = tuple(key) output_shape = tuple(output_shape) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - return arr, key, output_shape, new_split, advanced_indexing + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced def __get_local_slice(self, key: slice): split = self.split @@ -967,7 +1004,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases - # print("DEBUGGING: RAW KEY = ", key) + print("DEBUGGING: RAW KEY = ", key) # Single-element indexing scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 @@ -1048,6 +1085,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # data are not distributed or split dimension is not affected by indexing # if not self.is_distributed() or key[self.split] == slice(None): + print("split_key_is_sorted, key = ", split_key_is_sorted, key) if split_key_is_sorted: indexed_arr = self.larray[key] return DNDarray( @@ -1060,20 +1098,34 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # key is sorted along dim 0 but not along self.split - # key is tuple of torch.Tensor + # key is not sorted along self.split + # key is tuple of torch.Tensor or mix of torch.Tensors and slices _, displs = self.counts_displs() original_split = self.split - # send and receive "request key" info on what data element to shup where + # determine whether indexed array will be 1D or nD + key_shapes = [] + for k in key: + key_shapes.append(getattr(k, "shape", None)) + return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim + + # send and receive "request key" info on what data element to ship where recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) - request_key_shape = (0, self.ndim) + + # construct empty tensor that we'll append to later + if return_1d: + request_key_shape = (0, self.ndim) + else: + request_key_shape = (0, 1) + outgoing_request_key = torch.empty( tuple(request_key_shape), dtype=torch.int64, device=self.larray.device ) outgoing_request_key_counts = torch.zeros( (self.comm.size,), dtype=torch.int64, device=self.larray.device ) + + # process-local: calculate which/how many elements will be received from what process for i in range(self.comm.size): cond1 = key[original_split] >= displs[i] if i != self.comm.size - 1: @@ -1083,16 +1135,23 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar cond2 = torch.ones( (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device ) - selection = list(k[cond1 & cond2] for k in key) - recv_counts[i, :] = selection[0].shape[0] - selection = torch.stack(selection, dim=1) + if return_1d: + # advanced indexing returning 1D array (e.g. boolean indexing) + selection = list(k[cond1 & cond2] for k in key) + recv_counts[i, :] = selection[0].shape[0] + selection = torch.stack(selection, dim=1) + else: + selection = key[original_split][cond1 & cond2] + recv_counts[i, :] = selection.shape[0] + selection.unsqueeze_(dim=1) outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) - + print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) # share recv_counts among all processes comm_matrix = torch.empty( (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device ) self.comm.Allgather(recv_counts, comm_matrix) + print("DEBUGGING: comm_matrix = ", comm_matrix) outgoing_request_key_counts = comm_matrix[self.comm.rank] outgoing_request_key_displs = torch.cat( @@ -1106,6 +1165,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ), dim=0, ).cumsum(dim=0)[:-1] + print("DEBUGGING: outgoing_request_key_displs = ", outgoing_request_key_displs) + print("DEBUGGING: outgoing_request_key_counts = ", outgoing_request_key_counts) incoming_request_key_counts = comm_matrix[:, self.comm.rank] incoming_request_key_displs = torch.cat( ( @@ -1118,11 +1179,21 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ), dim=0, ).cumsum(dim=0)[:-1] - incoming_request_key = torch.empty( - (incoming_request_key_counts.sum(), self.ndim), - dtype=outgoing_request_key_counts.dtype, - device=outgoing_request_key_counts.device, - ) + print("DEBUGGING: incoming_request_key_displs = ", incoming_request_key_displs) + print("DEBUGGING: incoming_request_key_counts = ", incoming_request_key_counts) + + if return_1d: + incoming_request_key = torch.empty( + (incoming_request_key_counts.sum(), self.ndim), + dtype=outgoing_request_key_counts.dtype, + device=outgoing_request_key_counts.device, + ) + else: + incoming_request_key = torch.empty( + (incoming_request_key_counts.sum(), 1), + dtype=outgoing_request_key_counts.dtype, + device=outgoing_request_key_counts.device, + ) # send and receive request keys self.comm.Alltoallv( ( @@ -1136,12 +1207,22 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar incoming_request_key_displs.tolist(), ), ) + print("DEBUGGING:incoming_request_key = ", incoming_request_key) + if return_1d: + incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) + incoming_request_key[original_split] -= displs[self.comm.rank] + else: + incoming_request_key -= displs[self.comm.rank] + incoming_request_key = ( + key[:output_split] + + (incoming_request_key.squeeze_(1).tolist(),) + + key[output_split + 1 :] + ) - incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) - incoming_request_key[original_split] -= displs[self.comm.rank] + print("AFTER: incoming_request_key = ", incoming_request_key) send_buf = self.larray[incoming_request_key] output_lshape = list(output_shape) - output_lshape[output_split] = key[0].shape[0] + output_lshape[output_split] = key[output_split].shape[0] recv_buf = torch.empty( tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) @@ -1152,12 +1233,19 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) ) - # reorganize incoming counts according to original key order - key = torch.stack(key, dim=1).tolist() - outgoing_request_key = outgoing_request_key.tolist() + # reorganize incoming counts according to original key order along split axis + if return_1d: + key = torch.stack(key, dim=1).tolist() + outgoing_request_key = outgoing_request_key.tolist() + else: + print("key[output_split] = ", key[output_split]) + key = key[output_split].tolist() + print("key = ", key) + outgoing_request_key = outgoing_request_key.squeeze_(1).tolist() + print("outgoing_request_key = ", outgoing_request_key) map = [outgoing_request_key.index(k) for k in key] indexed_arr = recv_buf[map] - return factories.array(indexed_arr, is_split=0) + return factories.array(indexed_arr, is_split=output_split) # TODO: boolean indexing with data.split != 0 # __process_key() returns locally correct key From 86e8801a9332f9da528c5eed9af027b42ad1a25d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 26 Dec 2022 07:54:24 +0100 Subject: [PATCH 050/132] test slicing with negative step --- heat/core/tests/test_dndarray.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 564cfd63d9..274d9e0177 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -530,7 +530,7 @@ def test_getitem(self): self.assertTrue(x[2].item() == 2.0) self.assertTrue(x[-2].item() == 8.0) self.assertTrue(x[2].dtype == ht.float64) - self.assertTrue(x[2].split is None) + # self.assertTrue(x[2].split is None) # 2D, local x = ht.arange(10).reshape(2, 5) self.assertTrue((x[0] == ht.arange(5)).all().item()) @@ -552,7 +552,7 @@ def test_getitem(self): indexed_split0 = x_split0[key] self.assertTrue((indexed_split0.larray == x.larray[key]).all()) self.assertTrue(indexed_split0.dtype == ht.float32) - self.assertTrue(indexed_split0.split is None) + # self.assertTrue(indexed_split0.split is None) # 3D, distributed split, != 0 x_split2 = ht.array(x, dtype=ht.int64, split=2) key = ht.array(2) @@ -561,6 +561,21 @@ def test_getitem(self): self.assertTrue(indexed_split2.dtype == ht.int64) self.assertTrue(indexed_split2.split == 1) + # Slicing and striding + x = ht.arange(20, split=0) + x_sliced = x[1:11:3] + x_sliced.balance_() + self.assertTrue( + (x_sliced == ht.array([1, 4, 7, 10], dtype=x.dtype, device=x.device, split=0)) + .all() + .item() + ) + + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(20, 4, 3) + x_3d_sliced = x_3d[17:2:-2, :2, 2] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(20, 4, 3)[17:2:-2, :2, 2] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) np.random.seed(42) From 3b1f46d3fbb73b88de16982dc6cbcb401a431f52 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 27 Dec 2022 07:11:18 +0100 Subject: [PATCH 051/132] Fix single-element indexing within mixed-type key --- heat/core/dndarray.py | 26 ++++++++++++++++---------- heat/core/tests/test_dndarray.py | 4 +++- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index cc10132554..1f3cddfccb 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -802,13 +802,13 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key = list(key) if isinstance(key, Iterable) else [key] - # check for ellipsis, newaxis. NB: (np.newaxis is None)===true + # check for ellipsis, newaxis. NB: (np.newaxis is None)==True add_dims = sum(k is None for k in key) ellipsis = sum(isinstance(k, type(...)) for k in key) if ellipsis > 1: raise ValueError("key can only contain 1 ellipsis") - # replace with explicit `slice(None)` for interested dimensions if ellipsis == 1: + # replace with explicit `slice(None)` for interested dimensions # output_shape, split_bookkeeping not affected expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) @@ -816,14 +816,11 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] key = expand_key while add_dims > 0: - # expand array dims, output_shape, split_bookkeeping to reflect newaxis + # expand array dims: output_shape, split_bookkeeping to reflect newaxis # replace newaxis with slice(None) in key for i, k in reversed(list(enumerate(key))): if k is None: key[i] = slice(None) - if not arr_is_copy: - arr = arr.copy() - arr_is_copy = True arr = arr.expand_dims(i - add_dims + 1) output_shape = ( output_shape[: i - add_dims + 1] + [1] + output_shape[i - add_dims + 1 :] @@ -841,10 +838,16 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(k, DNDarray): # advanced indexing across dimensions - print("DEBUGGING: k = ", k) - advanced_indexing = True - advanced_indexing_dims.append(i) - if not isinstance(k, DNDarray): + if getattr(k, "ndim", 1) == 0: + # single-element indexing along axis i + output_shape = output_shape[:i] + output_shape[i + 1 :] + split_bookkeeping = split_bookkeeping[:i] + split_bookkeeping[i + 1 :] + else: + advanced_indexing = True + advanced_indexing_dims.append(i) + if isinstance(k, DNDarray): + key[i] = k.larray + elif not isinstance(k, torch.Tensor): key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step @@ -924,6 +927,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] non_adv_ind_dims = list( i for i in range(arr.ndim) if i not in advanced_indexing_dims ) + # TODO: work this out without array copy if not arr_is_copy: arr = arr.copy() arr_is_copy = True @@ -1007,6 +1011,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar print("DEBUGGING: RAW KEY = ", key) # Single-element indexing + # TODO: single-element indexing along split axis belongs here as well scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: output_shape = self.gshape[1:] @@ -1220,6 +1225,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) print("AFTER: incoming_request_key = ", incoming_request_key) + print("OUTPUT_SHAPE = ", output_shape) send_buf = self.larray[incoming_request_key] output_lshape = list(output_shape) output_lshape[output_split] = key[output_split].shape[0] diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 9e97b60325..6b58e76f2d 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -570,10 +570,12 @@ def test_getitem(self): .item() ) + # slicing with negative step x_3d = ht.arange(20 * 4 * 3, split=0).reshape(20, 4, 3) - x_3d_sliced = x_3d[17:2:-2, :2, 1] + x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(20, 4, 3)[17:2:-2, :2, 1] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 0) # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) From 1a4bf97160c23d39c5d788bd37bfd4e169ee3c20 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 27 Dec 2022 09:55:17 +0100 Subject: [PATCH 052/132] Non-ordered indexing, split != 0 --- heat/core/dndarray.py | 39 ++++++++++++++++++++------------ heat/core/tests/test_dndarray.py | 11 ++++++++- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 1f3cddfccb..3305b6a135 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -666,10 +666,11 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] """ output_shape = list(arr.gshape) split_bookkeeping = [None] * arr.ndim - if arr.is_distributed(): + if arr.split is not None: split_bookkeeping[arr.split] = "split" - counts, displs = arr.counts_displs() - new_split = arr.split + if arr.is_distributed(): + counts, displs = arr.counts_displs() + new_split = arr.split advanced_indexing = False arr_is_copy = False @@ -849,6 +850,10 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key[i] = k.larray elif not isinstance(k, torch.Tensor): key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) + elif isinstance(k, int): + # single-element indexing along axis i + output_shape = output_shape[:i] + output_shape[i + 1 :] + split_bookkeeping = split_bookkeeping[:i] + split_bookkeeping[i + 1 :] elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step if start is None: @@ -949,6 +954,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key = tuple(key) output_shape = tuple(output_shape) + print("DEBUGGING: split_bookkeeping = ", split_bookkeeping) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced @@ -1084,7 +1090,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar split_key_is_sorted, out_is_balanced, ) = self.__process_key(key) - print("DEBUGGING: processed key = ", key) + print("DEBUGGING: processed key, output_split = ", key, output_split) # TODO: test that key for not affected dims is always slice(None) # including match between self.split and key after self manipulation @@ -1151,12 +1157,13 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar selection.unsqueeze_(dim=1) outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) + print("RECV_COUNTS = ", recv_counts) # share recv_counts among all processes comm_matrix = torch.empty( (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device ) self.comm.Allgather(recv_counts, comm_matrix) - print("DEBUGGING: comm_matrix = ", comm_matrix) + print("DEBUGGING: comm_matrix = ", comm_matrix, comm_matrix.shape) outgoing_request_key_counts = comm_matrix[self.comm.rank] outgoing_request_key_displs = torch.cat( @@ -1170,8 +1177,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ), dim=0, ).cumsum(dim=0)[:-1] - print("DEBUGGING: outgoing_request_key_displs = ", outgoing_request_key_displs) - print("DEBUGGING: outgoing_request_key_counts = ", outgoing_request_key_counts) incoming_request_key_counts = comm_matrix[:, self.comm.rank] incoming_request_key_displs = torch.cat( ( @@ -1184,8 +1189,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ), dim=0, ).cumsum(dim=0)[:-1] - print("DEBUGGING: incoming_request_key_displs = ", incoming_request_key_displs) - print("DEBUGGING: incoming_request_key_counts = ", incoming_request_key_counts) if return_1d: incoming_request_key = torch.empty( @@ -1232,24 +1235,30 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar recv_buf = torch.empty( tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) - recv_displs = outgoing_request_key_displs - send_counts = incoming_request_key_counts - send_displs = incoming_request_key_displs + recv_counts = torch.squeeze(recv_counts, dim=1).tolist() + recv_displs = outgoing_request_key_displs.tolist() + send_counts = incoming_request_key_counts.tolist() + send_displs = incoming_request_key_displs.tolist() + print("BEFORE ALLTOALLV: recv_counts = ", recv_counts) self.comm.Alltoallv( - (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) + (send_buf, send_counts, send_displs), + (recv_buf, recv_counts, recv_displs), + send_axis=output_split, ) # reorganize incoming counts according to original key order along split axis if return_1d: key = torch.stack(key, dim=1).tolist() outgoing_request_key = outgoing_request_key.tolist() + map = [outgoing_request_key.index(k) for k in key] else: print("key[output_split] = ", key[output_split]) key = key[output_split].tolist() print("key = ", key) outgoing_request_key = outgoing_request_key.squeeze_(1).tolist() - print("outgoing_request_key = ", outgoing_request_key) - map = [outgoing_request_key.index(k) for k in key] + map = [slice(None)] * recv_buf.ndim + map[output_split] = [outgoing_request_key.index(k) for k in key] + indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 6b58e76f2d..744040104b 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -570,13 +570,22 @@ def test_getitem(self): .item() ) - # slicing with negative step + # slicing with negative step along the split axis x_3d = ht.arange(20 * 4 * 3, split=0).reshape(20, 4, 3) x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(20, 4, 3)[17:2:-2, :2, 1] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) self.assertTrue(x_3d_sliced.split == 0) + # slicing with negative step, split 1 + x_3d = ht.arange(20 * 4 * 3).reshape(4, 20, 3) + x_3d.resplit_(axis=1) + key = [slice(None, 2), slice(17, 2, -2), 1] + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(4, 20, 3)[:2, 17:2:-2, 1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 1) + # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) np.random.seed(42) From 9e421562682090075a453863186888f2226894f7 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 28 Dec 2022 07:27:13 +0100 Subject: [PATCH 053/132] generalize negative step slicing to all splits, loss of dims --- heat/core/dndarray.py | 34 +++++++++++++++-------- heat/core/tests/test_dndarray.py | 46 +++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 3305b6a135..d0f6305117 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -836,13 +836,17 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # check for advanced indexing and slices print("DEBUGGING: key = ", key) advanced_indexing_dims = [] + lose_dims = 0 for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(k, DNDarray): # advanced indexing across dimensions if getattr(k, "ndim", 1) == 0: # single-element indexing along axis i - output_shape = output_shape[:i] + output_shape[i + 1 :] - split_bookkeeping = split_bookkeeping[:i] + split_bookkeeping[i + 1 :] + output_shape[i] = None + split_bookkeeping = ( + split_bookkeeping[: i - lose_dims] + split_bookkeeping[i - lose_dims + 1 :] + ) + lose_dims += 1 else: advanced_indexing = True advanced_indexing_dims.append(i) @@ -852,8 +856,11 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) elif isinstance(k, int): # single-element indexing along axis i - output_shape = output_shape[:i] + output_shape[i + 1 :] - split_bookkeeping = split_bookkeeping[:i] + split_bookkeeping[i + 1 :] + output_shape[i] = None + split_bookkeeping = ( + split_bookkeeping[: i - lose_dims] + split_bookkeeping[i - lose_dims + 1 :] + ) + lose_dims += 1 elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step if start is None: @@ -953,8 +960,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key += [slice(None)] * (arr.ndim - len(key)) key = tuple(key) + for i in range(output_shape.count(None)): + output_shape.remove(None) output_shape = tuple(output_shape) - print("DEBUGGING: split_bookkeeping = ", split_bookkeeping) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced @@ -1222,16 +1230,18 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar else: incoming_request_key -= displs[self.comm.rank] incoming_request_key = ( - key[:output_split] + key[:original_split] + (incoming_request_key.squeeze_(1).tolist(),) - + key[output_split + 1 :] + + key[original_split + 1 :] ) print("AFTER: incoming_request_key = ", incoming_request_key) print("OUTPUT_SHAPE = ", output_shape) + print("OUTPUT_SPLIT = ", output_split) + send_buf = self.larray[incoming_request_key] output_lshape = list(output_shape) - output_lshape[output_split] = key[output_split].shape[0] + output_lshape[output_split] = key[original_split].shape[0] recv_buf = torch.empty( tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) @@ -1252,14 +1262,16 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar outgoing_request_key = outgoing_request_key.tolist() map = [outgoing_request_key.index(k) for k in key] else: - print("key[output_split] = ", key[output_split]) - key = key[output_split].tolist() - print("key = ", key) + key = key[original_split].tolist() outgoing_request_key = outgoing_request_key.squeeze_(1).tolist() map = [slice(None)] * recv_buf.ndim map[output_split] = [outgoing_request_key.index(k) for k in key] indexed_arr = recv_buf[map] + print( + factories.array(indexed_arr, is_split=output_split).lshape, + factories.array(indexed_arr, is_split=output_split).gshape, + ) return factories.array(indexed_arr, is_split=output_split) # TODO: boolean indexing with data.split != 0 diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 744040104b..ac7c71b257 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -563,29 +563,49 @@ def test_getitem(self): # Slicing and striding x = ht.arange(20, split=0) x_sliced = x[1:11:3] - x_sliced.balance_() - self.assertTrue( - (x_sliced == ht.array([1, 4, 7, 10], dtype=x.dtype, device=x.device, split=0)) - .all() - .item() - ) - - # slicing with negative step along the split axis - x_3d = ht.arange(20 * 4 * 3, split=0).reshape(20, 4, 3) + x_np = np.arange(20) + x_sliced_np = x_np[1:11:3] + self.assert_array_equal(x_sliced, x_sliced_np) + self.assertTrue(x_sliced.split == 0) + + # slicing with negative step along split axis 0 + shape = (20, 4, 3) + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] - x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(20, 4, 3)[17:2:-2, :2, 1] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[17:2:-2, :2, 1] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) self.assertTrue(x_3d_sliced.split == 0) - # slicing with negative step, split 1 - x_3d = ht.arange(20 * 4 * 3).reshape(4, 20, 3) + # slicing with negative step along split 1 + shape = (4, 20, 3) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) x_3d.resplit_(axis=1) key = [slice(None, 2), slice(17, 2, -2), 1] x_3d_sliced = x_3d[key] - x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(4, 20, 3)[:2, 17:2:-2, 1] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 17:2:-2, 1] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) self.assertTrue(x_3d_sliced.split == 1) + # slicing with negative step along split 2 and loss of axis < split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = [slice(None, 2), 1, slice(17, 10, -2)] + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 1, 17:10:-2] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 1) + + # slicing with negative step along split 2 and loss of all axes but split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = [0, 1, slice(17, 13, -1)] + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[0, 1, 17:13:-1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 0) + # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) np.random.seed(42) From 1a310a902593429382214a695313dc9cc68bb700 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 28 Dec 2022 08:51:33 +0100 Subject: [PATCH 054/132] loop over active ranks only when key in descending order --- heat/core/dndarray.py | 46 ++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index d0f6305117..8f64fee013 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -674,7 +674,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] advanced_indexing = False arr_is_copy = False - split_key_is_sorted = True + split_key_is_sorted = 1 out_is_balanced = False if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): @@ -760,7 +760,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = (key[0].shape[0],) new_split = 0 - split_key_is_sorted = False + split_key_is_sorted = 0 out_is_balanced = True return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced @@ -882,12 +882,12 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if arr.is_distributed() and new_split == i: # distribute key and proceed with non-ordered indexing key[i] = factories.array(key[i], split=0, device=arr.device).larray - split_key_is_sorted = False + split_key_is_sorted = -1 out_is_balanced = True elif step > 0 and start < stop: output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) if arr.is_distributed() and new_split == i: - split_key_is_sorted = True + split_key_is_sorted = 1 out_is_balanced = False if ( stop >= displs[arr.comm.rank] @@ -1105,7 +1105,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # data are not distributed or split dimension is not affected by indexing # if not self.is_distributed() or key[self.split] == slice(None): print("split_key_is_sorted, key = ", split_key_is_sorted, key) - if split_key_is_sorted: + if split_key_is_sorted == 1: indexed_arr = self.larray[key] return DNDarray( indexed_arr, @@ -1145,7 +1145,33 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) # process-local: calculate which/how many elements will be received from what process - for i in range(self.comm.size): + if split_key_is_sorted == -1: + # key is sorted in descending order + # shrink selection of active processes + if key[original_split].numel() > 0: + key_edges = torch.cat( + (key[original_split][-1].reshape(-1), key[original_split][0].reshape(-1)), dim=0 + ).unique() + displs = torch.tensor(displs, device=self.larray.device) + _, inverse, counts = torch.cat((displs, key_edges), dim=0).unique( + sorted=True, return_inverse=True, return_counts=True + ) + if key_edges.numel() == 2: + correction = counts[inverse[-2]] % 2 + start_rank = inverse[-2] - correction + correction += counts[inverse[-1]] % 2 + end_rank = inverse[-1] - correction + 1 + elif key_edges.numel() == 1: + correction = counts[inverse[-1]] % 2 + start_rank = inverse[-1] - correction + end_rank = start_rank + 1 + else: + start_rank = 0 + end_rank = 0 + else: + start_rank = 0 + end_rank = self.comm.size + for i in range(start_rank, end_rank): cond1 = key[original_split] >= displs[i] if i != self.comm.size - 1: cond2 = key[original_split] < displs[i + 1] @@ -1257,6 +1283,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) # reorganize incoming counts according to original key order along split axis + # if split_key_is_sorted == -1: + # indexed_arr = recv_buf.flip(dims=(output_split,)) + # else: if return_1d: key = torch.stack(key, dim=1).tolist() outgoing_request_key = outgoing_request_key.tolist() @@ -1266,12 +1295,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar outgoing_request_key = outgoing_request_key.squeeze_(1).tolist() map = [slice(None)] * recv_buf.ndim map[output_split] = [outgoing_request_key.index(k) for k in key] - indexed_arr = recv_buf[map] - print( - factories.array(indexed_arr, is_split=output_split).lshape, - factories.array(indexed_arr, is_split=output_split).gshape, - ) return factories.array(indexed_arr, is_split=output_split) # TODO: boolean indexing with data.split != 0 From c2ba0d901cc68e4076a8024b8d26a92725ca8a29 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 29 Dec 2022 06:10:25 +0100 Subject: [PATCH 055/132] replace list-on-list mapping with argsort mapping for non-ordered key --- heat/core/dndarray.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8f64fee013..c73b653509 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1249,7 +1249,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar incoming_request_key_displs.tolist(), ), ) - print("DEBUGGING:incoming_request_key = ", incoming_request_key) + # print("DEBUGGING:incoming_request_key = ", incoming_request_key) if return_1d: incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) incoming_request_key[original_split] -= displs[self.comm.rank] @@ -1261,9 +1261,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar + key[original_split + 1 :] ) - print("AFTER: incoming_request_key = ", incoming_request_key) - print("OUTPUT_SHAPE = ", output_shape) - print("OUTPUT_SPLIT = ", output_split) + # print("AFTER: incoming_request_key = ", incoming_request_key) + # print("OUTPUT_SHAPE = ", output_shape) + # print("OUTPUT_SPLIT = ", output_split) send_buf = self.larray[incoming_request_key] output_lshape = list(output_shape) @@ -1275,7 +1275,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar recv_displs = outgoing_request_key_displs.tolist() send_counts = incoming_request_key_counts.tolist() send_displs = incoming_request_key_displs.tolist() - print("BEFORE ALLTOALLV: recv_counts = ", recv_counts) self.comm.Alltoallv( (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs), @@ -1283,18 +1282,28 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) # reorganize incoming counts according to original key order along split axis - # if split_key_is_sorted == -1: - # indexed_arr = recv_buf.flip(dims=(output_split,)) - # else: if return_1d: - key = torch.stack(key, dim=1).tolist() + key = torch.stack(key, dim=1) # .tolist() + unique_keys, inverse = key.unique(dim=0, sorted=True, return_inverse=True) + if unique_keys.shape == key.shape: + pass + key = key.tolist() outgoing_request_key = outgoing_request_key.tolist() + # TODO: major bottleneck, replace with some vectorized sorting solution or use available info map = [outgoing_request_key.index(k) for k in key] else: - key = key[original_split].tolist() - outgoing_request_key = outgoing_request_key.squeeze_(1).tolist() + key = key[original_split] + outgoing_request_key = outgoing_request_key.squeeze_(1) + # incoming elements likely already stacked in ascending or descending order + if key == outgoing_request_key: + return factories.array(recv_buf, is_split=output_split) + if key == outgoing_request_key.flip(dims=(0,)): + return factories.array(recv_buf.flip(dims=(output_split,)), is_split=output_split) + map = [slice(None)] * recv_buf.ndim - map[output_split] = [outgoing_request_key.index(k) for k in key] + map[output_split] = outgoing_request_key.argsort(stable=True)[ + key.argsort(stable=True).argsort(stable=True) + ] indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split) From f6bb5c3827068cad3e3a498a56b13a7fddcb8a7a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 30 Dec 2022 08:10:32 +0100 Subject: [PATCH 056/132] replace list-on-list mapping with argsort mapping for boolean indexing --- heat/core/dndarray.py | 52 +++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c73b653509..88dba106fc 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -910,7 +910,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) # shapes of indexing arrays must be broadcastable try: - broadcasted_shape = torch.broadcast_shapes(advanced_indexing_shapes) + broadcasted_shape = torch.broadcast_shapes(*advanced_indexing_shapes) except RuntimeError: raise IndexError( "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( @@ -1283,27 +1283,35 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # reorganize incoming counts according to original key order along split axis if return_1d: - key = torch.stack(key, dim=1) # .tolist() - unique_keys, inverse = key.unique(dim=0, sorted=True, return_inverse=True) - if unique_keys.shape == key.shape: - pass - key = key.tolist() - outgoing_request_key = outgoing_request_key.tolist() - # TODO: major bottleneck, replace with some vectorized sorting solution or use available info - map = [outgoing_request_key.index(k) for k in key] - else: - key = key[original_split] - outgoing_request_key = outgoing_request_key.squeeze_(1) - # incoming elements likely already stacked in ascending or descending order - if key == outgoing_request_key: - return factories.array(recv_buf, is_split=output_split) - if key == outgoing_request_key.flip(dims=(0,)): - return factories.array(recv_buf.flip(dims=(output_split,)), is_split=output_split) - - map = [slice(None)] * recv_buf.ndim - map[output_split] = outgoing_request_key.argsort(stable=True)[ - key.argsort(stable=True).argsort(stable=True) - ] + key = torch.stack(key, dim=1) + _, key_inverse = key.unique(dim=0, sorted=True, return_inverse=True) + if _.shape == key.shape: + _, ork_inverse = outgoing_request_key.unique( + dim=0, sorted=True, return_inverse=True + ) + map = ork_inverse.argsort(stable=True)[ + key_inverse.argsort(stable=True).argsort(stable=True) + ] + else: + # major bottleneck + key = key.tolist() + outgoing_request_key = outgoing_request_key.tolist() + map = [outgoing_request_key.index(k) for k in key] + indexed_arr = recv_buf[map] + return factories.array(indexed_arr, is_split=output_split) + + key = key[original_split] + outgoing_request_key = outgoing_request_key.squeeze_(1) + # incoming elements likely already stacked in ascending or descending order + if (key == outgoing_request_key).all(): + return factories.array(recv_buf, is_split=output_split) + if (key == outgoing_request_key.flip(dims=(0,))).all(): + return factories.array(recv_buf.flip(dims=(output_split,)), is_split=output_split) + + map = [slice(None)] * recv_buf.ndim + map[output_split] = outgoing_request_key.argsort(stable=True)[ + key.argsort(stable=True).argsort(stable=True) + ] indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split) From cad99756cbe3af343ae88eebcd789c00c8c44600 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 31 Dec 2022 07:46:20 +0100 Subject: [PATCH 057/132] fix advanced indexing via list, remove last key-mapping bottleneck for unsorted key --- heat/core/dndarray.py | 51 ++++++++++++++++++++++---------- heat/core/tests/test_dndarray.py | 17 +++++++++-- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 88dba106fc..0b4c9a2316 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -666,17 +666,22 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] """ output_shape = list(arr.gshape) split_bookkeeping = [None] * arr.ndim + new_split = arr.split if arr.split is not None: split_bookkeeping[arr.split] = "split" if arr.is_distributed(): counts, displs = arr.counts_displs() - new_split = arr.split advanced_indexing = False arr_is_copy = False - split_key_is_sorted = 1 + split_key_is_sorted = 0 out_is_balanced = False + if isinstance(key, list): + try: + key = torch.tensor(key, device=arr.larray.device) + except RuntimeError: + raise IndexError("Invalid indices: expected a list of integers, got {}".format(key)) if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool, np.uint8): # boolean indexing: shape must match arr.shape @@ -707,6 +712,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = tuple(key[0].shape) new_split = None if arr.split is None else 0 out_is_balanced = True + split_key_is_sorted = 1 return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced # arr is distributed @@ -732,6 +738,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key[i] = k.larray key[arr.split] -= displs[arr.comm.rank] key = tuple(key) + split_key_is_sorted = 1 else: key = key.larray.nonzero(as_tuple=False) # construct global key array @@ -809,7 +816,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if ellipsis > 1: raise ValueError("key can only contain 1 ellipsis") if ellipsis == 1: - # replace with explicit `slice(None)` for interested dimensions + # replace with explicit `slice(None)` for affected dimensions # output_shape, split_bookkeeping not affected expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) @@ -850,6 +857,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] else: advanced_indexing = True advanced_indexing_dims.append(i) + if arr.is_distributed() and i == arr.split: + # make no assumption on data locality wrt key + split_key_is_sorted = 0 if isinstance(k, DNDarray): key[i] = k.larray elif not isinstance(k, torch.Tensor): @@ -1171,6 +1181,10 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar else: start_rank = 0 end_rank = self.comm.size + all_local_indexing = torch.ones( + (self.comm.size,), dtype=torch.bool, device=self.larray.device + ) + all_local_indexing[start_rank:end_rank] = False for i in range(start_rank, end_rank): cond1 = key[original_split] >= displs[i] if i != self.comm.size - 1: @@ -1184,12 +1198,21 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # advanced indexing returning 1D array (e.g. boolean indexing) selection = list(k[cond1 & cond2] for k in key) recv_counts[i, :] = selection[0].shape[0] + if i == self.comm.rank: + all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] selection = torch.stack(selection, dim=1) else: selection = key[original_split][cond1 & cond2] recv_counts[i, :] = selection.shape[0] + if i == self.comm.rank: + all_local_indexing[i] = selection.shape[0] == key[original_split].shape[0] selection.unsqueeze_(dim=1) outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) + all_local_indexing = factories.array(all_local_indexing, is_split=0, device=self.device) + if all_local_indexing.all().item(): + indexed_arr = self.larray[key] + return factories.array(indexed_arr, is_split=output_split, device=self.device) + print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) print("RECV_COUNTS = ", recv_counts) # share recv_counts among all processes @@ -1285,18 +1308,16 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if return_1d: key = torch.stack(key, dim=1) _, key_inverse = key.unique(dim=0, sorted=True, return_inverse=True) - if _.shape == key.shape: - _, ork_inverse = outgoing_request_key.unique( - dim=0, sorted=True, return_inverse=True - ) - map = ork_inverse.argsort(stable=True)[ - key_inverse.argsort(stable=True).argsort(stable=True) - ] - else: - # major bottleneck - key = key.tolist() - outgoing_request_key = outgoing_request_key.tolist() - map = [outgoing_request_key.index(k) for k in key] + # if _.shape == key.shape: + _, ork_inverse = outgoing_request_key.unique(dim=0, sorted=True, return_inverse=True) + map = ork_inverse.argsort(stable=True)[ + key_inverse.argsort(stable=True).argsort(stable=True) + ] + # else: + # # major bottleneck + # key = key.tolist() + # outgoing_request_key = outgoing_request_key.tolist() + # map = [outgoing_request_key.index(k) for k in key] indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ac7c71b257..ac55c69401 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -580,7 +580,7 @@ def test_getitem(self): shape = (4, 20, 3) x_3d = ht.arange(20 * 4 * 3).reshape(shape) x_3d.resplit_(axis=1) - key = [slice(None, 2), slice(17, 2, -2), 1] + key = (slice(None, 2), slice(17, 2, -2), 1) x_3d_sliced = x_3d[key] x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 17:2:-2, 1] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) @@ -590,7 +590,7 @@ def test_getitem(self): shape = (4, 3, 20) x_3d = ht.arange(20 * 4 * 3).reshape(shape) x_3d.resplit_(axis=2) - key = [slice(None, 2), 1, slice(17, 10, -2)] + key = (slice(None, 2), 1, slice(17, 10, -2)) x_3d_sliced = x_3d[key] x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 1, 17:10:-2] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) @@ -600,7 +600,7 @@ def test_getitem(self): shape = (4, 3, 20) x_3d = ht.arange(20 * 4 * 3).reshape(shape) x_3d.resplit_(axis=2) - key = [0, 1, slice(17, 13, -1)] + key = (0, 1, slice(17, 13, -1)) x_3d_sliced = x_3d[key] x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[0, 1, 17:13:-1] self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) @@ -625,6 +625,17 @@ def test_getitem(self): mask_split2 = ht.array(mask, split=2) self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) + # TODO: x[(1,1,1,1)] vs. x[[1,1,1,1]] + # advanced indexing + x = ht.arange(60, split=0).reshape(5, 3, 4) + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + self.assert_array_equal( + x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] + ) + def test_int_cast(self): # simple scalar tensor a = ht.ones(1) From 83e69501f4f93eb553030802c57a960c4f9a2cd4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 2 Jan 2023 06:58:31 +0100 Subject: [PATCH 058/132] fix local slices, expand tests --- heat/core/dndarray.py | 49 +++++++++++++++++++++++--------- heat/core/tests/test_dndarray.py | 10 ++++++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 0b4c9a2316..ab27fb09f3 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -674,7 +674,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] advanced_indexing = False arr_is_copy = False - split_key_is_sorted = 0 + split_key_is_sorted = 1 out_is_balanced = False if isinstance(key, list): @@ -891,7 +891,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape[i] = len(key[i]) if arr.is_distributed() and new_split == i: # distribute key and proceed with non-ordered indexing - key[i] = factories.array(key[i], split=0, device=arr.device).larray + key[i] = factories.array( + key[i], split=0, device=arr.device, copy=False + ).larray split_key_is_sorted = -1 out_is_balanced = True elif step > 0 and start < stop: @@ -899,13 +901,26 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if arr.is_distributed() and new_split == i: split_key_is_sorted = 1 out_is_balanced = False - if ( - stop >= displs[arr.comm.rank] - and start < displs[arr.comm.rank] + counts[arr.comm.rank] - ): + local_arr_end = displs[arr.comm.rank] + counts[arr.comm.rank] + if stop > displs[arr.comm.rank] and start < local_arr_end: + print( + "stop, start, displs[arr.comm.rank], displs[arr.comm.rank] + counts[arr.comm.rank] = ", + stop, + start, + displs[arr.comm.rank], + displs[arr.comm.rank] + counts[arr.comm.rank], + ) index_in_cycle = (displs[arr.comm.rank] - start) % step - local_start = 0 if index_in_cycle == 0 else step - index_in_cycle - local_stop = stop - displs[arr.comm.rank] + if start >= displs[arr.comm.rank]: + # slice begins on current rank + local_start = start - displs[arr.comm.rank] + else: + local_start = 0 if index_in_cycle == 0 else step - index_in_cycle + if stop <= local_arr_end: + # slice ends on current rank + local_stop = stop - displs[arr.comm.rank] + else: + local_stop = local_arr_end key[i] = slice(local_start, local_stop, step) else: key[i] = slice(0, 0) @@ -1208,10 +1223,14 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar all_local_indexing[i] = selection.shape[0] == key[original_split].shape[0] selection.unsqueeze_(dim=1) outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) - all_local_indexing = factories.array(all_local_indexing, is_split=0, device=self.device) + all_local_indexing = factories.array( + all_local_indexing, is_split=0, device=self.device, copy=False + ) if all_local_indexing.all().item(): indexed_arr = self.larray[key] - return factories.array(indexed_arr, is_split=output_split, device=self.device) + return factories.array( + indexed_arr, is_split=output_split, device=self.device, copy=False + ) print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) print("RECV_COUNTS = ", recv_counts) @@ -1319,22 +1338,24 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # outgoing_request_key = outgoing_request_key.tolist() # map = [outgoing_request_key.index(k) for k in key] indexed_arr = recv_buf[map] - return factories.array(indexed_arr, is_split=output_split) + return factories.array(indexed_arr, is_split=output_split, copy=False) key = key[original_split] outgoing_request_key = outgoing_request_key.squeeze_(1) # incoming elements likely already stacked in ascending or descending order if (key == outgoing_request_key).all(): - return factories.array(recv_buf, is_split=output_split) + return factories.array(recv_buf, is_split=output_split, copy=False) if (key == outgoing_request_key.flip(dims=(0,))).all(): - return factories.array(recv_buf.flip(dims=(output_split,)), is_split=output_split) + return factories.array( + recv_buf.flip(dims=(output_split,)), is_split=output_split, copy=False + ) map = [slice(None)] * recv_buf.ndim map[output_split] = outgoing_request_key.argsort(stable=True)[ key.argsort(stable=True).argsort(stable=True) ] indexed_arr = recv_buf[map] - return factories.array(indexed_arr, is_split=output_split) + return factories.array(indexed_arr, is_split=output_split, copy=False) # TODO: boolean indexing with data.split != 0 # __process_key() returns locally correct key diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ac55c69401..b9cd3f59af 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -568,6 +568,15 @@ def test_getitem(self): self.assert_array_equal(x_sliced, x_sliced_np) self.assertTrue(x_sliced.split == 0) + # 1-element slice along split axis + x = ht.arange(20).reshape(4, 5) + x.resplit_(axis=1) + x_sliced = x[:, 2:3] + x_np = np.arange(20).reshape(4, 5) + x_sliced_np = x_np[:, 2:3] + self.assert_array_equal(x_sliced, x_sliced_np) + self.assertTrue(x_sliced.split == 1) + # slicing with negative step along split axis 0 shape = (20, 4, 3) x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) @@ -625,7 +634,6 @@ def test_getitem(self): mask_split2 = ht.array(mask, split=2) self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) - # TODO: x[(1,1,1,1)] vs. x[[1,1,1,1]] # advanced indexing x = ht.arange(60, split=0).reshape(5, 3, 4) x_np = np.arange(60).reshape(5, 3, 4) From 28ab92500eb2ce3f7597f25f3504a5a17707181a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 2 Jan 2023 08:15:41 +0100 Subject: [PATCH 059/132] fix and test dimensional indexing --- heat/core/dndarray.py | 11 +++++++--- heat/core/tests/test_dndarray.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ab27fb09f3..07f0a7e55c 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -813,16 +813,19 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # check for ellipsis, newaxis. NB: (np.newaxis is None)==True add_dims = sum(k is None for k in key) ellipsis = sum(isinstance(k, type(...)) for k in key) - if ellipsis > 1: - raise ValueError("key can only contain 1 ellipsis") if ellipsis == 1: # replace with explicit `slice(None)` for affected dimensions # output_shape, split_bookkeeping not affected expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) expand_key[:ellipsis_index] = key[:ellipsis_index] - expand_key[ellipsis_index - len(key) :] = key[ellipsis_index + 1 :] + expand_key[ellipsis_index - (len(key) - ellipsis - ellipsis_index) :] = key[ + ellipsis_index + 1 : + ] key = expand_key + print("DEBUGGING: ELLIPSIS: ", key) + elif ellipsis > 1: + raise ValueError("key can only contain 1 ellipsis") while add_dims > 0: # expand array dims: output_shape, split_bookkeeping to reflect newaxis # replace newaxis with slice(None) in key @@ -840,6 +843,8 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] ) add_dims -= 1 + # recalculate new split axis after dimensions manipulation + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None # check for advanced indexing and slices print("DEBUGGING: key = ", key) advanced_indexing_dims = [] diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index b9cd3f59af..c401c71479 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -615,6 +615,43 @@ def test_getitem(self): self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) self.assertTrue(x_3d_sliced.split == 0) + # DIMENSIONAL INDEXING + # ellipsis + x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) + x_np_ellipsis = x_np[..., 0] + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + + # local + x_ellipsis = x[..., 0] + x_slice = x[:, :, 0] + self.assert_array_equal(x_ellipsis, x_np_ellipsis) + self.assert_array_equal(x_slice, x_np_ellipsis) + + # distributed + x.resplit_(axis=1) + x_ellipsis = x[..., 0] + x_slice = x[:, :, 0] + self.assert_array_equal(x_ellipsis, x_np_ellipsis) + self.assert_array_equal(x_slice, x_np_ellipsis) + self.assertTrue(x_ellipsis.split == 1) + + # newaxis: local + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + x_np_newaxis = x_np[:, np.newaxis, :2, :] + x_newaxis = x[:, np.newaxis, :2, :] + x_none = x[:, None, :2, :] + self.assert_array_equal(x_newaxis, x_np_newaxis) + self.assert_array_equal(x_none, x_np_newaxis) + + # newaxis: distributed + x.resplit_(axis=1) + x_newaxis = x[:, np.newaxis, :2, :] + x_none = x[:, None, :2, :] + self.assert_array_equal(x_newaxis, x_np_newaxis) + self.assert_array_equal(x_none, x_np_newaxis) + self.assertTrue(x_newaxis.split == 2) + self.assertTrue(x_none.split == 2) + # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) np.random.seed(42) From bc226fc88c38dc3e22bca5c29476a8ce56c5bd0f Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 5 Jan 2023 06:35:43 +0100 Subject: [PATCH 060/132] Fix same-dim advanced indexing, expand tests --- heat/core/dndarray.py | 127 ++++++++++++++++++++++--------- heat/core/tests/test_dndarray.py | 35 ++++++--- 2 files changed, 118 insertions(+), 44 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 07f0a7e55c..b0f799bed8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -773,8 +773,10 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # advanced indexing on first dimension: first dim will expand to shape of key output_shape = tuple(list(key.shape) + output_shape[1:]) + print("DEBUGGING ADV IND: output_shape = ", output_shape) # adjust split axis accordingly if arr.is_distributed(): + counts, displs = arr.counts_displs() if arr.split != 0: # split axis is not affected split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] @@ -793,18 +795,49 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] new_split = tuple(key.shape).index(arr.shape[0]) else: new_split = key.ndim - 1 + try: + key_split = key[new_split].larray + sorted, _ = key_split.sort(stable=True) + except AttributeError: + key_split = key[new_split] + sorted = key_split.sort() else: new_split = 0 - # assess if key is sorted along split axis - try: - key_split = key[new_split].larray - sorted, _ = key_split.sort() - except AttributeError: - key_split = key[new_split] - sorted = key_split.sort() - split_key_is_sorted = torch.tensor( - (key_split == sorted).all(), dtype=torch.uint8 - ) + # assess if key is sorted along split axis + try: + # DNDarray key + sorted, _ = torch.sort(key.larray, stable=True) + split_key_is_sorted = torch.tensor( + (key.larray == sorted).all(), + dtype=torch.uint8, + device=key.larray.device, + ) + if key.split is not None: + out_is_balanced = key.balanced + split_key_is_sorted = factories.array( + [split_key_is_sorted], is_split=0, device=arr.device, copy=False + ).all() + key = key.larray + except AttributeError: + # torch or ndarray key + try: + sorted, _ = torch.sort(key, stable=True) + except TypeError: + # ndarray key + sorted = torch.tensor(np.sort(key), device=arr.larray.device) + split_key_is_sorted = torch.tensor( + key == sorted, dtype=torch.uint8 + ).item() + if not split_key_is_sorted: + # prepare for distributed non-ordered indexing: distribute torch/numpy key + key = factories.array(key, split=0, device=arr.device).larray + out_is_balanced = True + if split_key_is_sorted: + # extract local key + cond1 = key >= displs[arr.comm.rank] + cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] + key = key[cond1 & cond2] + out_is_balanced = False return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced @@ -1052,7 +1085,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases - print("DEBUGGING: RAW KEY = ", key) + print("DEBUGGING: RAW KEY = ", key, type(key)) # Single-element indexing # TODO: single-element indexing along split axis belongs here as well @@ -1153,11 +1186,17 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar original_split = self.split # determine whether indexed array will be 1D or nD - key_shapes = [] - for k in key: - key_shapes.append(getattr(k, "shape", None)) - return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim - + try: + return_1d = getattr(key, "ndim") == self.ndim + except AttributeError: + # key is tuple of torch tensors + key_shapes = [] + for k in key: + key_shapes.append(getattr(k, "shape", None)) + print("KEY SHAPES = ", key_shapes) + return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim + + print("RANK, RETURN_1D = ", self.comm.rank, return_1d) # send and receive "request key" info on what data element to ship where recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) @@ -1206,21 +1245,37 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) all_local_indexing[start_rank:end_rank] = False for i in range(start_rank, end_rank): - cond1 = key[original_split] >= displs[i] - if i != self.comm.size - 1: - cond2 = key[original_split] < displs[i + 1] - else: - # cond2 is always true - cond2 = torch.ones( - (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device - ) + try: + cond1 = key >= displs[i] + if i != self.comm.size - 1: + cond2 = key < displs[i + 1] + else: + # cond2 is always true + cond2 = torch.ones((key.shape[0],), dtype=torch.bool, device=self.larray.device) + except TypeError: + cond1 = key[original_split] >= displs[i] + if i != self.comm.size - 1: + cond2 = key[original_split] < displs[i + 1] + else: + # cond2 is always true + cond2 = torch.ones( + (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device + ) if return_1d: - # advanced indexing returning 1D array (e.g. boolean indexing) - selection = list(k[cond1 & cond2] for k in key) - recv_counts[i, :] = selection[0].shape[0] - if i == self.comm.rank: - all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] - selection = torch.stack(selection, dim=1) + # advanced indexing returning 1D array + if isinstance(key, torch.Tensor): + selection = key[cond1 & cond2] + recv_counts[i, :] = selection.shape[0] + if i == self.comm.rank: + all_local_indexing[i] = selection.shape[0] == key.shape[0] + selection.unsqueeze_(dim=1) + else: + # key is tuple of torch tensors + selection = list(k[cond1 & cond2] for k in key) + recv_counts[i, :] = selection[0].shape[0] + if i == self.comm.rank: + all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] + selection = torch.stack(selection, dim=1) else: selection = key[original_split][cond1 & cond2] recv_counts[i, :] = selection.shape[0] @@ -1296,7 +1351,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar incoming_request_key_displs.tolist(), ), ) - # print("DEBUGGING:incoming_request_key = ", incoming_request_key) + print("DEBUGGING:incoming_request_key = ", incoming_request_key) if return_1d: incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) incoming_request_key[original_split] -= displs[self.comm.rank] @@ -1308,13 +1363,16 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar + key[original_split + 1 :] ) - # print("AFTER: incoming_request_key = ", incoming_request_key) + print("AFTER: incoming_request_key = ", incoming_request_key) # print("OUTPUT_SHAPE = ", output_shape) # print("OUTPUT_SPLIT = ", output_split) send_buf = self.larray[incoming_request_key] output_lshape = list(output_shape) - output_lshape[output_split] = key[original_split].shape[0] + if getattr(key, "ndim", 0) == 1: + output_lshape[output_split] = key.shape[0] + else: + output_lshape[output_split] = key[original_split].shape[0] recv_buf = torch.empty( tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) @@ -1330,7 +1388,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # reorganize incoming counts according to original key order along split axis if return_1d: - key = torch.stack(key, dim=1) + if isinstance(key, tuple): + key = torch.stack(key, dim=1) _, key_inverse = key.unique(dim=0, sorted=True, return_inverse=True) # if _.shape == key.shape: _, ork_inverse = outgoing_request_key.unique(dim=0, sorted=True, return_inverse=True) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index c401c71479..e11c9f280c 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -652,6 +652,31 @@ def test_getitem(self): self.assertTrue(x_newaxis.split == 2) self.assertTrue(x_none.split == 2) + x = ht.arange(5, split=0) + x_np = np.arange(5) + y = x[:, np.newaxis] + x[np.newaxis, :] + y_np = x_np[:, np.newaxis] + x_np[np.newaxis, :] + self.assert_array_equal(y, y_np) + self.assertTrue(y.split == 0) + + # ADVANCED INDEXING + # 1d + x = ht.arange(10, 1, -1, split=0) + x_np = np.arange(10, 1, -1) + x_adv_ind = x[np.array([3, 3, 1, 8])] + x_np_adv_ind = x_np[np.array([3, 3, 1, 8])] + self.assert_array_equal(x_adv_ind, x_np_adv_ind) + + # 3d, split 0 + x = ht.arange(60, split=0).reshape(5, 3, 4) + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + self.assert_array_equal( + x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] + ) + # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) np.random.seed(42) @@ -671,16 +696,6 @@ def test_getitem(self): mask_split2 = ht.array(mask, split=2) self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) - # advanced indexing - x = ht.arange(60, split=0).reshape(5, 3, 4) - x_np = np.arange(60).reshape(5, 3, 4) - k1 = np.array([0, 4, 1, 0]) - k2 = np.array([0, 2, 1, 0]) - k3 = np.array([1, 2, 3, 1]) - self.assert_array_equal( - x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] - ) - def test_int_cast(self): # simple scalar tensor a = ht.ones(1) From c48c66e5502390d40d1309640380d7c9032d2555 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 25 Jan 2023 06:14:31 +0100 Subject: [PATCH 061/132] [skip ci] implement single-element indexing along split axis w/ Iterable key --- heat/core/dndarray.py | 133 ++++++++++++++++++++++++++----- heat/core/tests/test_dndarray.py | 18 ++++- 2 files changed, 130 insertions(+), 21 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index b0f799bed8..a93bb4a886 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -667,15 +667,18 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = list(arr.gshape) split_bookkeeping = [None] * arr.ndim new_split = arr.split + arr_is_distributed = False if arr.split is not None: split_bookkeeping[arr.split] = "split" if arr.is_distributed(): counts, displs = arr.counts_displs() + arr_is_distributed = True advanced_indexing = False arr_is_copy = False - split_key_is_sorted = 1 + split_key_is_sorted = 1 # can be 1: ascending, 0: not sorted, -1: descending out_is_balanced = False + root = None if isinstance(key, list): try: @@ -697,7 +700,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] except AttributeError: # key is torch tensor key = key.clone() - if not arr.is_distributed(): + if not arr_is_distributed: try: # key is DNDarray, extract torch tensor key = key.larray @@ -713,7 +716,15 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] new_split = None if arr.split is None else 0 out_is_balanced = True split_key_is_sorted = 1 - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced + return ( + arr, + key, + output_shape, + new_split, + split_key_is_sorted, + out_is_balanced, + root, + ) # arr is distributed if not isinstance(key, DNDarray) or not key.is_distributed(): @@ -769,7 +780,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] new_split = 0 split_key_is_sorted = 0 out_is_balanced = True - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root # advanced indexing on first dimension: first dim will expand to shape of key output_shape = tuple(list(key.shape) + output_shape[1:]) @@ -801,6 +812,14 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] except AttributeError: key_split = key[new_split] sorted = key_split.sort() + # if split_key_is_sorted: + # # extract local key + # cond1 = key >= displs[arr.comm.rank] + # cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] + # key = key[cond1 & cond2] + # key -= displs[arr.comm.rank] + # out_is_balanced = False + else: new_split = 0 # assess if key is sorted along split axis @@ -837,9 +856,10 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] cond1 = key >= displs[arr.comm.rank] cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] key = key[cond1 & cond2] + key -= displs[arr.comm.rank] out_is_balanced = False - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root key = list(key) if isinstance(key, Iterable) else [key] @@ -895,13 +915,29 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] else: advanced_indexing = True advanced_indexing_dims.append(i) - if arr.is_distributed() and i == arr.split: - # make no assumption on data locality wrt key - split_key_is_sorted = 0 + # if arr.is_distributed() and i == arr.split: + # # make no assumption on data locality wrt key + # split_key_is_sorted = 0 if isinstance(k, DNDarray): key[i] = k.larray elif not isinstance(k, torch.Tensor): key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) + if arr_is_distributed and i == arr.split: + # make no assumption on data locality wrt key + sorted, _ = torch.sort(key[i], stable=True) + sort_status = torch.tensor( + (key[i] == sorted).all(), dtype=torch.uint8, device=key[i].device + ) + arr.comm.Allreduce(MPI.IN_PLACE, sort_status, MPI.SUM) + split_key_is_sorted = 1 if sort_status.item() == arr.comm.size else 0 + split_key_shape = key[i].shape + if split_key_is_sorted: + # extract local key + cond1 = key[i] >= displs[arr.comm.rank] + cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] + key[i] = key[i][cond1 & cond2] + key[i] -= displs[arr.comm.rank] + out_is_balanced = False elif isinstance(k, int): # single-element indexing along axis i output_shape[i] = None @@ -909,6 +945,23 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_bookkeeping[: i - lose_dims] + split_bookkeeping[i - lose_dims + 1 :] ) lose_dims += 1 + if arr_is_distributed and i == arr.split: + # single-element indexing along split axis + # work out root process for Bcast + key[i] = k + arr.shape[i] if k < 0 else k + if key[i] in displs: + root = displs.index(key[i]) + else: + displs = torch.cat( + (torch.tensor(displs), torch.tensor(key[i]).reshape(-1)), dim=0 + ) + _, sorted_indices = displs.unique(sorted=True, return_inverse=True) + root = sorted_indices[-1] - 1 + # allocate buffer on all processes + if arr.comm.rank == root: + # correct key for rank-specific displacement + key[i] -= displs[root] + elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step if start is None: @@ -927,7 +980,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # TODO: implement ht.fromiter (implemented in ASSET_ht) key[i] = list(range(start, stop, step)) output_shape[i] = len(key[i]) - if arr.is_distributed() and new_split == i: + if arr_is_distributed and new_split == i: # distribute key and proceed with non-ordered indexing key[i] = factories.array( key[i], split=0, device=arr.device, copy=False @@ -936,7 +989,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] out_is_balanced = True elif step > 0 and start < stop: output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) - if arr.is_distributed() and new_split == i: + if arr_is_distributed and new_split == i: split_key_is_sorted = 1 out_is_balanced = False local_arr_end = displs[arr.comm.rank] + counts[arr.comm.rank] @@ -969,7 +1022,14 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape[i] = 0 if advanced_indexing: + print("ADV IND KEY = ", key) advanced_indexing_shapes = tuple(tuple(key[i].shape) for i in advanced_indexing_dims) + if arr_is_distributed: + advanced_indexing_shapes = ( + advanced_indexing_shapes[: arr.split] + + split_key_shape + + advanced_indexing_shapes[arr.split + 1 :] + ) print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) # shapes of indexing arrays must be broadcastable try: @@ -996,6 +1056,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] + [None] * add_dims + split_bookkeeping[advanced_indexing_dims[0] :] ) + print("ADV IND output_shape = ", output_shape) else: # advanced-indexing dimensions are not consecutive: # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions @@ -1010,7 +1071,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = list(arr.gshape) output_shape[: len(advanced_indexing_dims)] = broadcasted_shape split_bookkeeping = [None] * arr.ndim - if arr.is_distributed: + if arr_is_distributed: split_bookkeeping[arr.split] = "split" split_bookkeeping = [None] * add_dims + split_bookkeeping # modify key to match the new dimension order @@ -1027,8 +1088,15 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape.remove(None) output_shape = tuple(output_shape) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced + print( + "key, output_shape, new_split, split_key_is_sorted, out_is_balanced = ", + key, + output_shape, + new_split, + split_key_is_sorted, + out_is_balanced, + ) + return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root def __get_local_slice(self, key: slice): split = self.split @@ -1087,6 +1155,13 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # Trivial cases print("DEBUGGING: RAW KEY = ", key, type(key)) + if key is None: + return self.expand_dims(0) + if ( + key is ... or isinstance(key, slice) and key == slice(None) + ): # latter doesnt work with torch for 0-dim tensors + return self + # Single-element indexing # TODO: single-element indexing along split axis belongs here as well scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 @@ -1111,6 +1186,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar balanced=self.balanced, ) return indexed_arr + # single-element indexing along split axis: # check for negative key key = key + self.shape[0] if key < 0 else key # identify root process @@ -1143,13 +1219,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) return indexed_arr - if key is None: - return self.expand_dims(0) - if ( - key is ... or isinstance(key, slice) and key == slice(None) - ): # latter doesnt work with torch for 0-dim tensors - return self - # Many-elements indexing: incl. slicing and striding, ordered advanced indexing # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays @@ -1160,8 +1229,32 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar output_split, split_key_is_sorted, out_is_balanced, + root, ) = self.__process_key(key) print("DEBUGGING: processed key, output_split = ", key, output_split) + + if root is not None: + # single-element indexing along split axis + # allocate buffer on all processes + if self.comm.rank == root: + indexed_arr = self.larray[key] + else: + indexed_arr = torch.zeros( + output_shape, dtype=self.larray.dtype, device=self.larray.device + ) + # broadcast result to all processes + self.comm.Bcast(indexed_arr, root=root) + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=True, + ) + return indexed_arr + # TODO: test that key for not affected dims is always slice(None) # including match between self.split and key after self manipulation diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e11c9f280c..4d26b36401 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -660,6 +660,22 @@ def test_getitem(self): self.assertTrue(y.split == 0) # ADVANCED INDEXING + # "x[(1, 2, 3),] is fundamentally different than x[(1, 2, 3)]" + + x_np = np.arange(60).reshape(5, 3, 4) + indexed_x_np = x_np[(1, 2, 3)] + adv_indexed_x_np = x_np[ + (1, 2, 3), + ] + x = ht.array(x_np, split=0) + indexed_x = x[(1, 2, 3)] + adv_indexed_x = x[ + (1, 2, 3), + ] + print("DEBUGGING: indexed_x, indexed_x_np = ", indexed_x.item(), indexed_x_np) + self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) + self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) + # 1d x = ht.arange(10, 1, -1, split=0) x_np = np.arange(10, 1, -1) @@ -667,7 +683,7 @@ def test_getitem(self): x_np_adv_ind = x_np[np.array([3, 3, 1, 8])] self.assert_array_equal(x_adv_ind, x_np_adv_ind) - # 3d, split 0 + # 3d, split 0, non-unique, non-ordered key along split axis x = ht.arange(60, split=0).reshape(5, 3, 4) x_np = np.arange(60).reshape(5, 3, 4) k1 = np.array([0, 4, 1, 0]) From 18329a194c8ee7cab67d81326f8026ed2287838b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 26 Jan 2023 11:32:33 +0100 Subject: [PATCH 062/132] [skip ci] generalize advanced indexing incl. distributed DNDarray key --- heat/core/dndarray.py | 78 ++++++++++++++++++-------------- heat/core/tests/test_dndarray.py | 1 - 2 files changed, 45 insertions(+), 34 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a93bb4a886..580a0d11c9 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -786,7 +786,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] output_shape = tuple(list(key.shape) + output_shape[1:]) print("DEBUGGING ADV IND: output_shape = ", output_shape) # adjust split axis accordingly - if arr.is_distributed(): + if arr_is_distributed: counts, displs = arr.counts_displs() if arr.split != 0: # split axis is not affected @@ -901,6 +901,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # check for advanced indexing and slices print("DEBUGGING: key = ", key) advanced_indexing_dims = [] + advanced_indexing_shapes = [] lose_dims = 0 for i, k in enumerate(key): if isinstance(k, Iterable) or isinstance(k, DNDarray): @@ -912,32 +913,51 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_bookkeeping[: i - lose_dims] + split_bookkeeping[i - lose_dims + 1 :] ) lose_dims += 1 + if arr_is_distributed and i == arr.split: + # single-element indexing along split axis + # work out root process for Bcast + key[i] = k.item() + arr.shape[i] if k < 0 else k.item() + if key[i] in displs: + root = displs.index(key[i]) + else: + displs = torch.cat( + (torch.tensor(displs), torch.tensor(key[i]).reshape(-1)), dim=0 + ) + _, sorted_indices = displs.unique(sorted=True, return_inverse=True) + root = sorted_indices[-1] - 1 + # correct key for rank-specific displacement + if arr.comm.rank == root: + key[i] -= displs[root] + else: + key[i] = k.item() else: advanced_indexing = True advanced_indexing_dims.append(i) - # if arr.is_distributed() and i == arr.split: - # # make no assumption on data locality wrt key - # split_key_is_sorted = 0 - if isinstance(k, DNDarray): - key[i] = k.larray - elif not isinstance(k, torch.Tensor): - key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) - if arr_is_distributed and i == arr.split: - # make no assumption on data locality wrt key - sorted, _ = torch.sort(key[i], stable=True) - sort_status = torch.tensor( - (key[i] == sorted).all(), dtype=torch.uint8, device=key[i].device - ) - arr.comm.Allreduce(MPI.IN_PLACE, sort_status, MPI.SUM) - split_key_is_sorted = 1 if sort_status.item() == arr.comm.size else 0 - split_key_shape = key[i].shape - if split_key_is_sorted: - # extract local key - cond1 = key[i] >= displs[arr.comm.rank] - cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] - key[i] = key[i][cond1 & cond2] - key[i] -= displs[arr.comm.rank] - out_is_balanced = False + if isinstance(k, DNDarray): + advanced_indexing_shapes.append(k.gshape) + if arr_is_distributed and i == arr.split: + out_is_balanced = k.balanced + if k.is_distributed(): + # we have no info on order of indices + split_key_is_sorted = 0 + key[i] = k.larray + elif not isinstance(k, torch.Tensor): + key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) + advanced_indexing_shapes.append(tuple(key[i].shape)) + # IMPORTANT: here we assume that torch or ndarray key is THE SAME SET OF GLOBAL INDICES on every rank + if arr_is_distributed and i == arr.split: + # make no assumption on data locality wrt key + out_is_balanced = None + # assess if indices are in ascending order + if (key[i] == torch.sort(key[i], stable=True)[0]).all(): + split_key_is_sorted = 1 + # extract local key + cond1 = key[i] >= displs[arr.comm.rank] + cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] + key[i] = key[i][cond1 & cond2] + key[i] -= displs[arr.comm.rank] + else: + split_key_is_sorted = 0 elif isinstance(k, int): # single-element indexing along axis i output_shape[i] = None @@ -957,9 +977,8 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] ) _, sorted_indices = displs.unique(sorted=True, return_inverse=True) root = sorted_indices[-1] - 1 - # allocate buffer on all processes + # correct key for rank-specific displacement if arr.comm.rank == root: - # correct key for rank-specific displacement key[i] -= displs[root] elif isinstance(k, slice) and k != slice(None): @@ -1023,13 +1042,6 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if advanced_indexing: print("ADV IND KEY = ", key) - advanced_indexing_shapes = tuple(tuple(key[i].shape) for i in advanced_indexing_dims) - if arr_is_distributed: - advanced_indexing_shapes = ( - advanced_indexing_shapes[: arr.split] - + split_key_shape - + advanced_indexing_shapes[arr.split + 1 :] - ) print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) # shapes of indexing arrays must be broadcastable try: diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 4d26b36401..a77ade3024 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -672,7 +672,6 @@ def test_getitem(self): adv_indexed_x = x[ (1, 2, 3), ] - print("DEBUGGING: indexed_x, indexed_x_np = ", indexed_x.item(), indexed_x_np) self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) From f024ebb32a7b982cbdf86c5a4c5d47c3cd5b3650 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 29 Jan 2023 08:24:09 +0100 Subject: [PATCH 063/132] [skip ci] Expand tests combined advanced / basic indexing --- heat/core/dndarray.py | 1 + heat/core/tests/test_dndarray.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 580a0d11c9..a4be249c82 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1382,6 +1382,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] selection = torch.stack(selection, dim=1) else: + print("DEBUGGING: key[original_split] = ", key[original_split]) selection = key[original_split][cond1 & cond2] recv_counts[i, :] = selection.shape[0] if i == self.comm.rank: diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index a77ade3024..8c8921165e 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -691,6 +691,40 @@ def test_getitem(self): self.assert_array_equal( x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] ) + # broadcasting shapes + self.assert_array_equal(x[ht.array(k1, split=0), ht.array(1), 2], x_np[k1, 1, 2]) + # test exception: broadcasting mismatching shapes + k2 = np.array([0, 2, 1]) + with self.assertRaises(IndexError): + x[k1, k2, k3] + + # more broadcasting + x_np = np.arange(12).reshape(4, 3) + rows = np.array([0, 3]) + cols = np.array([0, 2]) + x = ht.arange(12).reshape(4, 3) + x.resplit_(1) + x_np_indexed = x_np[rows[:, np.newaxis], cols] + x_indexed = x[ht.array(rows)[:, np.newaxis], cols] + self.assert_array_equal(x_indexed, x_np_indexed) + self.assertTrue(x_indexed.split == 1) + + # combining advanced and basic indexing + y_np = np.arange(35).reshape(5, 7) + y_np_indexed = y_np[np.array([0, 2, 4]), 1:3] + y = ht.array(y_np, split=1) + y_indexed = y[ht.array([0, 2, 4]), 1:3] + self.assert_array_equal(y_indexed, y_np_indexed) + self.assertTrue(y_indexed.split == 1) + + x_np = np.arange(10 * 20 * 30).reshape(10, 20, 30) + x = ht.array(x_np, split=1) + ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) + ind_array_np = ind_array.numpy() + x_np_indexed = x_np[..., ind_array_np, :] + x_indexed = x[..., ind_array, :] + self.assert_array_equal(x_indexed, x_np_indexed) + self.assertTrue(x_indexed.split == 3) # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) From 6ae27881e68b11d7ccd5f9a088c0f35b4eea0ae4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 5 Feb 2023 08:18:12 +0100 Subject: [PATCH 064/132] [skip ci] fix advanced dimensional indexing on non-distributed array --- heat/core/dndarray.py | 52 +++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a4be249c82..7e78d818a7 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -787,7 +787,6 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] print("DEBUGGING ADV IND: output_shape = ", output_shape) # adjust split axis accordingly if arr_is_distributed: - counts, displs = arr.counts_displs() if arr.split != 0: # split axis is not affected split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] @@ -858,7 +857,15 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] key = key[cond1 & cond2] key -= displs[arr.comm.rank] out_is_balanced = False - + else: + try: + out_is_balanced = key.balanced + new_split = key.split + key = key.larray + except AttributeError: + # torch or numpy key, non-distributed indexed array + out_is_balanced = True + new_split = None return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root key = list(key) if isinstance(key, Iterable) else [key] @@ -937,9 +944,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] advanced_indexing_shapes.append(k.gshape) if arr_is_distributed and i == arr.split: out_is_balanced = k.balanced - if k.is_distributed(): - # we have no info on order of indices - split_key_is_sorted = 0 + # we have no info on order of indices + split_key_is_sorted = 0 + k = k.resplit(-1) key[i] = k.larray elif not isinstance(k, torch.Tensor): key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) @@ -949,7 +956,14 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # make no assumption on data locality wrt key out_is_balanced = None # assess if indices are in ascending order - if (key[i] == torch.sort(key[i], stable=True)[0]).all(): + print( + "DEBUGGING: torch.sort(key[i], stable=True)[0] = ", + torch.sort(key[i], stable=True)[0], + ) + if ( + key[i].ndim == 1 + and (key[i] == torch.sort(key[i], stable=True)[0]).all() + ): split_key_is_sorted = 1 # extract local key cond1 = key[i] >= displs[arr.comm.rank] @@ -1300,8 +1314,17 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar key_shapes.append(getattr(k, "shape", None)) print("KEY SHAPES = ", key_shapes) return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim + # check for broadcasted indexing: key along split axis is not 1D + broadcasted_indexing = ( + key_shapes[original_split] is not None and len(key_shapes[original_split]) > 1 + ) + if broadcasted_indexing: + broadcast_shape = key_shapes[original_split] + key = list(key) + key[original_split] = key[original_split].flatten() + key = tuple(key) + # print("RANK, RETURN_1D, broadcasted_indexing = ", self.comm.rank, return_1d, broadcasted_indexing) - print("RANK, RETURN_1D = ", self.comm.rank, return_1d) # send and receive "request key" info on what data element to ship where recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) @@ -1320,7 +1343,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # process-local: calculate which/how many elements will be received from what process if split_key_is_sorted == -1: - # key is sorted in descending order + # key is sorted in descending order (i.e. slicing w/ negative step) # shrink selection of active processes if key[original_split].numel() > 0: key_edges = torch.cat( @@ -1393,6 +1416,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar all_local_indexing, is_split=0, device=self.device, copy=False ) if all_local_indexing.all().item(): + # TODO: if advanced indexing, indexed array must be a copy. Probably addressed by torch + if broadcasted_indexing: + key[original_split] = key[original_split].reshape(broadcast_shape) indexed_arr = self.larray[key] return factories.array( indexed_arr, is_split=output_split, device=self.device, copy=False @@ -1510,20 +1536,22 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split, copy=False) - key = key[original_split] + # key = key[original_split] outgoing_request_key = outgoing_request_key.squeeze_(1) # incoming elements likely already stacked in ascending or descending order - if (key == outgoing_request_key).all(): + if (key[original_split] == outgoing_request_key).all(): return factories.array(recv_buf, is_split=output_split, copy=False) - if (key == outgoing_request_key.flip(dims=(0,))).all(): + if (key[original_split] == outgoing_request_key.flip(dims=(0,))).all(): return factories.array( recv_buf.flip(dims=(output_split,)), is_split=output_split, copy=False ) map = [slice(None)] * recv_buf.ndim map[output_split] = outgoing_request_key.argsort(stable=True)[ - key.argsort(stable=True).argsort(stable=True) + key[original_split].argsort(stable=True).argsort(stable=True) ] + if broadcasted_indexing: + map[output_split] = map[output_split].reshape(broadcast_shape) indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split, copy=False) From 09e586cf1227592386734a49bce4a89c0d2c431a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 27 Jul 2023 16:26:56 +0200 Subject: [PATCH 065/132] fix distr advanced indexing with broadcasted shape --- heat/core/dndarray.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 3ea7d64e18..a95307e1ab 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -811,7 +811,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] except RuntimeError: raise IndexError("Invalid indices: expected a list of integers, got {}".format(key)) if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): - if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool, np.uint8): + if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool_, np.uint8): # boolean indexing: shape must match arr.shape if not tuple(key.shape) == arr.shape: raise IndexError( @@ -1068,10 +1068,11 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] if isinstance(k, DNDarray): advanced_indexing_shapes.append(k.gshape) if arr_is_distributed and i == arr.split: - out_is_balanced = k.balanced # we have no info on order of indices split_key_is_sorted = 0 + # redistribute key along last axis to match split axis of indexed array k = k.resplit(-1) + out_is_balanced = True key[i] = k.larray elif not isinstance(k, torch.Tensor): key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) @@ -1081,10 +1082,6 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # make no assumption on data locality wrt key out_is_balanced = None # assess if indices are in ascending order - print( - "DEBUGGING: torch.sort(key[i], stable=True)[0] = ", - torch.sort(key[i], stable=True)[0], - ) if ( key[i].ndim == 1 and (key[i] == torch.sort(key[i], stable=True)[0]).all() @@ -1432,6 +1429,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # determine whether indexed array will be 1D or nD try: return_1d = getattr(key, "ndim") == self.ndim + send_axis = 0 except AttributeError: # key is tuple of torch tensors key_shapes = [] @@ -1448,6 +1446,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar key = list(key) key[original_split] = key[original_split].flatten() key = tuple(key) + send_axis = original_split + else: + send_axis = output_split # print("RANK, RETURN_1D, broadcasted_indexing = ", self.comm.rank, return_1d, broadcasted_indexing) # send and receive "request key" info on what data element to ship where @@ -1530,7 +1531,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] selection = torch.stack(selection, dim=1) else: - print("DEBUGGING: key[original_split] = ", key[original_split]) selection = key[original_split][cond1 & cond2] recv_counts[i, :] = selection.shape[0] if i == self.comm.rank: @@ -1549,7 +1549,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar indexed_arr, is_split=output_split, device=self.device, copy=False ) - print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) print("RECV_COUNTS = ", recv_counts) # share recv_counts among all processes comm_matrix = torch.empty( @@ -1629,7 +1628,14 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if getattr(key, "ndim", 0) == 1: output_lshape[output_split] = key.shape[0] else: - output_lshape[output_split] = key[original_split].shape[0] + if broadcasted_indexing: + output_lshape = ( + output_lshape[:original_split] + + [torch.prod(torch.tensor(broadcast_shape, device=send_buf.device)).item()] + + output_lshape[output_split + 1 :] + ) + else: + output_lshape[output_split] = key[original_split].shape[0] recv_buf = torch.empty( tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) @@ -1637,10 +1643,11 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar recv_displs = outgoing_request_key_displs.tolist() send_counts = incoming_request_key_counts.tolist() send_displs = incoming_request_key_displs.tolist() + print("DEBUGGING: output_split = ", output_split) self.comm.Alltoallv( (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs), - send_axis=output_split, + send_axis=send_axis, ) # reorganize incoming counts according to original key order along split axis @@ -1672,11 +1679,17 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) map = [slice(None)] * recv_buf.ndim - map[output_split] = outgoing_request_key.argsort(stable=True)[ - key[original_split].argsort(stable=True).argsort(stable=True) - ] + print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) + print("DEBUGGING: key[original_split] = ", key[original_split]) if broadcasted_indexing: - map[output_split] = map[output_split].reshape(broadcast_shape) + map[original_split] = outgoing_request_key.argsort(stable=True)[ + key[original_split].argsort(stable=True).argsort(stable=True) + ] + map[original_split] = map[original_split].reshape(broadcast_shape) + else: + map[output_split] = outgoing_request_key.argsort(stable=True)[ + key[original_split].argsort(stable=True).argsort(stable=True) + ] indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split, copy=False) From c56ebf443e0746c7b8572f22a5979b126814542d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sat, 29 Jul 2023 07:31:15 +0200 Subject: [PATCH 066/132] transpose without copying --- heat/core/dndarray.py | 69 +++++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a95307e1ab..eaef80bcc5 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -800,10 +800,10 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] arr_is_distributed = True advanced_indexing = False - arr_is_copy = False split_key_is_sorted = 1 # can be 1: ascending, 0: not sorted, -1: descending out_is_balanced = False root = None + transpose_axes = tuple(range(arr.ndim)) if isinstance(key, list): try: @@ -849,6 +849,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted, out_is_balanced, root, + transpose_axes, ) # arr is distributed @@ -905,7 +906,16 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] new_split = 0 split_key_is_sorted = 0 out_is_balanced = True - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root + return ( + arr, + key, + output_shape, + new_split, + split_key_is_sorted, + out_is_balanced, + root, + transpose_axes, + ) # advanced indexing on first dimension: first dim will expand to shape of key output_shape = tuple(list(key.shape) + output_shape[1:]) @@ -991,7 +1001,16 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # torch or numpy key, non-distributed indexed array out_is_balanced = True new_split = None - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root + return ( + arr, + key, + output_shape, + new_split, + split_key_is_sorted, + out_is_balanced, + root, + transpose_axes, + ) key = list(key) if isinstance(key, Iterable) else [key] @@ -1028,8 +1047,10 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] ) add_dims -= 1 - # recalculate new split axis after dimensions manipulation + # recalculate new_split, transpose_axes after dimensions manipulation new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + transpose_axes = tuple(range(arr.ndim)) + # check for advanced indexing and slices print("DEBUGGING: key = ", key) advanced_indexing_dims = [] @@ -1211,11 +1232,9 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] non_adv_ind_dims = list( i for i in range(arr.ndim) if i not in advanced_indexing_dims ) - # TODO: work this out without array copy - if not arr_is_copy: - arr = arr.copy() - arr_is_copy = True - arr = arr.transpose(advanced_indexing_dims + non_adv_ind_dims) + # keep track of transpose axes order, to be able to transpose back later + transpose_axes = tuple(advanced_indexing_dims + non_adv_ind_dims) + arr = arr.transpose(transpose_axes) output_shape = list(arr.gshape) output_shape[: len(advanced_indexing_dims)] = broadcasted_shape split_bookkeeping = [None] * arr.ndim @@ -1244,7 +1263,16 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted, out_is_balanced, ) - return arr, key, output_shape, new_split, split_key_is_sorted, out_is_balanced, root + return ( + arr, + key, + output_shape, + new_split, + split_key_is_sorted, + out_is_balanced, + root, + transpose_axes, + ) def __get_local_slice(self, key: slice): split = self.split @@ -1378,7 +1406,13 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar split_key_is_sorted, out_is_balanced, root, + transpose_axes, ) = self.__process_key(key) + + backwards_transpose_axes = ( + torch.tensor(transpose_axes, device=self.larray.device).argsort(stable=True).tolist() + ) + print("DEBUGGING: processed key, output_split = ", key, output_split) if root is not None: @@ -1401,6 +1435,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, balanced=True, ) + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) return indexed_arr # TODO: test that key for not affected dims is always slice(None) @@ -1411,6 +1447,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar print("split_key_is_sorted, key = ", split_key_is_sorted, key) if split_key_is_sorted == 1: indexed_arr = self.larray[key] + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) return DNDarray( indexed_arr, gshape=output_shape, @@ -1545,6 +1583,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if broadcasted_indexing: key[original_split] = key[original_split].reshape(broadcast_shape) indexed_arr = self.larray[key] + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) return factories.array( indexed_arr, is_split=output_split, device=self.device, copy=False ) @@ -1649,6 +1689,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (recv_buf, recv_counts, recv_displs), send_axis=send_axis, ) + # transpose original array back if needed, all further indexing on recv_buf + self = self.transpose(backwards_transpose_axes) # reorganize incoming counts according to original key order along split axis if return_1d: @@ -1660,17 +1702,12 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar map = ork_inverse.argsort(stable=True)[ key_inverse.argsort(stable=True).argsort(stable=True) ] - # else: - # # major bottleneck - # key = key.tolist() - # outgoing_request_key = outgoing_request_key.tolist() - # map = [outgoing_request_key.index(k) for k in key] indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split, copy=False) - # key = key[original_split] outgoing_request_key = outgoing_request_key.squeeze_(1) # incoming elements likely already stacked in ascending or descending order + # TODO: is this check really worth it? blanket argsort solution below might be ok if (key[original_split] == outgoing_request_key).all(): return factories.array(recv_buf, is_split=output_split, copy=False) if (key[original_split] == outgoing_request_key.flip(dims=(0,))).all(): From 86f704a4db77a8e25d5a98db8eeef1b1dc17c39c Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 1 Aug 2023 14:07:35 +0200 Subject: [PATCH 067/132] [skip ci] document __process_key(), clean up code --- heat/core/dndarray.py | 373 +++++-------------------------- heat/core/tests/test_dndarray.py | 4 +- 2 files changed, 63 insertions(+), 314 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index eaef80bcc5..b0134d0ff9 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -774,20 +774,41 @@ def fill_diagonal(self, value: float) -> DNDarray: return self - def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]]) -> Tuple: + def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> Tuple: """ - TODO: expand docs!! - This function processes key, manipulates `arr` if necessary, returns the final output shape - Private method for processing keys for indexing. Returns wether advanced indexing is used as well as a processed key and self. - A processed key: - - doesn't contain any ellipses or newaxis - - all Iterables are converted to torch tensors - - has the same dimensionality as the ``DNDarray`` it indexes + Private method to process the key used for indexing a ``DNDarray`` so that it can be applied to the process-local data, i.e. `key` must be "torch-proof". + In a processed key: + - any ellipses or newaxis have been replaced with the appropriate number of slice objects + - ndarrays and DNDarrays have been converted to torch tensors + - the dimensionality is the same as the ``DNDarray`` it indexes + This function also manipulates `arr` if necessary, inserting and/or transposing dimensions as indicated by `key`. It calculates the output shape, split axis and balanced status of the indexed array. Parameters ---------- - key : int, slice, Tuple[int,...], List[int,...] - Indices for the tensor. + arr : DNDarray + The ``DNDarray`` to be indexed + key : int, Tuple[int, ...], List[int, ...] + The key used for indexing + + Returns + ------- + arr : DNDarray + The ``DNDarray`` to be indexed. Its dimensions might have been modified if advanced, dimensional, broadcasted indexing is used. + key : Union(Tuple[Any, ...], DNDarray, np.ndarray, torch.Tensor, slice, int, List[int, ...]) + The processed key ready for indexing ``arr``. Its dimensions match the (potentially modified) dimensions of the ``DNDarray``. + Note: the key indices along the split axis are LOCAL indices, i.e. refer to the process-local data, if ordered indexing is used. Otherwise, they are GLOBAL indices, referring to the global memory-distributed DNDarray. Communication to extract the non-ordered elements of the input ``DNDarray`` is handled by the ``__getitem__`` function. + output_shape : Tuple[int, ...] + The shape of the output ``DNDarray`` + new_split : int + The new split axis + split_key_is_sorted : int + Whether the split key is sorted. Can be 1: ascending, 0: not sorted, -1: descending + out_is_balanced : bool + Whether the output ``DNDarray`` is balanced + root : int + The root process for the ``MPI.Bcast`` call when single-element indexing along the split axis is used + backwards_transpose_axes : Tuple[int, ...] + The axes to transpose the input ``DNDarray`` back to its original shape if it has been transposed for advanced indexing """ output_shape = list(arr.gshape) split_bookkeeping = [None] * arr.ndim @@ -803,7 +824,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted = 1 # can be 1: ascending, 0: not sorted, -1: descending out_is_balanced = False root = None - transpose_axes = tuple(range(arr.ndim)) + backwards_transpose_axes = tuple(range(arr.ndim)) if isinstance(key, list): try: @@ -837,6 +858,8 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] except TypeError: # key is np.ndarray key = key.nonzero() + # convert to torch tensor + key = tuple(torch.tensor(k, device=arr.larray.device) for k in key) output_shape = tuple(key[0].shape) new_split = None if arr.split is None else 0 out_is_balanced = True @@ -849,7 +872,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted, out_is_balanced, root, - transpose_axes, + backwards_transpose_axes, ) # arr is distributed @@ -914,7 +937,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted, out_is_balanced, root, - transpose_axes, + backwards_transpose_axes, ) # advanced indexing on first dimension: first dim will expand to shape of key @@ -946,14 +969,6 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] except AttributeError: key_split = key[new_split] sorted = key_split.sort() - # if split_key_is_sorted: - # # extract local key - # cond1 = key >= displs[arr.comm.rank] - # cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] - # key = key[cond1 & cond2] - # key -= displs[arr.comm.rank] - # out_is_balanced = False - else: new_split = 0 # assess if key is sorted along split axis @@ -1009,7 +1024,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted, out_is_balanced, root, - transpose_axes, + backwards_transpose_axes, ) key = list(key) if isinstance(key, Iterable) else [key] @@ -1049,8 +1064,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # recalculate new_split, transpose_axes after dimensions manipulation new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - transpose_axes = tuple(range(arr.ndim)) - + transpose_axes, backwards_transpose_axes = tuple(range(arr.ndim)), tuple(range(arr.ndim)) # check for advanced indexing and slices print("DEBUGGING: key = ", key) advanced_indexing_dims = [] @@ -1235,6 +1249,12 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] # keep track of transpose axes order, to be able to transpose back later transpose_axes = tuple(advanced_indexing_dims + non_adv_ind_dims) arr = arr.transpose(transpose_axes) + backwards_transpose_axes = tuple( + torch.tensor(transpose_axes, device=arr.larray.device) + .argsort(stable=True) + .tolist() + ) + # output shape and split bookkeeping output_shape = list(arr.gshape) output_shape[: len(advanced_indexing_dims)] = broadcasted_shape split_bookkeeping = [None] * arr.ndim @@ -1271,7 +1291,7 @@ def __process_key(arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...] split_key_is_sorted, out_is_balanced, root, - transpose_axes, + backwards_transpose_axes, ) def __get_local_slice(self, key: slice): @@ -1338,20 +1358,27 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ): # latter doesnt work with torch for 0-dim tensors return self + original_split = self.split # Single-element indexing - # TODO: single-element indexing along split axis belongs here as well scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: output_shape = self.gshape[1:] try: - # is key an ndarray, DNDarray or torch tensor? + # is key an ndarray or DNDarray? key = key.copy().item() except AttributeError: - # key is already an integer, do nothing - pass - if not self.is_distributed() or self.split != 0: + try: + # is key a torch tensor? + key = key.clone().item() + except AttributeError: + # key is already an integer, do nothing + pass + if not self.is_distributed() or original_split != 0: + # single-element indexing along non-split axis indexed_arr = self.larray[key] - output_split = None if self.split is None else self.split - 1 + output_split = ( + None if (original_split is None or original_split == 0) else original_split - 1 + ) indexed_arr = DNDarray( indexed_arr, gshape=output_shape, @@ -1395,7 +1422,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) return indexed_arr - # Many-elements indexing: incl. slicing and striding, ordered advanced indexing + # Many-elements indexing: incl. slicing and striding, ordered and non-ordered advanced indexing # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays ( @@ -1406,13 +1433,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar split_key_is_sorted, out_is_balanced, root, - transpose_axes, + backwards_transpose_axes, ) = self.__process_key(key) - backwards_transpose_axes = ( - torch.tensor(transpose_axes, device=self.larray.device).argsort(stable=True).tolist() - ) - print("DEBUGGING: processed key, output_split = ", key, output_split) if root is not None: @@ -1443,7 +1466,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # including match between self.split and key after self manipulation # data are not distributed or split dimension is not affected by indexing - # if not self.is_distributed() or key[self.split] == slice(None): print("split_key_is_sorted, key = ", split_key_is_sorted, key) if split_key_is_sorted == 1: indexed_arr = self.larray[key] @@ -1462,7 +1484,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # key is not sorted along self.split # key is tuple of torch.Tensor or mix of torch.Tensors and slices _, displs = self.counts_displs() - original_split = self.split # determine whether indexed array will be 1D or nD try: @@ -1507,7 +1528,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # process-local: calculate which/how many elements will be received from what process if split_key_is_sorted == -1: - # key is sorted in descending order (i.e. slicing w/ negative step) + # key is sorted in descending order (i.e. slicing w/ negative step): # shrink selection of active processes if key[original_split].numel() > 0: key_edges = torch.cat( @@ -1730,278 +1751,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar indexed_arr = recv_buf[map] return factories.array(indexed_arr, is_split=output_split, copy=False) - # TODO: boolean indexing with data.split != 0 - # __process_key() returns locally correct key - # after local indexing, Alltoallv for correct order of output - - # data are distributed and split dimension is affected by indexing - # __process_key() returns the local key already - - # _, offsets = self.counts_displs() - # split = self.split - # # slice along the split axis - # if isinstance(key[split], slice): - # local_slice = self.__get_local_slice(key[split]) - # if local_slice is not None: - # key = list(key) - # key[split] = local_slice - # local_tensor = self.larray[tuple(key)] - # else: # local tensor is empty - # local_shape = list(output_shape) - # local_shape[output_split] = 0 - # local_tensor = torch.zeros( - # tuple(local_shape), dtype=self.larray.dtype, device=self.larray.device - # ) - - # return DNDarray( - # local_tensor, - # gshape=output_shape, - # dtype=self.dtype, - # split=output_split, - # device=self.device, - # balanced=False, - # comm=self.comm, - # ) - - # local indexing cases: - # self is not distributed, key is not distributed - DONE - # self is distributed, key along split is a slice - DONE - # self is distributed, key is boolean mask (what about distributed boolean mask?) - - # distributed indexing: - # key is distributed - # key calls for advanced indexing - # key is a non-sorted sequence - # key is a sorted sequence (descending) - - # key = getattr(key, "copy()", key) - # l_dtype = self.dtype.torch_type() - # advanced_ind = False - # if isinstance(key, DNDarray) and key.ndim == self.ndim: - # """ if the key is a DNDarray and it has as many dimensions as self, then each of the - # entries in the 0th dim refer to a single element. To handle this, the key is split - # into the torch tensors for each dimension. This signals that advanced indexing is - # to be used. """ - # # NOTE: this gathers the entire key on every process!! - # # TODO: remove this resplit!! - # key = manipulations.resplit(key) - # if key.larray.dtype in [torch.bool, torch.uint8]: - # key = indexing.nonzero(key) - - # if key.ndim > 1: - # key = list(key.larray.split(1, dim=1)) - # # key is now a list of tensors with dimensions (key.ndim, 1) - # # squeeze singleton dimension: - # key = list(key[i].squeeze_(1) for i in range(len(key))) - # else: - # key = [key] - # advanced_ind = True - # elif not isinstance(key, tuple): - # """ this loop handles all other cases. DNDarrays which make it to here refer to - # advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - # are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - # h = [slice(None, None, None)] * max(self.ndim, 1) - # if isinstance(key, DNDarray): - # key = manipulations.resplit(key) - # if key.larray.dtype in [torch.bool, torch.uint8]: - # h[0] = torch.nonzero(key.larray).flatten() # .tolist() - # else: - # h[0] = key.larray.tolist() - # elif isinstance(key, torch.Tensor): - # if key.dtype in [torch.bool, torch.uint8]: - # # (coquelin77) i am not certain why this works without being a list. but it works...for now - # h[0] = torch.nonzero(key).flatten() # .tolist() - # else: - # h[0] = key.tolist() - # else: - # h[0] = key - - # key = list(h) - - # if isinstance(key, (list, tuple)): - # key = list(key) - # for i, k in enumerate(key): - # # this might be a good place to check if the dtype is there - # try: - # k = manipulations.resplit(k) - # key[i] = k.larray - # except AttributeError: - # pass - - # # ellipsis - # key = list(key) - # key_classes = [type(n) for n in key] - # # if any(isinstance(n, ellipsis) for n in key): - # n_elips = key_classes.count(type(...)) - # if n_elips > 1: - # raise ValueError("key can only contain 1 ellipsis") - # elif n_elips == 1: - # # get which item is the ellipsis - # ell_ind = key_classes.index(type(...)) - # kst = key[:ell_ind] - # kend = key[ell_ind + 1 :] - # slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - # key = kst + slices + kend - # else: - # key = key + [slice(None)] * (self.ndim - len(key)) - - # self_proxy = self.__torch_proxy__() - # for i in range(len(key)): - # if self.__key_adds_dimension(key, i, self_proxy): - # key[i] = slice(None) - # return self.expand_dims(i)[tuple(key)] - - # key = tuple(key) - # # assess final global shape - # gout_full = list(self_proxy[key].shape) - - # # calculate new split axis - # new_split = self.split - # # when slicing, squeezed singleton dimensions may affect new split axis - # if self.split is not None and len(gout_full) < self.ndim: - # if advanced_ind: - # new_split = 0 - # else: - # for i in range(len(key[: self.split + 1])): - # if self.__key_is_singular(key, i, self_proxy): - # new_split = None if i == self.split else new_split - 1 - - # key = tuple(key) - # if not self.is_distributed(): - # arr = self.__array[key].reshape(gout_full) - # return DNDarray( - # arr, tuple(gout_full), self.dtype, new_split, self.device, self.comm, self.balanced - # ) - - # # else: (DNDarray is distributed) - # arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device) - # rank = self.comm.rank - # counts, chunk_starts = self.counts_displs() - # counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts) - # chunk_ends = chunk_starts + counts - # chunk_start = chunk_starts[rank] - # chunk_end = chunk_ends[rank] - - # if len(key) == 0: # handle empty list - # # this will return an array of shape (0, ...) - # arr = self.__array[key] - - # """ At the end of the following if/elif/elif block the output array will be set. - # each block handles the case where the element of the key along the split axis - # is a different type and converts the key from global indices to local indices. """ - # lout = gout_full.copy() - - # if ( - # isinstance(key[self.split], (list, torch.Tensor, DNDarray, np.ndarray)) - # and len(key[self.split]) > 1 - # ): - # # advanced indexing, elements in the split dimension are adjusted to the local indices - # lkey = list(key) - # if isinstance(key[self.split], DNDarray): - # lkey[self.split] = key[self.split].larray - - # if not isinstance(lkey[self.split], torch.Tensor): - # inds = torch.tensor( - # lkey[self.split], dtype=torch.long, device=self.device.torch_device - # ) - # else: - # if lkey[self.split].dtype in [torch.bool, torch.uint8]: # or torch.byte? - # # need to convert the bools to indices - # inds = torch.nonzero(lkey[self.split]) - # else: - # inds = lkey[self.split] - # # todo: remove where in favor of nonzero? might be a speed upgrade. testing required - # loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end)) - # # if there are no local indices on a process, then `arr` is empty - # # if local indices exist: - # if len(loc_inds[0]) != 0: - # # select same local indices for other (non-split) dimensions if necessary - # for i, k in enumerate(lkey): - # if isinstance(k, (list, torch.Tensor, DNDarray)): - # if i != self.split: - # lkey[i] = k[loc_inds] - # # correct local indices for offset - # inds = inds[loc_inds] - chunk_start - # lkey[self.split] = inds - # lout[new_split] = len(inds) - # arr = self.__array[tuple(lkey)].reshape(tuple(lout)) - # elif len(loc_inds[0]) == 0: - # if new_split is not None: - # lout[new_split] = len(loc_inds[0]) - # else: - # lout = [0] * len(gout_full) - # arr = torch.tensor([], dtype=self.larray.dtype, device=self.larray.device).reshape( - # tuple(lout) - # ) - - # elif isinstance(key[self.split], slice): - # # standard slicing along the split axis, - # # adjust the slice start, stop, and step, then run it on the processes which have the requested data - # key = list(key) - # key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) - # key_start, key_stop, key_step = ( - # key[self.split].start, - # key[self.split].stop, - # key[self.split].step, - # ) - # og_key_start = key_start - # st_pr = torch.where(key_start < chunk_ends)[0] - # st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - # sp_pr = torch.where(key_stop >= chunk_starts)[0] - # sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - # actives = list(range(st_pr, sp_pr + 1)) - # if rank in actives: - # key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - # key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - # key_start, key_stop = self.__xitem_get_key_start_stop( - # rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - # ) - # key[self.split] = slice(key_start, key_stop, key_step) - # lout[new_split] = ( - # math.ceil((key_stop - key_start) / key_step) - # if key_step is not None - # else key_stop - key_start - # ) - # arr = self.__array[tuple(key)].reshape(lout) - # else: - # lout[new_split] = 0 - # arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) - - # elif self.__key_is_singular(key, self.split, self_proxy): - # # getting one item along split axis: - # key = list(key) - # if isinstance(key[self.split], list): - # key[self.split] = key[self.split].pop() - # elif isinstance(key[self.split], (torch.Tensor, DNDarray, np.ndarray)): - # key[self.split] = key[self.split].item() - # # translate negative index - # if key[self.split] < 0: - # key[self.split] += self.gshape[self.split] - - # active_rank = torch.where(key[self.split] >= chunk_starts)[0][-1].item() - # # slice `self` on `active_rank`, allocate `arr` on all other ranks in preparation for Bcast - # if rank == active_rank: - # key[self.split] -= chunk_start.item() - # arr = self.__array[tuple(key)].reshape(tuple(lout)) - # else: - # arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) - # # broadcast result - # # TODO: Replace with `self.comm.Bcast(arr, root=active_rank)` after fixing #784 - # arr = self.comm.bcast(arr, root=active_rank) - # if arr.device != self.larray.device: - # # todo: remove when unnecessary (also after #784) - # arr = arr.to(device=self.larray.device) - - # return DNDarray( - # arr.type(l_dtype), - # gout_full if isinstance(gout_full, tuple) else tuple(gout_full), - # self.dtype, - # new_split, - # self.device, - # self.comm, - # balanced=True if new_split is None else None, - # ) - if torch.cuda.device_count() > 0: def gpu(self) -> DNDarray: diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 0adf4724f5..2880a2c853 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -546,7 +546,7 @@ def test_getitem(self): self.assertTrue(x[2].item() == 2.0) self.assertTrue(x[-2].item() == 8.0) self.assertTrue(x[2].dtype == ht.float64) - # self.assertTrue(x[2].split is None) + self.assertTrue(x[2].split is None) # 2D, local x = ht.arange(10).reshape(2, 5) self.assertTrue((x[0] == ht.arange(5)).all().item()) @@ -568,7 +568,7 @@ def test_getitem(self): indexed_split0 = x_split0[key] self.assertTrue((indexed_split0.larray == x.larray[key]).all()) self.assertTrue(indexed_split0.dtype == ht.float32) - # self.assertTrue(indexed_split0.split is None) + self.assertTrue(indexed_split0.split is None) # 3D, distributed split, != 0 x_split2 = ht.array(x, dtype=ht.int64, split=2) key = ht.array(2) From 68ead71df298cfb6c1bbcb9df4340420e62e37d5 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 1 Aug 2023 14:10:05 +0200 Subject: [PATCH 068/132] [skip ci] docs edits --- heat/core/dndarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index b0134d0ff9..8357b08d81 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -795,14 +795,14 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> arr : DNDarray The ``DNDarray`` to be indexed. Its dimensions might have been modified if advanced, dimensional, broadcasted indexing is used. key : Union(Tuple[Any, ...], DNDarray, np.ndarray, torch.Tensor, slice, int, List[int, ...]) - The processed key ready for indexing ``arr``. Its dimensions match the (potentially modified) dimensions of the ``DNDarray``. + The processed key ready for indexing ``arr``. Its dimensions match the (potentially modified) dimensions of ``arr``. Note: the key indices along the split axis are LOCAL indices, i.e. refer to the process-local data, if ordered indexing is used. Otherwise, they are GLOBAL indices, referring to the global memory-distributed DNDarray. Communication to extract the non-ordered elements of the input ``DNDarray`` is handled by the ``__getitem__`` function. output_shape : Tuple[int, ...] The shape of the output ``DNDarray`` new_split : int The new split axis split_key_is_sorted : int - Whether the split key is sorted. Can be 1: ascending, 0: not sorted, -1: descending + Whether the split key is sorted or ordered. Can be 1: ascending, 0: not ordered, -1: descending order. out_is_balanced : bool Whether the output ``DNDarray`` is balanced root : int From 252995cb53615512c22f37ba9b73707f98fa6563 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 4 Aug 2023 06:58:34 +0200 Subject: [PATCH 069/132] fix Ellipsis dimensions --- heat/core/dndarray.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8357b08d81..784916c64e 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1032,19 +1032,19 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> # check for ellipsis, newaxis. NB: (np.newaxis is None)==True add_dims = sum(k is None for k in key) ellipsis = sum(isinstance(k, type(...)) for k in key) - if ellipsis == 1: + if ellipsis > 1: + raise ValueError("indexing key can only contain 1 Ellipsis (...)") + if ellipsis: + # key contains exactly 1 ellipsis # replace with explicit `slice(None)` for affected dimensions # output_shape, split_bookkeeping not affected expand_key = [slice(None)] * (arr.ndim + add_dims) ellipsis_index = key.index(...) + ellipsis_dims = arr.ndim - (len(key) - ellipsis - add_dims) expand_key[:ellipsis_index] = key[:ellipsis_index] - expand_key[ellipsis_index - (len(key) - ellipsis - ellipsis_index) :] = key[ - ellipsis_index + 1 : - ] + expand_key[ellipsis_index + ellipsis_dims :] = key[ellipsis_index + 1 :] key = expand_key print("DEBUGGING: ELLIPSIS: ", key) - elif ellipsis > 1: - raise ValueError("key can only contain 1 ellipsis") while add_dims > 0: # expand array dims: output_shape, split_bookkeeping to reflect newaxis # replace newaxis with slice(None) in key From c2a7e204fc8aa30dc7733fc418fbd818e780039f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 4 Aug 2023 09:58:00 +0200 Subject: [PATCH 070/132] fix shape and split bookkeeping within advanced indexing --- heat/core/dndarray.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 784916c64e..8f4e71e960 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1075,15 +1075,12 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> # advanced indexing across dimensions if getattr(k, "ndim", 1) == 0: # single-element indexing along axis i - output_shape[i] = None - split_bookkeeping = ( - split_bookkeeping[: i - lose_dims] + split_bookkeeping[i - lose_dims + 1 :] - ) + output_shape[i], split_bookkeeping[i] = None, None lose_dims += 1 if arr_is_distributed and i == arr.split: # single-element indexing along split axis # work out root process for Bcast - key[i] = k.item() + arr.shape[i] if k < 0 else k.item() + key[i] = k.item() + arr.shape[i] if k.item() < 0 else k.item() if key[i] in displs: root = displs.index(key[i]) else: @@ -1131,10 +1128,7 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> split_key_is_sorted = 0 elif isinstance(k, int): # single-element indexing along axis i - output_shape[i] = None - split_bookkeeping = ( - split_bookkeeping[: i - lose_dims] + split_bookkeeping[i - lose_dims + 1 :] - ) + output_shape[i], split_bookkeeping[i] = None, None lose_dims += 1 if arr_is_distributed and i == arr.split: # single-element indexing along split axis @@ -1255,11 +1249,9 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> .tolist() ) # output shape and split bookkeeping - output_shape = list(arr.gshape) + output_shape = list(output_shape[i] for i in transpose_axes) output_shape[: len(advanced_indexing_dims)] = broadcasted_shape - split_bookkeeping = [None] * arr.ndim - if arr_is_distributed: - split_bookkeeping[arr.split] = "split" + split_bookkeeping = list(split_bookkeeping[i] for i in transpose_axes) split_bookkeeping = [None] * add_dims + split_bookkeeping # modify key to match the new dimension order key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] @@ -1272,7 +1264,9 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> key = tuple(key) for i in range(output_shape.count(None)): + lost_dim = output_shape.index(None) output_shape.remove(None) + split_bookkeeping = split_bookkeeping[:lost_dim] + split_bookkeeping[lost_dim + 1 :] output_shape = tuple(output_shape) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None print( From 235a7b8ce94d5e9499c8ca785dd0de8f0287fd13 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 4 Aug 2023 09:58:35 +0200 Subject: [PATCH 071/132] test adv indexing on non consecutive dims --- heat/core/tests/test_dndarray.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 2880a2c853..601e3c4c63 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -677,7 +677,7 @@ def test_getitem(self): self.assertTrue(y.split == 0) # ADVANCED INDEXING - # "x[(1, 2, 3),] is fundamentally different than x[(1, 2, 3)]" + # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" x_np = np.arange(60).reshape(5, 3, 4) indexed_x_np = x_np[(1, 2, 3)] @@ -704,7 +704,23 @@ def test_getitem(self): self.assert_array_equal( x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] ) + # advanced indexing on non-consecutive dimensions + x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + x_copy = x.copy() + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = 0 + k3 = np.array([1, 2, 3, 1]) + key = (k1, k2, k3) + self.assert_array_equal(x[key], x_np[key]) + # check that x is unchanged after internal manipulation + self.assertTrue(x.shape == x_copy.shape) + self.assertTrue(x.split == x_copy.split) + self.assertTrue(x.lshape == x_copy.lshape) + self.assertTrue((x == x_copy).all().item()) + # broadcasting shapes + x.resplit_(axis=0) self.assert_array_equal(x[ht.array(k1, split=0), ht.array(1), 2], x_np[k1, 1, 2]) # test exception: broadcasting mismatching shapes k2 = np.array([0, 2, 1]) From 4e936e84eb352b11aab101ea37e1e2389f52e2ec Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 7 Aug 2023 10:49:20 +0200 Subject: [PATCH 072/132] abstract scalar key checks for both getitem and setitem --- heat/core/dndarray.py | 105 +++++++++++++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 28 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8f4e71e960..ac4d356e48 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -156,6 +156,7 @@ def larray(self, array: torch.Tensor): ----------- Please use this function with care, as it might corrupt/invalidate the metadata in the ``DNDarray`` instance. """ + print("DEBUGGING: larray setter") # sanitize tensor input sanitation.sanitize_in_tensor(array) # verify consistency of tensor shape with global DNDarray @@ -1288,6 +1289,49 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> backwards_transpose_axes, ) + def __process_scalar_key( + arr: DNDarray, key: Union[int, DNDarray, torch.Tensor, np.ndarray] + ) -> Tuple(int, int): + """ + Private method to process a single-item scalar key used for indexing a ``DNDarray``. + + """ + device = arr.larray.device + try: + # is key an ndarray or DNDarray? + key = key.copy().item() + except AttributeError: + try: + # is key a torch tensor? + key = key.clone().item() + except AttributeError: + # key is already an integer, do nothing + pass + if arr.is_distributed() and arr.split == 0: + # adjust negative key + if key < 0: + key += arr.shape[0] + # work out active process + _, displs = arr.counts_displs() + if key in displs: + root = displs.index(key) + else: + displs = torch.cat( + ( + torch.tensor(displs, device=device), + torch.tensor(key, device=device).reshape(-1), + ), + dim=0, + ) + _, sorted_indices = displs.unique(sorted=True, return_inverse=True) + root = sorted_indices[-1] - 1 + # correct key for rank-specific displacement + if arr.comm.rank == root: + key -= displs[root] + else: + root = None + return key, root + def __get_local_slice(self, key: slice): split = self.split if split is None: @@ -1357,17 +1401,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: output_shape = self.gshape[1:] - try: - # is key an ndarray or DNDarray? - key = key.copy().item() - except AttributeError: - try: - # is key a torch tensor? - key = key.clone().item() - except AttributeError: - # key is already an integer, do nothing - pass - if not self.is_distributed() or original_split != 0: + key, root = self.__process_scalar_key(key) + if root is None: # single-element indexing along non-split axis indexed_arr = self.larray[key] output_split = ( @@ -1383,21 +1418,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar balanced=self.balanced, ) return indexed_arr - # single-element indexing along split axis: - # check for negative key - key = key + self.shape[0] if key < 0 else key - # identify root process - _, displs = self.counts_displs() - if key in displs: - root = displs.index(key) - else: - displs = torch.cat((torch.tensor(displs), torch.tensor(key).reshape(-1)), dim=0) - _, sorted_indices = displs.unique(sorted=True, return_inverse=True) - root = sorted_indices[-1] - 1 - # allocate buffer on all processes + # root is not None: single-element indexing along split axis + # prepare for Bcast: allocate buffer on all processes if self.comm.rank == root: - # correct key for rank-specific displacement - key -= displs[root] indexed_arr = self.larray[key] else: indexed_arr = torch.zeros( @@ -2240,8 +2263,13 @@ def __set(arr: DNDarray, value: DNDarray): """ Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. """ - if not isinstance(value, DNDarray): - value = factories.array(value, device=arr.device, comm=arr.comm) + value_split = value.split if isinstance(value, DNDarray) else None + try: + value = factories.array( + value, dtype=arr.dtype, split=value_split, device=arr.device, comm=arr.comm + ) + except TypeError: + raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") while value.ndim < arr.ndim: # broadcasting value = value.expand_dims(0) sanitation.sanitize_out(arr, value.shape, value.split, value.device, value.comm) @@ -2252,7 +2280,28 @@ def __set(arr: DNDarray, value: DNDarray): if key is None or key == ... or key == slice(None): return __set(self, value) - self, key, output_shape, output_split, advanced_indexing = self.__process_key(key) + # scalar key + scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 + if scalar: + key, root = self.__process_scalar_key(key) + if root is not None: + if self.comm.rank == root: + self.larray[key] = value.larray + else: + self.larray[key] = value.larray + return + + ( + self, + key, + output_shape, + output_split, + split_key_is_sorted, + out_is_balanced, + root, + backwards_transpose_axes, + ) = self.__process_key(key) + # if advanced_indexing: # raise Exception("Advanced indexing is not supported yet") From 8a74cd9d99a59df75fd28dabbfea457f4b1f8d0a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 8 Aug 2023 10:22:54 +0200 Subject: [PATCH 073/132] setitem scalar key --- heat/core/dndarray.py | 568 +++++++++++++++++++++------------------- heat/core/sanitation.py | 43 +-- 2 files changed, 321 insertions(+), 290 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ac4d356e48..ba654caff8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1076,7 +1076,12 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> # advanced indexing across dimensions if getattr(k, "ndim", 1) == 0: # single-element indexing along axis i - output_shape[i], split_bookkeeping[i] = None, None + try: + output_shape[i], split_bookkeeping[i] = None, None + except IndexError: + raise IndexError( + f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" + ) lose_dims += 1 if arr_is_distributed and i == arr.split: # single-element indexing along split axis @@ -1129,7 +1134,12 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> split_key_is_sorted = 0 elif isinstance(k, int): # single-element indexing along axis i - output_shape[i], split_bookkeeping[i] = None, None + try: + output_shape[i], split_bookkeeping[i] = None, None + except IndexError: + raise IndexError( + f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" + ) lose_dims += 1 if arr_is_distributed and i == arr.split: # single-element indexing along split axis @@ -2259,7 +2269,10 @@ def __setitem__( [0., 1., 0., 0., 0.]]) """ - def __set(arr: DNDarray, value: DNDarray): + def __set( + arr: Union[DNDarray, torch.Tensor], + value: Union[DNDarray, torch.Tensor, np.ndarray, float, int, list, tuple], + ): """ Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. """ @@ -2274,21 +2287,27 @@ def __set(arr: DNDarray, value: DNDarray): value = value.expand_dims(0) sanitation.sanitize_out(arr, value.shape, value.split, value.device, value.comm) value = sanitation.sanitize_distribution(value, target=arr) - arr.larray[None] = value.larray + try: + arr.larray[None] = value.larray + except AttributeError: + # arr is already the process-local torch tensor + arr[None] = value.larray return if key is None or key == ... or key == slice(None): return __set(self, value) + # torch_device = self.larray.device + # scalar key scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: key, root = self.__process_scalar_key(key) if root is not None: if self.comm.rank == root: - self.larray[key] = value.larray + __set(self.larray[key], value) else: - self.larray[key] = value.larray + __set(self[key], value) return ( @@ -2302,276 +2321,279 @@ def __set(arr: DNDarray, value: DNDarray): backwards_transpose_axes, ) = self.__process_key(key) + # if split_key_is_sorted: + # process-local indices + # if advanced_indexing: # raise Exception("Advanced indexing is not supported yet") - split = self.split - if not self.is_distributed() or key[split] == slice(None): - return __set(self[key], value) - - if isinstance(key[split], slice): - return __set(self[key], value) - - if np.isscalar(key[split]): - key = list(key) - idx = int(key[split]) - key[split] = slice(idx, idx + 1) - return __set(self[tuple(key)], value) - - key = getattr(key, "copy()", key) - try: - if value.split != self.split: - val_split = int(value.split) - sp = self.split - warnings.warn( - f"\nvalue.split {val_split} not equal to this DNDarray's split:" - f" {sp}. this may cause errors or unwanted behavior", - category=RuntimeWarning, - ) - except (AttributeError, TypeError): - pass - - # NOTE: for whatever reason, there is an inplace op which interferes with the abstraction - # of this next block of code. this is shared with __getitem__. I attempted to abstract it - # in a standard way, but it was causing errors in the test suite. If someone else is - # motived to do this they are welcome to, but i have no time right now - # print(key) - if isinstance(key, DNDarray) and key.ndim == self.ndim: - """if the key is a DNDarray and it has as many dimensions as self, then each of the - entries in the 0th dim refer to a single element. To handle this, the key is split - into the torch tensors for each dimension. This signals that advanced indexing is - to be used.""" - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) - - if key.ndim > 1: - key = list(key.larray.split(1, dim=1)) - # key is now a list of tensors with dimensions (key.ndim, 1) - # squeeze singleton dimension: - key = [key[i].squeeze_(1) for i in range(len(key))] - else: - key = [key] - elif not isinstance(key, tuple): - """this loop handles all other cases. DNDarrays which make it to here refer to - advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * self.ndim - if isinstance(key, DNDarray): - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - h[0] = torch.nonzero(key.larray).flatten() # .tolist() - else: - h[0] = key.larray.tolist() - elif isinstance(key, torch.Tensor): - if key.dtype in [torch.bool, torch.uint8]: - # (coquelin77) im not sure why this works without being a list...but it does...for now - h[0] = torch.nonzero(key).flatten() # .tolist() - else: - h[0] = key.tolist() - else: - h[0] = key - key = list(h) - - # key must be torch-proof - if isinstance(key, (list, tuple)): - key = list(key) - for i, k in enumerate(key): - try: # extract torch tensor - k = manipulations.resplit(k) - key[i] = k.larray - except AttributeError: - pass - # remove bools from a torch tensor in favor of indexes - try: - if key[i].dtype in [torch.bool, torch.uint8]: - key[i] = torch.nonzero(key[i]).flatten() - except (AttributeError, TypeError): - pass - - key = list(key) - - # ellipsis stuff - key_classes = [type(n) for n in key] - # if any(isinstance(n, ellipsis) for n in key): - n_elips = key_classes.count(type(...)) - if n_elips > 1: - raise ValueError("key can only contain 1 ellipsis") - elif n_elips == 1: - # get which item is the ellipsis - ell_ind = key_classes.index(type(...)) - kst = key[:ell_ind] - kend = key[ell_ind + 1 :] - slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - key = kst + slices + kend - # ---------- end ellipsis stuff ------------- - - for c, k in enumerate(key): - try: - key[c] = k.item() - except (AttributeError, ValueError, RuntimeError): - pass - - rank = self.comm.rank - if self.split is not None: - counts, chunk_starts = self.counts_displs() - else: - counts, chunk_starts = 0, [0] * self.comm.size - counts = torch.tensor(counts, device=self.device.torch_device) - chunk_starts = torch.tensor(chunk_starts, device=self.device.torch_device) - chunk_ends = chunk_starts + counts - chunk_start = chunk_starts[rank] - chunk_end = chunk_ends[rank] - # determine which elements are on the local process (if the key is a torch tensor) - try: - # if isinstance(key[self.split], torch.Tensor): - filter_key = torch.nonzero( - (chunk_start <= key[self.split]) & (key[self.split] < chunk_end) - ) - for k in range(len(key)): - try: - key[k] = key[k][filter_key].flatten() - except TypeError: - pass - except TypeError: # this will happen if the key doesnt have that many - pass - - key = tuple(key) - - if not self.is_distributed(): - return self.__setter(key, value) # returns None - - # raise RuntimeError("split axis of array and the target value are not equal") removed - # this will occur if the local shapes do not match - rank = self.comm.rank - ends = [] - for pr in range(self.comm.size): - _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) - ends.append(e[self.split].stop - e[self.split].start) - ends = torch.tensor(ends, device=self.device.torch_device) - chunk_ends = ends.cumsum(dim=0) - chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) - _, _, chunk_slice = self.comm.chunk(self.shape, self.split) - chunk_start = chunk_slice[self.split].start - chunk_end = chunk_slice[self.split].stop - - self_proxy = self.__torch_proxy__() - - # if the value is a DNDarray, the divisions need to be balanced: - # this means that we need to know how much data is where for both DNDarrays - # if the value data is not in the right place, then it will need to be moved - - if isinstance(key[self.split], slice): - key = list(key) - key_start = key[self.split].start if key[self.split].start is not None else 0 - key_stop = ( - key[self.split].stop - if key[self.split].stop is not None - else self.gshape[self.split] - ) - if key_stop < 0: - key_stop = self.gshape[self.split] + key[self.split].stop - key_step = key[self.split].step - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) - - if ( - isinstance(value, type(self)) - and value.split is not None - and value.shape[self.split] != self.shape[self.split] - ): - # setting elements in self with a DNDarray which is not the same size in the - # split dimension - local_keys = [] - # below is used if the target needs to be reshaped - target_reshape_map = torch.zeros( - (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device - ) - for r in range(self.comm.size): - if r not in actives: - loc_key = key.copy() - loc_key[self.split] = slice(0, 0, 0) - else: - key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] - key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] - key_start_l, key_stop_l = self.__xitem_get_key_start_stop( - r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start - ) - loc_key = key.copy() - loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) - - gout_full = torch.tensor( - self_proxy[loc_key].shape, device=self.device.torch_device - ) - target_reshape_map[r] = gout_full - local_keys.append(loc_key) - - key = local_keys[rank] - value = value.redistribute(target_map=target_reshape_map) - - if rank not in actives: - return # non-active ranks can exit here - - chunk_starts_v = target_reshape_map[:, self.split] - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = (chunk_starts_v[rank] - og_key_start).item() - - key_start = max(key_start, 0) - key_stop = key_start + key_stop - slice_loc = min(self.split, value.ndim - 1) - value_slice[slice_loc] = slice( - key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 - ) - - self.__setter(tuple(key), value.larray) - return - - # if rank in actives: - if rank not in actives: - return # non-active ranks can exit here - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__xitem_get_key_start_stop( - rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - ) - key[self.split] = slice(key_start, key_stop, key_step) - - # todo: need to slice the values to be the right size... - if isinstance(value, (torch.Tensor, type(self))): - # if its a torch tensor, it is assumed to exist on all processes - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = (chunk_starts[rank] - og_key_start).item() - key_start = max(key_start, 0) - key_stop = key_start + key_stop - slice_loc = min(self.split, value.ndim - 1) - value_slice[slice_loc] = slice( - key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 - ) - self.__setter(tuple(key), value[tuple(value_slice)]) - else: - self.__setter(tuple(key), value) - elif isinstance(key[self.split], (torch.Tensor, list)): - key = list(key) - key[self.split] -= chunk_start - if len(key[self.split]) != 0: - self.__setter(tuple(key), value) - - elif key[self.split] in range(chunk_start, chunk_end): - key = list(key) - key[self.split] = key[self.split] - chunk_start - self.__setter(tuple(key), value) - - elif key[self.split] < 0: - key = list(key) - if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): - key[self.split] = key[self.split] + self.shape[self.split] - chunk_start - self.__setter(tuple(key), value) + # split = self.split + # if not self.is_distributed() or key[split] == slice(None): + # return __set(self[key], value) + + # if isinstance(key[split], slice): + # return __set(self[key], value) + + # if np.isscalar(key[split]): + # key = list(key) + # idx = int(key[split]) + # key[split] = slice(idx, idx + 1) + # return __set(self[tuple(key)], value) + + # key = getattr(key, "copy()", key) + # try: + # if value.split != self.split: + # val_split = int(value.split) + # sp = self.split + # warnings.warn( + # f"\nvalue.split {val_split} not equal to this DNDarray's split:" + # f" {sp}. this may cause errors or unwanted behavior", + # category=RuntimeWarning, + # ) + # except (AttributeError, TypeError): + # pass + + # # NOTE: for whatever reason, there is an inplace op which interferes with the abstraction + # # of this next block of code. this is shared with __getitem__. I attempted to abstract it + # # in a standard way, but it was causing errors in the test suite. If someone else is + # # motived to do this they are welcome to, but i have no time right now + # # print(key) + # if isinstance(key, DNDarray) and key.ndim == self.ndim: + # """if the key is a DNDarray and it has as many dimensions as self, then each of the + # entries in the 0th dim refer to a single element. To handle this, the key is split + # into the torch tensors for each dimension. This signals that advanced indexing is + # to be used.""" + # key = manipulations.resplit(key) + # if key.larray.dtype in [torch.bool, torch.uint8]: + # key = indexing.nonzero(key) + + # if key.ndim > 1: + # key = list(key.larray.split(1, dim=1)) + # # key is now a list of tensors with dimensions (key.ndim, 1) + # # squeeze singleton dimension: + # key = [key[i].squeeze_(1) for i in range(len(key))] + # else: + # key = [key] + # elif not isinstance(key, tuple): + # """this loop handles all other cases. DNDarrays which make it to here refer to + # advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors + # are cast into lists here by PyTorch. lists mean advanced indexing will be used""" + # h = [slice(None, None, None)] * self.ndim + # if isinstance(key, DNDarray): + # key = manipulations.resplit(key) + # if key.larray.dtype in [torch.bool, torch.uint8]: + # h[0] = torch.nonzero(key.larray).flatten() # .tolist() + # else: + # h[0] = key.larray.tolist() + # elif isinstance(key, torch.Tensor): + # if key.dtype in [torch.bool, torch.uint8]: + # # (coquelin77) im not sure why this works without being a list...but it does...for now + # h[0] = torch.nonzero(key).flatten() # .tolist() + # else: + # h[0] = key.tolist() + # else: + # h[0] = key + # key = list(h) + + # # key must be torch-proof + # if isinstance(key, (list, tuple)): + # key = list(key) + # for i, k in enumerate(key): + # try: # extract torch tensor + # k = manipulations.resplit(k) + # key[i] = k.larray + # except AttributeError: + # pass + # # remove bools from a torch tensor in favor of indexes + # try: + # if key[i].dtype in [torch.bool, torch.uint8]: + # key[i] = torch.nonzero(key[i]).flatten() + # except (AttributeError, TypeError): + # pass + + # key = list(key) + + # # ellipsis stuff + # key_classes = [type(n) for n in key] + # # if any(isinstance(n, ellipsis) for n in key): + # n_elips = key_classes.count(type(...)) + # if n_elips > 1: + # raise ValueError("key can only contain 1 ellipsis") + # elif n_elips == 1: + # # get which item is the ellipsis + # ell_ind = key_classes.index(type(...)) + # kst = key[:ell_ind] + # kend = key[ell_ind + 1 :] + # slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) + # key = kst + slices + kend + # # ---------- end ellipsis stuff ------------- + + # for c, k in enumerate(key): + # try: + # key[c] = k.item() + # except (AttributeError, ValueError, RuntimeError): + # pass + + # rank = self.comm.rank + # if self.split is not None: + # counts, chunk_starts = self.counts_displs() + # else: + # counts, chunk_starts = 0, [0] * self.comm.size + # counts = torch.tensor(counts, device=self.device.torch_device) + # chunk_starts = torch.tensor(chunk_starts, device=self.device.torch_device) + # chunk_ends = chunk_starts + counts + # chunk_start = chunk_starts[rank] + # chunk_end = chunk_ends[rank] + # # determine which elements are on the local process (if the key is a torch tensor) + # try: + # # if isinstance(key[self.split], torch.Tensor): + # filter_key = torch.nonzero( + # (chunk_start <= key[self.split]) & (key[self.split] < chunk_end) + # ) + # for k in range(len(key)): + # try: + # key[k] = key[k][filter_key].flatten() + # except TypeError: + # pass + # except TypeError: # this will happen if the key doesnt have that many + # pass + + # key = tuple(key) + + # if not self.is_distributed(): + # return self.__setter(key, value) # returns None + + # # raise RuntimeError("split axis of array and the target value are not equal") removed + # # this will occur if the local shapes do not match + # rank = self.comm.rank + # ends = [] + # for pr in range(self.comm.size): + # _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) + # ends.append(e[self.split].stop - e[self.split].start) + # ends = torch.tensor(ends, device=self.device.torch_device) + # chunk_ends = ends.cumsum(dim=0) + # chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) + # _, _, chunk_slice = self.comm.chunk(self.shape, self.split) + # chunk_start = chunk_slice[self.split].start + # chunk_end = chunk_slice[self.split].stop + + # self_proxy = self.__torch_proxy__() + + # # if the value is a DNDarray, the divisions need to be balanced: + # # this means that we need to know how much data is where for both DNDarrays + # # if the value data is not in the right place, then it will need to be moved + + # if isinstance(key[self.split], slice): + # key = list(key) + # key_start = key[self.split].start if key[self.split].start is not None else 0 + # key_stop = ( + # key[self.split].stop + # if key[self.split].stop is not None + # else self.gshape[self.split] + # ) + # if key_stop < 0: + # key_stop = self.gshape[self.split] + key[self.split].stop + # key_step = key[self.split].step + # og_key_start = key_start + # st_pr = torch.where(key_start < chunk_ends)[0] + # st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size + # sp_pr = torch.where(key_stop >= chunk_starts)[0] + # sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 + # actives = list(range(st_pr, sp_pr + 1)) + + # if ( + # isinstance(value, type(self)) + # and value.split is not None + # and value.shape[self.split] != self.shape[self.split] + # ): + # # setting elements in self with a DNDarray which is not the same size in the + # # split dimension + # local_keys = [] + # # below is used if the target needs to be reshaped + # target_reshape_map = torch.zeros( + # (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device + # ) + # for r in range(self.comm.size): + # if r not in actives: + # loc_key = key.copy() + # loc_key[self.split] = slice(0, 0, 0) + # else: + # key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] + # key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] + # key_start_l, key_stop_l = self.__xitem_get_key_start_stop( + # r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start + # ) + # loc_key = key.copy() + # loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) + + # gout_full = torch.tensor( + # self_proxy[loc_key].shape, device=self.device.torch_device + # ) + # target_reshape_map[r] = gout_full + # local_keys.append(loc_key) + + # key = local_keys[rank] + # value = value.redistribute(target_map=target_reshape_map) + + # if rank not in actives: + # return # non-active ranks can exit here + + # chunk_starts_v = target_reshape_map[:, self.split] + # value_slice = [slice(None, None, None)] * value.ndim + # step2 = key_step if key_step is not None else 1 + # key_start = (chunk_starts_v[rank] - og_key_start).item() + + # key_start = max(key_start, 0) + # key_stop = key_start + key_stop + # slice_loc = min(self.split, value.ndim - 1) + # value_slice[slice_loc] = slice( + # key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # ) + + # self.__setter(tuple(key), value.larray) + # return + + # # if rank in actives: + # if rank not in actives: + # return # non-active ranks can exit here + # key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] + # key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] + # key_start, key_stop = self.__xitem_get_key_start_stop( + # rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + # ) + # key[self.split] = slice(key_start, key_stop, key_step) + + # # todo: need to slice the values to be the right size... + # if isinstance(value, (torch.Tensor, type(self))): + # # if its a torch tensor, it is assumed to exist on all processes + # value_slice = [slice(None, None, None)] * value.ndim + # step2 = key_step if key_step is not None else 1 + # key_start = (chunk_starts[rank] - og_key_start).item() + # key_start = max(key_start, 0) + # key_stop = key_start + key_stop + # slice_loc = min(self.split, value.ndim - 1) + # value_slice[slice_loc] = slice( + # key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # ) + # self.__setter(tuple(key), value[tuple(value_slice)]) + # else: + # self.__setter(tuple(key), value) + # elif isinstance(key[self.split], (torch.Tensor, list)): + # key = list(key) + # key[self.split] -= chunk_start + # if len(key[self.split]) != 0: + # self.__setter(tuple(key), value) + + # elif key[self.split] in range(chunk_start, chunk_end): + # key = list(key) + # key[self.split] = key[self.split] - chunk_start + # self.__setter(tuple(key), value) + + # elif key[self.split] < 0: + # key = list(key) + # if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): + # key[self.split] = key[self.split] + self.shape[self.split] - chunk_start + # self.__setter(tuple(key), value) def __setter( self, diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 863e140799..d23fa40a7b 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -288,23 +288,31 @@ def sanitize_out( if not isinstance(out, DNDarray): raise TypeError(f"expected `out` to be None or a DNDarray, but was {type(out)}") - out_proxy = out.__torch_proxy__() - out_proxy.names = [ - "split" if (out.split is not None and i == out.split) else f"_{i}" - for i in range(out_proxy.ndim) - ] - out_proxy = out_proxy.squeeze() - - check_proxy = torch.ones(1).expand(output_shape) - check_proxy.names = [ - "split" if (output_split is not None and i == output_split) else f"_{i}" - for i in range(check_proxy.ndim) - ] - check_proxy = check_proxy.squeeze() - - if out_proxy.shape != check_proxy.shape: - raise ValueError(f"Expecting output buffer of shape {output_shape}, got {out.shape}") - count_split = int(out.split is not None) + int(output_split is not None) + if len(output_shape) == 0: + # 0-dimensional arrays don't need so many checks + if len(out.shape) != 0: + raise ValueError(f"Expecting output buffer of shape {output_shape}, got {out.shape}") + # 0-dimensional arrays cannot be split + count_split = 0 + else: + out_proxy = out.__torch_proxy__() + out_proxy.names = [ + "split" if (out.split is not None and i == out.split) else f"_{i}" + for i in range(out_proxy.ndim) + ] + out_proxy = out_proxy.squeeze() + + check_proxy = torch.ones(1).expand(output_shape) + check_proxy.names = [ + "split" if (output_split is not None and i == output_split) else f"_{i}" + for i in range(check_proxy.ndim) + ] + check_proxy = check_proxy.squeeze() + + if out_proxy.shape != check_proxy.shape: + raise ValueError(f"Expecting output buffer of shape {output_shape}, got {out.shape}") + count_split = int(out.split is not None) + int(output_split is not None) + if count_split == 1: raise ValueError( "Split axis of output buffer is inconsistent with split semantics for this operation." @@ -326,6 +334,7 @@ def sanitize_out( raise ValueError( "Split axis of output buffer is inconsistent with split semantics for this operation." ) + if out.device != output_device: raise ValueError(f"Device mismatch: out is on {out.device}, should be on {output_device}") if output_comm is not None and out.comm != output_comm: From 8cf3ff129d8760b6135e9395f76365214361c0cd Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 9 Aug 2023 09:13:48 +0200 Subject: [PATCH 074/132] DRAFT - abstraction common utilities for getitem and setitem --- heat/core/dndarray.py | 221 ++++++++++++++++++++++-------------------- 1 file changed, 115 insertions(+), 106 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ba654caff8..8419a1ed3a 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -582,8 +582,8 @@ def counts_displs(self) -> Tuple[Tuple[int], Tuple[int]]: counts = self.lshape_map[:, self.split] displs = [0] + torch.cumsum(counts, dim=0)[:-1].tolist() return tuple(counts.tolist()), tuple(displs) - else: - raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.") + + raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.") def cpu(self) -> DNDarray: """ @@ -775,7 +775,11 @@ def fill_diagonal(self, value: float) -> DNDarray: return self - def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> Tuple: + def __process_key( + arr: DNDarray, + key: Union[Tuple[int, ...], List[int, ...]], + return_local_indices: Optional[bool] = False, + ) -> Tuple: """ Private method to process the key used for indexing a ``DNDarray`` so that it can be applied to the process-local data, i.e. `key` must be "torch-proof". In a processed key: @@ -790,6 +794,8 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> The ``DNDarray`` to be indexed key : int, Tuple[int, ...], List[int, ...] The key used for indexing + return_local_indices : bool, optional + Whether to return the process-local indices of the key in the split dimension. This is only possible when the indexing key in the split dimension is ordered e.g. `split_key_is_sorted == 1`. Default: False Returns ------- @@ -822,7 +828,7 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> arr_is_distributed = True advanced_indexing = False - split_key_is_sorted = 1 # can be 1: ascending, 0: not sorted, -1: descending + split_key_is_sorted = 1 out_is_balanced = False root = None backwards_transpose_axes = tuple(range(arr.ndim)) @@ -893,13 +899,13 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> key = list(key.nonzero()) output_shape = key[0].shape new_split = 0 - # all local indexing + split_key_is_sorted = 1 out_is_balanced = False for i, k in enumerate(key): key[i] = k.larray - key[arr.split] -= displs[arr.comm.rank] + if return_local_indices: + key[arr.split] -= displs[arr.comm.rank] key = tuple(key) - split_key_is_sorted = 1 else: key = key.larray.nonzero(as_tuple=False) # construct global key array @@ -1006,7 +1012,8 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> cond1 = key >= displs[arr.comm.rank] cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] key = key[cond1 & cond2] - key -= displs[arr.comm.rank] + if return_local_indices: + key -= displs[arr.comm.rank] out_is_balanced = False else: try: @@ -1072,67 +1079,7 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> advanced_indexing_shapes = [] lose_dims = 0 for i, k in enumerate(key): - if isinstance(k, Iterable) or isinstance(k, DNDarray): - # advanced indexing across dimensions - if getattr(k, "ndim", 1) == 0: - # single-element indexing along axis i - try: - output_shape[i], split_bookkeeping[i] = None, None - except IndexError: - raise IndexError( - f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" - ) - lose_dims += 1 - if arr_is_distributed and i == arr.split: - # single-element indexing along split axis - # work out root process for Bcast - key[i] = k.item() + arr.shape[i] if k.item() < 0 else k.item() - if key[i] in displs: - root = displs.index(key[i]) - else: - displs = torch.cat( - (torch.tensor(displs), torch.tensor(key[i]).reshape(-1)), dim=0 - ) - _, sorted_indices = displs.unique(sorted=True, return_inverse=True) - root = sorted_indices[-1] - 1 - # correct key for rank-specific displacement - if arr.comm.rank == root: - key[i] -= displs[root] - else: - key[i] = k.item() - else: - advanced_indexing = True - advanced_indexing_dims.append(i) - if isinstance(k, DNDarray): - advanced_indexing_shapes.append(k.gshape) - if arr_is_distributed and i == arr.split: - # we have no info on order of indices - split_key_is_sorted = 0 - # redistribute key along last axis to match split axis of indexed array - k = k.resplit(-1) - out_is_balanced = True - key[i] = k.larray - elif not isinstance(k, torch.Tensor): - key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) - advanced_indexing_shapes.append(tuple(key[i].shape)) - # IMPORTANT: here we assume that torch or ndarray key is THE SAME SET OF GLOBAL INDICES on every rank - if arr_is_distributed and i == arr.split: - # make no assumption on data locality wrt key - out_is_balanced = None - # assess if indices are in ascending order - if ( - key[i].ndim == 1 - and (key[i] == torch.sort(key[i], stable=True)[0]).all() - ): - split_key_is_sorted = 1 - # extract local key - cond1 = key[i] >= displs[arr.comm.rank] - cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] - key[i] = key[i][cond1 & cond2] - key[i] -= displs[arr.comm.rank] - else: - split_key_is_sorted = 0 - elif isinstance(k, int): + if np.isscalar(k) or getattr(k, "ndim", 1) == 0: # single-element indexing along axis i try: output_shape[i], split_bookkeeping[i] = None, None @@ -1141,21 +1088,42 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" ) lose_dims += 1 - if arr_is_distributed and i == arr.split: - # single-element indexing along split axis - # work out root process for Bcast - key[i] = k + arr.shape[i] if k < 0 else k - if key[i] in displs: - root = displs.index(key[i]) - else: - displs = torch.cat( - (torch.tensor(displs), torch.tensor(key[i]).reshape(-1)), dim=0 - ) - _, sorted_indices = displs.unique(sorted=True, return_inverse=True) - root = sorted_indices[-1] - 1 - # correct key for rank-specific displacement - if arr.comm.rank == root: - key[i] -= displs[root] + key[i], root = arr.__process_scalar_key( + k, split=i, return_local_indices=return_local_indices + ) + elif isinstance(k, Iterable) or isinstance(k, DNDarray): + advanced_indexing = True + advanced_indexing_dims.append(i) + if isinstance(k, DNDarray): + advanced_indexing_shapes.append(k.gshape) + if arr_is_distributed and i == arr.split: + # we have no info on order of indices + split_key_is_sorted = 0 + # redistribute key along last axis to match split axis of indexed array + k = k.resplit(-1) + out_is_balanced = True + key[i] = k.larray + elif not isinstance(k, torch.Tensor): + key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) + advanced_indexing_shapes.append(tuple(key[i].shape)) + # IMPORTANT: here we assume that torch or ndarray key is THE SAME SET OF GLOBAL INDICES on every rank + if arr_is_distributed and i == arr.split: + # make no assumption on data locality wrt key + out_is_balanced = None + # assess if indices are in ascending order + if ( + key[i].ndim == 1 + and (key[i] == torch.sort(key[i], stable=True)[0]).all() + ): + split_key_is_sorted = 1 + # extract local key + cond1 = key[i] >= displs[arr.comm.rank] + cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] + key[i] = key[i][cond1 & cond2] + if return_local_indices: + key[i] -= displs[arr.comm.rank] + else: + split_key_is_sorted = 0 elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step @@ -1300,7 +1268,10 @@ def __process_key(arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]]) -> ) def __process_scalar_key( - arr: DNDarray, key: Union[int, DNDarray, torch.Tensor, np.ndarray] + arr: DNDarray, + key: Union[int, DNDarray, torch.Tensor, np.ndarray], + split: int, + return_local_indices: Optional[bool] = False, ) -> Tuple(int, int): """ Private method to process a single-item scalar key used for indexing a ``DNDarray``. @@ -1317,7 +1288,10 @@ def __process_scalar_key( except AttributeError: # key is already an integer, do nothing pass - if arr.is_distributed() and arr.split == 0: + if not arr.is_distributed(): + root = 0 + return key, root + if arr.is_distributed() and arr.split == split: # adjust negative key if key < 0: key += arr.shape[0] @@ -1336,8 +1310,9 @@ def __process_scalar_key( _, sorted_indices = displs.unique(sorted=True, return_inverse=True) root = sorted_indices[-1] - 1 # correct key for rank-specific displacement - if arr.comm.rank == root: - key -= displs[root] + if return_local_indices: + if arr.comm.rank == root: + key -= displs[root] else: root = None return key, root @@ -1407,11 +1382,12 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar return self original_split = self.split + # Single-element indexing scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: output_shape = self.gshape[1:] - key, root = self.__process_scalar_key(key) + key, root = self.__process_scalar_key(key, split=0, return_local_indices=True) if root is None: # single-element indexing along non-split axis indexed_arr = self.larray[key] @@ -1461,7 +1437,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar out_is_balanced, root, backwards_transpose_axes, - ) = self.__process_key(key) + ) = self.__process_key(key, return_local_indices=True) print("DEBUGGING: processed key, output_split = ", key, output_split) @@ -1489,9 +1465,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar self = self.transpose(backwards_transpose_axes) return indexed_arr - # TODO: test that key for not affected dims is always slice(None) - # including match between self.split and key after self manipulation - # data are not distributed or split dimension is not affected by indexing print("split_key_is_sorted, key = ", split_key_is_sorted, key) if split_key_is_sorted == 1: @@ -2270,7 +2243,7 @@ def __setitem__( """ def __set( - arr: Union[DNDarray, torch.Tensor], + arr: DNDarray, value: Union[DNDarray, torch.Tensor, np.ndarray, float, int, list, tuple], ): """ @@ -2283,33 +2256,39 @@ def __set( ) except TypeError: raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") + value_shape = value.shape while value.ndim < arr.ndim: # broadcasting + print("DEBUGGING: value.ndim, value.shape = ", value.ndim, value.shape) value = value.expand_dims(0) - sanitation.sanitize_out(arr, value.shape, value.split, value.device, value.comm) + print("DEBUGGING: value.shape = ", value.shape) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, arr.shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value.shape} into shape {arr.shape}" + ) + sanitation.sanitize_out(arr, value_shape, value.split, value.device, value.comm) value = sanitation.sanitize_distribution(value, target=arr) - try: - arr.larray[None] = value.larray - except AttributeError: - # arr is already the process-local torch tensor - arr[None] = value.larray + arr.larray[None] = value.larray return if key is None or key == ... or key == slice(None): - return __set(self, value) + return __set(self, self.larray, value) # torch_device = self.larray.device - # scalar key + # single-element key scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: - key, root = self.__process_scalar_key(key) + key, root = self.__process_scalar_key(key, split=0, return_local_indices=False) if root is not None: if self.comm.rank == root: - __set(self.larray[key], value) + __set(self[key], value) else: __set(self[key], value) return + # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing ( self, key, @@ -2319,9 +2298,39 @@ def __set( out_is_balanced, root, backwards_transpose_axes, - ) = self.__process_key(key) + ) = self.__process_key(key, return_local_indices=True) + + # sanitize value + value_split = value.split if isinstance(value, DNDarray) else None + try: + value = factories.array( + value, dtype=self.dtype, split=value_split, device=self.device, comm=self.comm + ) + except TypeError: + raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") + value_shape = value.shape + while value.ndim < len(output_shape): # broadcasting + value = value.expand_dims(0) + try: + value_shape = tuple(torch.broadcast_shapes(value_shape, output_shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value.shape} into shape {output_shape}" + ) + # TODO: sanitize distribution without allocating getitem array + + if split_key_is_sorted: + # data are not distributed or split dimension is not affected by indexing + # key all local + if root is not None: + # single-element assignment along split axis, only one active process + if self.comm.rank == root: + self.larray[key] = value.larray + else: + self.larray[key] = value.larray + self = self.transpose(backwards_transpose_axes) + return - # if split_key_is_sorted: # process-local indices # if advanced_indexing: From b45578adf17d222c14484e7c442e2463d8126785 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 9 Aug 2023 11:23:04 +0200 Subject: [PATCH 075/132] handle all single-element indexing along split axis in same block --- heat/core/dndarray.py | 194 +++++++++++++++---------------- heat/core/tests/test_dndarray.py | 2 +- 2 files changed, 94 insertions(+), 102 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8419a1ed3a..0d5eb58fb9 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -795,7 +795,7 @@ def __process_key( key : int, Tuple[int, ...], List[int, ...] The key used for indexing return_local_indices : bool, optional - Whether to return the process-local indices of the key in the split dimension. This is only possible when the indexing key in the split dimension is ordered e.g. `split_key_is_sorted == 1`. Default: False + Whether to return the process-local indices of the key in the split dimension. This is only possible when the indexing key in the split dimension is ordered e.g. `split_key_is_ordered == 1`. Default: False Returns ------- @@ -808,7 +808,7 @@ def __process_key( The shape of the output ``DNDarray`` new_split : int The new split axis - split_key_is_sorted : int + split_key_is_ordered : int Whether the split key is sorted or ordered. Can be 1: ascending, 0: not ordered, -1: descending order. out_is_balanced : bool Whether the output ``DNDarray`` is balanced @@ -828,7 +828,7 @@ def __process_key( arr_is_distributed = True advanced_indexing = False - split_key_is_sorted = 1 + split_key_is_ordered = 1 out_is_balanced = False root = None backwards_transpose_axes = tuple(range(arr.ndim)) @@ -870,13 +870,13 @@ def __process_key( output_shape = tuple(key[0].shape) new_split = None if arr.split is None else 0 out_is_balanced = True - split_key_is_sorted = 1 + split_key_is_ordered = 1 return ( arr, key, output_shape, new_split, - split_key_is_sorted, + split_key_is_ordered, out_is_balanced, root, backwards_transpose_axes, @@ -899,7 +899,7 @@ def __process_key( key = list(key.nonzero()) output_shape = key[0].shape new_split = 0 - split_key_is_sorted = 1 + split_key_is_ordered = 1 out_is_balanced = False for i, k in enumerate(key): key[i] = k.larray @@ -934,14 +934,14 @@ def __process_key( output_shape = (key[0].shape[0],) new_split = 0 - split_key_is_sorted = 0 + split_key_is_ordered = 0 out_is_balanced = True return ( arr, key, output_shape, new_split, - split_key_is_sorted, + split_key_is_ordered, out_is_balanced, root, backwards_transpose_axes, @@ -982,15 +982,18 @@ def __process_key( try: # DNDarray key sorted, _ = torch.sort(key.larray, stable=True) - split_key_is_sorted = torch.tensor( + split_key_is_ordered = torch.tensor( (key.larray == sorted).all(), dtype=torch.uint8, device=key.larray.device, ) if key.split is not None: out_is_balanced = key.balanced - split_key_is_sorted = factories.array( - [split_key_is_sorted], is_split=0, device=arr.device, copy=False + split_key_is_ordered = factories.array( + [split_key_is_ordered], + is_split=0, + device=arr.device, + copy=False, ).all() key = key.larray except AttributeError: @@ -1000,14 +1003,14 @@ def __process_key( except TypeError: # ndarray key sorted = torch.tensor(np.sort(key), device=arr.larray.device) - split_key_is_sorted = torch.tensor( + split_key_is_ordered = torch.tensor( key == sorted, dtype=torch.uint8 ).item() - if not split_key_is_sorted: + if not split_key_is_ordered: # prepare for distributed non-ordered indexing: distribute torch/numpy key key = factories.array(key, split=0, device=arr.device).larray out_is_balanced = True - if split_key_is_sorted: + if split_key_is_ordered: # extract local key cond1 = key >= displs[arr.comm.rank] cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] @@ -1029,7 +1032,7 @@ def __process_key( key, output_shape, new_split, - split_key_is_sorted, + split_key_is_ordered, out_is_balanced, root, backwards_transpose_axes, @@ -1088,9 +1091,14 @@ def __process_key( f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" ) lose_dims += 1 - key[i], root = arr.__process_scalar_key( - k, split=i, return_local_indices=return_local_indices - ) + if i == arr.split: + key[i], root = arr.__process_scalar_key( + k, indexed_axis=i, return_local_indices=return_local_indices + ) + else: + key[i], _ = arr.__process_scalar_key( + k, indexed_axis=i, return_local_indices=False + ) elif isinstance(k, Iterable) or isinstance(k, DNDarray): advanced_indexing = True advanced_indexing_dims.append(i) @@ -1098,7 +1106,7 @@ def __process_key( advanced_indexing_shapes.append(k.gshape) if arr_is_distributed and i == arr.split: # we have no info on order of indices - split_key_is_sorted = 0 + split_key_is_ordered = 0 # redistribute key along last axis to match split axis of indexed array k = k.resplit(-1) out_is_balanced = True @@ -1115,7 +1123,7 @@ def __process_key( key[i].ndim == 1 and (key[i] == torch.sort(key[i], stable=True)[0]).all() ): - split_key_is_sorted = 1 + split_key_is_ordered = 1 # extract local key cond1 = key[i] >= displs[arr.comm.rank] cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] @@ -1123,7 +1131,7 @@ def __process_key( if return_local_indices: key[i] -= displs[arr.comm.rank] else: - split_key_is_sorted = 0 + split_key_is_ordered = 0 elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step @@ -1148,12 +1156,12 @@ def __process_key( key[i] = factories.array( key[i], split=0, device=arr.device, copy=False ).larray - split_key_is_sorted = -1 + split_key_is_ordered = -1 out_is_balanced = True elif step > 0 and start < stop: output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) if arr_is_distributed and new_split == i: - split_key_is_sorted = 1 + split_key_is_ordered = 1 out_is_balanced = False local_arr_end = displs[arr.comm.rank] + counts[arr.comm.rank] if stop > displs[arr.comm.rank] and start < local_arr_end: @@ -1249,11 +1257,11 @@ def __process_key( output_shape = tuple(output_shape) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None print( - "key, output_shape, new_split, split_key_is_sorted, out_is_balanced = ", + "key, output_shape, new_split, split_key_is_ordered, out_is_balanced = ", key, output_shape, new_split, - split_key_is_sorted, + split_key_is_ordered, out_is_balanced, ) return ( @@ -1261,7 +1269,7 @@ def __process_key( key, output_shape, new_split, - split_key_is_sorted, + split_key_is_ordered, out_is_balanced, root, backwards_transpose_axes, @@ -1270,7 +1278,7 @@ def __process_key( def __process_scalar_key( arr: DNDarray, key: Union[int, DNDarray, torch.Tensor, np.ndarray], - split: int, + indexed_axis: int, return_local_indices: Optional[bool] = False, ) -> Tuple(int, int): """ @@ -1289,9 +1297,9 @@ def __process_scalar_key( # key is already an integer, do nothing pass if not arr.is_distributed(): - root = 0 + root = None return key, root - if arr.is_distributed() and arr.split == split: + if arr.split == indexed_axis: # adjust negative key if key < 0: key += arr.shape[0] @@ -1386,14 +1394,23 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # Single-element indexing scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: + # single-element indexing on axis 0 + if self.ndim == 0: + raise IndexError( + "Too many indices for DNDarray: DNDarray is 0-dimensional, but 1 were indexed" + ) output_shape = self.gshape[1:] - key, root = self.__process_scalar_key(key, split=0, return_local_indices=True) + if original_split is None or original_split == 0: + output_split = None + else: + output_split = original_split - 1 + split_key_is_ordered = 1 + out_is_balanced = True + backwards_transpose_axes = tuple(range(self.ndim)) + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) if root is None: - # single-element indexing along non-split axis + # early out for single-element indexing not affecting split axis indexed_arr = self.larray[key] - output_split = ( - None if (original_split is None or original_split == 0) else original_split - 1 - ) indexed_arr = DNDarray( indexed_arr, gshape=output_shape, @@ -1401,73 +1418,48 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar split=output_split, device=self.device, comm=self.comm, - balanced=self.balanced, + balanced=out_is_balanced, ) return indexed_arr - # root is not None: single-element indexing along split axis - # prepare for Bcast: allocate buffer on all processes - if self.comm.rank == root: - indexed_arr = self.larray[key] - else: - indexed_arr = torch.zeros( - output_shape, dtype=self.larray.dtype, device=self.larray.device - ) - # broadcast result to all processes - self.comm.Bcast(indexed_arr, root=root) - indexed_arr = DNDarray( - indexed_arr, - gshape=output_shape, - dtype=self.dtype, - split=None, - device=self.device, - comm=self.comm, - balanced=True, - ) - return indexed_arr - - # Many-elements indexing: incl. slicing and striding, ordered and non-ordered advanced indexing - - # Preprocess: Process Ellipsis + 'None' indexing; make Iterables to DNDarrays - ( - self, - key, - output_shape, - output_split, - split_key_is_sorted, - out_is_balanced, - root, - backwards_transpose_axes, - ) = self.__process_key(key, return_local_indices=True) - - print("DEBUGGING: processed key, output_split = ", key, output_split) + else: + # multi-element key + ( + self, + key, + output_shape, + output_split, + split_key_is_ordered, + out_is_balanced, + root, + backwards_transpose_axes, + ) = self.__process_key(key, return_local_indices=True) - if root is not None: - # single-element indexing along split axis - # allocate buffer on all processes - if self.comm.rank == root: - indexed_arr = self.larray[key] - else: - indexed_arr = torch.zeros( - output_shape, dtype=self.larray.dtype, device=self.larray.device + if split_key_is_ordered == 1: + if root is not None: + # single-element indexing along split axis + # prepare for Bcast: allocate buffer on all processes + if self.comm.rank == root: + indexed_arr = self.larray[key] + else: + indexed_arr = torch.zeros( + output_shape, dtype=self.larray.dtype, device=self.larray.device + ) + # broadcast result to all processes + self.comm.Bcast(indexed_arr, root=root) + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, ) - # broadcast result to all processes - self.comm.Bcast(indexed_arr, root=root) - indexed_arr = DNDarray( - indexed_arr, - gshape=output_shape, - dtype=self.dtype, - split=output_split, - device=self.device, - comm=self.comm, - balanced=True, - ) - # transpose array back if needed - self = self.transpose(backwards_transpose_axes) - return indexed_arr + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) + return indexed_arr - # data are not distributed or split dimension is not affected by indexing - print("split_key_is_sorted, key = ", split_key_is_sorted, key) - if split_key_is_sorted == 1: + # root is None, i.e. indexing does not affect split axis, apply as is indexed_arr = self.larray[key] # transpose array back if needed self = self.transpose(backwards_transpose_axes) @@ -1481,7 +1473,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # key is not sorted along self.split + # key is not ordered along self.split # key is tuple of torch.Tensor or mix of torch.Tensors and slices _, displs = self.counts_displs() @@ -1527,7 +1519,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) # process-local: calculate which/how many elements will be received from what process - if split_key_is_sorted == -1: + if split_key_is_ordered == -1: # key is sorted in descending order (i.e. slicing w/ negative step): # shrink selection of active processes if key[original_split].numel() > 0: @@ -2280,7 +2272,7 @@ def __set( # single-element key scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: - key, root = self.__process_scalar_key(key, split=0, return_local_indices=False) + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=False) if root is not None: if self.comm.rank == root: __set(self[key], value) @@ -2294,7 +2286,7 @@ def __set( key, output_shape, output_split, - split_key_is_sorted, + split_key_is_ordered, out_is_balanced, root, backwards_transpose_axes, @@ -2319,7 +2311,7 @@ def __set( ) # TODO: sanitize distribution without allocating getitem array - if split_key_is_sorted: + if split_key_is_ordered: # data are not distributed or split dimension is not affected by indexing # key all local if root is not None: diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 601e3c4c63..c90378ee8b 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -684,8 +684,8 @@ def test_getitem(self): adv_indexed_x_np = x_np[(1, 2, 3),] x = ht.array(x_np, split=0) indexed_x = x[(1, 2, 3)] - adv_indexed_x = x[(1, 2, 3),] self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) + adv_indexed_x = x[(1, 2, 3),] self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) # 1d From cec4bb977ac414adc0e18400b767d08e9c954408 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:11:14 +0200 Subject: [PATCH 076/132] resolve send/recv dimensions mismatch in a few edge cases --- heat/core/dndarray.py | 114 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 98 insertions(+), 16 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 0d5eb58fb9..c04fa72289 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -779,6 +779,7 @@ def __process_key( arr: DNDarray, key: Union[Tuple[int, ...], List[int, ...]], return_local_indices: Optional[bool] = False, + op: Optional[str] = None, ) -> Tuple: """ Private method to process the key used for indexing a ``DNDarray`` so that it can be applied to the process-local data, i.e. `key` must be "torch-proof". @@ -796,6 +797,8 @@ def __process_key( The key used for indexing return_local_indices : bool, optional Whether to return the process-local indices of the key in the split dimension. This is only possible when the indexing key in the split dimension is ordered e.g. `split_key_is_ordered == 1`. Default: False + op : str, optional + The indexing operation that the key is being processed for. Get be "get" for `__getitem__` or "set" for `__setitem__`. Default: "get". Returns ------- @@ -1146,18 +1149,29 @@ def __process_key( if step is None: step = 1 if step < 0 and start > stop: + print("TEST LOCAL SLICE: ", arr.__get_local_slice(k)) # PyTorch doesn't support negative step as of 1.13 # Lazy solution, potentially large memory footprint # TODO: implement ht.fromiter (implemented in ASSET_ht) - key[i] = list(range(start, stop, step)) + key[i] = torch.tensor(list(range(start, stop, step)), device=arr.larray.device) output_shape[i] = len(key[i]) + split_key_is_ordered = -1 if arr_is_distributed and new_split == i: - # distribute key and proceed with non-ordered indexing - key[i] = factories.array( - key[i], split=0, device=arr.device, copy=False - ).larray - split_key_is_ordered = -1 - out_is_balanced = True + if op == "set": + # setitem: flip key and keep process-local indices + key[i] = key[i].flip(0) + cond1 = key[i] >= displs[arr.comm.rank] + cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] + key[i] = key[i][cond1 & cond2] + if return_local_indices: + key[i] -= displs[arr.comm.rank] + else: + # getitem: distribute key and proceed with non-ordered indexing + key[i] = factories.array( + key[i], split=0, device=arr.device, copy=False + ).larray + print("DEBUGGING: key[i] = ", key[i]) + out_is_balanced = True elif step > 0 and start < stop: output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) if arr_is_distributed and new_split == i: @@ -1668,15 +1682,13 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar incoming_request_key -= displs[self.comm.rank] incoming_request_key = ( key[:original_split] - + (incoming_request_key.squeeze_(1).tolist(),) + + (incoming_request_key.squeeze_(1),) + key[original_split + 1 :] ) print("AFTER: incoming_request_key = ", incoming_request_key) - # print("OUTPUT_SHAPE = ", output_shape) - # print("OUTPUT_SPLIT = ", output_split) - - send_buf = self.larray[incoming_request_key] + print("original_split = ", original_split) + # calculate shape of local recv buffer output_lshape = list(output_shape) if getattr(key, "ndim", 0) == 1: output_lshape[output_split] = key.shape[0] @@ -1684,18 +1696,64 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if broadcasted_indexing: output_lshape = ( output_lshape[:original_split] - + [torch.prod(torch.tensor(broadcast_shape, device=send_buf.device)).item()] + + [torch.prod(torch.tensor(broadcast_shape, device=self.larray.device)).item()] + output_lshape[output_split + 1 :] ) else: output_lshape[output_split] = key[original_split].shape[0] + # allocate recv buffer recv_buf = torch.empty( tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device ) + + # index local data into send_buf. + send_empty = sum( + list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in incoming_request_key) + ) # incoming_request_key.count([]) + if send_empty: + # Edge case 1. empty slice along split axis: send_buf is 0-element tensor + empty_shape = list(output_shape) + empty_shape[output_split] = 0 + send_buf = torch.empty(empty_shape, dtype=self.larray.dtype, device=self.larray.device) + else: + send_buf = self.larray[incoming_request_key] + # Edge case 2. local single-element indexing results into local loss of split axis + if send_buf.ndim < len(output_lshape): + all_keys_scalar = sum( + list( + np.isscalar(k) or k.numel() == 1 and getattr(k, "ndim", 2) < 2 + for k in incoming_request_key + ) + ) == len(incoming_request_key) + if not all_keys_scalar: + send_buf = send_buf.unsqueeze_(dim=output_split) + + print("OUTPUT_SHAPE = ", output_shape) + print("OUTPUT_SPLIT = ", output_split) + print("SEND_BUF SHAPE = ", send_buf.shape) + + # output_lshape = list(output_shape) + # if getattr(key, "ndim", 0) == 1: + # output_lshape[output_split] = key.shape[0] + # else: + # if broadcasted_indexing: + # output_lshape = ( + # output_lshape[:original_split] + # + [torch.prod(torch.tensor(broadcast_shape, device=send_buf.device)).item()] + # + output_lshape[output_split + 1 :] + # ) + # else: + # output_lshape[output_split] = key[original_split].shape[0] + # recv_buf = torch.empty( + # tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device + # ) recv_counts = torch.squeeze(recv_counts, dim=1).tolist() recv_displs = outgoing_request_key_displs.tolist() send_counts = incoming_request_key_counts.tolist() send_displs = incoming_request_key_displs.tolist() + print("DEBUGGING: send_buf recv_buf shape= ", send_buf.shape, recv_buf.shape) + print("DEBUGGING: send_counts recv_counts = ", send_counts, recv_counts) + print("DEBUGGING: send_displs recv_displs = ", send_displs, recv_displs) print("DEBUGGING: output_split = ", output_split) self.comm.Alltoallv( (send_buf, send_counts, send_displs), @@ -2265,7 +2323,7 @@ def __set( return if key is None or key == ... or key == slice(None): - return __set(self, self.larray, value) + return __set(self, value) # torch_device = self.larray.device @@ -2290,7 +2348,7 @@ def __set( out_is_balanced, root, backwards_transpose_axes, - ) = self.__process_key(key, return_local_indices=True) + ) = self.__process_key(key, return_local_indices=True, op="set") # sanitize value value_split = value.split if isinstance(value, DNDarray) else None @@ -2311,7 +2369,7 @@ def __set( ) # TODO: sanitize distribution without allocating getitem array - if split_key_is_ordered: + if split_key_is_ordered == 1: # data are not distributed or split dimension is not affected by indexing # key all local if root is not None: @@ -2323,6 +2381,30 @@ def __set( self = self.transpose(backwards_transpose_axes) return + if split_key_is_ordered == -1: + # key is in descending order, i.e. slice with negative step + + # flip value, match value distribution to keys + value = manipulations.flip(value, axis=output_split) + split_key = factories.array( + key[output_split], is_split=0, device=self.device, comm=self.comm + ) + if value.is_distributed(): + target_map = value.lshape_map + target_map[:, output_split] = split_key.lshape_map[:, 0] + print( + "DEBUGGING: TEST target_map, value.lshape_map = ", target_map, value.lshape_map + ) + value.redistribute_(target_map=target_map) + + process_is_inactive = sum( + list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) + ) + if not process_is_inactive: + # only assign values if key does not contain empty slices + self.larray[key] = value.larray + return + # process-local indices # if advanced_indexing: From cc49a49fbf319b1ebe616580cc8103e2243085a9 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sat, 12 Aug 2023 09:08:31 +0200 Subject: [PATCH 077/132] transpose self back to original shape after indexing --- heat/core/dndarray.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c04fa72289..33150e6022 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2403,8 +2403,12 @@ def __set( if not process_is_inactive: # only assign values if key does not contain empty slices self.larray[key] = value.larray + self = self.transpose(backwards_transpose_axes) return + # non-ordered key along split axis + # indices are global + # process-local indices # if advanced_indexing: From fe26ae825d07d46b2a8b220de5221033c4a58bb1 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 30 Aug 2023 06:04:30 +0200 Subject: [PATCH 078/132] add setitem tests --- heat/core/tests/test_dndarray.py | 247 +++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index c90378ee8b..0d19e98023 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1326,6 +1326,253 @@ def test_rshift(self): res = ht.right_shift(ht.array([True]), 2) self.assertTrue(res == 0) + def test_setitem(self): + # following https://numpy.org/doc/stable/user/basics.indexing.html + + # Single element indexing + # 1D, local + x = ht.zeros(10) + x[2] = 2 + x[-2] = 8 + self.assertTrue(x[2].item() == 2) + self.assertTrue(x[-2].item() == 8) + self.assertTrue(x[2].dtype == ht.float32) + # 1D, distributed + x = ht.zeros(10, split=0, dtype=ht.float64) + x[2] = 2 + x[-2] = 8 + self.assertTrue(x[2].item() == 2.0) + self.assertTrue(x[-2].item() == 8.0) + self.assertTrue(x[2].dtype == ht.float64) + self.assertTrue(x.split == 0) + # 2D, local + x = ht.zeros(10).reshape(2, 5) + x[0] = ht.arange(5) + self.assertTrue((x[0] == ht.arange(5)).all().item()) + self.assertTrue(x[0].dtype == ht.float32) + # 2D, distributed + x_split0 = ht.zeros(10, split=0).reshape(2, 5) + x_split0[0] = ht.arange(5) + self.assertTrue((x_split0[0] == ht.arange(5, split=None)).all().item()) + x_split1 = ht.zeros(10, split=0).reshape(2, 5, new_split=1) + x_split1[-2] = ht.arange(5) + self.assertTrue((x_split1[-2] == ht.arange(5, split=0)).all().item()) + # 3D, distributed, split = 0 + x_split0 = ht.zeros(27, split=0).reshape(3, 3, 3) + key = -2 + x_split0[key] = ht.arange(3) + self.assertTrue((x_split0[key].larray == torch.arange(3)).all()) + self.assertTrue(x_split0[key].dtype == ht.float32) + self.assertTrue(x_split0[key].split == 0) + # 3D, distributed split, != 0 + x_split2 = ht.zeros(27, dtype=ht.int64, split=0).reshape(3, 3, 3, new_split=2) + key = ht.array(2) + x_split2[key] = [6, 7, 8] + indexed_split2 = x_split2[key] + self.assertTrue((indexed_split2.numpy()[0] == np.array([6, 7, 8])).all()) + self.assertTrue(indexed_split2.dtype == ht.int64) + self.assertTrue(x_split2.split == 2) + + # Slicing and striding + x = ht.arange(20, split=0) + x_sliced = x[1:11:3] + x[1:11:3] = ht.array([10, 40, 70, 100]) + x_np = np.arange(20) + x_sliced_np = x_np[1:11:3] + x_np[1:11:3] = np.array([10, 40, 70, 100]) + self.assert_array_equal(x_sliced, x_sliced_np) + self.assert_array_equal(x_sliced, np.array([10, 40, 70, 100])) + self.assertTrue(x.split == 0) + + # # 1-element slice along split axis + # x = ht.arange(20).reshape(4, 5) + # x.resplit_(axis=1) + # x_sliced = x[:, 2:3] + # x_np = np.arange(20).reshape(4, 5) + # x_sliced_np = x_np[:, 2:3] + # self.assert_array_equal(x_sliced, x_sliced_np) + # self.assertTrue(x_sliced.split == 1) + + # # slicing with negative step along split axis 0 + # shape = (20, 4, 3) + # x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) + # x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] + # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[17:2:-2, :2, 1] + # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + # self.assertTrue(x_3d_sliced.split == 0) + + # # slicing with negative step along split 1 + # shape = (4, 20, 3) + # x_3d = ht.arange(20 * 4 * 3).reshape(shape) + # x_3d.resplit_(axis=1) + # key = (slice(None, 2), slice(17, 2, -2), 1) + # x_3d_sliced = x_3d[key] + # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 17:2:-2, 1] + # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + # self.assertTrue(x_3d_sliced.split == 1) + + # # slicing with negative step along split 2 and loss of axis < split + # shape = (4, 3, 20) + # x_3d = ht.arange(20 * 4 * 3).reshape(shape) + # x_3d.resplit_(axis=2) + # key = (slice(None, 2), 1, slice(17, 10, -2)) + # x_3d_sliced = x_3d[key] + # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 1, 17:10:-2] + # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + # self.assertTrue(x_3d_sliced.split == 1) + + # # slicing with negative step along split 2 and loss of all axes but split + # shape = (4, 3, 20) + # x_3d = ht.arange(20 * 4 * 3).reshape(shape) + # x_3d.resplit_(axis=2) + # key = (0, 1, slice(17, 13, -1)) + # x_3d_sliced = x_3d[key] + # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[0, 1, 17:13:-1] + # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + # self.assertTrue(x_3d_sliced.split == 0) + + # # DIMENSIONAL INDEXING + # # ellipsis + # x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) + # x_np_ellipsis = x_np[..., 0] + # x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + + # # local + # x_ellipsis = x[..., 0] + # x_slice = x[:, :, 0] + # self.assert_array_equal(x_ellipsis, x_np_ellipsis) + # self.assert_array_equal(x_slice, x_np_ellipsis) + + # # distributed + # x.resplit_(axis=1) + # x_ellipsis = x[..., 0] + # x_slice = x[:, :, 0] + # self.assert_array_equal(x_ellipsis, x_np_ellipsis) + # self.assert_array_equal(x_slice, x_np_ellipsis) + # self.assertTrue(x_ellipsis.split == 1) + + # # newaxis: local + # x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + # x_np_newaxis = x_np[:, np.newaxis, :2, :] + # x_newaxis = x[:, np.newaxis, :2, :] + # x_none = x[:, None, :2, :] + # self.assert_array_equal(x_newaxis, x_np_newaxis) + # self.assert_array_equal(x_none, x_np_newaxis) + + # # newaxis: distributed + # x.resplit_(axis=1) + # x_newaxis = x[:, np.newaxis, :2, :] + # x_none = x[:, None, :2, :] + # self.assert_array_equal(x_newaxis, x_np_newaxis) + # self.assert_array_equal(x_none, x_np_newaxis) + # self.assertTrue(x_newaxis.split == 2) + # self.assertTrue(x_none.split == 2) + + # x = ht.arange(5, split=0) + # x_np = np.arange(5) + # y = x[:, np.newaxis] + x[np.newaxis, :] + # y_np = x_np[:, np.newaxis] + x_np[np.newaxis, :] + # self.assert_array_equal(y, y_np) + # self.assertTrue(y.split == 0) + + # # ADVANCED INDEXING + # # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" + + # x_np = np.arange(60).reshape(5, 3, 4) + # indexed_x_np = x_np[(1, 2, 3)] + # adv_indexed_x_np = x_np[(1, 2, 3),] + # x = ht.array(x_np, split=0) + # indexed_x = x[(1, 2, 3)] + # self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) + # adv_indexed_x = x[(1, 2, 3),] + # self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) + + # # 1d + # x = ht.arange(10, 1, -1, split=0) + # x_np = np.arange(10, 1, -1) + # x_adv_ind = x[np.array([3, 3, 1, 8])] + # x_np_adv_ind = x_np[np.array([3, 3, 1, 8])] + # self.assert_array_equal(x_adv_ind, x_np_adv_ind) + + # # 3d, split 0, non-unique, non-ordered key along split axis + # x = ht.arange(60, split=0).reshape(5, 3, 4) + # x_np = np.arange(60).reshape(5, 3, 4) + # k1 = np.array([0, 4, 1, 0]) + # k2 = np.array([0, 2, 1, 0]) + # k3 = np.array([1, 2, 3, 1]) + # self.assert_array_equal( + # x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] + # ) + # # advanced indexing on non-consecutive dimensions + # x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + # x_copy = x.copy() + # x_np = np.arange(60).reshape(5, 3, 4) + # k1 = np.array([0, 4, 1, 0]) + # k2 = 0 + # k3 = np.array([1, 2, 3, 1]) + # key = (k1, k2, k3) + # self.assert_array_equal(x[key], x_np[key]) + # # check that x is unchanged after internal manipulation + # self.assertTrue(x.shape == x_copy.shape) + # self.assertTrue(x.split == x_copy.split) + # self.assertTrue(x.lshape == x_copy.lshape) + # self.assertTrue((x == x_copy).all().item()) + + # # broadcasting shapes + # x.resplit_(axis=0) + # self.assert_array_equal(x[ht.array(k1, split=0), ht.array(1), 2], x_np[k1, 1, 2]) + # # test exception: broadcasting mismatching shapes + # k2 = np.array([0, 2, 1]) + # with self.assertRaises(IndexError): + # x[k1, k2, k3] + + # # more broadcasting + # x_np = np.arange(12).reshape(4, 3) + # rows = np.array([0, 3]) + # cols = np.array([0, 2]) + # x = ht.arange(12).reshape(4, 3) + # x.resplit_(1) + # x_np_indexed = x_np[rows[:, np.newaxis], cols] + # x_indexed = x[ht.array(rows)[:, np.newaxis], cols] + # self.assert_array_equal(x_indexed, x_np_indexed) + # self.assertTrue(x_indexed.split == 1) + + # # combining advanced and basic indexing + # y_np = np.arange(35).reshape(5, 7) + # y_np_indexed = y_np[np.array([0, 2, 4]), 1:3] + # y = ht.array(y_np, split=1) + # y_indexed = y[ht.array([0, 2, 4]), 1:3] + # self.assert_array_equal(y_indexed, y_np_indexed) + # self.assertTrue(y_indexed.split == 1) + + # x_np = np.arange(10 * 20 * 30).reshape(10, 20, 30) + # x = ht.array(x_np, split=1) + # ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) + # ind_array_np = ind_array.numpy() + # x_np_indexed = x_np[..., ind_array_np, :] + # x_indexed = x[..., ind_array, :] + # self.assert_array_equal(x_indexed, x_np_indexed) + # self.assertTrue(x_indexed.split == 3) + + # # boolean mask, local + # arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + # np.random.seed(42) + # mask = np.random.randint(0, 2, arr.shape, dtype=bool) + # self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) + + # # boolean mask, distributed + # arr_split0 = ht.array(arr, split=0) + # mask_split0 = ht.array(mask, split=0) + # self.assertTrue((arr_split0[mask_split0].numpy() == arr.numpy()[mask]).all()) + + # arr_split1 = ht.array(arr, split=1) + # mask_split1 = ht.array(mask, split=1) + # self.assert_array_equal(arr_split1[mask_split1], arr.numpy()[mask]) + + # arr_split2 = ht.array(arr, split=2) + # mask_split2 = ht.array(mask, split=2) + # self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) + # def test_setitem_getitem(self): # # tests for bug #825 # a = ht.ones((102, 102), split=0) From 6d2e36968dc1e9bc8298a7e33707979a46f7c55f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 7 Dec 2023 12:40:51 +0100 Subject: [PATCH 079/132] do not index input unnecessarily for sanitation --- heat/core/dndarray.py | 94 +++++++++++++++++++++++--------- heat/core/tests/test_dndarray.py | 2 +- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index efc36e46b4..88ca54aaea 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2344,48 +2344,83 @@ def __setitem__( def __set( arr: DNDarray, + key: Union[int, Tuple[int, ...], List[int, ...]], value: Union[DNDarray, torch.Tensor, np.ndarray, float, int, list, tuple], ): """ Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. """ - value_split = value.split if isinstance(value, DNDarray) else None + # # need information on indexed array, use proxy to limit memory usage + # subarray = arr.__torch_proxy__()[key] + # subarray_shape, subarray_ndim = tuple(subarray.shape), subarray.ndim + # while value.ndim < subarray_ndim: # broadcasting + # value = value.expand_dims(0) + # try: + # value_shape = tuple(torch.broadcast_shapes(value_shape, subarray_shape)) + # except RuntimeError: + # raise ValueError( + # f"could not broadcast input array from shape {value.shape} into shape {arr.shape}" + # ) + # # TODO: take this out of this function + # sanitation.sanitize_out(subarray, value_shape, value.split, value.device, value.comm) + # arr.larray[None] = value.larray + arr.__array__().__setitem__(key, value.__array__()) + return + + # make sure `value` is a DNDarray + if not isinstance(value, DNDarray): try: value = factories.array( - value, dtype=arr.dtype, split=value_split, device=arr.device, comm=arr.comm + value, dtype=self.dtype, split=None, device=self.device, comm=self.comm ) except TypeError: raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") - value_shape = value.shape - while value.ndim < arr.ndim: # broadcasting - print("DEBUGGING: value.ndim, value.shape = ", value.ndim, value.shape) - value = value.expand_dims(0) - print("DEBUGGING: value.shape = ", value.shape) - try: - value_shape = tuple(torch.broadcast_shapes(value.shape, arr.shape)) - except RuntimeError: - raise ValueError( - f"could not broadcast input array from shape {value.shape} into shape {arr.shape}" - ) - sanitation.sanitize_out(arr, value_shape, value.split, value.device, value.comm) - value = sanitation.sanitize_distribution(value, target=arr) - arr.larray[None] = value.larray - return - if key is None or key == ... or key == slice(None): - return __set(self, value) + # use low-memory torch_proxy in sanitation + indexed_proxy = self.__torch_proxy__()[key] + # `value` might be broadcasted + value_shape = value.shape + while value.ndim < indexed_proxy.ndim: # broadcasting + value = value.expand_dims(0) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {tuple(indexed_proxy.shape)}" + ) - # torch_device = self.larray.device + if key is None or key == ... or key == slice(None): + # make sure `self` and `value` distribution are aligned + value = sanitation.sanitize_distribution(value, target=self) + return __set(self, key, value) # single-element key scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: - key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=False) + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) + # `root` will be None when the indexed axis is not the split axis, or when the + # indexed axis is the split axis but the indexed element is not local if root is not None: if self.comm.rank == root: - __set(self[key], value) + # verify that `self[key]` and `value` distribution are aligned + # do not index `self` with `key` directly here, as this would MPI-broadcast to all ranks + if indexed_proxy.names.count("split") != 0: + # indexed_split = indexed_proxy.names.index("split") + # lshape_map of indexed subarray is the same as the lshape_map of the original array after losing the first dimension + indexed_lshape_map = self.lshape_map[:, 1:] + if value.lshape_map != indexed_lshape_map: + try: + value.redistribute_(target_map=indexed_lshape_map) + except ValueError: + raise ValueError( + f"cannot assign value to indexed DNDarray because distribution schemes do not match: {value.lshape_map} vs. {indexed_lshape_map}" + ) + __set(self, key, value) else: - __set(self[key], value) + # `root` is None, i.e. the indexed element is local on each process + # verify that `self[key]` and `value` distribution are aligned + value = sanitation.sanitize_distribution(value, target=self[key]) + __set(self, key, value) return # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing @@ -2797,11 +2832,16 @@ def tolist(self, keepsplit: bool = False) -> List: def __torch_proxy__(self) -> torch.Tensor: """ - Return a 1-element `torch.Tensor` strided as the global `self` shape. - Used internally for sanitation purposes. + Return a 1-element `torch.Tensor` strided as the global `self` shape, and with named split axis. + Used internally to lower memory footprint of sanitation. """ - return torch.ones((1,), dtype=torch.int8, device=self.larray.device).as_strided( - self.gshape, [0] * self.ndim + names = [None] * self.ndim + if self.split is not None: + names[self.split] = "split" + return ( + torch.ones((1,), dtype=torch.int8, device=self.larray.device) + .as_strided(self.gshape, [0] * self.ndim) + .refine_names(*names) ) @staticmethod diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 66149da572..9ce19a815e 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1385,7 +1385,7 @@ def test_setitem(self): x_split0[key] = ht.arange(3) self.assertTrue((x_split0[key].larray == torch.arange(3)).all()) self.assertTrue(x_split0[key].dtype == ht.float32) - self.assertTrue(x_split0[key].split == 0) + self.assertTrue(x_split0.split == 0) # 3D, distributed split, != 0 x_split2 = ht.zeros(27, dtype=ht.int64, split=0).reshape(3, 3, 3, new_split=2) key = ht.array(2) From f528356e7697253787bb768339a80ca31f0bf239 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 7 Dec 2023 12:43:34 +0100 Subject: [PATCH 080/132] test named split dimension for torch_proxy --- heat/core/tests/test_dndarray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 9ce19a815e..4d7147ad03 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -2226,6 +2226,7 @@ def test_torch_proxy(self): dndarray_proxy.storage().size() * dndarray_proxy.storage().element_size() ) self.assertTrue(dndarray_proxy_nbytes == 1) + self.assertTrue(dndarray_proxy.names.index("split") == 1) def test_xor(self): int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16) From 01a1140f672a7488f7f45016adf9e8ea98f8f317 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 8 Dec 2023 06:04:39 +0100 Subject: [PATCH 081/132] value broadcasting abstraction --- heat/core/dndarray.py | 47 +++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 88ca54aaea..dbc807bd32 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2342,6 +2342,25 @@ def __setitem__( [0., 1., 0., 0., 0.]]) """ + def __broadcast_value( + arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]], value: DNDarray + ): + """ + Broadcasts the given DNDarray `value` to the shape of the indexed array `arr[key]`. + """ + # need information on indexed array, use proxy to avoid MPI communication and limit memory usage + indexed_proxy = arr.__torch_proxy__()[key] + value_shape = value.shape + while value.ndim < indexed_proxy.ndim: # broadcasting + value = value.expand_dims(0) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {tuple(indexed_proxy.shape)}" + ) + return value + def __set( arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]], @@ -2376,34 +2395,28 @@ def __set( except TypeError: raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") - # use low-memory torch_proxy in sanitation - indexed_proxy = self.__torch_proxy__()[key] - # `value` might be broadcasted - value_shape = value.shape - while value.ndim < indexed_proxy.ndim: # broadcasting - value = value.expand_dims(0) - try: - value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape)) - except RuntimeError: - raise ValueError( - f"could not broadcast input array from shape {value_shape} into shape {tuple(indexed_proxy.shape)}" - ) - - if key is None or key == ... or key == slice(None): - # make sure `self` and `value` distribution are aligned - value = sanitation.sanitize_distribution(value, target=self) - return __set(self, key, value) + # workaround for Heat issue #1292. TODO: remove when issue is fixed + if not isinstance(key, DNDarray): + if key is None or key is ... or key is slice(None): + # match dimensions + value = __broadcast_value(self, key, value) + # make sure `self` and `value` distribution are aligned + value = sanitation.sanitize_distribution(value, target=self) + return __set(self, key, value) # single-element key scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 if scalar: key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) + # match dimensions + value = __broadcast_value(self, key, value) # `root` will be None when the indexed axis is not the split axis, or when the # indexed axis is the split axis but the indexed element is not local if root is not None: if self.comm.rank == root: # verify that `self[key]` and `value` distribution are aligned # do not index `self` with `key` directly here, as this would MPI-broadcast to all ranks + indexed_proxy = self.__torch_proxy__()[key] if indexed_proxy.names.count("split") != 0: # indexed_split = indexed_proxy.names.index("split") # lshape_map of indexed subarray is the same as the lshape_map of the original array after losing the first dimension From f8264a98d8c7b1972f805e5eedbd18d0d373d7c7 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 13 Dec 2023 05:22:38 +0100 Subject: [PATCH 082/132] introduce distr sanitation for value when key is ordered --- heat/core/dndarray.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index dbc807bd32..ca7465e742 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2448,23 +2448,6 @@ def __set( backwards_transpose_axes, ) = self.__process_key(key, return_local_indices=True, op="set") - # sanitize value - value_split = value.split if isinstance(value, DNDarray) else None - try: - value = factories.array( - value, dtype=self.dtype, split=value_split, device=self.device, comm=self.comm - ) - except TypeError: - raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") - value_shape = value.shape - while value.ndim < len(output_shape): # broadcasting - value = value.expand_dims(0) - try: - value_shape = tuple(torch.broadcast_shapes(value_shape, output_shape)) - except RuntimeError: - raise ValueError( - f"could not broadcast input array from shape {value.shape} into shape {output_shape}" - ) # TODO: sanitize distribution without allocating getitem array if split_key_is_ordered == 1: @@ -2475,6 +2458,10 @@ def __set( if self.comm.rank == root: self.larray[key] = value.larray else: + # indexed elements are process-local + # self[key] is a view and does not trigger communication + # verify that `self[key]` and `value` distribution are aligned + value = sanitation.sanitize_distribution(value, target=self[key]) self.larray[key] = value.larray self = self.transpose(backwards_transpose_axes) return From b1cd02f2db1595026944e81d502b6769eccfe02c Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 13 Dec 2023 06:19:33 +0100 Subject: [PATCH 083/132] keep track of original key --- heat/core/dndarray.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ca7465e742..8dbd578d56 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2343,12 +2343,16 @@ def __setitem__( """ def __broadcast_value( - arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]], value: DNDarray + arr: DNDarray, key: Union[int, Tuple[int, ...], slice], value: DNDarray ): """ Broadcasts the given DNDarray `value` to the shape of the indexed array `arr[key]`. """ # need information on indexed array, use proxy to avoid MPI communication and limit memory usage + if not isinstance(key, (int, tuple, slice)): + raise TypeError( + f"only integers, slices (`:`), and tuples are valid indices (got {type(key)})" + ) indexed_proxy = arr.__torch_proxy__()[key] value_shape = value.shape while value.ndim < indexed_proxy.ndim: # broadcasting @@ -2437,6 +2441,15 @@ def __set( return # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing + # store original key for later use + try: + original_key = key.copy() + except AttributeError: + try: + original_key = key.clone() + except AttributeError: + original_key = key + ( self, key, @@ -2461,7 +2474,7 @@ def __set( # indexed elements are process-local # self[key] is a view and does not trigger communication # verify that `self[key]` and `value` distribution are aligned - value = sanitation.sanitize_distribution(value, target=self[key]) + value = sanitation.sanitize_distribution(value, target=self[original_key]) self.larray[key] = value.larray self = self.transpose(backwards_transpose_axes) return @@ -2832,7 +2845,7 @@ def tolist(self, keepsplit: bool = False) -> List: def __torch_proxy__(self) -> torch.Tensor: """ - Return a 1-element `torch.Tensor` strided as the global `self` shape, and with named split axis. + Return a 1-element `torch.Tensor` strided as the global `self` shape. The split axis of the initial DNDarray is stored in the `names` attribute of the returned tensor. Used internally to lower memory footprint of sanitation. """ names = [None] * self.ndim From 31bdb34fd1f8b861b10cfb884634907a5a209e96 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 14 Dec 2023 06:37:02 +0100 Subject: [PATCH 084/132] fix value broadcasting for advanced setitem --- heat/core/dndarray.py | 49 ++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8dbd578d56..145c0037c7 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1351,15 +1351,11 @@ def __process_scalar_key( """ device = arr.larray.device try: - # is key an ndarray or DNDarray? - key = key.copy().item() + # is key an ndarray or DNDarray or torch.Tensor? + key = key.item() except AttributeError: - try: - # is key a torch tensor? - key = key.clone().item() - except AttributeError: - # key is already an integer, do nothing - pass + # key is already an integer, do nothing + pass if not arr.is_distributed(): root = None return key, root @@ -1380,7 +1376,8 @@ def __process_scalar_key( dim=0, ) _, sorted_indices = displs.unique(sorted=True, return_inverse=True) - root = sorted_indices[-1] - 1 + root = sorted_indices[-1].item() - 1 + displs = displs.tolist() # correct key for rank-specific displacement if return_local_indices: if arr.comm.rank == root: @@ -2343,19 +2340,30 @@ def __setitem__( """ def __broadcast_value( - arr: DNDarray, key: Union[int, Tuple[int, ...], slice], value: DNDarray + arr: DNDarray, + key: Union[int, Tuple[int, ...], slice], + value: DNDarray, + **kwargs, ): """ Broadcasts the given DNDarray `value` to the shape of the indexed array `arr[key]`. """ - # need information on indexed array, use proxy to avoid MPI communication and limit memory usage - if not isinstance(key, (int, tuple, slice)): - raise TypeError( - f"only integers, slices (`:`), and tuples are valid indices (got {type(key)})" - ) - indexed_proxy = arr.__torch_proxy__()[key] + # need information on indexed array + output_shape = kwargs.get("output_shape", None) + if output_shape is not None: + indexed_dims = len(output_shape) + else: + if isinstance(key, (int, tuple)): + # direct indexing, output_shape has not been calculated + # use proxy to avoid MPI communication and limit memory usage + indexed_proxy = arr.__torch_proxy__()[key] + indexed_dims = indexed_proxy.ndim + else: + raise RuntimeError( + "Not enough information to broadcast value to indexed array, please provide `output_shape`" + ) value_shape = value.shape - while value.ndim < indexed_proxy.ndim: # broadcasting + while value.ndim < indexed_dims: # broadcasting value = value.expand_dims(0) try: value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape)) @@ -2422,8 +2430,7 @@ def __set( # do not index `self` with `key` directly here, as this would MPI-broadcast to all ranks indexed_proxy = self.__torch_proxy__()[key] if indexed_proxy.names.count("split") != 0: - # indexed_split = indexed_proxy.names.index("split") - # lshape_map of indexed subarray is the same as the lshape_map of the original array after losing the first dimension + # distribution map of indexed subarray is the same as the lshape_map of the original array after losing the first dimension indexed_lshape_map = self.lshape_map[:, 1:] if value.lshape_map != indexed_lshape_map: try: @@ -2461,10 +2468,10 @@ def __set( backwards_transpose_axes, ) = self.__process_key(key, return_local_indices=True, op="set") - # TODO: sanitize distribution without allocating getitem array + # match dimensions + value = __broadcast_value(self, key, value, output_shape=output_shape) if split_key_is_ordered == 1: - # data are not distributed or split dimension is not affected by indexing # key all local if root is not None: # single-element assignment along split axis, only one active process From c4d674935d6a3ce378191ffcbe3cec5f637cd00c Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sat, 16 Dec 2023 13:51:37 +0100 Subject: [PATCH 085/132] match broadcasting to numpy --- heat/core/dndarray.py | 50 +++++++++++++++++++++++++++----- heat/core/tests/test_dndarray.py | 21 ++++++-------- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 145c0037c7..160ffcd8fc 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2358,19 +2358,53 @@ def __broadcast_value( # use proxy to avoid MPI communication and limit memory usage indexed_proxy = arr.__torch_proxy__()[key] indexed_dims = indexed_proxy.ndim + output_shape = tuple(indexed_proxy.shape) else: raise RuntimeError( "Not enough information to broadcast value to indexed array, please provide `output_shape`" ) value_shape = value.shape - while value.ndim < indexed_dims: # broadcasting - value = value.expand_dims(0) - try: - value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape)) - except RuntimeError: - raise ValueError( - f"could not broadcast input array from shape {value_shape} into shape {tuple(indexed_proxy.shape)}" - ) + print("DEBUGGING: OUTPUT SHAPE, value shape = ", output_shape, value_shape) + + if value_shape != output_shape: + # assess whether the shapes are compatible, starting from the trailing dimension + for i in range(1, min(len(value_shape), len(output_shape))): + if i == 1: + if value_shape[-i] != output_shape[-i]: + # shapes are not compatible, raise error + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + else: + if ( + value_shape[-i] != output_shape[-i] + and not value_shape[-i] == 1 + or output_shape[-i] == 1 + ): + # shapes are not compatible, raise error + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + while value.ndim < indexed_dims: + print("DEBUGGING: value ndim = ", value.ndim) + # broadcasting + # expand missing dimensions to align split axis + print("DEBUGGING: value shape before expanding = ", value.shape) + value = value.expand_dims(0) + print("DEBUGGING: value shape after expanding = ", value.shape) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, output_shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + return value + # # value has more dimensions than indexed array + # print("DEBUGGING: not broadcastable = ", value.ndim, output_shape) + # raise ValueError( + # f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + # ) + # value and output shape are the same return value def __set( diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 4d7147ad03..2d25e45038 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1397,23 +1397,20 @@ def test_setitem(self): # Slicing and striding x = ht.arange(20, split=0) - x_sliced = x[1:11:3] x[1:11:3] = ht.array([10, 40, 70, 100]) x_np = np.arange(20) - x_sliced_np = x_np[1:11:3] x_np[1:11:3] = np.array([10, 40, 70, 100]) - self.assert_array_equal(x_sliced, x_sliced_np) - self.assert_array_equal(x_sliced, np.array([10, 40, 70, 100])) + self.assert_array_equal(x, x_np) self.assertTrue(x.split == 0) - # # 1-element slice along split axis - # x = ht.arange(20).reshape(4, 5) - # x.resplit_(axis=1) - # x_sliced = x[:, 2:3] - # x_np = np.arange(20).reshape(4, 5) - # x_sliced_np = x_np[:, 2:3] - # self.assert_array_equal(x_sliced, x_sliced_np) - # self.assertTrue(x_sliced.split == 1) + # 1-element slice along split axis + x = ht.arange(20).reshape(4, 5) + x.resplit_(axis=1) + x[:, 2:3] = ht.array([10, 40, 70, 100]) + x_np = np.arange(20).reshape(4, 5) + x_np[:, 2:3] = np.array([10, 40, 70, 100]) + self.assert_array_equal(x, x_np) + self.assertTrue(x.split == 1) # # slicing with negative step along split axis 0 # shape = (20, 4, 3) From 5782d6ea014e89de3952c5bdad71ac7c1dfbb92d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 20 Dec 2023 05:10:09 +0100 Subject: [PATCH 086/132] finalize broadcast_value and fix test --- heat/core/dndarray.py | 23 +++++++++-------------- heat/core/tests/test_dndarray.py | 6 ++++-- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 160ffcd8fc..8da8a79066 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2346,7 +2346,7 @@ def __broadcast_value( **kwargs, ): """ - Broadcasts the given DNDarray `value` to the shape of the indexed array `arr[key]`. + Broadcasts the assignment DNDarray `value` to the shape of the indexed array `arr[key]` if necessary. """ # need information on indexed array output_shape = kwargs.get("output_shape", None) @@ -2364,8 +2364,7 @@ def __broadcast_value( "Not enough information to broadcast value to indexed array, please provide `output_shape`" ) value_shape = value.shape - print("DEBUGGING: OUTPUT SHAPE, value shape = ", output_shape, value_shape) - + # check if value needs to be broadcasted if value_shape != output_shape: # assess whether the shapes are compatible, starting from the trailing dimension for i in range(1, min(len(value_shape), len(output_shape))): @@ -2379,19 +2378,16 @@ def __broadcast_value( if ( value_shape[-i] != output_shape[-i] and not value_shape[-i] == 1 - or output_shape[-i] == 1 + or not output_shape[-i] == 1 ): # shapes are not compatible, raise error raise ValueError( - f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + f"could not broadcast input from shape {value_shape} into shape {output_shape}" ) while value.ndim < indexed_dims: - print("DEBUGGING: value ndim = ", value.ndim) # broadcasting # expand missing dimensions to align split axis - print("DEBUGGING: value shape before expanding = ", value.shape) value = value.expand_dims(0) - print("DEBUGGING: value shape after expanding = ", value.shape) try: value_shape = tuple(torch.broadcast_shapes(value.shape, output_shape)) except RuntimeError: @@ -2399,12 +2395,11 @@ def __broadcast_value( f"could not broadcast input array from shape {value_shape} into shape {output_shape}" ) return value - # # value has more dimensions than indexed array - # print("DEBUGGING: not broadcastable = ", value.ndim, output_shape) - # raise ValueError( - # f"could not broadcast input array from shape {value_shape} into shape {output_shape}" - # ) - # value and output shape are the same + # value has more dimensions than indexed array + if value.ndim > indexed_dims: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) return value def __set( diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 2d25e45038..ef08c87e94 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1406,11 +1406,13 @@ def test_setitem(self): # 1-element slice along split axis x = ht.arange(20).reshape(4, 5) x.resplit_(axis=1) - x[:, 2:3] = ht.array([10, 40, 70, 100]) + x[:, 2:3] = ht.array([10, 40, 70, 100]).reshape(4, 1) x_np = np.arange(20).reshape(4, 5) - x_np[:, 2:3] = np.array([10, 40, 70, 100]) + x_np[:, 2:3] = np.array([10, 40, 70, 100]).reshape(4, 1) self.assert_array_equal(x, x_np) self.assertTrue(x.split == 1) + with self.assertRaises(ValueError): + x[:, 2:3] = ht.array([10, 40, 70, 100]) # # slicing with negative step along split axis 0 # shape = (20, 4, 3) From 2174e848139bc9b4a265f8beed7924c08b2fb3f9 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 20 Dec 2023 05:56:27 +0100 Subject: [PATCH 087/132] assignment to negative slice along split axis --- heat/core/dndarray.py | 34 ++++++++++++++++++++++---------- heat/core/tests/test_dndarray.py | 16 ++++++++------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8da8a79066..a2b56cbe02 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2520,10 +2520,21 @@ def __set( # flip value, match value distribution to keys value = manipulations.flip(value, axis=output_split) - split_key = factories.array( - key[output_split], is_split=0, device=self.device, comm=self.comm - ) - if value.is_distributed(): + if self.is_distributed(): + split_key = factories.array( + key[output_split], is_split=0, device=self.device, comm=self.comm + ) + if not value.is_distributed(): + # work with a distributed copy of `value` + value = factories.array( + value, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + copy=True, + ) + # match `value` distribution to `self[key]` distribution target_map = value.lshape_map target_map[:, output_split] = split_key.lshape_map[:, 0] print( @@ -2531,12 +2542,15 @@ def __set( ) value.redistribute_(target_map=target_map) - process_is_inactive = sum( - list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) - ) - if not process_is_inactive: - # only assign values if key does not contain empty slices - self.larray[key] = value.larray + process_is_inactive = sum( + list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) + ) + if not process_is_inactive: + # only assign values if key does not contain empty slices + __set(self, key, value) + else: + # no communication necessary + __set(self, key, value) self = self.transpose(backwards_transpose_axes) return diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ef08c87e94..e4e3e19516 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1414,13 +1414,15 @@ def test_setitem(self): with self.assertRaises(ValueError): x[:, 2:3] = ht.array([10, 40, 70, 100]) - # # slicing with negative step along split axis 0 - # shape = (20, 4, 3) - # x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) - # x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] - # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[17:2:-2, :2, 1] - # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) - # self.assertTrue(x_3d_sliced.split == 0) + # slicing with negative step along split axis 0 + # assign different dtype + shape = (20, 4, 3) + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) + value = ht.random.randn(8, 2) + x_3d[17:2:-2, :2, ht.array(1)] = value + x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) # # slicing with negative step along split 1 # shape = (4, 20, 3) From 782bde2d02c523734920dffe6bf6ee3cff88dd6d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 8 Jan 2024 05:31:19 +0100 Subject: [PATCH 088/132] getitem: index underlying tensor with processed key in non-distr case --- heat/core/dndarray.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a2b56cbe02..333b763e64 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1483,7 +1483,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) return indexed_arr else: - # multi-element key + # process multi-element key ( self, key, @@ -1495,6 +1495,21 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar backwards_transpose_axes, ) = self.__process_key(key, return_local_indices=True) + if not self.is_distributed(): + # key is torch-proof, index underlying torch tensor + indexed_arr = self.larray[key] + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) + return DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + if split_key_is_ordered == 1: if root is not None: # single-element indexing along split axis From 084371d1449c6f695640127cd4f9e49623083a77 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 8 Jan 2024 05:35:18 +0100 Subject: [PATCH 089/132] setitem: test neg step slice along non-zero split axis --- heat/core/tests/test_dndarray.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e4e3e19516..7c9c45d40b 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1424,15 +1424,16 @@ def test_setitem(self): self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) - # # slicing with negative step along split 1 - # shape = (4, 20, 3) - # x_3d = ht.arange(20 * 4 * 3).reshape(shape) - # x_3d.resplit_(axis=1) - # key = (slice(None, 2), slice(17, 2, -2), 1) - # x_3d_sliced = x_3d[key] - # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 17:2:-2, 1] - # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) - # self.assertTrue(x_3d_sliced.split == 1) + # slicing with negative step along split 1 + shape = (4, 20, 3) + x_3d = ht.arange(20 * 4 * 3, dtype=ht.float32).reshape(shape) + x_3d.resplit_(axis=1) + key = (slice(None, 2), slice(17, 2, -2), 1) + value = ht.random.randn(2, 8) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) # # slicing with negative step along split 2 and loss of axis < split # shape = (4, 3, 20) From b1aa7aa1629902cb85b6b617a56e96de91fae009 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 8 Jan 2024 06:07:20 +0100 Subject: [PATCH 090/132] allow for nominal value/self split mismatch --- heat/core/dndarray.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 333b763e64..41d1793916 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2533,11 +2533,12 @@ def __set( if split_key_is_ordered == -1: # key is in descending order, i.e. slice with negative step - # flip value, match value distribution to keys + # flip value, match value distribution to key's + # NB: `value.ndim` might be smaller than `self.ndim`, `value.split` nominally different from `self.split` value = manipulations.flip(value, axis=output_split) if self.is_distributed(): split_key = factories.array( - key[output_split], is_split=0, device=self.device, comm=self.comm + key[self.split], is_split=0, device=self.device, comm=self.comm ) if not value.is_distributed(): # work with a distributed copy of `value` @@ -2561,6 +2562,7 @@ def __set( list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) ) if not process_is_inactive: + print("DEBUGGING: value.larray = ", value.larray, value.lshape_map) # only assign values if key does not contain empty slices __set(self, key, value) else: From 1c2b71ef03e26ba806e5259c22046cb13d2c9e1b Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 8 Jan 2024 06:08:43 +0100 Subject: [PATCH 091/132] expand test negative step along split axis --- heat/core/tests/test_dndarray.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 7c9c45d40b..39437e09a1 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1435,15 +1435,16 @@ def test_setitem(self): self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) - # # slicing with negative step along split 2 and loss of axis < split - # shape = (4, 3, 20) - # x_3d = ht.arange(20 * 4 * 3).reshape(shape) - # x_3d.resplit_(axis=2) - # key = (slice(None, 2), 1, slice(17, 10, -2)) - # x_3d_sliced = x_3d[key] - # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 1, 17:10:-2] - # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) - # self.assertTrue(x_3d_sliced.split == 1) + # slicing with negative step along split 2 and loss of axis < split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3, dtype=ht.float64).reshape(shape) + x_3d.resplit_(axis=2) + key = (slice(None, 2), 1, slice(17, 10, -2)) + value = ht.random.randn(2, 4) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) # # slicing with negative step along split 2 and loss of all axes but split # shape = (4, 3, 20) From 7201a89ee8bc834398560ea3e08f0cc1f2b8f325 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 12 Jan 2024 06:05:39 +0100 Subject: [PATCH 092/132] allow value.ndim > indexed_dims if extra dims are singletons --- heat/core/dndarray.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 41d1793916..ed03bbdd1f 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2412,9 +2412,14 @@ def __broadcast_value( return value # value has more dimensions than indexed array if value.ndim > indexed_dims: - raise ValueError( - f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + # check if all dimensions except the indexed ones are singletons + all_singletons = value.shape[: value.ndim - indexed_dims] == (1,) * ( + value.ndim - indexed_dims ) + if not all_singletons: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) return value def __set( @@ -2439,7 +2444,7 @@ def __set( # # TODO: take this out of this function # sanitation.sanitize_out(subarray, value_shape, value.split, value.device, value.comm) # arr.larray[None] = value.larray - arr.__array__().__setitem__(key, value.__array__()) + arr.larray[key] = value.larray return # make sure `value` is a DNDarray @@ -2562,7 +2567,6 @@ def __set( list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) ) if not process_is_inactive: - print("DEBUGGING: value.larray = ", value.larray, value.lshape_map) # only assign values if key does not contain empty slices __set(self, key, value) else: From dfc7266b2fa4ba28b9975753e66b267098d01e7d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 12 Jan 2024 06:06:23 +0100 Subject: [PATCH 093/132] BROKEN: expand negative step tests --- heat/core/tests/test_dndarray.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 39437e09a1..38eded25f4 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1446,15 +1446,24 @@ def test_setitem(self): self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) - # # slicing with negative step along split 2 and loss of all axes but split - # shape = (4, 3, 20) - # x_3d = ht.arange(20 * 4 * 3).reshape(shape) - # x_3d.resplit_(axis=2) - # key = (0, 1, slice(17, 13, -1)) - # x_3d_sliced = x_3d[key] - # x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[0, 1, 17:13:-1] - # self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) - # self.assertTrue(x_3d_sliced.split == 0) + # slicing with negative step along split 2 and loss of all axes but split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (0, 1, slice(17, 13, -1)) + value = ht.random.randint( + 200, + 220, + ( + 1, + 4, + ), + split=1, + ) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) # # DIMENSIONAL INDEXING # # ellipsis From 8bbe242113a3358188b6e9266d88f55ec5c2c426 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 15 Jan 2024 05:10:27 +0100 Subject: [PATCH 094/132] squeeze out singleton dimensions when broadcasting value --- heat/core/dndarray.py | 3 +++ heat/core/tests/test_dndarray.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ed03bbdd1f..39e1d30c7e 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2420,6 +2420,8 @@ def __broadcast_value( raise ValueError( f"could not broadcast input array from shape {value_shape} into shape {output_shape}" ) + # squeeze out singleton dimensions + value = value.squeeze(tuple(range(value.ndim - indexed_dims))) return value def __set( @@ -2540,6 +2542,7 @@ def __set( # flip value, match value distribution to key's # NB: `value.ndim` might be smaller than `self.ndim`, `value.split` nominally different from `self.split` + print("DEBUGGING: output_split = ", output_split) value = manipulations.flip(value, axis=output_split) if self.is_distributed(): split_key = factories.array( diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 38eded25f4..6afe6ea147 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1452,8 +1452,8 @@ def test_setitem(self): x_3d.resplit_(axis=2) key = (0, 1, slice(17, 13, -1)) value = ht.random.randint( - 200, - 220, + 0, + 5, ( 1, 4, @@ -1462,7 +1462,7 @@ def test_setitem(self): ) x_3d[key] = value x_3d_sliced = x_3d[key] - self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(ht.allclose(x_3d_sliced, value.squeeze(0).astype(x_3d.dtype))) self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) # # DIMENSIONAL INDEXING From 00a17e61cfccf66f93f0ffa2463bc25c292ea6a3 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:08:29 +0100 Subject: [PATCH 095/132] fix negative step slicing on 1 process --- heat/core/dndarray.py | 40 +++++++++++++++----------------- heat/core/tests/test_dndarray.py | 16 +++++++------ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 39e1d30c7e..1f729bf2e3 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1497,6 +1497,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if not self.is_distributed(): # key is torch-proof, index underlying torch tensor + print("DEBUGGING: key = ", key) indexed_arr = self.larray[key] # transpose array back if needed self = self.transpose(backwards_transpose_axes) @@ -2381,8 +2382,9 @@ def __broadcast_value( value_shape = value.shape # check if value needs to be broadcasted if value_shape != output_shape: + print("DEBUGGING: value_shape, output_shape = ", value_shape, output_shape) # assess whether the shapes are compatible, starting from the trailing dimension - for i in range(1, min(len(value_shape), len(output_shape))): + for i in range(1, min(len(value_shape), len(output_shape)) + 1): if i == 1: if value_shape[-i] != output_shape[-i]: # shapes are not compatible, raise error @@ -2446,7 +2448,9 @@ def __set( # # TODO: take this out of this function # sanitation.sanitize_out(subarray, value_shape, value.split, value.device, value.comm) # arr.larray[None] = value.larray - arr.larray[key] = value.larray + + # make sure value is same datatype as arr + arr.larray[key] = value.larray.type(arr.dtype.torch_type()) return # make sure `value` is a DNDarray @@ -2538,42 +2542,36 @@ def __set( return if split_key_is_ordered == -1: - # key is in descending order, i.e. slice with negative step - - # flip value, match value distribution to key's - # NB: `value.ndim` might be smaller than `self.ndim`, `value.split` nominally different from `self.split` - print("DEBUGGING: output_split = ", output_split) - value = manipulations.flip(value, axis=output_split) + # key along split axis is in descending order, i.e. slice with negative step if self.is_distributed(): + # flip value, match value distribution to key's + # NB: `value.ndim` might be smaller than `self.ndim`, hence `value.split` nominally different from `self.split` + flipped_value = manipulations.flip(value, axis=output_split) split_key = factories.array( key[self.split], is_split=0, device=self.device, comm=self.comm ) - if not value.is_distributed(): - # work with a distributed copy of `value` - value = factories.array( - value, - dtype=self.dtype, + if not flipped_value.is_distributed(): + # work with distributed `flipped_value` + flipped_value = factories.array( + flipped_value.larray, + dtype=flipped_value.dtype, split=output_split, device=self.device, comm=self.comm, - copy=True, ) # match `value` distribution to `self[key]` distribution - target_map = value.lshape_map + target_map = flipped_value.lshape_map target_map[:, output_split] = split_key.lshape_map[:, 0] - print( - "DEBUGGING: TEST target_map, value.lshape_map = ", target_map, value.lshape_map - ) - value.redistribute_(target_map=target_map) + flipped_value.redistribute_(target_map=target_map) process_is_inactive = sum( list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) ) if not process_is_inactive: # only assign values if key does not contain empty slices - __set(self, key, value) + __set(self, key, flipped_value) else: - # no communication necessary + # 1 process, no communication needed __set(self, key, value) self = self.transpose(backwards_transpose_axes) return diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 6afe6ea147..2410bc55eb 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1465,17 +1465,19 @@ def test_setitem(self): self.assertTrue(ht.allclose(x_3d_sliced, value.squeeze(0).astype(x_3d.dtype))) self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) - # # DIMENSIONAL INDEXING - # # ellipsis + # DIMENSIONAL INDEXING + # ellipsis # x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) # x_np_ellipsis = x_np[..., 0] # x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) - # # local - # x_ellipsis = x[..., 0] - # x_slice = x[:, :, 0] - # self.assert_array_equal(x_ellipsis, x_np_ellipsis) - # self.assert_array_equal(x_slice, x_np_ellipsis) + # local + # value = x.squeeze()+7 + # x[..., 0] = value + # self.assertTrue(ht.all(x[..., 0] == value)) + # value -= 7 + # x[:, :, 0] = value + # self.assertTrue(ht.all(x[:, :, 0] == value)) # # distributed # x.resplit_(axis=1) From bdd2dd8717ae49163aae67d582cea8a98dc53b13 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 16 Jan 2024 06:17:16 +0100 Subject: [PATCH 096/132] setitem w. dimensional indexing, add tests --- heat/core/dndarray.py | 57 ++++++++++++++------ heat/core/tests/test_dndarray.py | 92 +++++++++++++++++--------------- 2 files changed, 91 insertions(+), 58 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 1f729bf2e3..35f704c290 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2392,10 +2392,8 @@ def __broadcast_value( f"could not broadcast input array from shape {value_shape} into shape {output_shape}" ) else: - if ( - value_shape[-i] != output_shape[-i] - and not value_shape[-i] == 1 - or not output_shape[-i] == 1 + if value_shape[-i] != output_shape[-i] and ( + not value_shape[-i] == 1 or not output_shape[-i] == 1 ): # shapes are not compatible, raise error raise ValueError( @@ -2450,7 +2448,10 @@ def __set( # arr.larray[None] = value.larray # make sure value is same datatype as arr - arr.larray[key] = value.larray.type(arr.dtype.torch_type()) + process_is_inactive = arr.larray[key].numel() == 0 + if not process_is_inactive: + # only assign values if key does not contain empty slices + arr.larray[key] = value.larray.type(arr.dtype.torch_type()) return # make sure `value` is a DNDarray @@ -2503,14 +2504,14 @@ def __set( return # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing - # store original key for later use - try: - original_key = key.copy() - except AttributeError: - try: - original_key = key.clone() - except AttributeError: - original_key = key + # # store original key for later use + # try: + # original_key = key.copy() + # except AttributeError: + # try: + # original_key = key.clone() + # except AttributeError: + # original_key = key ( self, @@ -2531,13 +2532,37 @@ def __set( if root is not None: # single-element assignment along split axis, only one active process if self.comm.rank == root: - self.larray[key] = value.larray + self.larray[key] = value.larray.type(self.dtype.torch_type()) else: # indexed elements are process-local # self[key] is a view and does not trigger communication # verify that `self[key]` and `value` distribution are aligned - value = sanitation.sanitize_distribution(value, target=self[original_key]) - self.larray[key] = value.larray + if self.is_distributed() and not value.is_distributed(): + # work with distributed `value` + value = factories.array( + value.larray, + dtype=value.dtype, + split=output_split, + device=self.device, + comm=self.comm, + ) + target_shape = torch.tensor( + tuple(self.larray[key].shape), device=self.device.torch_device + ) + target_map = torch.zeros( + (self.comm.size, len(target_shape)), + dtype=torch.int64, + device=self.device.torch_device, + ) + # gather all shapes into target_map + self.comm.Allgather(target_shape, target_map) + value.redistribute_(target_map=target_map) + process_is_inactive = sum( + list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) + ) + if not process_is_inactive: + # only assign values if key does not contain empty slices + __set(self, key, value) self = self.transpose(backwards_transpose_axes) return diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 2410bc55eb..c3bf9b6367 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1466,50 +1466,58 @@ def test_setitem(self): self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) # DIMENSIONAL INDEXING - # ellipsis - # x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) - # x_np_ellipsis = x_np[..., 0] - # x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + # ellipsis + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) # local - # value = x.squeeze()+7 - # x[..., 0] = value - # self.assertTrue(ht.all(x[..., 0] == value)) - # value -= 7 - # x[:, :, 0] = value - # self.assertTrue(ht.all(x[:, :, 0] == value)) - - # # distributed - # x.resplit_(axis=1) - # x_ellipsis = x[..., 0] - # x_slice = x[:, :, 0] - # self.assert_array_equal(x_ellipsis, x_np_ellipsis) - # self.assert_array_equal(x_slice, x_np_ellipsis) - # self.assertTrue(x_ellipsis.split == 1) - - # # newaxis: local - # x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) - # x_np_newaxis = x_np[:, np.newaxis, :2, :] - # x_newaxis = x[:, np.newaxis, :2, :] - # x_none = x[:, None, :2, :] - # self.assert_array_equal(x_newaxis, x_np_newaxis) - # self.assert_array_equal(x_none, x_np_newaxis) - - # # newaxis: distributed - # x.resplit_(axis=1) - # x_newaxis = x[:, np.newaxis, :2, :] - # x_none = x[:, None, :2, :] - # self.assert_array_equal(x_newaxis, x_np_newaxis) - # self.assert_array_equal(x_none, x_np_newaxis) - # self.assertTrue(x_newaxis.split == 2) - # self.assertTrue(x_none.split == 2) - - # x = ht.arange(5, split=0) - # x_np = np.arange(5) - # y = x[:, np.newaxis] + x[np.newaxis, :] - # y_np = x_np[:, np.newaxis] + x_np[np.newaxis, :] - # self.assert_array_equal(y, y_np) - # self.assertTrue(y.split == 0) + value = x.squeeze() + 7 + x[..., 0] = value + self.assertTrue(ht.all(x[..., 0] == value).item()) + value -= 7 + x[:, :, 0] = value + self.assertTrue(ht.all(x[:, :, 0] == value).item()) + + # distributed + x.resplit_(axis=1) + value *= 2 + x[..., 0] = value + x_ellipsis = x[..., 0] + self.assertTrue(ht.all(x_ellipsis == value).item()) + value += 2 + x[:, :, 0] = value + self.assertTrue(ht.all(x[:, :, 0] == value).item()) + self.assertTrue(x_ellipsis.split == 1) + + # newaxis: local, w. broadcasting and different dtype + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + value = ht.array([10.0, 20.0]).reshape(2, 1) + x[:, None, :2, :] = value + x_newaxis = x[:, None, :2, :] + self.assertTrue(ht.all(x_newaxis == value).item()) + value += 2 + x[:, None, :2, :] = value + self.assertTrue(ht.all(x[:, None, :2, :] == value).item()) + self.assertTrue(x[:, None, :2, :].dtype == x.dtype) + + # newaxis: distributed w. broadcasting and different dtype + x.resplit_(axis=1) + value = ht.array([30.0, 40.0]).reshape(1, 2, 1) + x[:, np.newaxis, :2, :] = value + x_newaxis = x[:, np.newaxis, :2, :] + self.assertTrue(ht.all(x_newaxis == value).item()) + value += 2 + x[:, None, :2, :] = value + x_none = x[:, None, :2, :] + self.assertTrue(ht.all(x_none == value).item()) + self.assertTrue(x_none.dtype == x.dtype) + + # distributed value + x = ht.arange(6).reshape(1, 1, 2, 3) + x.resplit_(axis=-1) + value = ht.arange(3).reshape(1, 3) + value.resplit_(axis=1) + x[..., 0, :] = value + self.assertTrue(ht.all(x[..., 0, :] == value).item()) # # ADVANCED INDEXING # # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" From 1fbd4d6351c566b2cfbde8961e53c41edce6ba44 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 17 Jan 2024 05:46:01 +0100 Subject: [PATCH 097/132] setitem w. advanced indexing on first dim --- heat/core/dndarray.py | 65 +++++++++++++++++--------------- heat/core/tests/test_dndarray.py | 22 ++++++----- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 35f704c290..947cbf39bc 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2364,6 +2364,14 @@ def __broadcast_value( """ Broadcasts the assignment DNDarray `value` to the shape of the indexed array `arr[key]` if necessary. """ + is_scalar = ( + np.isscalar(value) + or getattr(value, "ndim", 1) == 0 + or (value.shape == (1,) and value.split is None) + ) + if is_scalar: + # no need to broadcast + return value, is_scalar # need information on indexed array output_shape = kwargs.get("output_shape", None) if output_shape is not None: @@ -2382,7 +2390,6 @@ def __broadcast_value( value_shape = value.shape # check if value needs to be broadcasted if value_shape != output_shape: - print("DEBUGGING: value_shape, output_shape = ", value_shape, output_shape) # assess whether the shapes are compatible, starting from the trailing dimension for i in range(1, min(len(value_shape), len(output_shape)) + 1): if i == 1: @@ -2399,17 +2406,6 @@ def __broadcast_value( raise ValueError( f"could not broadcast input from shape {value_shape} into shape {output_shape}" ) - while value.ndim < indexed_dims: - # broadcasting - # expand missing dimensions to align split axis - value = value.expand_dims(0) - try: - value_shape = tuple(torch.broadcast_shapes(value.shape, output_shape)) - except RuntimeError: - raise ValueError( - f"could not broadcast input array from shape {value_shape} into shape {output_shape}" - ) - return value # value has more dimensions than indexed array if value.ndim > indexed_dims: # check if all dimensions except the indexed ones are singletons @@ -2422,7 +2418,18 @@ def __broadcast_value( ) # squeeze out singleton dimensions value = value.squeeze(tuple(range(value.ndim - indexed_dims))) - return value + else: + while value.ndim < indexed_dims: + # broadcasting + # expand missing dimensions to align split axis + value = value.expand_dims(0) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, output_shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + return value, is_scalar def __set( arr: DNDarray, @@ -2467,7 +2474,7 @@ def __set( if not isinstance(key, DNDarray): if key is None or key is ... or key is slice(None): # match dimensions - value = __broadcast_value(self, key, value) + value, _ = __broadcast_value(self, key, value) # make sure `self` and `value` distribution are aligned value = sanitation.sanitize_distribution(value, target=self) return __set(self, key, value) @@ -2477,7 +2484,7 @@ def __set( if scalar: key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) # match dimensions - value = __broadcast_value(self, key, value) + value, _ = __broadcast_value(self, key, value) # `root` will be None when the indexed axis is not the split axis, or when the # indexed axis is the split axis but the indexed element is not local if root is not None: @@ -2525,7 +2532,7 @@ def __set( ) = self.__process_key(key, return_local_indices=True, op="set") # match dimensions - value = __broadcast_value(self, key, value, output_shape=output_shape) + value, value_is_scalar = __broadcast_value(self, key, value, output_shape=output_shape) if split_key_is_ordered == 1: # key all local @@ -2535,9 +2542,7 @@ def __set( self.larray[key] = value.larray.type(self.dtype.torch_type()) else: # indexed elements are process-local - # self[key] is a view and does not trigger communication - # verify that `self[key]` and `value` distribution are aligned - if self.is_distributed() and not value.is_distributed(): + if self.is_distributed() and not value_is_scalar and not value.is_distributed(): # work with distributed `value` value = factories.array( value.larray, @@ -2546,17 +2551,17 @@ def __set( device=self.device, comm=self.comm, ) - target_shape = torch.tensor( - tuple(self.larray[key].shape), device=self.device.torch_device - ) - target_map = torch.zeros( - (self.comm.size, len(target_shape)), - dtype=torch.int64, - device=self.device.torch_device, - ) - # gather all shapes into target_map - self.comm.Allgather(target_shape, target_map) - value.redistribute_(target_map=target_map) + # verify that `self[key]` and `value` distribution are aligned + target_shape = torch.tensor( + tuple(self.larray[key].shape), device=self.device.torch_device + ) + target_map = torch.zeros( + (self.comm.size, len(target_shape)), + dtype=torch.int64, + device=self.device.torch_device, + ) + self.comm.Allgather(target_shape, target_map) + value.redistribute_(target_map=target_map) process_is_inactive = sum( list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) ) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index c3bf9b6367..a5a606ef11 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1519,17 +1519,19 @@ def test_setitem(self): x[..., 0, :] = value self.assertTrue(ht.all(x[..., 0, :] == value).item()) - # # ADVANCED INDEXING - # # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" + # ADVANCED INDEXING + # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" - # x_np = np.arange(60).reshape(5, 3, 4) - # indexed_x_np = x_np[(1, 2, 3)] - # adv_indexed_x_np = x_np[(1, 2, 3),] - # x = ht.array(x_np, split=0) - # indexed_x = x[(1, 2, 3)] - # self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) - # adv_indexed_x = x[(1, 2, 3),] - # self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) + x = ht.arange(60, split=0).reshape(5, 3, 4) + value = 99.0 + x[(1, 2, 3)] = value + indexed_x = x[(1, 2, 3)] + self.assertTrue((indexed_x == value).item()) + self.assertTrue(indexed_x.dtype == x.dtype) + x[(1, 2, 3),] = value + adv_indexed_x = x[(1, 2, 3),] + self.assertTrue(ht.all(adv_indexed_x == value).item()) + self.assertTrue(adv_indexed_x.dtype == x.dtype) # # 1d # x = ht.arange(10, 1, -1, split=0) From 95d3c920c6a4ac6133ada19010d559bad663cbc9 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 17 Jan 2024 06:17:40 +0100 Subject: [PATCH 098/132] setitem: test boolean indexing, local and split=0 --- heat/core/dndarray.py | 7 +++++++ heat/core/tests/test_dndarray.py | 31 +++++++++++++++++++------------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 947cbf39bc..659aa3ccb1 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2374,6 +2374,7 @@ def __broadcast_value( return value, is_scalar # need information on indexed array output_shape = kwargs.get("output_shape", None) + print("DEBUGGING: output_shape = ", output_shape) if output_shape is not None: indexed_dims = len(output_shape) else: @@ -2390,6 +2391,7 @@ def __broadcast_value( value_shape = value.shape # check if value needs to be broadcasted if value_shape != output_shape: + print("DEBUGGING: value_shape, output_shape = ", value_shape, output_shape) # assess whether the shapes are compatible, starting from the trailing dimension for i in range(1, min(len(value_shape), len(output_shape)) + 1): if i == 1: @@ -2532,6 +2534,11 @@ def __set( ) = self.__process_key(key, return_local_indices=True, op="set") # match dimensions + print( + "DEBUGGING: BEFORE BROADCAST: OUTPUT_SHAPE, SPLIT_KEY_IS_ORDERED = ", + output_shape, + split_key_is_ordered, + ) value, value_is_scalar = __broadcast_value(self, key, value, output_shape=output_shape) if split_key_is_ordered == 1: diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index a5a606ef11..1340cf2e06 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1600,21 +1600,28 @@ def test_setitem(self): # self.assert_array_equal(x_indexed, x_np_indexed) # self.assertTrue(x_indexed.split == 3) - # # boolean mask, local - # arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) - # np.random.seed(42) - # mask = np.random.randint(0, 2, arr.shape, dtype=bool) - # self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) - - # # boolean mask, distributed - # arr_split0 = ht.array(arr, split=0) - # mask_split0 = ht.array(mask, split=0) - # self.assertTrue((arr_split0[mask_split0].numpy() == arr.numpy()[mask]).all()) + # boolean mask, local + arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + np.random.seed(42) + mask = np.random.randint(0, 2, arr.shape, dtype=bool) + value = 99.0 + arr[mask] = value + self.assertTrue((arr[mask] == value).all().item()) + self.assertTrue(arr[mask].dtype == arr.dtype) + value = ht.ones_like(arr) + arr[mask] = value[mask] + self.assertTrue((arr[mask] == value[mask]).all().item()) + # boolean mask, distributed + arr_split0 = ht.array(arr, split=0) + mask_split0 = ht.array(mask, split=0) + arr_split0[mask_split0] = value[mask] + self.assertTrue((arr_split0[mask_split0] == value[mask]).all().item()) # arr_split1 = ht.array(arr, split=1) # mask_split1 = ht.array(mask, split=1) - # self.assert_array_equal(arr_split1[mask_split1], arr.numpy()[mask]) - + # print("DEBUGGING: arr_split1[mask_split1].shape, value[mask].shape = ", arr_split1[mask_split1].shape, value[mask].shape) + # arr_split1[mask_split1] = value[mask] + # self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) # arr_split2 = ht.array(arr, split=2) # mask_split2 = ht.array(mask, split=2) # self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) From f335aa8c935f2cc86281b5cd01c8d5bd1a765071 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 18 Jan 2024 06:32:03 +0100 Subject: [PATCH 099/132] fix output shape for boolean indexing w. split>0 --- heat/core/dndarray.py | 104 +++++++++++++++---------------- heat/core/tests/test_dndarray.py | 14 +++-- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 659aa3ccb1..211a1701f3 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -976,19 +976,20 @@ def __process_key( comm=arr.comm, balanced=False, ) - # vectorized sorting along axis 0 key.balance_() + # set output parameters + output_shape = (key.gshape[0],) + new_split = 0 + split_key_is_ordered = 0 + out_is_balanced = True + # vectorized sorting of key along axis 0 key = manipulations.unique(key, axis=0, return_inverse=False) - # return tuple key + # return tuple key of torch tensors key = list(key.larray.split(1, dim=1)) for i, k in enumerate(key): key[i] = k.squeeze(1) key = tuple(key) - output_shape = (key[0].shape[0],) - new_split = 0 - split_key_is_ordered = 0 - out_is_balanced = True return ( arr, key, @@ -2374,7 +2375,6 @@ def __broadcast_value( return value, is_scalar # need information on indexed array output_shape = kwargs.get("output_shape", None) - print("DEBUGGING: output_shape = ", output_shape) if output_shape is not None: indexed_dims = len(output_shape) else: @@ -2391,7 +2391,6 @@ def __broadcast_value( value_shape = value.shape # check if value needs to be broadcasted if value_shape != output_shape: - print("DEBUGGING: value_shape, output_shape = ", value_shape, output_shape) # assess whether the shapes are compatible, starting from the trailing dimension for i in range(1, min(len(value_shape), len(output_shape)) + 1): if i == 1: @@ -2456,21 +2455,18 @@ def __set( # sanitation.sanitize_out(subarray, value_shape, value.split, value.device, value.comm) # arr.larray[None] = value.larray - # make sure value is same datatype as arr + # only assign values if key does not contain empty slices process_is_inactive = arr.larray[key].numel() == 0 if not process_is_inactive: - # only assign values if key does not contain empty slices + # make sure value is same datatype as arr arr.larray[key] = value.larray.type(arr.dtype.torch_type()) return # make sure `value` is a DNDarray - if not isinstance(value, DNDarray): - try: - value = factories.array( - value, dtype=self.dtype, split=None, device=self.device, comm=self.comm - ) - except TypeError: - raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") + try: + value = factories.array(value) + except TypeError: + raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") # workaround for Heat issue #1292. TODO: remove when issue is fixed if not isinstance(key, DNDarray): @@ -2513,15 +2509,6 @@ def __set( return # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing - # # store original key for later use - # try: - # original_key = key.copy() - # except AttributeError: - # try: - # original_key = key.clone() - # except AttributeError: - # original_key = key - ( self, key, @@ -2541,6 +2528,14 @@ def __set( ) value, value_is_scalar = __broadcast_value(self, key, value, output_shape=output_shape) + # early out for non-distributed case + if not self.is_distributed() and not value.is_distributed(): + # no communication needed + __set(self, key, value) + self = self.transpose(backwards_transpose_axes) + return + + # distributed case if split_key_is_ordered == 1: # key all local if root is not None: @@ -2580,39 +2575,40 @@ def __set( if split_key_is_ordered == -1: # key along split axis is in descending order, i.e. slice with negative step - if self.is_distributed(): - # flip value, match value distribution to key's - # NB: `value.ndim` might be smaller than `self.ndim`, hence `value.split` nominally different from `self.split` - flipped_value = manipulations.flip(value, axis=output_split) - split_key = factories.array( - key[self.split], is_split=0, device=self.device, comm=self.comm - ) - if not flipped_value.is_distributed(): - # work with distributed `flipped_value` - flipped_value = factories.array( - flipped_value.larray, - dtype=flipped_value.dtype, - split=output_split, - device=self.device, - comm=self.comm, - ) - # match `value` distribution to `self[key]` distribution - target_map = flipped_value.lshape_map - target_map[:, output_split] = split_key.lshape_map[:, 0] - flipped_value.redistribute_(target_map=target_map) + # N.B. PyTorch doesn't support negative-step slices. Key has been processed into torch tensor. - process_is_inactive = sum( - list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) + # flip value, match value distribution to key's + # NB: `value.ndim` can be smaller than `self.ndim`, hence `value.split` nominally different from `self.split` + flipped_value = manipulations.flip(value, axis=output_split) + split_key = factories.array( + key[self.split], is_split=0, device=self.device, comm=self.comm + ) + if not flipped_value.is_distributed(): + # work with distributed `flipped_value` + flipped_value = factories.array( + flipped_value.larray, + dtype=flipped_value.dtype, + split=output_split, + device=self.device, + comm=self.comm, ) - if not process_is_inactive: - # only assign values if key does not contain empty slices - __set(self, key, flipped_value) - else: - # 1 process, no communication needed - __set(self, key, value) + # match `value` distribution to `self[key]` distribution + target_map = flipped_value.lshape_map + target_map[:, output_split] = split_key.lshape_map[:, 0] + flipped_value.redistribute_(target_map=target_map) + + process_is_inactive = sum( + list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) + ) + if not process_is_inactive: + # only assign values if key does not contain empty slices + __set(self, key, flipped_value) self = self.transpose(backwards_transpose_axes) return + # split_key_is_ordered == 0 -> key along split axis is unordered, communication needed + # key along the split axis is 1-D torch tensor, indices are global + # non-ordered key along split axis # indices are global diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 1340cf2e06..78db68e81a 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1617,11 +1617,15 @@ def test_setitem(self): mask_split0 = ht.array(mask, split=0) arr_split0[mask_split0] = value[mask] self.assertTrue((arr_split0[mask_split0] == value[mask]).all().item()) - # arr_split1 = ht.array(arr, split=1) - # mask_split1 = ht.array(mask, split=1) - # print("DEBUGGING: arr_split1[mask_split1].shape, value[mask].shape = ", arr_split1[mask_split1].shape, value[mask].shape) - # arr_split1[mask_split1] = value[mask] - # self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) + arr_split1 = ht.array(arr, split=1) + mask_split1 = ht.array(mask, split=1) + print( + "DEBUGGING: arr_split1[mask_split1].shape, value[mask].shape = ", + arr_split1[mask_split1].shape, + value[mask].shape, + ) + arr_split1[mask_split1] = value[mask] + self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) # arr_split2 = ht.array(arr, split=2) # mask_split2 = ht.array(mask, split=2) # self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) From d520ddf070bcba6e03b28650dfbfd6d3e1803230 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:19:46 +0100 Subject: [PATCH 100/132] setitem with non-ordered, mask-like key and non-distr value --- heat/core/dndarray.py | 71 ++++++++++++++++++++++---------- heat/core/tests/test_dndarray.py | 5 --- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 211a1701f3..d97d36177e 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1442,7 +1442,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof # Trivial cases - print("DEBUGGING: RAW KEY = ", key, type(key)) + # print("DEBUGGING: RAW KEY = ", key, type(key)) if key is None: return self.expand_dims(0) @@ -1498,7 +1498,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if not self.is_distributed(): # key is torch-proof, index underlying torch tensor - print("DEBUGGING: key = ", key) + # print("DEBUGGING: key = ", key) indexed_arr = self.larray[key] # transpose array back if needed self = self.transpose(backwards_transpose_axes) @@ -1564,7 +1564,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar key_shapes = [] for k in key: key_shapes.append(getattr(k, "shape", None)) - print("KEY SHAPES = ", key_shapes) + # print("KEY SHAPES = ", key_shapes) return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim # check for broadcasted indexing: key along split axis is not 1D broadcasted_indexing = ( @@ -1680,13 +1680,13 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar indexed_arr, is_split=output_split, device=self.device, copy=False ) - print("RECV_COUNTS = ", recv_counts) + # print("RECV_COUNTS = ", recv_counts) # share recv_counts among all processes comm_matrix = torch.empty( (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device ) self.comm.Allgather(recv_counts, comm_matrix) - print("DEBUGGING: comm_matrix = ", comm_matrix, comm_matrix.shape) + # print("DEBUGGING: comm_matrix = ", comm_matrix, comm_matrix.shape) outgoing_request_key_counts = comm_matrix[self.comm.rank] outgoing_request_key_displs = torch.cat( @@ -1738,7 +1738,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar incoming_request_key_displs.tolist(), ), ) - print("DEBUGGING:incoming_request_key = ", incoming_request_key) + # print("DEBUGGING:incoming_request_key = ", incoming_request_key) if return_1d: incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) incoming_request_key[original_split] -= displs[self.comm.rank] @@ -1750,8 +1750,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar + key[original_split + 1 :] ) - print("AFTER: incoming_request_key = ", incoming_request_key) - print("original_split = ", original_split) + # print("AFTER: incoming_request_key = ", incoming_request_key) + # print("original_split = ", original_split) # calculate shape of local recv buffer output_lshape = list(output_shape) if getattr(key, "ndim", 0) == 1: @@ -1792,9 +1792,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if not all_keys_scalar: send_buf = send_buf.unsqueeze_(dim=output_split) - print("OUTPUT_SHAPE = ", output_shape) - print("OUTPUT_SPLIT = ", output_split) - print("SEND_BUF SHAPE = ", send_buf.shape) + # print("OUTPUT_SHAPE = ", output_shape) + # print("OUTPUT_SPLIT = ", output_split) + # print("SEND_BUF SHAPE = ", send_buf.shape) # output_lshape = list(output_shape) # if getattr(key, "ndim", 0) == 1: @@ -1815,10 +1815,10 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar recv_displs = outgoing_request_key_displs.tolist() send_counts = incoming_request_key_counts.tolist() send_displs = incoming_request_key_displs.tolist() - print("DEBUGGING: send_buf recv_buf shape= ", send_buf.shape, recv_buf.shape) - print("DEBUGGING: send_counts recv_counts = ", send_counts, recv_counts) - print("DEBUGGING: send_displs recv_displs = ", send_displs, recv_displs) - print("DEBUGGING: output_split = ", output_split) + # print("DEBUGGING: send_buf recv_buf shape= ", send_buf.shape, recv_buf.shape) + # print("DEBUGGING: send_counts recv_counts = ", send_counts, recv_counts) + # print("DEBUGGING: send_displs recv_displs = ", send_displs, recv_displs) + # print("DEBUGGING: output_split = ", output_split) self.comm.Alltoallv( (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs), @@ -1851,8 +1851,8 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar ) map = [slice(None)] * recv_buf.ndim - print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) - print("DEBUGGING: key[original_split] = ", key[original_split]) + # print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) + # print("DEBUGGING: key[original_split] = ", key[original_split]) if broadcasted_indexing: map[original_split] = outgoing_request_key.argsort(stable=True)[ key[original_split].argsort(stable=True).argsort(stable=True) @@ -2607,12 +2607,39 @@ def __set( return # split_key_is_ordered == 0 -> key along split axis is unordered, communication needed - # key along the split axis is 1-D torch tensor, indices are global - - # non-ordered key along split axis - # indices are global + # key along the split axis is 1-D torch tensor, but indices are GLOBAL + counts, displs = self.counts_displs() + # rank, size = self.comm.rank, self.comm.size + rank = self.comm.rank + # define key as mask_like if each element of key is a torch.Tensor and all elements of key are of the same shape + key_is_mask_like = ( + all(isinstance(k, torch.Tensor) for k in key) and len(set(k.shape for k in key)) == 1 + ) - # process-local indices + if not value.is_distributed(): + if key_is_mask_like: + split_key = key[self.split] + # find elements of `split_key` that are local to this process + local_indices = torch.nonzero( + (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) + ).flatten() + # keep local indexing key only and correct for displacements along the split axis + key = list(key) + key = tuple( + [ + key[i][local_indices] - displs[rank] + if i == self.split + else key[i][local_indices] + for i in range(len(key)) + ] + ) + # set local elements of `self` to corresponding elements of `value` + # + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + self = self.transpose(backwards_transpose_axes) + return + # key not mask_like + # both `self` and `value` are distributed # if advanced_indexing: # raise Exception("Advanced indexing is not supported yet") diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 78db68e81a..3cbdaa0837 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1619,11 +1619,6 @@ def test_setitem(self): self.assertTrue((arr_split0[mask_split0] == value[mask]).all().item()) arr_split1 = ht.array(arr, split=1) mask_split1 = ht.array(mask, split=1) - print( - "DEBUGGING: arr_split1[mask_split1].shape, value[mask].shape = ", - arr_split1[mask_split1].shape, - value[mask].shape, - ) arr_split1[mask_split1] = value[mask] self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) # arr_split2 = ht.array(arr, split=2) From d754a9c9d7d8b28f9fd06eba23e9eefe8a64858b Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 19 Jan 2024 05:37:19 +0100 Subject: [PATCH 101/132] allow for partial boolean indexing on first key.ndim dims of array --- heat/core/dndarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index d97d36177e..17f29b26d8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -893,8 +893,9 @@ def __process_key( raise IndexError("Invalid indices: expected a list of integers, got {}".format(key)) if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool_, np.uint8): - # boolean indexing: shape must match arr.shape - if not tuple(key.shape) == arr.shape: + # boolean indexing: shape must be consistent with arr.shape + key_ndim = key.ndim + if not tuple(key.shape) == arr.shape[:key_ndim]: raise IndexError( "Boolean index of shape {} does not match indexed array of shape {}".format( tuple(key.shape), arr.shape @@ -920,7 +921,7 @@ def __process_key( key = key.nonzero() # convert to torch tensor key = tuple(torch.tensor(k, device=arr.larray.device) for k in key) - output_shape = tuple(key[0].shape) + output_shape = tuple(key[0].shape) + arr.shape[key_ndim:] new_split = None if arr.split is None else 0 out_is_balanced = True split_key_is_ordered = 1 From 5e69fe689413d7c8d467675edd12df2eef26243a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 19 Jan 2024 05:57:23 +0100 Subject: [PATCH 102/132] remove unnecessary check --- heat/core/dndarray.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 17f29b26d8..664f885a85 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1842,15 +1842,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar return factories.array(indexed_arr, is_split=output_split, copy=False) outgoing_request_key = outgoing_request_key.squeeze_(1) - # incoming elements likely already stacked in ascending or descending order - # TODO: is this check really worth it? blanket argsort solution below might be ok - if (key[original_split] == outgoing_request_key).all(): - return factories.array(recv_buf, is_split=output_split, copy=False) - if (key[original_split] == outgoing_request_key.flip(dims=(0,))).all(): - return factories.array( - recv_buf.flip(dims=(output_split,)), is_split=output_split, copy=False - ) - map = [slice(None)] * recv_buf.ndim # print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) # print("DEBUGGING: key[original_split] = ", key[original_split]) From 8d9849ee5a2795cb6e4a1b31d32f2d4024436480 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 19 Jan 2024 05:57:50 +0100 Subject: [PATCH 103/132] add tests for partial boolean indexing --- heat/core/tests/test_dndarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 3cbdaa0837..81d416bc39 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1621,9 +1621,10 @@ def test_setitem(self): mask_split1 = ht.array(mask, split=1) arr_split1[mask_split1] = value[mask] self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) - # arr_split2 = ht.array(arr, split=2) - # mask_split2 = ht.array(mask, split=2) - # self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) + arr_split2 = ht.array(arr, split=2) + mask_split2 = ht.array(mask, split=2) + arr_split2[mask_split2] = value[mask] + self.assertTrue((arr_split2[mask_split2] == value[mask]).all().item()) # def test_setitem_getitem(self): # # tests for bug #825 From 66ae3710cd665ab5a40cdf529cae0a56a39aebfb Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 22 Jan 2024 06:16:30 +0100 Subject: [PATCH 104/132] set w. single-tensor key and non-distr value --- heat/core/dndarray.py | 76 +++++++++++++++----------------- heat/core/tests/test_dndarray.py | 18 ++++---- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 664f885a85..4f5d16c58b 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1044,12 +1044,19 @@ def __process_key( ) if key.split is not None: out_is_balanced = key.balanced - split_key_is_ordered = factories.array( - [split_key_is_ordered], - is_split=0, - device=arr.device, - copy=False, - ).all() + split_key_is_ordered = ( + factories.array( + [split_key_is_ordered], + is_split=0, + device=arr.device, + copy=False, + ) + .all() + .astype(types.canonical_heat_types.uint8) + .item() + ) + else: + split_key_is_ordered = split_key_is_ordered.item() key = key.larray except AttributeError: # torch or ndarray key @@ -1565,7 +1572,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar key_shapes = [] for k in key: key_shapes.append(getattr(k, "shape", None)) - # print("KEY SHAPES = ", key_shapes) return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim # check for broadcasted indexing: key along split axis is not 1D broadcasted_indexing = ( @@ -1579,7 +1585,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar send_axis = original_split else: send_axis = output_split - # print("RANK, RETURN_1D, broadcasted_indexing = ", self.comm.rank, return_1d, broadcasted_indexing) # send and receive "request key" info on what data element to ship where recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) @@ -1681,13 +1686,11 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar indexed_arr, is_split=output_split, device=self.device, copy=False ) - # print("RECV_COUNTS = ", recv_counts) # share recv_counts among all processes comm_matrix = torch.empty( (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device ) self.comm.Allgather(recv_counts, comm_matrix) - # print("DEBUGGING: comm_matrix = ", comm_matrix, comm_matrix.shape) outgoing_request_key_counts = comm_matrix[self.comm.rank] outgoing_request_key_displs = torch.cat( @@ -1739,7 +1742,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar incoming_request_key_displs.tolist(), ), ) - # print("DEBUGGING:incoming_request_key = ", incoming_request_key) if return_1d: incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) incoming_request_key[original_split] -= displs[self.comm.rank] @@ -1751,8 +1753,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar + key[original_split + 1 :] ) - # print("AFTER: incoming_request_key = ", incoming_request_key) - # print("original_split = ", original_split) # calculate shape of local recv buffer output_lshape = list(output_shape) if getattr(key, "ndim", 0) == 1: @@ -1793,33 +1793,10 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if not all_keys_scalar: send_buf = send_buf.unsqueeze_(dim=output_split) - # print("OUTPUT_SHAPE = ", output_shape) - # print("OUTPUT_SPLIT = ", output_split) - # print("SEND_BUF SHAPE = ", send_buf.shape) - - # output_lshape = list(output_shape) - # if getattr(key, "ndim", 0) == 1: - # output_lshape[output_split] = key.shape[0] - # else: - # if broadcasted_indexing: - # output_lshape = ( - # output_lshape[:original_split] - # + [torch.prod(torch.tensor(broadcast_shape, device=send_buf.device)).item()] - # + output_lshape[output_split + 1 :] - # ) - # else: - # output_lshape[output_split] = key[original_split].shape[0] - # recv_buf = torch.empty( - # tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device - # ) recv_counts = torch.squeeze(recv_counts, dim=1).tolist() recv_displs = outgoing_request_key_displs.tolist() send_counts = incoming_request_key_counts.tolist() send_displs = incoming_request_key_displs.tolist() - # print("DEBUGGING: send_buf recv_buf shape= ", send_buf.shape, recv_buf.shape) - # print("DEBUGGING: send_counts recv_counts = ", send_counts, recv_counts) - # print("DEBUGGING: send_displs recv_displs = ", send_displs, recv_displs) - # print("DEBUGGING: output_split = ", output_split) self.comm.Alltoallv( (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs), @@ -2603,12 +2580,30 @@ def __set( counts, displs = self.counts_displs() # rank, size = self.comm.rank, self.comm.size rank = self.comm.rank - # define key as mask_like if each element of key is a torch.Tensor and all elements of key are of the same shape - key_is_mask_like = ( - all(isinstance(k, torch.Tensor) for k in key) and len(set(k.shape for k in key)) == 1 - ) + # + single_tensor_key = isinstance(key, torch.Tensor) + key_is_mask_like = False + # define key as mask_like if each element of key is a torch.Tensor and all elements of key are of the same shape + if not single_tensor_key: + key_is_mask_like = ( + all(isinstance(k, torch.Tensor) for k in key) + and len(set(k.shape for k in key)) == 1 + ) if not value.is_distributed(): + if single_tensor_key: + # key is a single torch.Tensor + split_key = key + # find elements of `split_key` that are local to this process + local_indices = torch.nonzero( + (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) + ).flatten() + # keep local indexing key only and correct for displacements along the split axis + key = key[local_indices] - displs[rank] + # set local elements of `self` to corresponding elements of `value` + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + self = self.transpose(backwards_transpose_axes) + return if key_is_mask_like: split_key = key[self.split] # find elements of `split_key` that are local to this process @@ -2631,6 +2626,7 @@ def __set( self = self.transpose(backwards_transpose_axes) return # key not mask_like + # both `self` and `value` are distributed # if advanced_indexing: diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 81d416bc39..ebdaea270b 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1533,13 +1533,13 @@ def test_setitem(self): self.assertTrue(ht.all(adv_indexed_x == value).item()) self.assertTrue(adv_indexed_x.dtype == x.dtype) - # # 1d - # x = ht.arange(10, 1, -1, split=0) - # x_np = np.arange(10, 1, -1) - # x_adv_ind = x[np.array([3, 3, 1, 8])] - # x_np_adv_ind = x_np[np.array([3, 3, 1, 8])] - # self.assert_array_equal(x_adv_ind, x_np_adv_ind) - + # 1d + x = ht.arange(10, 1, -1, split=0) + value = ht.arange(4) + x[ht.array([3, 2, 1, 8])] = value + x_adv_ind = x[np.array([3, 2, 1, 8])] + self.assertTrue(ht.all(x_adv_ind == value).item()) + self.assertTrue(x_adv_ind.dtype == x.dtype) # # 3d, split 0, non-unique, non-ordered key along split axis # x = ht.arange(60, split=0).reshape(5, 3, 4) # x_np = np.arange(60).reshape(5, 3, 4) @@ -1612,7 +1612,7 @@ def test_setitem(self): arr[mask] = value[mask] self.assertTrue((arr[mask] == value[mask]).all().item()) - # boolean mask, distributed + # boolean mask, distributed, non-distributed `value` arr_split0 = ht.array(arr, split=0) mask_split0 = ht.array(mask, split=0) arr_split0[mask_split0] = value[mask] @@ -1626,6 +1626,8 @@ def test_setitem(self): arr_split2[mask_split2] = value[mask] self.assertTrue((arr_split2[mask_split2] == value[mask]).all().item()) + # TODO boolean mask, distributed, distributed `value` + # def test_setitem_getitem(self): # # tests for bug #825 # a = ht.ones((102, 102), split=0) From ae4d4239fd05f8db38d727913f3939ec626ca2cc Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 5 Feb 2024 06:10:53 +0100 Subject: [PATCH 105/132] non-ordered, non-mask-like key and local value --- heat/core/dndarray.py | 113 ++++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 4f5d16c58b..e4e93d5ea7 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2533,12 +2533,7 @@ def __set( ) self.comm.Allgather(target_shape, target_map) value.redistribute_(target_map=target_map) - process_is_inactive = sum( - list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) - ) - if not process_is_inactive: - # only assign values if key does not contain empty slices - __set(self, key, value) + __set(self, key, value) self = self.transpose(backwards_transpose_axes) return @@ -2565,67 +2560,75 @@ def __set( target_map = flipped_value.lshape_map target_map[:, output_split] = split_key.lshape_map[:, 0] flipped_value.redistribute_(target_map=target_map) - - process_is_inactive = sum( - list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in key) - ) - if not process_is_inactive: - # only assign values if key does not contain empty slices - __set(self, key, flipped_value) + __set(self, key, flipped_value) self = self.transpose(backwards_transpose_axes) return - # split_key_is_ordered == 0 -> key along split axis is unordered, communication needed - # key along the split axis is 1-D torch tensor, but indices are GLOBAL - counts, displs = self.counts_displs() - # rank, size = self.comm.rank, self.comm.size - rank = self.comm.rank + if split_key_is_ordered == 0: + # key along split axis is unordered, communication needed + # key along the split axis is 1-D torch tensor, but indices are GLOBAL + counts, displs = self.counts_displs() + # rank, size = self.comm.rank, self.comm.size + rank = self.comm.rank - # - single_tensor_key = isinstance(key, torch.Tensor) - key_is_mask_like = False - # define key as mask_like if each element of key is a torch.Tensor and all elements of key are of the same shape - if not single_tensor_key: - key_is_mask_like = ( - all(isinstance(k, torch.Tensor) for k in key) - and len(set(k.shape for k in key)) == 1 - ) - if not value.is_distributed(): - if single_tensor_key: - # key is a single torch.Tensor - split_key = key - # find elements of `split_key` that are local to this process - local_indices = torch.nonzero( - (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) - ).flatten() - # keep local indexing key only and correct for displacements along the split axis - key = key[local_indices] - displs[rank] - # set local elements of `self` to corresponding elements of `value` - self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) - self = self.transpose(backwards_transpose_axes) - return - if key_is_mask_like: + # + key_is_single_tensor = isinstance(key, torch.Tensor) + key_is_mask_like = False + # define key as mask_like if each element of key is a torch.Tensor and all elements of key are of the same shape + if not key_is_single_tensor: + key_is_mask_like = ( + all(isinstance(k, torch.Tensor) for k in key) + and len(set(k.shape for k in key)) == 1 + ) + if not value.is_distributed(): + if key_is_single_tensor: + # key is a single torch.Tensor + split_key = key + # find elements of `split_key` that are local to this process + local_indices = torch.nonzero( + (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) + ).flatten() + # keep local indexing key only and correct for displacements along the split axis + key = key[local_indices] - displs[rank] + # set local elements of `self` to corresponding elements of `value` + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + self = self.transpose(backwards_transpose_axes) + return + # key is a sequence of torch.Tensors split_key = key[self.split] # find elements of `split_key` that are local to this process local_indices = torch.nonzero( (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) ).flatten() - # keep local indexing key only and correct for displacements along the split axis key = list(key) - key = tuple( - [ - key[i][local_indices] - displs[rank] - if i == self.split - else key[i][local_indices] - for i in range(len(key)) - ] - ) - # set local elements of `self` to corresponding elements of `value` - # - self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + if key_is_mask_like: + # keep local indexing keys across all dimensions + # correct for displacements along the split axis + key = tuple( + [ + key[i][local_indices] - displs[rank] + if i == self.split + else key[i][local_indices] + for i in range(len(key)) + ] + ) + if not key[self.split].numel() == 0: + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + else: + # keep local indexing key and correct for displacements along split dimension + key[self.split] = key[self.split][local_indices] - displs[rank] + key = tuple(key) + value_key = tuple( + [ + local_indices if i == output_split else slice(None) + for i in range(value.ndim) + ] + ) + # set local elements of `self` to corresponding elements of `value` + if not key[self.split].numel() == 0: + self.larray[key] = value.larray[value_key].type(self.dtype.torch_type()) self = self.transpose(backwards_transpose_axes) return - # key not mask_like # both `self` and `value` are distributed From b695e5ac80bb69676d8ce90495ad69f50c4c4be8 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 7 Feb 2024 06:24:52 +0100 Subject: [PATCH 106/132] broken: set up comm map for full distributed setitem --- heat/core/dndarray.py | 49 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index e4e93d5ea7..4fc76c32d2 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2630,7 +2630,54 @@ def __set( self = self.transpose(backwards_transpose_axes) return - # both `self` and `value` are distributed + # both `self` and `value` are distributed + # distribution of `key` and `value` must be aligned + if key_is_mask_like: + # redistribute `value` to match distribution of `key` in one pass + split_key = key[self.split] + global_split_key = factories.array( + split_key, is_split=0, device=self.device, comm=self.comm, copy=False + ) + target_map = value.lshape_map + target_map[:, value.split] = global_split_key.lshape_map[:, 0] + value.redistribute_(target_map=target_map) + else: + # redistribute split-axis `key` to match distribution of `value` in one pass + if key_is_single_tensor: + # key is a single torch.Tensor + split_key = key + elif not key_is_mask_like: + split_key = key[self.split] + global_split_key = factories.array( + split_key, is_split=0, device=self.device, comm=self.comm, copy=False + ) + target_map = global_split_key.lshape_map + target_map[:, 0] = value.lshape_map[:, value.split] + global_split_key.redistribute_(target_map=target_map) + split_key = global_split_key.larray + + # key and value are now aligned + # create communication map, stack `value`elements according to destination rank + value_counts, value_displs = value.counts_displs() + recv_counts = torch.zeros( + self.comm.size, dtype=torch.int64, device=self.device.torch_device + ) + recv_displs = torch.zeros_like(recv_counts) + send_buf = torch.zeros_like(value.larray) + for recv_process in range(self.comm.size): + # find elements of `split_key` that are local to `recv_process` + local_indices = torch.nonzero( + (split_key >= displs[recv_process]) + & (split_key < displs[recv_process] + counts[recv_process]) + ).flatten() + recv_counts[recv_process] = local_indices.numel() + recv_displs[recv_process] = ( + recv_counts[:recv_process].sum().item() if recv_process > 0 else 0 + ) + send_buf[ + recv_displs[recv_process] : recv_displs[recv_process] + + recv_counts[recv_process] + ] = value.larray[local_indices] # if advanced_indexing: # raise Exception("Advanced indexing is not supported yet") From e6c1e1008242477dfe0f461a962d2a1fed58d87c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Feb 2024 05:26:09 +0000 Subject: [PATCH 107/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- heat/core/dndarray.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 4fc76c32d2..7a2663babd 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2606,9 +2606,11 @@ def __set( # correct for displacements along the split axis key = tuple( [ - key[i][local_indices] - displs[rank] - if i == self.split - else key[i][local_indices] + ( + key[i][local_indices] - displs[rank] + if i == self.split + else key[i][local_indices] + ) for i in range(len(key)) ] ) From d42f1cbb92b98b2f567de03564df260264cca436 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:38:28 +0100 Subject: [PATCH 108/132] implement setitem w. distributed non-ordered key --- heat/core/dndarray.py | 96 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 17 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 4fc76c32d2..2d6e278b1d 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2657,27 +2657,89 @@ def __set( split_key = global_split_key.larray # key and value are now aligned - # create communication map, stack `value`elements according to destination rank - value_counts, value_displs = value.counts_displs() - recv_counts = torch.zeros( + + # prepare for `value` Alltoallv: + # work along axis 0, transpose if necessary + transpose_axes = list(range(value.ndim)) + transpose_axes[0], transpose_axes[value.split] = ( + transpose_axes[value.split], + transpose_axes[0], + ) + value = value.transpose(transpose_axes) + send_counts = torch.zeros( self.comm.size, dtype=torch.int64, device=self.device.torch_device ) - recv_displs = torch.zeros_like(recv_counts) - send_buf = torch.zeros_like(value.larray) - for recv_process in range(self.comm.size): - # find elements of `split_key` that are local to `recv_process` - local_indices = torch.nonzero( - (split_key >= displs[recv_process]) - & (split_key < displs[recv_process] + counts[recv_process]) + send_displs = torch.zeros_like(send_counts) + # allocate send buffer: add 1 column to store sent indices + send_buf_shape = list(value.lshape) + send_buf_shape[-1] += 1 + send_buf = torch.zeros( + send_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + ) + for proc in range(self.comm.size): + # calculate what local elements of `value` belong on process `proc` + send_indices = torch.nonzero( + (split_key >= displs[proc]) & (split_key < displs[proc] + counts[proc]) ).flatten() - recv_counts[recv_process] = local_indices.numel() - recv_displs[recv_process] = ( - recv_counts[:recv_process].sum().item() if recv_process > 0 else 0 - ) + # calculate outgoing counts and displacements for each process + send_counts[proc] = send_indices.numel() + send_displs[proc] = send_counts[:proc].sum() + # compose send buffer: stack local elements of `value` according to destination process send_buf[ - recv_displs[recv_process] : recv_displs[recv_process] - + recv_counts[recv_process] - ] = value.larray[local_indices] + send_displs[proc] : send_displs[proc] + send_counts[proc], :-1 + ] = value.larray[send_indices] + # store outgoing indices in the last column of send_buf + while send_indices.ndim < send_buf.ndim: + # broadcast send_indices to correct shape + send_indices = send_indices.unsqueeze(-1) + send_buf[ + send_displs[proc] : send_displs[proc] + send_counts[proc], -1 + ] = send_indices + + # compose communication matrix: share `send_counts` information with all processes + comm_matrix = torch.zeros( + (self.comm.size, self.comm.size), + dtype=torch.int64, + device=self.device.torch_device, + ) + self.comm.Allgather(send_counts, comm_matrix) + # comm_matrix columns contain recv_counts for each process + recv_counts = comm_matrix[:, self.comm.rank].squeeze(0) + recv_displs = torch.zeros_like(recv_counts) + recv_displs[1:] = recv_counts.cumsum(0)[:-1] + # allocate receive buffer, with 1 extra column for incoming indices + recv_shape = value.lshape_map[self.comm.rank] + recv_shape[value.split] = recv_counts.sum() + recv_shape[-1] += 1 + recv_shape = tuple(recv_shape.tolist()) + recv_buf = torch.zeros( + recv_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + ) + # perform Alltoallv along the 0 axis + self.comm.Alltoallv( + (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) + ) + del send_buf, comm_matrix + # store incoming indices in int 1-D tensor and correct for rank offset + recv_indices = recv_buf[..., -1].type(torch.int64) - displs[rank] + # remove last column from recv_buf + recv_buf = recv_buf[..., :-1] + # transpose back value and recv_buf if necessary, wrap recv_buf in DNDarray + value = value.transpose(transpose_axes) + recv_buf = DNDarray( + recv_buf.permute(*transpose_axes), + gshape=value.gshape, + split=value.split, + device=value.device, + comm=value.comm, + balanced=value.balanced, + ) + # replace split-axis key with incoming local indices + key = list(key) + key[self.split] = recv_indices + key = tuple(key) + # set local elements of `self` to corresponding elements of `value` + __set(self, key, recv_buf) # if advanced_indexing: # raise Exception("Advanced indexing is not supported yet") From 7868fa0133190266e668ca8c30fb385f24a935d5 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:51:42 +0100 Subject: [PATCH 109/132] [skip ci] broken: add tests for distr value non-ordered key --- heat/core/tests/test_dndarray.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ebdaea270b..2e9c3e786e 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1540,15 +1540,18 @@ def test_setitem(self): x_adv_ind = x[np.array([3, 2, 1, 8])] self.assertTrue(ht.all(x_adv_ind == value).item()) self.assertTrue(x_adv_ind.dtype == x.dtype) - # # 3d, split 0, non-unique, non-ordered key along split axis - # x = ht.arange(60, split=0).reshape(5, 3, 4) - # x_np = np.arange(60).reshape(5, 3, 4) - # k1 = np.array([0, 4, 1, 0]) - # k2 = np.array([0, 2, 1, 0]) - # k3 = np.array([1, 2, 3, 1]) - # self.assert_array_equal( - # x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] - # ) + + # TODO: n-d value + + # 3d, split 0, non-unique, non-ordered key along split axis, key mask-like + x = ht.arange(60, split=0).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + value = ht.array([99, 98, 97, 96], split=0) + x[k1, k2, k3] = value + print(x.comm.rank, x.larray) + # self.assertTrue((x[k1, k2, k3] == value).all().item()) # # advanced indexing on non-consecutive dimensions # x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) # x_copy = x.copy() From 2944903a5ceec32bd015063647f65a279dcd9b1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Feb 2024 09:54:07 +0000 Subject: [PATCH 110/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- heat/core/dndarray.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c8a5fc9f8e..25de8e003c 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2687,16 +2687,16 @@ def __set( send_counts[proc] = send_indices.numel() send_displs[proc] = send_counts[:proc].sum() # compose send buffer: stack local elements of `value` according to destination process - send_buf[ - send_displs[proc] : send_displs[proc] + send_counts[proc], :-1 - ] = value.larray[send_indices] + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices] + ) # store outgoing indices in the last column of send_buf while send_indices.ndim < send_buf.ndim: # broadcast send_indices to correct shape send_indices = send_indices.unsqueeze(-1) - send_buf[ - send_displs[proc] : send_displs[proc] + send_counts[proc], -1 - ] = send_indices + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], -1] = ( + send_indices + ) # compose communication matrix: share `send_counts` information with all processes comm_matrix = torch.zeros( From 1bffd26b1ea284742696600b6bb54646053f3c93 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 12 Feb 2024 07:42:34 +0100 Subject: [PATCH 111/132] __process_key(): refactor adv indexing tensor extraction --- heat/core/dndarray.py | 196 ++++++++++++++++++++++++++++++++---------- 1 file changed, 151 insertions(+), 45 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c8a5fc9f8e..64ceade256 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -937,8 +937,12 @@ def __process_key( ) # arr is distributed - if not isinstance(key, DNDarray) or not key.is_distributed(): - key = factories.array(key, split=arr.split, device=arr.device) + if not isinstance(key, DNDarray) or ( + isinstance(key, DNDarray) and not key.is_distributed() + ): + key = factories.array( + key, split=arr.split, device=arr.device, comm=arr.comm, copy=None + ) else: if key.split != arr.split: raise IndexError( @@ -1164,36 +1168,24 @@ def __process_key( elif isinstance(k, Iterable) or isinstance(k, DNDarray): advanced_indexing = True advanced_indexing_dims.append(i) - if isinstance(k, DNDarray): - advanced_indexing_shapes.append(k.gshape) - if arr_is_distributed and i == arr.split: - # we have no info on order of indices + # work with DNDarrays to assess distribution + # torch tensors will be extracted in the advanced indexing section below + k = factories.array(k, device=arr.device, comm=arr.comm, copy=None) + advanced_indexing_shapes.append(k.gshape) + if arr_is_distributed and i == arr.split: + if ( + not k.is_distributed() + and k.ndim == 1 + and (k.larray == torch.sort(k.larray, stable=True)[0]).all() + ): + split_key_is_ordered = 1 + out_is_balanced = None + else: split_key_is_ordered = 0 # redistribute key along last axis to match split axis of indexed array k = k.resplit(-1) out_is_balanced = True - key[i] = k.larray - elif not isinstance(k, torch.Tensor): - key[i] = torch.tensor(k, dtype=torch.int64, device=arr.larray.device) - advanced_indexing_shapes.append(tuple(key[i].shape)) - # IMPORTANT: here we assume that torch or ndarray key is THE SAME SET OF GLOBAL INDICES on every rank - if arr_is_distributed and i == arr.split: - # make no assumption on data locality wrt key - out_is_balanced = None - # assess if indices are in ascending order - if ( - key[i].ndim == 1 - and (key[i] == torch.sort(key[i], stable=True)[0]).all() - ): - split_key_is_ordered = 1 - # extract local key - cond1 = key[i] >= displs[arr.comm.rank] - cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] - key[i] = key[i][cond1 & cond2] - if return_local_indices: - key[i] -= displs[arr.comm.rank] - else: - split_key_is_ordered = 0 + key[i] = k elif isinstance(k, slice) and k != slice(None): start, stop, step = k.start, k.stop, k.step @@ -1266,9 +1258,66 @@ def __process_key( output_shape[i] = 0 if advanced_indexing: + # adv indexing key elements are DNDarrays: extract torch tensors + # options: 1. key is mask-like (covers boolean mask as well), 2. key along arr.split is DNDarray, 3. everything else + # 1. define key as mask_like if each element of key is a DNDarray, and all elements of key are of the same shape, and the advanced-indexing dimensions are consecutive + key_is_mask_like = ( + all(isinstance(k, DNDarray) for k in key) + and len(set(k.shape for k in key)) == 1 + and torch.tensor(advanced_indexing_dims).diff().eq(1).all() + ) + if key_is_mask_like: + key = list(key) + key_splits = [k.split for k in key] + non_split_dims = list(advanced_indexing_dims).copy() + non_split_dims.remove(arr.split) + if not key_splits.count(key_splits[arr.split]) == len(key_splits): + if ( + key_splits[arr.split] is not None + and key_splits.count(None) == len(key_splits) - 1 + ): + for i in non_split_dims: + key[i] = factories.array( + key[i], + split=key_splits[arr.split], + device=arr.device, + comm=arr.comm, + copy=None, + ) + else: + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." + ) + # all key elements are now DNDarrays of the same shape, same split axis + # 2. key along arr.split is DNDarray + if arr.split in advanced_indexing_dims: + if split_key_is_ordered == 1: + # extract torch tensors, keep process-local indices only + k = key[arr.split].larray + cond1 = k >= displs[arr.comm.rank] + cond2 = k < displs[arr.comm.rank] + counts[arr.comm.rank] + k = k[cond1 & cond2] + if return_local_indices: + k -= displs[arr.comm.rank] + key[arr.split] = k + if key_is_mask_like: + # select the same elements along non-split dimensions + for i in non_split_dims: + key[i] = key[i].larray[cond1 & cond2] + elif split_key_is_ordered == 0: + # extract torch tensors, any other communication + mask-like case are handled in __getitem__ or __setitem__ + for i in advanced_indexing_dims: + key[i] = key[i].larray + # split_key_is_ordered == -1 not treated here as it is slicing, not advanced indexing + else: + # advanced indexing does not affect split axis, return torch tensors + for i in advanced_indexing_dims: + key[i] = key[i].larray print("ADV IND KEY = ", key) print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) - # shapes of indexing arrays must be broadcastable + # all adv indexing keys are now torch tensors + + # shapes of adv indexing arrays must be broadcastable try: broadcasted_shape = torch.broadcast_shapes(*advanced_indexing_shapes) except RuntimeError: @@ -2641,7 +2690,10 @@ def __set( split_key, is_split=0, device=self.device, comm=self.comm, copy=False ) target_map = value.lshape_map + print("DEBUGGING: target_map = ", target_map) + print("DEBUGGING: global_split_key.lshape_map = ", global_split_key.lshape_map) target_map[:, value.split] = global_split_key.lshape_map[:, 0] + print("DEBUGGING: target_map AFTER = ", target_map) value.redistribute_(target_map=target_map) else: # redistribute split-axis `key` to match distribution of `value` in one pass @@ -2674,29 +2726,49 @@ def __set( send_displs = torch.zeros_like(send_counts) # allocate send buffer: add 1 column to store sent indices send_buf_shape = list(value.lshape) - send_buf_shape[-1] += 1 + if value.ndim < 2: + send_buf_shape.append(1) + if key_is_mask_like: + send_buf_shape[-1] += len(key) + else: + send_buf_shape[-1] += 1 + print("DEBUGGING: send_buf_shape = ", send_buf_shape) send_buf = torch.zeros( send_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device ) + print("DEBUGGING: BEFORE LOOP: counts, displs = ", counts, displs) + print("debugging: key_is_mask_like = ", key_is_mask_like) for proc in range(self.comm.size): # calculate what local elements of `value` belong on process `proc` send_indices = torch.nonzero( (split_key >= displs[proc]) & (split_key < displs[proc] + counts[proc]) ).flatten() + print( + "DEBUGGING: proc, send_indices = ", proc, send_indices, split_key[send_indices] + ) # calculate outgoing counts and displacements for each process send_counts[proc] = send_indices.numel() send_displs[proc] = send_counts[:proc].sum() # compose send buffer: stack local elements of `value` according to destination process - send_buf[ - send_displs[proc] : send_displs[proc] + send_counts[proc], :-1 - ] = value.larray[send_indices] - # store outgoing indices in the last column of send_buf - while send_indices.ndim < send_buf.ndim: - # broadcast send_indices to correct shape - send_indices = send_indices.unsqueeze(-1) - send_buf[ - send_displs[proc] : send_displs[proc] + send_counts[proc], -1 - ] = send_indices + if send_indices.numel() > 0: + send_buf[ + send_displs[proc] : send_displs[proc] + send_counts[proc], :-1 + ] = value.larray[send_indices] + # store outgoing GLOBAL indices in the last column of send_buf + # TODO: if key_is_mask_like: apply send_indices to all dimensions of key + if key_is_mask_like: + for i in range(-len(key), 0): + send_buf[ + send_displs[proc] : send_displs[proc] + send_counts[proc], i + ] = key[i + len(key)][send_indices] + else: + while send_indices.ndim < send_buf.ndim: + send_indices = split_key[send_indices] + # broadcast send_indices to correct shape + send_indices = send_indices.unsqueeze(-1) + send_buf[ + send_displs[proc] : send_displs[proc] + send_counts[proc], -1 + ] = send_indices # compose communication matrix: share `send_counts` information with all processes comm_matrix = torch.zeros( @@ -2705,32 +2777,66 @@ def __set( device=self.device.torch_device, ) self.comm.Allgather(send_counts, comm_matrix) + print("DEBUGGING:, RANK, SEND_BUF = ", self.comm.rank, send_buf) # comm_matrix columns contain recv_counts for each process recv_counts = comm_matrix[:, self.comm.rank].squeeze(0) recv_displs = torch.zeros_like(recv_counts) recv_displs[1:] = recv_counts.cumsum(0)[:-1] # allocate receive buffer, with 1 extra column for incoming indices - recv_shape = value.lshape_map[self.comm.rank] - recv_shape[value.split] = recv_counts.sum() - recv_shape[-1] += 1 - recv_shape = tuple(recv_shape.tolist()) + recv_buf_shape = value.lshape_map[self.comm.rank] + recv_buf_shape[value.split] = recv_counts.sum() + recv_buf_shape = recv_buf_shape.tolist() + if value.ndim < 2: + recv_buf_shape.append(1) + if key_is_mask_like: + recv_buf_shape[-1] += len(key) + else: + recv_buf_shape[-1] += 1 + recv_buf_shape = tuple(recv_buf_shape) + print("DEBUGGING: recv_buf_shape = ", recv_buf_shape) recv_buf = torch.zeros( - recv_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + recv_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device ) # perform Alltoallv along the 0 axis + send_counts, send_displs, recv_counts, recv_displs = ( + send_counts.tolist(), + send_displs.tolist(), + recv_counts.tolist(), + recv_displs.tolist(), + ) + print("DEBUGGING: send_buf.shape, recv_buf.shape = ", send_buf.shape, recv_buf.shape) + print( + "DEBUGGING: send_counts, send_displs, recv_counts, recv_displs = ", + send_counts, + send_displs, + recv_counts, + recv_displs, + ) self.comm.Alltoallv( (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) ) del send_buf, comm_matrix + key = list(key) + print("DEBUGGING: recv_buf = ", recv_buf) + if key_is_mask_like: + # extract incoming indices from recv_buf + recv_indices = recv_buf[..., -len(key) :] + recv_buf = recv_buf[..., : -len(key)] + # store incoming indices in int 1-D tensor and correct for rank offset recv_indices = recv_buf[..., -1].type(torch.int64) - displs[rank] # remove last column from recv_buf recv_buf = recv_buf[..., :-1] # transpose back value and recv_buf if necessary, wrap recv_buf in DNDarray value = value.transpose(transpose_axes) + if value.ndim < 2: + recv_buf.squeeze_(1) + print("DEBUGGING: transpose_axes = ", transpose_axes) + print("DEBUGGING: value.shape, recv_buf.shape = ", value.shape, recv_buf.shape) recv_buf = DNDarray( recv_buf.permute(*transpose_axes), gshape=value.gshape, + dtype=value.dtype, split=value.split, device=value.device, comm=value.comm, From 83842ec7653e0df2f0ca8cba05c705b399a5b187 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 12 Feb 2024 08:05:39 +0100 Subject: [PATCH 112/132] working: setitem w. mask-like adv indexing, non-ordered split key --- heat/core/dndarray.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 64ceade256..dc9b436e94 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1290,7 +1290,7 @@ def __process_key( ) # all key elements are now DNDarrays of the same shape, same split axis # 2. key along arr.split is DNDarray - if arr.split in advanced_indexing_dims: + if arr.is_distributed() and arr.split in advanced_indexing_dims: if split_key_is_ordered == 1: # extract torch tensors, keep process-local indices only k = key[arr.split].larray @@ -2821,12 +2821,21 @@ def __set( if key_is_mask_like: # extract incoming indices from recv_buf recv_indices = recv_buf[..., -len(key) :] + # correct split-axis indices for rank offset + recv_indices[:, 0] -= displs[rank] + key = recv_indices.split(1, dim=1) + key = [key[i].squeeze_(1) for i in range(len(key))] + # remove indices from recv_buf recv_buf = recv_buf[..., : -len(key)] - - # store incoming indices in int 1-D tensor and correct for rank offset - recv_indices = recv_buf[..., -1].type(torch.int64) - displs[rank] - # remove last column from recv_buf - recv_buf = recv_buf[..., :-1] + else: + # store incoming indices in int 1-D tensor and correct for rank offset + recv_indices = recv_buf[..., -1].type(torch.int64) - displs[rank] + # remove last column from recv_buf + recv_buf = recv_buf[..., :-1] + # replace split-axis key with incoming local indices + key = list(key) + key[self.split] = recv_indices + key = tuple(key) # transpose back value and recv_buf if necessary, wrap recv_buf in DNDarray value = value.transpose(transpose_axes) if value.ndim < 2: @@ -2842,10 +2851,6 @@ def __set( comm=value.comm, balanced=value.balanced, ) - # replace split-axis key with incoming local indices - key = list(key) - key[self.split] = recv_indices - key = tuple(key) # set local elements of `self` to corresponding elements of `value` __set(self, key, recv_buf) From 366aaf9e5f0830ce75091ec8fb995b4482ae8a55 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 12 Feb 2024 08:06:13 +0100 Subject: [PATCH 113/132] adapt tests --- heat/core/tests/test_dndarray.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 2e9c3e786e..44c073b0bf 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1550,8 +1550,7 @@ def test_setitem(self): k3 = np.array([1, 2, 3, 1]) value = ht.array([99, 98, 97, 96], split=0) x[k1, k2, k3] = value - print(x.comm.rank, x.larray) - # self.assertTrue((x[k1, k2, k3] == value).all().item()) + self.assertTrue((x[k1, k2, k3] == ht.array([96, 98, 97, 96], split=0)).all().item()) # # advanced indexing on non-consecutive dimensions # x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) # x_copy = x.copy() From bbe0a7b1df0e2cdf1dcc0029eca823f1ac4dae9e Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 12 Feb 2024 08:43:17 +0100 Subject: [PATCH 114/132] refactor __process_key(): address boolean ind within adv ind --- heat/core/dndarray.py | 429 ++++++++++++++++++++++-------------------- 1 file changed, 226 insertions(+), 203 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index dc9b436e94..57cfc0e8e5 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -901,100 +901,209 @@ def __process_key( tuple(key.shape), arr.shape ) ) + # extract non-zero elements try: - # key is DNDarray or ndarray - key = key.copy() - except AttributeError: # key is torch tensor - key = key.clone() - if not arr_is_distributed: + key = key.nonzero(as_tuple=True) + except TypeError: + # key is np.ndarray or DNDarray + key = key.nonzero() + # key is a tuple of arrays/tensors, will be treated as advanced indexing + + # try: + # # key is DNDarray or ndarray + # key = key.copy() + # except AttributeError: + # # key is torch tensor + # key = key.clone() + # if not arr_is_distributed: + # try: + # # key is DNDarray, extract torch tensor + # key = key.larray + # except AttributeError: + # pass + # try: + # # key is torch tensor + # key = key.nonzero(as_tuple=True) + # except TypeError: + # # key is np.ndarray + # key = key.nonzero() + # # convert to torch tensor + # key = tuple(torch.tensor(k, device=arr.larray.device) for k in key) + # output_shape = tuple(key[0].shape) + arr.shape[key_ndim:] + # new_split = None if arr.split is None else 0 + # out_is_balanced = True + # split_key_is_ordered = 1 + # return ( + # arr, + # key, + # output_shape, + # new_split, + # split_key_is_ordered, + # out_is_balanced, + # root, + # backwards_transpose_axes, + # ) + + # # arr is distributed + # if not isinstance(key, DNDarray) or ( + # isinstance(key, DNDarray) and not key.is_distributed() + # ): + # key = factories.array( + # key, split=arr.split, device=arr.device, comm=arr.comm, copy=None + # ) + # else: + # if key.split != arr.split: + # raise IndexError( + # "Boolean index does not match distribution scheme of indexed array. index.split is {}, array.split is {}".format( + # key.split, arr.split + # ) + # ) + # if arr.split == 0: + # # ensure arr and key are aligned + # key.redistribute_(target_map=arr.lshape_map) + # # transform key to sequence of indexing (1-D) arrays + # key = list(key.nonzero()) + # output_shape = key[0].shape + # new_split = 0 + # split_key_is_ordered = 1 + # out_is_balanced = False + # for i, k in enumerate(key): + # key[i] = k.larray + # if return_local_indices: + # key[arr.split] -= displs[arr.comm.rank] + # key = tuple(key) + # else: + # key = key.larray.nonzero(as_tuple=False) + # # construct global key array + # nz_size = torch.tensor(key.shape[0], device=key.device, dtype=key.dtype) + # arr.comm.Allreduce(MPI.IN_PLACE, nz_size, MPI.SUM) + # key_gshape = (nz_size.item(), arr.ndim) + # key[:, arr.split] += displs[arr.comm.rank] + # key_split = 0 + # key = DNDarray( + # key, + # gshape=key_gshape, + # dtype=canonical_heat_type(key.dtype), + # split=key_split, + # device=arr.device, + # comm=arr.comm, + # balanced=False, + # ) + # key.balance_() + # # set output parameters + # output_shape = (key.gshape[0],) + # new_split = 0 + # split_key_is_ordered = 0 + # out_is_balanced = True + # # vectorized sorting of key along axis 0 + # key = manipulations.unique(key, axis=0, return_inverse=False) + # # return tuple key of torch tensors + # key = list(key.larray.split(1, dim=1)) + # for i, k in enumerate(key): + # key[i] = k.squeeze(1) + # key = tuple(key) + + # return ( + # arr, + # key, + # output_shape, + # new_split, + # split_key_is_ordered, + # out_is_balanced, + # root, + # backwards_transpose_axes, + # ) + else: + # advanced indexing on first dimension: first dim will expand to shape of key + output_shape = tuple(list(key.shape) + output_shape[1:]) + print("DEBUGGING ADV IND: output_shape = ", output_shape) + # adjust split axis accordingly + if arr_is_distributed: + if arr.split != 0: + # split axis is not affected + split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] + new_split = ( + split_bookkeeping.index("split") + if "split" in split_bookkeeping + else None + ) + out_is_balanced = arr.balanced + else: + # split axis is affected + if key.ndim > 1: + try: + key_numel = key.numel() + except AttributeError: + key_numel = key.size + if key_numel == arr.shape[0]: + new_split = tuple(key.shape).index(arr.shape[0]) + else: + new_split = key.ndim - 1 + try: + key_split = key[new_split].larray + sorted, _ = key_split.sort(stable=True) + except AttributeError: + key_split = key[new_split] + sorted = key_split.sort() + else: + new_split = 0 + # assess if key is sorted along split axis + try: + # DNDarray key + sorted, _ = torch.sort(key.larray, stable=True) + split_key_is_ordered = torch.tensor( + (key.larray == sorted).all(), + dtype=torch.uint8, + device=key.larray.device, + ) + if key.split is not None: + out_is_balanced = key.balanced + split_key_is_ordered = ( + factories.array( + [split_key_is_ordered], + is_split=0, + device=arr.device, + copy=False, + ) + .all() + .astype(types.canonical_heat_types.uint8) + .item() + ) + else: + split_key_is_ordered = split_key_is_ordered.item() + key = key.larray + except AttributeError: + # torch or ndarray key + try: + sorted, _ = torch.sort(key, stable=True) + except TypeError: + # ndarray key + sorted = torch.tensor(np.sort(key), device=arr.larray.device) + split_key_is_ordered = torch.tensor( + key == sorted, dtype=torch.uint8 + ).item() + if not split_key_is_ordered: + # prepare for distributed non-ordered indexing: distribute torch/numpy key + key = factories.array(key, split=0, device=arr.device).larray + out_is_balanced = True + if split_key_is_ordered: + # extract local key + cond1 = key >= displs[arr.comm.rank] + cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] + key = key[cond1 & cond2] + if return_local_indices: + key -= displs[arr.comm.rank] + out_is_balanced = False + else: try: - # key is DNDarray, extract torch tensor + out_is_balanced = key.balanced + new_split = key.split key = key.larray except AttributeError: - pass - try: - # key is torch tensor - key = key.nonzero(as_tuple=True) - except TypeError: - # key is np.ndarray - key = key.nonzero() - # convert to torch tensor - key = tuple(torch.tensor(k, device=arr.larray.device) for k in key) - output_shape = tuple(key[0].shape) + arr.shape[key_ndim:] - new_split = None if arr.split is None else 0 - out_is_balanced = True - split_key_is_ordered = 1 - return ( - arr, - key, - output_shape, - new_split, - split_key_is_ordered, - out_is_balanced, - root, - backwards_transpose_axes, - ) - - # arr is distributed - if not isinstance(key, DNDarray) or ( - isinstance(key, DNDarray) and not key.is_distributed() - ): - key = factories.array( - key, split=arr.split, device=arr.device, comm=arr.comm, copy=None - ) - else: - if key.split != arr.split: - raise IndexError( - "Boolean index does not match distribution scheme of indexed array. index.split is {}, array.split is {}".format( - key.split, arr.split - ) - ) - if arr.split == 0: - # ensure arr and key are aligned - key.redistribute_(target_map=arr.lshape_map) - # transform key to sequence of indexing (1-D) arrays - key = list(key.nonzero()) - output_shape = key[0].shape - new_split = 0 - split_key_is_ordered = 1 - out_is_balanced = False - for i, k in enumerate(key): - key[i] = k.larray - if return_local_indices: - key[arr.split] -= displs[arr.comm.rank] - key = tuple(key) - else: - key = key.larray.nonzero(as_tuple=False) - # construct global key array - nz_size = torch.tensor(key.shape[0], device=key.device, dtype=key.dtype) - arr.comm.Allreduce(MPI.IN_PLACE, nz_size, MPI.SUM) - key_gshape = (nz_size.item(), arr.ndim) - key[:, arr.split] += displs[arr.comm.rank] - key_split = 0 - key = DNDarray( - key, - gshape=key_gshape, - dtype=canonical_heat_type(key.dtype), - split=key_split, - device=arr.device, - comm=arr.comm, - balanced=False, - ) - key.balance_() - # set output parameters - output_shape = (key.gshape[0],) - new_split = 0 - split_key_is_ordered = 0 - out_is_balanced = True - # vectorized sorting of key along axis 0 - key = manipulations.unique(key, axis=0, return_inverse=False) - # return tuple key of torch tensors - key = list(key.larray.split(1, dim=1)) - for i, k in enumerate(key): - key[i] = k.squeeze(1) - key = tuple(key) - + # torch or numpy key, non-distributed indexed array + out_is_balanced = True + new_split = None return ( arr, key, @@ -1006,104 +1115,6 @@ def __process_key( backwards_transpose_axes, ) - # advanced indexing on first dimension: first dim will expand to shape of key - output_shape = tuple(list(key.shape) + output_shape[1:]) - print("DEBUGGING ADV IND: output_shape = ", output_shape) - # adjust split axis accordingly - if arr_is_distributed: - if arr.split != 0: - # split axis is not affected - split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] - new_split = ( - split_bookkeeping.index("split") if "split" in split_bookkeeping else None - ) - out_is_balanced = arr.balanced - else: - # split axis is affected - if key.ndim > 1: - try: - key_numel = key.numel() - except AttributeError: - key_numel = key.size - if key_numel == arr.shape[0]: - new_split = tuple(key.shape).index(arr.shape[0]) - else: - new_split = key.ndim - 1 - try: - key_split = key[new_split].larray - sorted, _ = key_split.sort(stable=True) - except AttributeError: - key_split = key[new_split] - sorted = key_split.sort() - else: - new_split = 0 - # assess if key is sorted along split axis - try: - # DNDarray key - sorted, _ = torch.sort(key.larray, stable=True) - split_key_is_ordered = torch.tensor( - (key.larray == sorted).all(), - dtype=torch.uint8, - device=key.larray.device, - ) - if key.split is not None: - out_is_balanced = key.balanced - split_key_is_ordered = ( - factories.array( - [split_key_is_ordered], - is_split=0, - device=arr.device, - copy=False, - ) - .all() - .astype(types.canonical_heat_types.uint8) - .item() - ) - else: - split_key_is_ordered = split_key_is_ordered.item() - key = key.larray - except AttributeError: - # torch or ndarray key - try: - sorted, _ = torch.sort(key, stable=True) - except TypeError: - # ndarray key - sorted = torch.tensor(np.sort(key), device=arr.larray.device) - split_key_is_ordered = torch.tensor( - key == sorted, dtype=torch.uint8 - ).item() - if not split_key_is_ordered: - # prepare for distributed non-ordered indexing: distribute torch/numpy key - key = factories.array(key, split=0, device=arr.device).larray - out_is_balanced = True - if split_key_is_ordered: - # extract local key - cond1 = key >= displs[arr.comm.rank] - cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] - key = key[cond1 & cond2] - if return_local_indices: - key -= displs[arr.comm.rank] - out_is_balanced = False - else: - try: - out_is_balanced = key.balanced - new_split = key.split - key = key.larray - except AttributeError: - # torch or numpy key, non-distributed indexed array - out_is_balanced = True - new_split = None - return ( - arr, - key, - output_shape, - new_split, - split_key_is_ordered, - out_is_balanced, - root, - backwards_transpose_axes, - ) - key = list(key) if isinstance(key, Iterable) else [key] # check for ellipsis, newaxis. NB: (np.newaxis is None)==True @@ -1270,24 +1281,36 @@ def __process_key( key = list(key) key_splits = [k.split for k in key] non_split_dims = list(advanced_indexing_dims).copy() - non_split_dims.remove(arr.split) - if not key_splits.count(key_splits[arr.split]) == len(key_splits): - if ( - key_splits[arr.split] is not None - and key_splits.count(None) == len(key_splits) - 1 - ): - for i in non_split_dims: - key[i] = factories.array( - key[i], - split=key_splits[arr.split], - device=arr.device, - comm=arr.comm, - copy=None, + print( + "DEBUGGING: advanced_indexing_dims, arr.split = ", + advanced_indexing_dims, + arr.split, + ) + if arr.split is not None: + non_split_dims.remove(arr.split) + if not key_splits.count(key_splits[arr.split]) == len(key_splits): + if ( + key_splits[arr.split] is not None + and key_splits.count(None) == len(key_splits) - 1 + ): + for i in non_split_dims: + key[i] = factories.array( + key[i], + split=key_splits[arr.split], + device=arr.device, + comm=arr.comm, + copy=None, + ) + else: + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." ) else: - raise IndexError( - f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." - ) + # all key_splits must be the same, otherwise raise IndexError + if not key_splits.count(key_splits[0]) == len(key_splits): + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." + ) # all key elements are now DNDarrays of the same shape, same split axis # 2. key along arr.split is DNDarray if arr.is_distributed() and arr.split in advanced_indexing_dims: From 1c47b42afefcc8ffbcd0a02b3057138e70be7374 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 08:39:31 +0000 Subject: [PATCH 115/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- heat/core/dndarray.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 57cfc0e8e5..a5af7b0a30 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2774,9 +2774,9 @@ def __set( send_displs[proc] = send_counts[:proc].sum() # compose send buffer: stack local elements of `value` according to destination process if send_indices.numel() > 0: - send_buf[ - send_displs[proc] : send_displs[proc] + send_counts[proc], :-1 - ] = value.larray[send_indices] + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices] + ) # store outgoing GLOBAL indices in the last column of send_buf # TODO: if key_is_mask_like: apply send_indices to all dimensions of key if key_is_mask_like: @@ -2789,9 +2789,9 @@ def __set( send_indices = split_key[send_indices] # broadcast send_indices to correct shape send_indices = send_indices.unsqueeze(-1) - send_buf[ - send_displs[proc] : send_displs[proc] + send_counts[proc], -1 - ] = send_indices + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], -1] = ( + send_indices + ) # compose communication matrix: share `send_counts` information with all processes comm_matrix = torch.zeros( From 4ee9b966cfe7efb9fe3d1c7cb151bd7766af7cdf Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 15 Feb 2024 10:51:48 +0100 Subject: [PATCH 116/132] getitem: address mask-like key --- heat/core/dndarray.py | 681 +++++++++++++++++++++++++----------------- 1 file changed, 411 insertions(+), 270 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 57cfc0e8e5..c7c1ab5c83 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -882,6 +882,7 @@ def __process_key( advanced_indexing = False split_key_is_ordered = 1 + key_is_mask_like = False out_is_balanced = False root = None backwards_transpose_axes = tuple(range(arr.ndim)) @@ -1110,6 +1111,7 @@ def __process_key( output_shape, new_split, split_key_is_ordered, + key_is_mask_like, out_is_balanced, root, backwards_transpose_axes, @@ -1415,6 +1417,7 @@ def __process_key( output_shape, new_split, split_key_is_ordered, + key_is_mask_like, out_is_balanced, root, backwards_transpose_axes, @@ -1571,6 +1574,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar output_shape, output_split, split_key_is_ordered, + key_is_mask_like, out_is_balanced, root, backwards_transpose_axes, @@ -1631,280 +1635,424 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # key is not ordered along self.split - # key is tuple of torch.Tensor or mix of torch.Tensors and slices - _, displs = self.counts_displs() + # key along split axis is unordered, communication needed + # key along the split axis is torch tensor, indices are GLOBAL + counts, displs = self.counts_displs() + rank, size = self.comm.rank, self.comm.size - # determine whether indexed array will be 1D or nD - try: - return_1d = getattr(key, "ndim") == self.ndim - send_axis = 0 - except AttributeError: - # key is tuple of torch tensors - key_shapes = [] - for k in key: - key_shapes.append(getattr(k, "shape", None)) - return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim - # check for broadcasted indexing: key along split axis is not 1D - broadcasted_indexing = ( - key_shapes[original_split] is not None and len(key_shapes[original_split]) > 1 - ) - if broadcasted_indexing: - broadcast_shape = key_shapes[original_split] - key = list(key) - key[original_split] = key[original_split].flatten() - key = tuple(key) - send_axis = original_split - else: - send_axis = output_split - - # send and receive "request key" info on what data element to ship where - recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) - - # construct empty tensor that we'll append to later - if return_1d: - request_key_shape = (0, self.ndim) + # determine what elements of the local array will be received from what process + key_is_single_tensor = isinstance(key, torch.Tensor) + if key_is_single_tensor: + split_key = key else: - request_key_shape = (0, 1) - - outgoing_request_key = torch.empty( - tuple(request_key_shape), dtype=torch.int64, device=self.larray.device - ) - outgoing_request_key_counts = torch.zeros( - (self.comm.size,), dtype=torch.int64, device=self.larray.device - ) - - # process-local: calculate which/how many elements will be received from what process - if split_key_is_ordered == -1: - # key is sorted in descending order (i.e. slicing w/ negative step): - # shrink selection of active processes - if key[original_split].numel() > 0: - key_edges = torch.cat( - (key[original_split][-1].reshape(-1), key[original_split][0].reshape(-1)), dim=0 - ).unique() - displs = torch.tensor(displs, device=self.larray.device) - _, inverse, counts = torch.cat((displs, key_edges), dim=0).unique( - sorted=True, return_inverse=True, return_counts=True - ) - if key_edges.numel() == 2: - correction = counts[inverse[-2]] % 2 - start_rank = inverse[-2] - correction - correction += counts[inverse[-1]] % 2 - end_rank = inverse[-1] - correction + 1 - elif key_edges.numel() == 1: - correction = counts[inverse[-1]] % 2 - start_rank = inverse[-1] - correction - end_rank = start_rank + 1 - else: - start_rank = 0 - end_rank = 0 + split_key = key[self.split] + if split_key.ndim > 1: + # original_split_key_shape = split_key.shape + split_key = split_key.flatten() + recv_counts = torch.zeros((size, 1), dtype=torch.int64, device=self.larray.device) + if key_is_mask_like: + recv_indices = torch.zeros( + (len(split_key), len(key)), dtype=torch.int64, device=self.larray.device + ) else: - start_rank = 0 - end_rank = self.comm.size - all_local_indexing = torch.ones( - (self.comm.size,), dtype=torch.bool, device=self.larray.device - ) - all_local_indexing[start_rank:end_rank] = False - for i in range(start_rank, end_rank): - try: - cond1 = key >= displs[i] - if i != self.comm.size - 1: - cond2 = key < displs[i + 1] + recv_indices = torch.zeros( + (split_key.shape), dtype=torch.int64, device=self.larray.device + ) + print("DEBUGGING: SPLIY_KEY = ", split_key) + print("DEBUGGING: counts, displs = ", counts, displs) + for p in range(size): + cond1 = split_key >= displs[p] + cond2 = split_key < displs[p] + counts[p] + indices_from_p = torch.nonzero(cond1 & cond2, as_tuple=False) + incoming_indices = split_key[indices_from_p].flatten() + recv_counts[p, 0] = incoming_indices.numel() + print("DEBUGGING: P, RECV_COUNTS = ", p, incoming_indices.numel(), recv_counts) + # store incoming indices in appropiate slice of recv_indices + # TODO: this is a bit of a convenience solution, but it doubles the memory footprint of split_key + start = recv_counts[:p].sum().item() + stop = start + recv_counts[p].item() + # print("DEBUGGING: incoming_indices = ", incoming_indices) + # print("DEBUGGING: start, stop = ", start, stop) + if incoming_indices.numel() > 0: + if key_is_mask_like: + # apply selection to all dimensions + for i in range(len(key)): + recv_indices[start:stop, i] = key[i][indices_from_p].flatten() + recv_indices[start:stop, self.split] -= displs[p] else: - # cond2 is always true - cond2 = torch.ones((key.shape[0],), dtype=torch.bool, device=self.larray.device) - except TypeError: - cond1 = key[original_split] >= displs[i] - if i != self.comm.size - 1: - cond2 = key[original_split] < displs[i + 1] + recv_indices[start:stop] = incoming_indices - displs[p] + print("DEBUGGING: AFTER: INC_INDICES = ", recv_indices[start:stop]) + # build communication matrix by sharing recv_counts with all processes + # comm_matrix rows contain the send_counts for each process, columns contain the recv_counts + comm_matrix = torch.zeros((size, size), dtype=torch.int64, device=self.larray.device) + self.comm.Allgather(recv_counts, comm_matrix) + send_counts = comm_matrix[:, rank] + + # active rank pairs: + active_rank_pairs = torch.nonzero(comm_matrix, as_tuple=False) + + # Communication build-up: + active_recv_indices_from = active_rank_pairs[torch.where(active_rank_pairs[:, 1] == rank)][ + :, 0 + ] + active_send_indices_to = active_rank_pairs[torch.where(active_rank_pairs[:, 0] == rank)][ + :, 1 + ] + rank_is_active = active_recv_indices_from.numel() > 0 or active_send_indices_to.numel() > 0 + + # allocate recv_buf for incoming data + recv_buf_shape = list(output_shape) + recv_buf_shape[output_split] = recv_counts.sum().item() + recv_buf = torch.zeros( + tuple(recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device + ) + print("DEBUGGING: recv_counts, send_counts = ", recv_counts, send_counts) + print("DEBUGGING: comm_matrix = ", comm_matrix) + if rank_is_active: + # non-blocking send indices to active_send_indices_to + send_requests = [] + for i in active_send_indices_to: + start = recv_counts[:i].sum().item() + stop = start + recv_counts[i].item() + outgoing_indices = recv_indices[start:stop] + send_requests.append(self.comm.Isend(outgoing_indices, dest=i)) + for i in active_recv_indices_from: + # receive indices from active_recv_indices_from + if key_is_mask_like: + incoming_indices = torch.zeros( + (send_counts[i].item(), len(key)), + dtype=torch.int64, + device=self.larray.device, + ) else: - # cond2 is always true - cond2 = torch.ones( - (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device + incoming_indices = torch.zeros( + send_counts[i].item(), dtype=torch.int64, device=self.larray.device ) - if return_1d: - # advanced indexing returning 1D array - if isinstance(key, torch.Tensor): - selection = key[cond1 & cond2] - recv_counts[i, :] = selection.shape[0] - if i == self.comm.rank: - all_local_indexing[i] = selection.shape[0] == key.shape[0] - selection.unsqueeze_(dim=1) + self.comm.Recv(incoming_indices, source=i) + # prepare send_buf for outgoing data + if key_is_single_tensor: + send_buf = self.larray[incoming_indices] else: - # key is tuple of torch tensors - selection = list(k[cond1 & cond2] for k in key) - recv_counts[i, :] = selection[0].shape[0] - if i == self.comm.rank: - all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] - selection = torch.stack(selection, dim=1) - else: - selection = key[original_split][cond1 & cond2] - recv_counts[i, :] = selection.shape[0] - if i == self.comm.rank: - all_local_indexing[i] = selection.shape[0] == key[original_split].shape[0] - selection.unsqueeze_(dim=1) - outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) - all_local_indexing = factories.array( - all_local_indexing, is_split=0, device=self.device, copy=False - ) - if all_local_indexing.all().item(): - # TODO: if advanced indexing, indexed array must be a copy. Probably addressed by torch - if broadcasted_indexing: - key[original_split] = key[original_split].reshape(broadcast_shape) - indexed_arr = self.larray[key] - # transpose array back if needed - self = self.transpose(backwards_transpose_axes) - return factories.array( - indexed_arr, is_split=output_split, device=self.device, copy=False + if key_is_mask_like: + send_key = tuple( + incoming_indices[:, i].reshape(-1) + for i in range(incoming_indices.shape[1]) + ) + send_buf = self.larray[send_key] + else: + send_key = list(key) + send_key[self.split] = incoming_indices + send_buf = self.larray[tuple(send_key)] + print(f"DEBUGGING: send_buf to {i} = {send_buf}") + # non-blocking send requested data to i + send_requests.append(self.comm.Isend(send_buf, dest=i)) + print("DEBUGGING: active_send_indices_to = ", active_send_indices_to) + tmp_recv_buf_shape = recv_buf_shape.copy() + tmp_recv_buf_shape[output_split] = recv_counts.max().item() + tmp_recv_buf = torch.zeros( + tuple(tmp_recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device ) + for i in active_send_indices_to: + # non-blocking receive data from i + print("debugging:, i = ", i) + print("DEBUGGING: split_key = ", split_key) + tmp_recv_slice = [slice(None)] * tmp_recv_buf.ndim + tmp_recv_slice[output_split] = slice(0, recv_counts[i].item()) + self.comm.Recv(tmp_recv_buf[tmp_recv_slice], source=i) + print(f"DEBUGGING: tmp_recv_buf from {i} = {tmp_recv_buf}") + # write received data to appropriate portion of recv_buf + cond1 = split_key >= displs[i] + cond2 = split_key < displs[i] + counts[i] + recv_buf_indices = torch.nonzero(cond1 & cond2, as_tuple=False).flatten() + recv_buf_key = [slice(None)] * recv_buf.ndim + recv_buf_key[output_split] = recv_buf_indices + recv_buf[recv_buf_key] = tmp_recv_buf[tmp_recv_slice] + # wait for all non-blocking communication to finish + for req in send_requests: + req.Wait() - # share recv_counts among all processes - comm_matrix = torch.empty( - (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device + # construct indexed array from recv_buf + indexed_arr = DNDarray( + recv_buf, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, ) - self.comm.Allgather(recv_counts, comm_matrix) + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) + return indexed_arr - outgoing_request_key_counts = comm_matrix[self.comm.rank] - outgoing_request_key_displs = torch.cat( - ( - torch.zeros( - (1,), - dtype=outgoing_request_key_counts.dtype, - device=outgoing_request_key_counts.device, - ), - outgoing_request_key_counts, - ), - dim=0, - ).cumsum(dim=0)[:-1] - incoming_request_key_counts = comm_matrix[:, self.comm.rank] - incoming_request_key_displs = torch.cat( - ( - torch.zeros( - (1,), - dtype=outgoing_request_key_counts.dtype, - device=outgoing_request_key_counts.device, - ), - incoming_request_key_counts, - ), - dim=0, - ).cumsum(dim=0)[:-1] + # # determine whether indexed array will be 1D or nD + # try: + # return_1d = getattr(key, "ndim") == self.ndim + # send_axis = 0 + # except AttributeError: + # # key is tuple of torch tensors + # key_shapes = [] + # for k in key: + # key_shapes.append(getattr(k, "shape", None)) + # return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim + # # check for broadcasted indexing: key along split axis is not 1D + # broadcasted_indexing = ( + # key_shapes[original_split] is not None and len(key_shapes[original_split]) > 1 + # ) + # if broadcasted_indexing: + # broadcast_shape = key_shapes[original_split] + # key = list(key) + # key[original_split] = key[original_split].flatten() + # key = tuple(key) + # send_axis = original_split + # else: + # send_axis = output_split - if return_1d: - incoming_request_key = torch.empty( - (incoming_request_key_counts.sum(), self.ndim), - dtype=outgoing_request_key_counts.dtype, - device=outgoing_request_key_counts.device, - ) - else: - incoming_request_key = torch.empty( - (incoming_request_key_counts.sum(), 1), - dtype=outgoing_request_key_counts.dtype, - device=outgoing_request_key_counts.device, - ) - # send and receive request keys - self.comm.Alltoallv( - ( - outgoing_request_key, - outgoing_request_key_counts.tolist(), - outgoing_request_key_displs.tolist(), - ), - ( - incoming_request_key, - incoming_request_key_counts.tolist(), - incoming_request_key_displs.tolist(), - ), - ) - if return_1d: - incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) - incoming_request_key[original_split] -= displs[self.comm.rank] - else: - incoming_request_key -= displs[self.comm.rank] - incoming_request_key = ( - key[:original_split] - + (incoming_request_key.squeeze_(1),) - + key[original_split + 1 :] - ) + # # send and receive "request key" info on what data element to ship where + # recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) - # calculate shape of local recv buffer - output_lshape = list(output_shape) - if getattr(key, "ndim", 0) == 1: - output_lshape[output_split] = key.shape[0] - else: - if broadcasted_indexing: - output_lshape = ( - output_lshape[:original_split] - + [torch.prod(torch.tensor(broadcast_shape, device=self.larray.device)).item()] - + output_lshape[output_split + 1 :] - ) - else: - output_lshape[output_split] = key[original_split].shape[0] - # allocate recv buffer - recv_buf = torch.empty( - tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device - ) + # # construct empty tensor that we'll append to later + # if return_1d: + # request_key_shape = (0, self.ndim) + # else: + # request_key_shape = (0, 1) + + # outgoing_request_key = torch.empty( + # tuple(request_key_shape), dtype=torch.int64, device=self.larray.device + # ) + # outgoing_request_key_counts = torch.zeros( + # (self.comm.size,), dtype=torch.int64, device=self.larray.device + # ) + + # # process-local: calculate which/how many elements will be received from what process + # if split_key_is_ordered == -1: + # # key is sorted in descending order (i.e. slicing w/ negative step): + # # shrink selection of active processes + # if key[original_split].numel() > 0: + # key_edges = torch.cat( + # (key[original_split][-1].reshape(-1), key[original_split][0].reshape(-1)), dim=0 + # ).unique() + # displs = torch.tensor(displs, device=self.larray.device) + # _, inverse, counts = torch.cat((displs, key_edges), dim=0).unique( + # sorted=True, return_inverse=True, return_counts=True + # ) + # if key_edges.numel() == 2: + # correction = counts[inverse[-2]] % 2 + # start_rank = inverse[-2] - correction + # correction += counts[inverse[-1]] % 2 + # end_rank = inverse[-1] - correction + 1 + # elif key_edges.numel() == 1: + # correction = counts[inverse[-1]] % 2 + # start_rank = inverse[-1] - correction + # end_rank = start_rank + 1 + # else: + # start_rank = 0 + # end_rank = 0 + # else: + # start_rank = 0 + # end_rank = self.comm.size + # all_local_indexing = torch.ones( + # (self.comm.size,), dtype=torch.bool, device=self.larray.device + # ) + # all_local_indexing[start_rank:end_rank] = False + # for i in range(start_rank, end_rank): + # try: + # cond1 = key >= displs[i] + # if i != self.comm.size - 1: + # cond2 = key < displs[i + 1] + # else: + # # cond2 is always true + # cond2 = torch.ones((key.shape[0],), dtype=torch.bool, device=self.larray.device) + # except TypeError: + # cond1 = key[original_split] >= displs[i] + # if i != self.comm.size - 1: + # cond2 = key[original_split] < displs[i + 1] + # else: + # # cond2 is always true + # cond2 = torch.ones( + # (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device + # ) + # if return_1d: + # # advanced indexing returning 1D array + # if isinstance(key, torch.Tensor): + # selection = key[cond1 & cond2] + # recv_counts[i, :] = selection.shape[0] + # if i == self.comm.rank: + # all_local_indexing[i] = selection.shape[0] == key.shape[0] + # selection.unsqueeze_(dim=1) + # else: + # # key is tuple of torch tensors + # selection = list(k[cond1 & cond2] for k in key) + # recv_counts[i, :] = selection[0].shape[0] + # if i == self.comm.rank: + # all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] + # selection = torch.stack(selection, dim=1) + # else: + # selection = key[original_split][cond1 & cond2] + # recv_counts[i, :] = selection.shape[0] + # if i == self.comm.rank: + # all_local_indexing[i] = selection.shape[0] == key[original_split].shape[0] + # selection.unsqueeze_(dim=1) + # outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) + # all_local_indexing = factories.array( + # all_local_indexing, is_split=0, device=self.device, copy=False + # ) + # if all_local_indexing.all().item(): + # if broadcasted_indexing: + # key[original_split] = key[original_split].reshape(broadcast_shape) + # indexed_arr = self.larray[key] + # # transpose array back if needed + # self = self.transpose(backwards_transpose_axes) + # return factories.array( + # indexed_arr, is_split=output_split, device=self.device, copy=False + # ) - # index local data into send_buf. - send_empty = sum( - list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in incoming_request_key) - ) # incoming_request_key.count([]) - if send_empty: - # Edge case 1. empty slice along split axis: send_buf is 0-element tensor - empty_shape = list(output_shape) - empty_shape[output_split] = 0 - send_buf = torch.empty(empty_shape, dtype=self.larray.dtype, device=self.larray.device) - else: - send_buf = self.larray[incoming_request_key] - # Edge case 2. local single-element indexing results into local loss of split axis - if send_buf.ndim < len(output_lshape): - all_keys_scalar = sum( - list( - np.isscalar(k) or k.numel() == 1 and getattr(k, "ndim", 2) < 2 - for k in incoming_request_key - ) - ) == len(incoming_request_key) - if not all_keys_scalar: - send_buf = send_buf.unsqueeze_(dim=output_split) - - recv_counts = torch.squeeze(recv_counts, dim=1).tolist() - recv_displs = outgoing_request_key_displs.tolist() - send_counts = incoming_request_key_counts.tolist() - send_displs = incoming_request_key_displs.tolist() - self.comm.Alltoallv( - (send_buf, send_counts, send_displs), - (recv_buf, recv_counts, recv_displs), - send_axis=send_axis, - ) - # transpose original array back if needed, all further indexing on recv_buf - self = self.transpose(backwards_transpose_axes) + # # share recv_counts among all processes + # comm_matrix = torch.empty( + # (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device + # ) + # self.comm.Allgather(recv_counts, comm_matrix) + + # outgoing_request_key_counts = comm_matrix[self.comm.rank] + # outgoing_request_key_displs = torch.cat( + # ( + # torch.zeros( + # (1,), + # dtype=outgoing_request_key_counts.dtype, + # device=outgoing_request_key_counts.device, + # ), + # outgoing_request_key_counts, + # ), + # dim=0, + # ).cumsum(dim=0)[:-1] + # incoming_request_key_counts = comm_matrix[:, self.comm.rank] + # incoming_request_key_displs = torch.cat( + # ( + # torch.zeros( + # (1,), + # dtype=outgoing_request_key_counts.dtype, + # device=outgoing_request_key_counts.device, + # ), + # incoming_request_key_counts, + # ), + # dim=0, + # ).cumsum(dim=0)[:-1] + + # if return_1d: + # incoming_request_key = torch.empty( + # (incoming_request_key_counts.sum(), self.ndim), + # dtype=outgoing_request_key_counts.dtype, + # device=outgoing_request_key_counts.device, + # ) + # else: + # incoming_request_key = torch.empty( + # (incoming_request_key_counts.sum(), 1), + # dtype=outgoing_request_key_counts.dtype, + # device=outgoing_request_key_counts.device, + # ) + # # send and receive request keys + # self.comm.Alltoallv( + # ( + # outgoing_request_key, + # outgoing_request_key_counts.tolist(), + # outgoing_request_key_displs.tolist(), + # ), + # ( + # incoming_request_key, + # incoming_request_key_counts.tolist(), + # incoming_request_key_displs.tolist(), + # ), + # ) + # if return_1d: + # incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) + # incoming_request_key[original_split] -= displs[self.comm.rank] + # else: + # incoming_request_key -= displs[self.comm.rank] + # incoming_request_key = ( + # key[:original_split] + # + (incoming_request_key.squeeze_(1),) + # + key[original_split + 1 :] + # ) - # reorganize incoming counts according to original key order along split axis - if return_1d: - if isinstance(key, tuple): - key = torch.stack(key, dim=1) - _, key_inverse = key.unique(dim=0, sorted=True, return_inverse=True) - # if _.shape == key.shape: - _, ork_inverse = outgoing_request_key.unique(dim=0, sorted=True, return_inverse=True) - map = ork_inverse.argsort(stable=True)[ - key_inverse.argsort(stable=True).argsort(stable=True) - ] - indexed_arr = recv_buf[map] - return factories.array(indexed_arr, is_split=output_split, copy=False) - - outgoing_request_key = outgoing_request_key.squeeze_(1) - map = [slice(None)] * recv_buf.ndim - # print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) - # print("DEBUGGING: key[original_split] = ", key[original_split]) - if broadcasted_indexing: - map[original_split] = outgoing_request_key.argsort(stable=True)[ - key[original_split].argsort(stable=True).argsort(stable=True) - ] - map[original_split] = map[original_split].reshape(broadcast_shape) - else: - map[output_split] = outgoing_request_key.argsort(stable=True)[ - key[original_split].argsort(stable=True).argsort(stable=True) - ] - indexed_arr = recv_buf[map] - return factories.array(indexed_arr, is_split=output_split, copy=False) + # # calculate shape of local recv buffer + # output_lshape = list(output_shape) + # if getattr(key, "ndim", 0) == 1: + # output_lshape[output_split] = key.shape[0] + # else: + # if broadcasted_indexing: + # output_lshape = ( + # output_lshape[:original_split] + # + [torch.prod(torch.tensor(broadcast_shape, device=self.larray.device)).item()] + # + output_lshape[output_split + 1 :] + # ) + # else: + # output_lshape[output_split] = key[original_split].shape[0] + # # allocate recv buffer + # recv_buf = torch.empty( + # tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device + # ) + + # # index local data into send_buf. + # send_empty = sum( + # list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in incoming_request_key) + # ) + # if send_empty: + # # Edge case 1. empty slice along split axis: send_buf is 0-element tensor + # empty_shape = list(output_shape) + # empty_shape[output_split] = 0 + # send_buf = torch.empty(empty_shape, dtype=self.larray.dtype, device=self.larray.device) + # else: + # send_buf = self.larray[incoming_request_key] + # # Edge case 2. local single-element indexing results into local loss of split axis + # if send_buf.ndim < len(output_lshape): + # all_keys_scalar = sum( + # list( + # np.isscalar(k) or k.numel() == 1 and getattr(k, "ndim", 2) < 2 + # for k in incoming_request_key + # ) + # ) == len(incoming_request_key) + # if not all_keys_scalar: + # send_buf = send_buf.unsqueeze_(dim=output_split) + + # recv_counts = torch.squeeze(recv_counts, dim=1).tolist() + # recv_displs = outgoing_request_key_displs.tolist() + # send_counts = incoming_request_key_counts.tolist() + # send_displs = incoming_request_key_displs.tolist() + # self.comm.Alltoallv( + # (send_buf, send_counts, send_displs), + # (recv_buf, recv_counts, recv_displs), + # send_axis=send_axis, + # ) + # # transpose original array back if needed, all further indexing on recv_buf + # self = self.transpose(backwards_transpose_axes) + + # # reorganize incoming counts according to original key order along split axis + # if return_1d: + # if isinstance(key, tuple): + # key = torch.stack(key, dim=1) + # _, key_inverse = key.unique(dim=0, sorted=True, return_inverse=True) + # _, ork_inverse = outgoing_request_key.unique(dim=0, sorted=True, return_inverse=True) + # map = ork_inverse.argsort(stable=True)[ + # key_inverse.argsort(stable=True).argsort(stable=True) + # ] + # indexed_arr = recv_buf[map] + # return factories.array(indexed_arr, is_split=output_split, copy=False) + + # outgoing_request_key = outgoing_request_key.squeeze_(1) + # map = [slice(None)] * recv_buf.ndim + # # print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) + # # print("DEBUGGING: key[original_split] = ", key[original_split]) + # if broadcasted_indexing: + # map[original_split] = outgoing_request_key.argsort(stable=True)[ + # key[original_split].argsort(stable=True).argsort(stable=True) + # ] + # map[original_split] = map[original_split].reshape(broadcast_shape) + # else: + # map[output_split] = outgoing_request_key.argsort(stable=True)[ + # key[original_split].argsort(stable=True).argsort(stable=True) + # ] + # indexed_arr = recv_buf[map] + # return factories.array(indexed_arr, is_split=output_split, copy=False) if torch.cuda.device_count() > 0: @@ -2556,7 +2704,8 @@ def __set( output_shape, output_split, split_key_is_ordered, - out_is_balanced, + key_is_mask_like, + _, root, backwards_transpose_axes, ) = self.__process_key(key, return_local_indices=True, op="set") @@ -2638,20 +2787,12 @@ def __set( if split_key_is_ordered == 0: # key along split axis is unordered, communication needed - # key along the split axis is 1-D torch tensor, but indices are GLOBAL + # key along the split axis is torch tensor, indices are GLOBAL counts, displs = self.counts_displs() - # rank, size = self.comm.rank, self.comm.size - rank = self.comm.rank + rank, _ = self.comm.rank, self.comm.size # key_is_single_tensor = isinstance(key, torch.Tensor) - key_is_mask_like = False - # define key as mask_like if each element of key is a torch.Tensor and all elements of key are of the same shape - if not key_is_single_tensor: - key_is_mask_like = ( - all(isinstance(k, torch.Tensor) for k in key) - and len(set(k.shape for k in key)) == 1 - ) if not value.is_distributed(): if key_is_single_tensor: # key is a single torch.Tensor From 15cce447fe30bfce600cf8bb469ba07441c9b44a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 06:22:03 +0100 Subject: [PATCH 117/132] define nonzero_size in non-distr case --- heat/core/indexing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 0e7ee3d0a0..0a521608e3 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -62,7 +62,7 @@ def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: # nonzero indices as tuple lcl_nonzero = torch.nonzero(input=local_x, as_tuple=True) # bookkeeping for final DNDarray construct - output_shape = (lcl_nonzero[0].shape,) + nonzero_size = lcl_nonzero[0].shape[0] output_split = None if x.split is None else 0 output_balanced = True else: @@ -98,10 +98,11 @@ def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: # return indices as tuple of columns lcl_nonzero = lcl_nonzero.split(1, dim=1) output_balanced = False + nonzero_size = nonzero_size.item() # return global_nonzero as tuple of DNDarrays global_nonzero = list(lcl_nonzero) - output_shape = (nonzero_size.item(),) + output_shape = (nonzero_size,) output_split = 0 for i, nz_tensor in enumerate(global_nonzero): if nz_tensor.ndim > 1: From 09fb199219b9b43a4eab3dfc469d9b764f099ca5 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 06:25:36 +0100 Subject: [PATCH 118/132] handle split_bookkeeping when key is mask-like --- heat/core/dndarray.py | 140 +++++++++--------------------------------- 1 file changed, 28 insertions(+), 112 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a76bb00918..158214f792 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -910,112 +910,7 @@ def __process_key( except TypeError: # key is np.ndarray or DNDarray key = key.nonzero() - # key is a tuple of arrays/tensors, will be treated as advanced indexing - - # try: - # # key is DNDarray or ndarray - # key = key.copy() - # except AttributeError: - # # key is torch tensor - # key = key.clone() - # if not arr_is_distributed: - # try: - # # key is DNDarray, extract torch tensor - # key = key.larray - # except AttributeError: - # pass - # try: - # # key is torch tensor - # key = key.nonzero(as_tuple=True) - # except TypeError: - # # key is np.ndarray - # key = key.nonzero() - # # convert to torch tensor - # key = tuple(torch.tensor(k, device=arr.larray.device) for k in key) - # output_shape = tuple(key[0].shape) + arr.shape[key_ndim:] - # new_split = None if arr.split is None else 0 - # out_is_balanced = True - # split_key_is_ordered = 1 - # return ( - # arr, - # key, - # output_shape, - # new_split, - # split_key_is_ordered, - # out_is_balanced, - # root, - # backwards_transpose_axes, - # ) - - # # arr is distributed - # if not isinstance(key, DNDarray) or ( - # isinstance(key, DNDarray) and not key.is_distributed() - # ): - # key = factories.array( - # key, split=arr.split, device=arr.device, comm=arr.comm, copy=None - # ) - # else: - # if key.split != arr.split: - # raise IndexError( - # "Boolean index does not match distribution scheme of indexed array. index.split is {}, array.split is {}".format( - # key.split, arr.split - # ) - # ) - # if arr.split == 0: - # # ensure arr and key are aligned - # key.redistribute_(target_map=arr.lshape_map) - # # transform key to sequence of indexing (1-D) arrays - # key = list(key.nonzero()) - # output_shape = key[0].shape - # new_split = 0 - # split_key_is_ordered = 1 - # out_is_balanced = False - # for i, k in enumerate(key): - # key[i] = k.larray - # if return_local_indices: - # key[arr.split] -= displs[arr.comm.rank] - # key = tuple(key) - # else: - # key = key.larray.nonzero(as_tuple=False) - # # construct global key array - # nz_size = torch.tensor(key.shape[0], device=key.device, dtype=key.dtype) - # arr.comm.Allreduce(MPI.IN_PLACE, nz_size, MPI.SUM) - # key_gshape = (nz_size.item(), arr.ndim) - # key[:, arr.split] += displs[arr.comm.rank] - # key_split = 0 - # key = DNDarray( - # key, - # gshape=key_gshape, - # dtype=canonical_heat_type(key.dtype), - # split=key_split, - # device=arr.device, - # comm=arr.comm, - # balanced=False, - # ) - # key.balance_() - # # set output parameters - # output_shape = (key.gshape[0],) - # new_split = 0 - # split_key_is_ordered = 0 - # out_is_balanced = True - # # vectorized sorting of key along axis 0 - # key = manipulations.unique(key, axis=0, return_inverse=False) - # # return tuple key of torch tensors - # key = list(key.larray.split(1, dim=1)) - # for i, k in enumerate(key): - # key[i] = k.squeeze(1) - # key = tuple(key) - - # return ( - # arr, - # key, - # output_shape, - # new_split, - # split_key_is_ordered, - # out_is_balanced, - # root, - # backwards_transpose_axes, - # ) + key_is_mask_like = True else: # advanced indexing on first dimension: first dim will expand to shape of key output_shape = tuple(list(key.shape) + output_shape[1:]) @@ -1273,8 +1168,8 @@ def __process_key( if advanced_indexing: # adv indexing key elements are DNDarrays: extract torch tensors - # options: 1. key is mask-like (covers boolean mask as well), 2. key along arr.split is DNDarray, 3. everything else - # 1. define key as mask_like if each element of key is a DNDarray, and all elements of key are of the same shape, and the advanced-indexing dimensions are consecutive + # options: 1. key is mask-like (covers boolean mask as well), 2. adv indexing along split axis, 3. everything else + # 1. define key as mask-like if each element of key is a DNDarray, and all elements of key are of the same shape, and the advanced-indexing dimensions are consecutive key_is_mask_like = ( all(isinstance(k, DNDarray) for k in key) and len(set(k.shape for k in key)) == 1 @@ -1363,11 +1258,32 @@ def __process_key( advanced_indexing_dims[0] : advanced_indexing_dims[0] + len(advanced_indexing_dims) ] = broadcasted_shape - split_bookkeeping = ( - split_bookkeeping[: advanced_indexing_dims[0]] - + [None] * add_dims - + split_bookkeeping[advanced_indexing_dims[0] :] + print( + "DEBUGGING: broadcasted_shape, split_bookkeeping = ", + broadcasted_shape, + split_bookkeeping, ) + if key_is_mask_like: + # advanced indexing dimensions will be collapsed into one dimension + if ( + "split" in split_bookkeeping + and split_bookkeeping.index("split") in advanced_indexing_dims + ): + split_bookkeeping[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = ["split"] + else: + split_bookkeeping[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = [None] + else: + split_bookkeeping = ( + split_bookkeeping[: advanced_indexing_dims[0]] + + [None] * add_dims + + split_bookkeeping[advanced_indexing_dims[0] :] + ) print("ADV IND output_shape = ", output_shape) else: # advanced-indexing dimensions are not consecutive: From 9c8d05151b0c65b363a6eee26183dcd6f87ee10d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 07:17:56 +0100 Subject: [PATCH 119/132] fix key type mismatch in advanced indexing --- heat/core/dndarray.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 158214f792..c3b5f741f8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1175,17 +1175,17 @@ def __process_key( and len(set(k.shape for k in key)) == 1 and torch.tensor(advanced_indexing_dims).diff().eq(1).all() ) + print("KEY_IS_MASK_LIKE = ", key_is_mask_like) + # if split axis is affected by advanced indexing, keep track of non-split dimensions for later + if arr.is_distributed() and arr.split in advanced_indexing_dims: + non_split_dims = list(advanced_indexing_dims).copy() + if arr.split is not None: + non_split_dims.remove(arr.split) + # 1. key is mask-like if key_is_mask_like: key = list(key) key_splits = [k.split for k in key] - non_split_dims = list(advanced_indexing_dims).copy() - print( - "DEBUGGING: advanced_indexing_dims, arr.split = ", - advanced_indexing_dims, - arr.split, - ) if arr.split is not None: - non_split_dims.remove(arr.split) if not key_splits.count(key_splits[arr.split]) == len(key_splits): if ( key_splits[arr.split] is not None @@ -1210,7 +1210,7 @@ def __process_key( f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." ) # all key elements are now DNDarrays of the same shape, same split axis - # 2. key along arr.split is DNDarray + # 2. advanced indexing along split axis if arr.is_distributed() and arr.split in advanced_indexing_dims: if split_key_is_ordered == 1: # extract torch tensors, keep process-local indices only @@ -1221,10 +1221,12 @@ def __process_key( if return_local_indices: k -= displs[arr.comm.rank] key[arr.split] = k - if key_is_mask_like: - # select the same elements along non-split dimensions - for i in non_split_dims: + for i in non_split_dims: + if key_is_mask_like: + # select the same elements along non-split dimensions key[i] = key[i].larray[cond1 & cond2] + else: + key[i] = key[i].larray elif split_key_is_ordered == 0: # extract torch tensors, any other communication + mask-like case are handled in __getitem__ or __setitem__ for i in advanced_indexing_dims: @@ -1539,6 +1541,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar return indexed_arr # root is None, i.e. indexing does not affect split axis, apply as is + print("DEBUGGING: key = ", key) indexed_arr = self.larray[key] # transpose array back if needed self = self.transpose(backwards_transpose_axes) From 41fba0a9ab398be7bd89e5c8f87a53da5ad9f934 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 08:48:07 +0100 Subject: [PATCH 120/132] getitem: address n-D key along split axis, free memory --- heat/core/dndarray.py | 336 +++++------------------------------------- 1 file changed, 39 insertions(+), 297 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c3b5f741f8..acd1ab60f4 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1541,7 +1541,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar return indexed_arr # root is None, i.e. indexing does not affect split axis, apply as is - print("DEBUGGING: key = ", key) indexed_arr = self.larray[key] # transpose array back if needed self = self.transpose(backwards_transpose_axes) @@ -1555,44 +1554,43 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar comm=self.comm, ) - # key along split axis is unordered, communication needed - # key along the split axis is torch tensor, indices are GLOBAL + # key along split axis is not ordered, indices are GLOBAL + # prepare for communication of indices and data counts, displs = self.counts_displs() rank, size = self.comm.rank, self.comm.size - # determine what elements of the local array will be received from what process key_is_single_tensor = isinstance(key, torch.Tensor) if key_is_single_tensor: split_key = key else: split_key = key[self.split] + # split_key might be multi-dimensional, flatten it for communication if split_key.ndim > 1: - # original_split_key_shape = split_key.shape + original_split_key_shape = split_key.shape + communication_split = output_split - (split_key.ndim - 1) split_key = split_key.flatten() + else: + communication_split = output_split + + # determine the number of elements to be received from each process recv_counts = torch.zeros((size, 1), dtype=torch.int64, device=self.larray.device) if key_is_mask_like: recv_indices = torch.zeros( - (len(split_key), len(key)), dtype=torch.int64, device=self.larray.device + (len(split_key), len(key)), dtype=split_key.dtype, device=self.larray.device ) else: recv_indices = torch.zeros( - (split_key.shape), dtype=torch.int64, device=self.larray.device + (split_key.shape), dtype=split_key.dtype, device=self.larray.device ) - print("DEBUGGING: SPLIY_KEY = ", split_key) - print("DEBUGGING: counts, displs = ", counts, displs) for p in range(size): cond1 = split_key >= displs[p] cond2 = split_key < displs[p] + counts[p] indices_from_p = torch.nonzero(cond1 & cond2, as_tuple=False) incoming_indices = split_key[indices_from_p].flatten() recv_counts[p, 0] = incoming_indices.numel() - print("DEBUGGING: P, RECV_COUNTS = ", p, incoming_indices.numel(), recv_counts) # store incoming indices in appropiate slice of recv_indices - # TODO: this is a bit of a convenience solution, but it doubles the memory footprint of split_key start = recv_counts[:p].sum().item() stop = start + recv_counts[p].item() - # print("DEBUGGING: incoming_indices = ", incoming_indices) - # print("DEBUGGING: start, stop = ", start, stop) if incoming_indices.numel() > 0: if key_is_mask_like: # apply selection to all dimensions @@ -1601,7 +1599,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar recv_indices[start:stop, self.split] -= displs[p] else: recv_indices[start:stop] = incoming_indices - displs[p] - print("DEBUGGING: AFTER: INC_INDICES = ", recv_indices[start:stop]) # build communication matrix by sharing recv_counts with all processes # comm_matrix rows contain the send_counts for each process, columns contain the recv_counts comm_matrix = torch.zeros((size, size), dtype=torch.int64, device=self.larray.device) @@ -1622,22 +1619,30 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # allocate recv_buf for incoming data recv_buf_shape = list(output_shape) - recv_buf_shape[output_split] = recv_counts.sum().item() + if communication_split != output_split: + # split key was flattened, flatten corresponding dims in recv_buf accordingly + recv_buf_shape = ( + recv_buf_shape[:communication_split] + + [recv_counts.sum().item()] + + recv_buf_shape[output_split + 1 :] + ) + else: + recv_buf_shape[communication_split] = recv_counts.sum().item() recv_buf = torch.zeros( tuple(recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device ) - print("DEBUGGING: recv_counts, send_counts = ", recv_counts, send_counts) - print("DEBUGGING: comm_matrix = ", comm_matrix) if rank_is_active: - # non-blocking send indices to active_send_indices_to + # non-blocking send indices to `active_send_indices_to` send_requests = [] for i in active_send_indices_to: start = recv_counts[:i].sum().item() stop = start + recv_counts[i].item() outgoing_indices = recv_indices[start:stop] send_requests.append(self.comm.Isend(outgoing_indices, dest=i)) + del outgoing_indices + del recv_indices for i in active_recv_indices_from: - # receive indices from active_recv_indices_from + # receive indices from `active_recv_indices_from` if key_is_mask_like: incoming_indices = torch.zeros( (send_counts[i].item(), len(key)), @@ -1663,33 +1668,39 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar send_key = list(key) send_key[self.split] = incoming_indices send_buf = self.larray[tuple(send_key)] - print(f"DEBUGGING: send_buf to {i} = {send_buf}") # non-blocking send requested data to i send_requests.append(self.comm.Isend(send_buf, dest=i)) - print("DEBUGGING: active_send_indices_to = ", active_send_indices_to) + del send_buf + # allocate temporary recv_buf to receive data from all active processes tmp_recv_buf_shape = recv_buf_shape.copy() - tmp_recv_buf_shape[output_split] = recv_counts.max().item() + tmp_recv_buf_shape[communication_split] = recv_counts.max().item() tmp_recv_buf = torch.zeros( tuple(tmp_recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device ) for i in active_send_indices_to: - # non-blocking receive data from i - print("debugging:, i = ", i) - print("DEBUGGING: split_key = ", split_key) + # receive data from i tmp_recv_slice = [slice(None)] * tmp_recv_buf.ndim - tmp_recv_slice[output_split] = slice(0, recv_counts[i].item()) + tmp_recv_slice[communication_split] = slice(0, recv_counts[i].item()) self.comm.Recv(tmp_recv_buf[tmp_recv_slice], source=i) - print(f"DEBUGGING: tmp_recv_buf from {i} = {tmp_recv_buf}") # write received data to appropriate portion of recv_buf cond1 = split_key >= displs[i] cond2 = split_key < displs[i] + counts[i] recv_buf_indices = torch.nonzero(cond1 & cond2, as_tuple=False).flatten() recv_buf_key = [slice(None)] * recv_buf.ndim - recv_buf_key[output_split] = recv_buf_indices + recv_buf_key[communication_split] = recv_buf_indices recv_buf[recv_buf_key] = tmp_recv_buf[tmp_recv_slice] + del tmp_recv_buf # wait for all non-blocking communication to finish for req in send_requests: req.Wait() + if communication_split != output_split: + # split_key has been flattened, bring back recv_buf to intended shape + original_local_shape = ( + output_shape[:communication_split] + + original_split_key_shape + + output_shape[output_split + 1 :] + ) + recv_buf = recv_buf.reshape(original_local_shape) # construct indexed array from recv_buf indexed_arr = DNDarray( @@ -1705,275 +1716,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar self = self.transpose(backwards_transpose_axes) return indexed_arr - # # determine whether indexed array will be 1D or nD - # try: - # return_1d = getattr(key, "ndim") == self.ndim - # send_axis = 0 - # except AttributeError: - # # key is tuple of torch tensors - # key_shapes = [] - # for k in key: - # key_shapes.append(getattr(k, "shape", None)) - # return_1d = key_shapes.count(key_shapes[original_split]) == self.ndim - # # check for broadcasted indexing: key along split axis is not 1D - # broadcasted_indexing = ( - # key_shapes[original_split] is not None and len(key_shapes[original_split]) > 1 - # ) - # if broadcasted_indexing: - # broadcast_shape = key_shapes[original_split] - # key = list(key) - # key[original_split] = key[original_split].flatten() - # key = tuple(key) - # send_axis = original_split - # else: - # send_axis = output_split - - # # send and receive "request key" info on what data element to ship where - # recv_counts = torch.zeros((self.comm.size, 1), dtype=torch.int64, device=self.larray.device) - - # # construct empty tensor that we'll append to later - # if return_1d: - # request_key_shape = (0, self.ndim) - # else: - # request_key_shape = (0, 1) - - # outgoing_request_key = torch.empty( - # tuple(request_key_shape), dtype=torch.int64, device=self.larray.device - # ) - # outgoing_request_key_counts = torch.zeros( - # (self.comm.size,), dtype=torch.int64, device=self.larray.device - # ) - - # # process-local: calculate which/how many elements will be received from what process - # if split_key_is_ordered == -1: - # # key is sorted in descending order (i.e. slicing w/ negative step): - # # shrink selection of active processes - # if key[original_split].numel() > 0: - # key_edges = torch.cat( - # (key[original_split][-1].reshape(-1), key[original_split][0].reshape(-1)), dim=0 - # ).unique() - # displs = torch.tensor(displs, device=self.larray.device) - # _, inverse, counts = torch.cat((displs, key_edges), dim=0).unique( - # sorted=True, return_inverse=True, return_counts=True - # ) - # if key_edges.numel() == 2: - # correction = counts[inverse[-2]] % 2 - # start_rank = inverse[-2] - correction - # correction += counts[inverse[-1]] % 2 - # end_rank = inverse[-1] - correction + 1 - # elif key_edges.numel() == 1: - # correction = counts[inverse[-1]] % 2 - # start_rank = inverse[-1] - correction - # end_rank = start_rank + 1 - # else: - # start_rank = 0 - # end_rank = 0 - # else: - # start_rank = 0 - # end_rank = self.comm.size - # all_local_indexing = torch.ones( - # (self.comm.size,), dtype=torch.bool, device=self.larray.device - # ) - # all_local_indexing[start_rank:end_rank] = False - # for i in range(start_rank, end_rank): - # try: - # cond1 = key >= displs[i] - # if i != self.comm.size - 1: - # cond2 = key < displs[i + 1] - # else: - # # cond2 is always true - # cond2 = torch.ones((key.shape[0],), dtype=torch.bool, device=self.larray.device) - # except TypeError: - # cond1 = key[original_split] >= displs[i] - # if i != self.comm.size - 1: - # cond2 = key[original_split] < displs[i + 1] - # else: - # # cond2 is always true - # cond2 = torch.ones( - # (key[original_split].shape[0],), dtype=torch.bool, device=self.larray.device - # ) - # if return_1d: - # # advanced indexing returning 1D array - # if isinstance(key, torch.Tensor): - # selection = key[cond1 & cond2] - # recv_counts[i, :] = selection.shape[0] - # if i == self.comm.rank: - # all_local_indexing[i] = selection.shape[0] == key.shape[0] - # selection.unsqueeze_(dim=1) - # else: - # # key is tuple of torch tensors - # selection = list(k[cond1 & cond2] for k in key) - # recv_counts[i, :] = selection[0].shape[0] - # if i == self.comm.rank: - # all_local_indexing[i] = selection[0].shape[0] == key[0].shape[0] - # selection = torch.stack(selection, dim=1) - # else: - # selection = key[original_split][cond1 & cond2] - # recv_counts[i, :] = selection.shape[0] - # if i == self.comm.rank: - # all_local_indexing[i] = selection.shape[0] == key[original_split].shape[0] - # selection.unsqueeze_(dim=1) - # outgoing_request_key = torch.cat((outgoing_request_key, selection), dim=0) - # all_local_indexing = factories.array( - # all_local_indexing, is_split=0, device=self.device, copy=False - # ) - # if all_local_indexing.all().item(): - # if broadcasted_indexing: - # key[original_split] = key[original_split].reshape(broadcast_shape) - # indexed_arr = self.larray[key] - # # transpose array back if needed - # self = self.transpose(backwards_transpose_axes) - # return factories.array( - # indexed_arr, is_split=output_split, device=self.device, copy=False - # ) - - # # share recv_counts among all processes - # comm_matrix = torch.empty( - # (self.comm.size, self.comm.size), dtype=recv_counts.dtype, device=recv_counts.device - # ) - # self.comm.Allgather(recv_counts, comm_matrix) - - # outgoing_request_key_counts = comm_matrix[self.comm.rank] - # outgoing_request_key_displs = torch.cat( - # ( - # torch.zeros( - # (1,), - # dtype=outgoing_request_key_counts.dtype, - # device=outgoing_request_key_counts.device, - # ), - # outgoing_request_key_counts, - # ), - # dim=0, - # ).cumsum(dim=0)[:-1] - # incoming_request_key_counts = comm_matrix[:, self.comm.rank] - # incoming_request_key_displs = torch.cat( - # ( - # torch.zeros( - # (1,), - # dtype=outgoing_request_key_counts.dtype, - # device=outgoing_request_key_counts.device, - # ), - # incoming_request_key_counts, - # ), - # dim=0, - # ).cumsum(dim=0)[:-1] - - # if return_1d: - # incoming_request_key = torch.empty( - # (incoming_request_key_counts.sum(), self.ndim), - # dtype=outgoing_request_key_counts.dtype, - # device=outgoing_request_key_counts.device, - # ) - # else: - # incoming_request_key = torch.empty( - # (incoming_request_key_counts.sum(), 1), - # dtype=outgoing_request_key_counts.dtype, - # device=outgoing_request_key_counts.device, - # ) - # # send and receive request keys - # self.comm.Alltoallv( - # ( - # outgoing_request_key, - # outgoing_request_key_counts.tolist(), - # outgoing_request_key_displs.tolist(), - # ), - # ( - # incoming_request_key, - # incoming_request_key_counts.tolist(), - # incoming_request_key_displs.tolist(), - # ), - # ) - # if return_1d: - # incoming_request_key = list(incoming_request_key[:, d] for d in range(self.ndim)) - # incoming_request_key[original_split] -= displs[self.comm.rank] - # else: - # incoming_request_key -= displs[self.comm.rank] - # incoming_request_key = ( - # key[:original_split] - # + (incoming_request_key.squeeze_(1),) - # + key[original_split + 1 :] - # ) - - # # calculate shape of local recv buffer - # output_lshape = list(output_shape) - # if getattr(key, "ndim", 0) == 1: - # output_lshape[output_split] = key.shape[0] - # else: - # if broadcasted_indexing: - # output_lshape = ( - # output_lshape[:original_split] - # + [torch.prod(torch.tensor(broadcast_shape, device=self.larray.device)).item()] - # + output_lshape[output_split + 1 :] - # ) - # else: - # output_lshape[output_split] = key[original_split].shape[0] - # # allocate recv buffer - # recv_buf = torch.empty( - # tuple(output_lshape), dtype=self.larray.dtype, device=self.larray.device - # ) - - # # index local data into send_buf. - # send_empty = sum( - # list(isinstance(k, torch.Tensor) and k.numel() == 0 for k in incoming_request_key) - # ) - # if send_empty: - # # Edge case 1. empty slice along split axis: send_buf is 0-element tensor - # empty_shape = list(output_shape) - # empty_shape[output_split] = 0 - # send_buf = torch.empty(empty_shape, dtype=self.larray.dtype, device=self.larray.device) - # else: - # send_buf = self.larray[incoming_request_key] - # # Edge case 2. local single-element indexing results into local loss of split axis - # if send_buf.ndim < len(output_lshape): - # all_keys_scalar = sum( - # list( - # np.isscalar(k) or k.numel() == 1 and getattr(k, "ndim", 2) < 2 - # for k in incoming_request_key - # ) - # ) == len(incoming_request_key) - # if not all_keys_scalar: - # send_buf = send_buf.unsqueeze_(dim=output_split) - - # recv_counts = torch.squeeze(recv_counts, dim=1).tolist() - # recv_displs = outgoing_request_key_displs.tolist() - # send_counts = incoming_request_key_counts.tolist() - # send_displs = incoming_request_key_displs.tolist() - # self.comm.Alltoallv( - # (send_buf, send_counts, send_displs), - # (recv_buf, recv_counts, recv_displs), - # send_axis=send_axis, - # ) - # # transpose original array back if needed, all further indexing on recv_buf - # self = self.transpose(backwards_transpose_axes) - - # # reorganize incoming counts according to original key order along split axis - # if return_1d: - # if isinstance(key, tuple): - # key = torch.stack(key, dim=1) - # _, key_inverse = key.unique(dim=0, sorted=True, return_inverse=True) - # _, ork_inverse = outgoing_request_key.unique(dim=0, sorted=True, return_inverse=True) - # map = ork_inverse.argsort(stable=True)[ - # key_inverse.argsort(stable=True).argsort(stable=True) - # ] - # indexed_arr = recv_buf[map] - # return factories.array(indexed_arr, is_split=output_split, copy=False) - - # outgoing_request_key = outgoing_request_key.squeeze_(1) - # map = [slice(None)] * recv_buf.ndim - # # print("DEBUGGING: outgoing_request_key = ", outgoing_request_key) - # # print("DEBUGGING: key[original_split] = ", key[original_split]) - # if broadcasted_indexing: - # map[original_split] = outgoing_request_key.argsort(stable=True)[ - # key[original_split].argsort(stable=True).argsort(stable=True) - # ] - # map[original_split] = map[original_split].reshape(broadcast_shape) - # else: - # map[output_split] = outgoing_request_key.argsort(stable=True)[ - # key[original_split].argsort(stable=True).argsort(stable=True) - # ] - # indexed_arr = recv_buf[map] - # return factories.array(indexed_arr, is_split=output_split, copy=False) - if torch.cuda.device_count() > 0: def gpu(self) -> DNDarray: From e4a90deef50992fc7afd88d7ee5edf8f54f06e39 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 10:18:21 +0100 Subject: [PATCH 121/132] balance indexed array before eq() --- heat/core/tests/test_dndarray.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 44c073b0bf..8dfa57544c 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1551,6 +1551,7 @@ def test_setitem(self): value = ht.array([99, 98, 97, 96], split=0) x[k1, k2, k3] = value self.assertTrue((x[k1, k2, k3] == ht.array([96, 98, 97, 96], split=0)).all().item()) + # # advanced indexing on non-consecutive dimensions # x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) # x_copy = x.copy() @@ -1618,7 +1619,9 @@ def test_setitem(self): arr_split0 = ht.array(arr, split=0) mask_split0 = ht.array(mask, split=0) arr_split0[mask_split0] = value[mask] - self.assertTrue((arr_split0[mask_split0] == value[mask]).all().item()) + indexed_arr = arr_split0[mask_split0] + indexed_arr.balance_() + self.assertTrue((indexed_arr == value[mask]).all().item()) arr_split1 = ht.array(arr, split=1) mask_split1 = ht.array(mask, split=1) arr_split1[mask_split1] = value[mask] From c8967e7b0de027fa8998d01c353a7b16b0f22668 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 10:19:22 +0100 Subject: [PATCH 122/132] remove print statements --- heat/core/dndarray.py | 340 +----------------------------------------- 1 file changed, 1 insertion(+), 339 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index acd1ab60f4..d87574d95d 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -157,7 +157,6 @@ def larray(self, array: torch.Tensor): ----------- Please use this function with care, as it might corrupt/invalidate the metadata in the ``DNDarray`` instance. """ - print("DEBUGGING: larray setter") # sanitize tensor input sanitation.sanitize_in_tensor(array) # verify consistency of tensor shape with global DNDarray @@ -914,7 +913,6 @@ def __process_key( else: # advanced indexing on first dimension: first dim will expand to shape of key output_shape = tuple(list(key.shape) + output_shape[1:]) - print("DEBUGGING ADV IND: output_shape = ", output_shape) # adjust split axis accordingly if arr_is_distributed: if arr.split != 0: @@ -1030,7 +1028,6 @@ def __process_key( expand_key[:ellipsis_index] = key[:ellipsis_index] expand_key[ellipsis_index + ellipsis_dims :] = key[ellipsis_index + 1 :] key = expand_key - print("DEBUGGING: ELLIPSIS: ", key) while add_dims > 0: # expand array dims: output_shape, split_bookkeeping to reflect newaxis # replace newaxis with slice(None) in key @@ -1052,7 +1049,6 @@ def __process_key( new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None transpose_axes, backwards_transpose_axes = tuple(range(arr.ndim)), tuple(range(arr.ndim)) # check for advanced indexing and slices - print("DEBUGGING: key = ", key) advanced_indexing_dims = [] advanced_indexing_shapes = [] lose_dims = 0 @@ -1109,7 +1105,6 @@ def __process_key( if step is None: step = 1 if step < 0 and start > stop: - print("TEST LOCAL SLICE: ", arr.__get_local_slice(k)) # PyTorch doesn't support negative step as of 1.13 # Lazy solution, potentially large memory footprint # TODO: implement ht.fromiter (implemented in ASSET_ht) @@ -1130,7 +1125,6 @@ def __process_key( key[i] = factories.array( key[i], split=0, device=arr.device, copy=False ).larray - print("DEBUGGING: key[i] = ", key[i]) out_is_balanced = True elif step > 0 and start < stop: output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) @@ -1175,7 +1169,6 @@ def __process_key( and len(set(k.shape for k in key)) == 1 and torch.tensor(advanced_indexing_dims).diff().eq(1).all() ) - print("KEY_IS_MASK_LIKE = ", key_is_mask_like) # if split axis is affected by advanced indexing, keep track of non-split dimensions for later if arr.is_distributed() and arr.split in advanced_indexing_dims: non_split_dims = list(advanced_indexing_dims).copy() @@ -1236,8 +1229,6 @@ def __process_key( # advanced indexing does not affect split axis, return torch tensors for i in advanced_indexing_dims: key[i] = key[i].larray - print("ADV IND KEY = ", key) - print("DEBUGGING: advanced_indexing_shapes = ", advanced_indexing_shapes) # all adv indexing keys are now torch tensors # shapes of adv indexing arrays must be broadcastable @@ -1260,11 +1251,6 @@ def __process_key( advanced_indexing_dims[0] : advanced_indexing_dims[0] + len(advanced_indexing_dims) ] = broadcasted_shape - print( - "DEBUGGING: broadcasted_shape, split_bookkeeping = ", - broadcasted_shape, - split_bookkeeping, - ) if key_is_mask_like: # advanced indexing dimensions will be collapsed into one dimension if ( @@ -1286,7 +1272,6 @@ def __process_key( + [None] * add_dims + split_bookkeeping[advanced_indexing_dims[0] :] ) - print("ADV IND output_shape = ", output_shape) else: # advanced-indexing dimensions are not consecutive: # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions @@ -1322,14 +1307,6 @@ def __process_key( split_bookkeeping = split_bookkeeping[:lost_dim] + split_bookkeeping[lost_dim + 1 :] output_shape = tuple(output_shape) new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None - print( - "key, output_shape, new_split, split_key_is_ordered, out_is_balanced = ", - key, - output_shape, - new_split, - split_key_is_ordered, - out_is_balanced, - ) return ( arr, key, @@ -1443,8 +1420,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (2/2) >>> tensor([0., 0.]) """ # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof - # Trivial cases - # print("DEBUGGING: RAW KEY = ", key, type(key)) if key is None: return self.expand_dims(0) @@ -1501,7 +1476,6 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if not self.is_distributed(): # key is torch-proof, index underlying torch tensor - # print("DEBUGGING: key = ", key) indexed_arr = self.larray[key] # transpose array back if needed self = self.transpose(backwards_transpose_axes) @@ -2291,21 +2265,6 @@ def __set( """ Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. """ - # # need information on indexed array, use proxy to limit memory usage - # subarray = arr.__torch_proxy__()[key] - # subarray_shape, subarray_ndim = tuple(subarray.shape), subarray.ndim - # while value.ndim < subarray_ndim: # broadcasting - # value = value.expand_dims(0) - # try: - # value_shape = tuple(torch.broadcast_shapes(value_shape, subarray_shape)) - # except RuntimeError: - # raise ValueError( - # f"could not broadcast input array from shape {value.shape} into shape {arr.shape}" - # ) - # # TODO: take this out of this function - # sanitation.sanitize_out(subarray, value_shape, value.split, value.device, value.comm) - # arr.larray[None] = value.larray - # only assign values if key does not contain empty slices process_is_inactive = arr.larray[key].numel() == 0 if not process_is_inactive: @@ -2373,11 +2332,6 @@ def __set( ) = self.__process_key(key, return_local_indices=True, op="set") # match dimensions - print( - "DEBUGGING: BEFORE BROADCAST: OUTPUT_SHAPE, SPLIT_KEY_IS_ORDERED = ", - output_shape, - split_key_is_ordered, - ) value, value_is_scalar = __broadcast_value(self, key, value, output_shape=output_shape) # early out for non-distributed case @@ -2516,10 +2470,7 @@ def __set( split_key, is_split=0, device=self.device, comm=self.comm, copy=False ) target_map = value.lshape_map - print("DEBUGGING: target_map = ", target_map) - print("DEBUGGING: global_split_key.lshape_map = ", global_split_key.lshape_map) target_map[:, value.split] = global_split_key.lshape_map[:, 0] - print("DEBUGGING: target_map AFTER = ", target_map) value.redistribute_(target_map=target_map) else: # redistribute split-axis `key` to match distribution of `value` in one pass @@ -2558,20 +2509,14 @@ def __set( send_buf_shape[-1] += len(key) else: send_buf_shape[-1] += 1 - print("DEBUGGING: send_buf_shape = ", send_buf_shape) send_buf = torch.zeros( send_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device ) - print("DEBUGGING: BEFORE LOOP: counts, displs = ", counts, displs) - print("debugging: key_is_mask_like = ", key_is_mask_like) for proc in range(self.comm.size): # calculate what local elements of `value` belong on process `proc` send_indices = torch.nonzero( (split_key >= displs[proc]) & (split_key < displs[proc] + counts[proc]) ).flatten() - print( - "DEBUGGING: proc, send_indices = ", proc, send_indices, split_key[send_indices] - ) # calculate outgoing counts and displacements for each process send_counts[proc] = send_indices.numel() send_displs[proc] = send_counts[:proc].sum() @@ -2603,7 +2548,6 @@ def __set( device=self.device.torch_device, ) self.comm.Allgather(send_counts, comm_matrix) - print("DEBUGGING:, RANK, SEND_BUF = ", self.comm.rank, send_buf) # comm_matrix columns contain recv_counts for each process recv_counts = comm_matrix[:, self.comm.rank].squeeze(0) recv_displs = torch.zeros_like(recv_counts) @@ -2619,7 +2563,6 @@ def __set( else: recv_buf_shape[-1] += 1 recv_buf_shape = tuple(recv_buf_shape) - print("DEBUGGING: recv_buf_shape = ", recv_buf_shape) recv_buf = torch.zeros( recv_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device ) @@ -2630,20 +2573,11 @@ def __set( recv_counts.tolist(), recv_displs.tolist(), ) - print("DEBUGGING: send_buf.shape, recv_buf.shape = ", send_buf.shape, recv_buf.shape) - print( - "DEBUGGING: send_counts, send_displs, recv_counts, recv_displs = ", - send_counts, - send_displs, - recv_counts, - recv_displs, - ) self.comm.Alltoallv( (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) ) del send_buf, comm_matrix key = list(key) - print("DEBUGGING: recv_buf = ", recv_buf) if key_is_mask_like: # extract incoming indices from recv_buf recv_indices = recv_buf[..., -len(key) :] @@ -2666,8 +2600,6 @@ def __set( value = value.transpose(transpose_axes) if value.ndim < 2: recv_buf.squeeze_(1) - print("DEBUGGING: transpose_axes = ", transpose_axes) - print("DEBUGGING: value.shape, recv_buf.shape = ", value.shape, recv_buf.shape) recv_buf = DNDarray( recv_buf.permute(*transpose_axes), gshape=value.gshape, @@ -2679,277 +2611,7 @@ def __set( ) # set local elements of `self` to corresponding elements of `value` __set(self, key, recv_buf) - - # if advanced_indexing: - # raise Exception("Advanced indexing is not supported yet") - - # split = self.split - # if not self.is_distributed() or key[split] == slice(None): - # return __set(self[key], value) - - # if isinstance(key[split], slice): - # return __set(self[key], value) - - # if np.isscalar(key[split]): - # key = list(key) - # idx = int(key[split]) - # key[split] = slice(idx, idx + 1) - # return __set(self[tuple(key)], value) - - # key = getattr(key, "copy()", key) - # try: - # if value.split != self.split: - # val_split = int(value.split) - # sp = self.split - # warnings.warn( - # f"\nvalue.split {val_split} not equal to this DNDarray's split:" - # f" {sp}. this may cause errors or unwanted behavior", - # category=RuntimeWarning, - # ) - # except (AttributeError, TypeError): - # pass - - # # NOTE: for whatever reason, there is an inplace op which interferes with the abstraction - # # of this next block of code. this is shared with __getitem__. I attempted to abstract it - # # in a standard way, but it was causing errors in the test suite. If someone else is - # # motived to do this they are welcome to, but i have no time right now - # # print(key) - # if isinstance(key, DNDarray) and key.ndim == self.ndim: - # """if the key is a DNDarray and it has as many dimensions as self, then each of the - # entries in the 0th dim refer to a single element. To handle this, the key is split - # into the torch tensors for each dimension. This signals that advanced indexing is - # to be used.""" - # key = manipulations.resplit(key) - # if key.larray.dtype in [torch.bool, torch.uint8]: - # key = indexing.nonzero(key) - - # if key.ndim > 1: - # key = list(key.larray.split(1, dim=1)) - # # key is now a list of tensors with dimensions (key.ndim, 1) - # # squeeze singleton dimension: - # key = [key[i].squeeze_(1) for i in range(len(key))] - # else: - # key = [key] - # elif not isinstance(key, tuple): - # """this loop handles all other cases. DNDarrays which make it to here refer to - # advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - # are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - # h = [slice(None, None, None)] * self.ndim - # if isinstance(key, DNDarray): - # key = manipulations.resplit(key) - # if key.larray.dtype in [torch.bool, torch.uint8]: - # h[0] = torch.nonzero(key.larray).flatten() # .tolist() - # else: - # h[0] = key.larray.tolist() - # elif isinstance(key, torch.Tensor): - # if key.dtype in [torch.bool, torch.uint8]: - # # (coquelin77) im not sure why this works without being a list...but it does...for now - # h[0] = torch.nonzero(key).flatten() # .tolist() - # else: - # h[0] = key.tolist() - # else: - # h[0] = key - # key = list(h) - - # # key must be torch-proof - # if isinstance(key, (list, tuple)): - # key = list(key) - # for i, k in enumerate(key): - # try: # extract torch tensor - # k = manipulations.resplit(k) - # key[i] = k.larray - # except AttributeError: - # pass - # # remove bools from a torch tensor in favor of indexes - # try: - # if key[i].dtype in [torch.bool, torch.uint8]: - # key[i] = torch.nonzero(key[i]).flatten() - # except (AttributeError, TypeError): - # pass - - # key = list(key) - - # # ellipsis stuff - # key_classes = [type(n) for n in key] - # # if any(isinstance(n, ellipsis) for n in key): - # n_elips = key_classes.count(type(...)) - # if n_elips > 1: - # raise ValueError("key can only contain 1 ellipsis") - # elif n_elips == 1: - # # get which item is the ellipsis - # ell_ind = key_classes.index(type(...)) - # kst = key[:ell_ind] - # kend = key[ell_ind + 1 :] - # slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - # key = kst + slices + kend - # # ---------- end ellipsis stuff ------------- - - # for c, k in enumerate(key): - # try: - # key[c] = k.item() - # except (AttributeError, ValueError, RuntimeError): - # pass - - # rank = self.comm.rank - # if self.split is not None: - # counts, chunk_starts = self.counts_displs() - # else: - # counts, chunk_starts = 0, [0] * self.comm.size - # counts = torch.tensor(counts, device=self.device.torch_device) - # chunk_starts = torch.tensor(chunk_starts, device=self.device.torch_device) - # chunk_ends = chunk_starts + counts - # chunk_start = chunk_starts[rank] - # chunk_end = chunk_ends[rank] - # # determine which elements are on the local process (if the key is a torch tensor) - # try: - # # if isinstance(key[self.split], torch.Tensor): - # filter_key = torch.nonzero( - # (chunk_start <= key[self.split]) & (key[self.split] < chunk_end) - # ) - # for k in range(len(key)): - # try: - # key[k] = key[k][filter_key].flatten() - # except TypeError: - # pass - # except TypeError: # this will happen if the key doesnt have that many - # pass - - # key = tuple(key) - - # if not self.is_distributed(): - # return self.__setter(key, value) # returns None - - # # raise RuntimeError("split axis of array and the target value are not equal") removed - # # this will occur if the local shapes do not match - # rank = self.comm.rank - # ends = [] - # for pr in range(self.comm.size): - # _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) - # ends.append(e[self.split].stop - e[self.split].start) - # ends = torch.tensor(ends, device=self.device.torch_device) - # chunk_ends = ends.cumsum(dim=0) - # chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) - # _, _, chunk_slice = self.comm.chunk(self.shape, self.split) - # chunk_start = chunk_slice[self.split].start - # chunk_end = chunk_slice[self.split].stop - - # self_proxy = self.__torch_proxy__() - - # # if the value is a DNDarray, the divisions need to be balanced: - # # this means that we need to know how much data is where for both DNDarrays - # # if the value data is not in the right place, then it will need to be moved - - # if isinstance(key[self.split], slice): - # key = list(key) - # key_start = key[self.split].start if key[self.split].start is not None else 0 - # key_stop = ( - # key[self.split].stop - # if key[self.split].stop is not None - # else self.gshape[self.split] - # ) - # if key_stop < 0: - # key_stop = self.gshape[self.split] + key[self.split].stop - # key_step = key[self.split].step - # og_key_start = key_start - # st_pr = torch.where(key_start < chunk_ends)[0] - # st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - # sp_pr = torch.where(key_stop >= chunk_starts)[0] - # sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - # actives = list(range(st_pr, sp_pr + 1)) - - # if ( - # isinstance(value, type(self)) - # and value.split is not None - # and value.shape[self.split] != self.shape[self.split] - # ): - # # setting elements in self with a DNDarray which is not the same size in the - # # split dimension - # local_keys = [] - # # below is used if the target needs to be reshaped - # target_reshape_map = torch.zeros( - # (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device - # ) - # for r in range(self.comm.size): - # if r not in actives: - # loc_key = key.copy() - # loc_key[self.split] = slice(0, 0, 0) - # else: - # key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] - # key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] - # key_start_l, key_stop_l = self.__xitem_get_key_start_stop( - # r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start - # ) - # loc_key = key.copy() - # loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) - - # gout_full = torch.tensor( - # self_proxy[loc_key].shape, device=self.device.torch_device - # ) - # target_reshape_map[r] = gout_full - # local_keys.append(loc_key) - - # key = local_keys[rank] - # value = value.redistribute(target_map=target_reshape_map) - - # if rank not in actives: - # return # non-active ranks can exit here - - # chunk_starts_v = target_reshape_map[:, self.split] - # value_slice = [slice(None, None, None)] * value.ndim - # step2 = key_step if key_step is not None else 1 - # key_start = (chunk_starts_v[rank] - og_key_start).item() - - # key_start = max(key_start, 0) - # key_stop = key_start + key_stop - # slice_loc = min(self.split, value.ndim - 1) - # value_slice[slice_loc] = slice( - # key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 - # ) - - # self.__setter(tuple(key), value.larray) - # return - - # # if rank in actives: - # if rank not in actives: - # return # non-active ranks can exit here - # key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - # key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - # key_start, key_stop = self.__xitem_get_key_start_stop( - # rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - # ) - # key[self.split] = slice(key_start, key_stop, key_step) - - # # todo: need to slice the values to be the right size... - # if isinstance(value, (torch.Tensor, type(self))): - # # if its a torch tensor, it is assumed to exist on all processes - # value_slice = [slice(None, None, None)] * value.ndim - # step2 = key_step if key_step is not None else 1 - # key_start = (chunk_starts[rank] - og_key_start).item() - # key_start = max(key_start, 0) - # key_stop = key_start + key_stop - # slice_loc = min(self.split, value.ndim - 1) - # value_slice[slice_loc] = slice( - # key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 - # ) - # self.__setter(tuple(key), value[tuple(value_slice)]) - # else: - # self.__setter(tuple(key), value) - # elif isinstance(key[self.split], (torch.Tensor, list)): - # key = list(key) - # key[self.split] -= chunk_start - # if len(key[self.split]) != 0: - # self.__setter(tuple(key), value) - - # elif key[self.split] in range(chunk_start, chunk_end): - # key = list(key) - # key[self.split] = key[self.split] - chunk_start - # self.__setter(tuple(key), value) - - # elif key[self.split] < 0: - # key = list(key) - # if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): - # key[self.split] = key[self.split] + self.shape[self.split] - chunk_start - # self.__setter(tuple(key), value) + self = self.transpose(backwards_transpose_axes) def __setter( self, From 95eaaeb341fc43c78cc3a9581c03c71fc89f0502 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Sun, 18 Feb 2024 10:33:03 +0100 Subject: [PATCH 123/132] test adv ind on non-consecutive dims --- heat/core/tests/test_dndarray.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 8dfa57544c..12bc61f7ec 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1552,20 +1552,20 @@ def test_setitem(self): x[k1, k2, k3] = value self.assertTrue((x[k1, k2, k3] == ht.array([96, 98, 97, 96], split=0)).all().item()) - # # advanced indexing on non-consecutive dimensions - # x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) - # x_copy = x.copy() - # x_np = np.arange(60).reshape(5, 3, 4) - # k1 = np.array([0, 4, 1, 0]) - # k2 = 0 - # k3 = np.array([1, 2, 3, 1]) - # key = (k1, k2, k3) - # self.assert_array_equal(x[key], x_np[key]) - # # check that x is unchanged after internal manipulation - # self.assertTrue(x.shape == x_copy.shape) - # self.assertTrue(x.split == x_copy.split) - # self.assertTrue(x.lshape == x_copy.lshape) - # self.assertTrue((x == x_copy).all().item()) + # advanced indexing on non-consecutive dimensions, split dimension will be lost + x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + x_copy = x.copy() + k1 = np.array([0, 4, 1, 2]) + k2 = 0 + k3 = np.array([1, 2, 3, 1]) + key = (k1, k2, k3) + value = ht.array([99, 98, 97, 96]) + x[key] = value + self.assertTrue((x[key] == ht.array([99, 98, 97, 96])).all().item()) + # check that x is unchanged after internal manipulation + self.assertTrue(x.shape == x_copy.shape) + self.assertTrue(x.split == x_copy.split) + self.assertTrue(x.lshape == x_copy.lshape) # # broadcasting shapes # x.resplit_(axis=0) From 835a13fd542860a4b560cad530744f13a9bd93b0 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 20 Feb 2024 05:03:29 +0100 Subject: [PATCH 124/132] remove print statement --- heat/core/dndarray.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index d87574d95d..6847dbb7d8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1133,13 +1133,6 @@ def __process_key( out_is_balanced = False local_arr_end = displs[arr.comm.rank] + counts[arr.comm.rank] if stop > displs[arr.comm.rank] and start < local_arr_end: - print( - "stop, start, displs[arr.comm.rank], displs[arr.comm.rank] + counts[arr.comm.rank] = ", - stop, - start, - displs[arr.comm.rank], - displs[arr.comm.rank] + counts[arr.comm.rank], - ) index_in_cycle = (displs[arr.comm.rank] - start) % step if start >= displs[arr.comm.rank]: # slice begins on current rank From 216a1a01d99b9cba8a65c9a401709b6e1e33550f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 20 Feb 2024 05:57:04 +0100 Subject: [PATCH 125/132] setitem: mixed indexing w. shape broadcasting --- heat/core/dndarray.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 6847dbb7d8..3c3bfe5b3c 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2343,15 +2343,21 @@ def __set( self.larray[key] = value.larray.type(self.dtype.torch_type()) else: # indexed elements are process-local - if self.is_distributed() and not value_is_scalar and not value.is_distributed(): - # work with distributed `value` - value = factories.array( - value.larray, - dtype=value.dtype, - split=output_split, - device=self.device, - comm=self.comm, - ) + if self.is_distributed() and not value_is_scalar: + if not value.is_distributed(): + # work with distributed `value` + value = factories.array( + value.larray, + dtype=value.dtype, + split=output_split, + device=self.device, + comm=self.comm, + ) + else: + if value.split != output_split: + raise RuntimeError( + f"Cannot assign distributed `value` with split axis {value.split} to indexed DNDarray with split axis {output_split}." + ) # verify that `self[key]` and `value` distribution are aligned target_shape = torch.tensor( tuple(self.larray[key].shape), device=self.device.torch_device From b62bad2eacf702ced86ebe001ac97b3930d29f78 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 20 Feb 2024 05:57:47 +0100 Subject: [PATCH 126/132] expand tests for mixed indexing w. broadcasting --- heat/core/tests/test_dndarray.py | 41 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 12bc61f7ec..3e3bd41ab7 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1567,24 +1567,29 @@ def test_setitem(self): self.assertTrue(x.split == x_copy.split) self.assertTrue(x.lshape == x_copy.lshape) - # # broadcasting shapes - # x.resplit_(axis=0) - # self.assert_array_equal(x[ht.array(k1, split=0), ht.array(1), 2], x_np[k1, 1, 2]) - # # test exception: broadcasting mismatching shapes - # k2 = np.array([0, 2, 1]) - # with self.assertRaises(IndexError): - # x[k1, k2, k3] - - # # more broadcasting - # x_np = np.arange(12).reshape(4, 3) - # rows = np.array([0, 3]) - # cols = np.array([0, 2]) - # x = ht.arange(12).reshape(4, 3) - # x.resplit_(1) - # x_np_indexed = x_np[rows[:, np.newaxis], cols] - # x_indexed = x[ht.array(rows)[:, np.newaxis], cols] - # self.assert_array_equal(x_indexed, x_np_indexed) - # self.assertTrue(x_indexed.split == 1) + # broadcasting shapes + x.resplit_(axis=0) + key = (ht.array(k1, split=0), ht.array(1), 2) + value = ht.array([99, 98, 97, 96], split=0) + x[key] = value + self.assertTrue((x[key] == value).all().item()) + # test exception: broadcasting mismatching shapes + k2 = np.array([0, 2, 1]) + with self.assertRaises(IndexError): + x[k1, k2, k3] = value + + # more broadcasting + x = ht.arange(12).reshape(4, 3) + x.resplit_(1) + rows = np.array([0, 3]) + cols = np.array([0, 2]) + key = (ht.array(rows)[:, np.newaxis], cols) + value = ht.array([[99, 98], [97, 96]], split=1) + x[key] = value + self.assertTrue((x[key] == value).all().item()) + with self.assertRaises(RuntimeError): + value = ht.array([[99, 98], [97, 96]], split=0) + x[key] = value # # combining advanced and basic indexing # y_np = np.arange(35).reshape(5, 7) From 435ff0c74bd94b5cea75b32fa204af1ca43aeec4 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:37:26 +0100 Subject: [PATCH 127/132] reinstate tests for specific bugs --- heat/core/tests/test_dndarray.py | 55 ++++++++++++++++---------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 3e3bd41ab7..e775562396 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -676,6 +676,13 @@ def test_getitem(self): self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) self.assertTrue(x_3d_sliced.split == 0) + # tests for bug 730: + a = ht.ones((10, 25, 30), split=1) + if a.comm.size > 1: + self.assertEqual(a[0].split, 0) + self.assertEqual(a[:, 0, :].split, None) + self.assertEqual(a[:, :, 0].split, 1) + # DIMENSIONAL INDEXING # ellipsis x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) @@ -1638,34 +1645,26 @@ def test_setitem(self): # TODO boolean mask, distributed, distributed `value` - # def test_setitem_getitem(self): - # # tests for bug #825 - # a = ht.ones((102, 102), split=0) - # setting = ht.zeros((100, 100), split=0) - # a[1:-1, 1:-1] = setting - # self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) - - # a = ht.ones((102, 102), split=1) - # setting = ht.zeros((30, 100), split=1) - # a[-30:, 1:-1] = setting - # self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) - - # a = ht.ones((102, 102), split=1) - # setting = ht.zeros((100, 100), split=1) - # a[1:-1, 1:-1] = setting - # self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) - - # a = ht.ones((102, 102), split=1) - # setting = ht.zeros((100, 20), split=1) - # a[1:-1, :20] = setting - # self.assertTrue(ht.all(a[1:-1, :20] == 0)) - - # # tests for bug 730: - # a = ht.ones((10, 25, 30), split=1) - # if a.comm.size > 1: - # self.assertEqual(a[0].split, 0) - # self.assertEqual(a[:, 0, :].split, None) - # self.assertEqual(a[:, :, 0].split, 1) + # tests for bug #825 + a = ht.ones((102, 102), split=0) + setting = ht.zeros((100, 100), split=0) + a[1:-1, 1:-1] = setting + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0).item()) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((30, 100), split=1) + a[-30:, 1:-1] = setting + self.assertTrue(ht.all(a[-30:, 1:-1] == 0).item()) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((100, 100), split=1) + a[1:-1, 1:-1] = setting + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0).item()) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((100, 20), split=1) + a[1:-1, :20] = setting + self.assertTrue(ht.all(a[1:-1, :20] == 0).item()) # # set and get single value # a = ht.zeros((13, 5), split=0) From ad9682211393ccae5a7a9235963463c689b8982a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 10 Apr 2024 05:59:46 +0200 Subject: [PATCH 128/132] prep send_buffer - expand value dimension if necessary --- heat/core/dndarray.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 3c3bfe5b3c..6c41fdac95 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2521,9 +2521,15 @@ def __set( send_displs[proc] = send_counts[:proc].sum() # compose send buffer: stack local elements of `value` according to destination process if send_indices.numel() > 0: - send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( - value.larray[send_indices] - ) + if value.ndim < 2: + # temporarily add a singleton dimension to value to accmodate column dimension for send_indices + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices].unsqueeze(1) + ) + else: + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices] + ) # store outgoing GLOBAL indices in the last column of send_buf # TODO: if key_is_mask_like: apply send_indices to all dimensions of key if key_is_mask_like: From c9d44aebbd6457e1121b8e3a245658848d58d3d3 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 10 Apr 2024 06:20:11 +0200 Subject: [PATCH 129/132] fix send_indices dims when key is not mask-like --- heat/core/dndarray.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 6c41fdac95..f83dab5afa 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2538,10 +2538,7 @@ def __set( send_displs[proc] : send_displs[proc] + send_counts[proc], i ] = key[i + len(key)][send_indices] else: - while send_indices.ndim < send_buf.ndim: - send_indices = split_key[send_indices] - # broadcast send_indices to correct shape - send_indices = send_indices.unsqueeze(-1) + send_indices = split_key[send_indices] send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], -1] = ( send_indices ) From cc7040007fe1e3aaaa449b1e73230936fccd5d28 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 11 Apr 2024 04:57:24 +0200 Subject: [PATCH 130/132] test split mismatch on comm.size > 1 --- heat/core/tests/test_dndarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e775562396..7e6a8f7d8e 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1594,9 +1594,10 @@ def test_setitem(self): value = ht.array([[99, 98], [97, 96]], split=1) x[key] = value self.assertTrue((x[key] == value).all().item()) - with self.assertRaises(RuntimeError): - value = ht.array([[99, 98], [97, 96]], split=0) - x[key] = value + if x.comm.size > 1: + with self.assertRaises(RuntimeError): + value = ht.array([[99, 98], [97, 96]], split=0) + x[key] = value # # combining advanced and basic indexing # y_np = np.arange(35).reshape(5, 7) From b78de30c25fe2d8ccd26e22e6deb0a5858c1fd1c Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:40:45 +0200 Subject: [PATCH 131/132] broadcasting assignment along split axis --- heat/core/dndarray.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index f83dab5afa..f6a6d1a698 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2212,15 +2212,13 @@ def __broadcast_value( # assess whether the shapes are compatible, starting from the trailing dimension for i in range(1, min(len(value_shape), len(output_shape)) + 1): if i == 1: - if value_shape[-i] != output_shape[-i]: + if value_shape[-i] != output_shape[-i] and not value_shape[-i] == 1: # shapes are not compatible, raise error raise ValueError( f"could not broadcast input array from shape {value_shape} into shape {output_shape}" ) else: - if value_shape[-i] != output_shape[-i] and ( - not value_shape[-i] == 1 or not output_shape[-i] == 1 - ): + if value_shape[-i] != output_shape[-i] and (not value_shape[-i] == 1): # shapes are not compatible, raise error raise ValueError( f"could not broadcast input from shape {value_shape} into shape {output_shape}" @@ -2424,6 +2422,19 @@ def __set( return # key is a sequence of torch.Tensors split_key = key[self.split] + split_key_dims = split_key.ndim + if split_key_dims > 1: + # flatten `split_key` + split_key = split_key.flatten() + # flatten split_key dimensions of `value`: + new_shape = list(value.shape) + new_shape = ( + new_shape[: output_split - (split_key_dims - 1)] + + [-1] + + new_shape[output_split + 1 :] + ) + value = value.reshape(new_shape) + output_split -= split_key_dims - 1 # find elements of `split_key` that are local to this process local_indices = torch.nonzero( (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) @@ -2446,7 +2457,7 @@ def __set( self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) else: # keep local indexing key and correct for displacements along split dimension - key[self.split] = key[self.split][local_indices] - displs[rank] + key[self.split] = split_key[local_indices] - displs[rank] key = tuple(key) value_key = tuple( [ @@ -2476,7 +2487,7 @@ def __set( if key_is_single_tensor: # key is a single torch.Tensor split_key = key - elif not key_is_mask_like: + else: split_key = key[self.split] global_split_key = factories.array( split_key, is_split=0, device=self.device, comm=self.comm, copy=False From a8f2d57462890939c028e6bb62259d8ac27d4d37 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:41:09 +0200 Subject: [PATCH 132/132] expand tests --- heat/core/tests/test_dndarray.py | 36 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 7e6a8f7d8e..6d0776c3b5 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1599,22 +1599,26 @@ def test_setitem(self): value = ht.array([[99, 98], [97, 96]], split=0) x[key] = value - # # combining advanced and basic indexing - # y_np = np.arange(35).reshape(5, 7) - # y_np_indexed = y_np[np.array([0, 2, 4]), 1:3] - # y = ht.array(y_np, split=1) - # y_indexed = y[ht.array([0, 2, 4]), 1:3] - # self.assert_array_equal(y_indexed, y_np_indexed) - # self.assertTrue(y_indexed.split == 1) - - # x_np = np.arange(10 * 20 * 30).reshape(10, 20, 30) - # x = ht.array(x_np, split=1) - # ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) - # ind_array_np = ind_array.numpy() - # x_np_indexed = x_np[..., ind_array_np, :] - # x_indexed = x[..., ind_array, :] - # self.assert_array_equal(x_indexed, x_np_indexed) - # self.assertTrue(x_indexed.split == 3) + # combining advanced and basic indexing + + y = ht.arange(35).reshape(5, 7) + y.resplit_(1) + y_copy = y.copy() + # assign non-distributed value + value = ht.arange(6).reshape(3, 2) + y[ht.array([0, 2, 4]), 1:3] = value + self.assertTrue((y[ht.array([0, 2, 4]), 1:3] == value).all().item()) + # assign distributed value + value.resplit_(1) + y_copy[ht.array([0, 2, 4]), 1:3] = value + self.assertTrue((y_copy[ht.array([0, 2, 4]), 1:3] == value).all().item()) + + x = ht.arange(10 * 20 * 30).reshape(10, 20, 30) + x.resplit_(1) + ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) + value = ht.ones((1, 2, 3, 4, 1)) + x[..., ind_array, :] = value + self.assertTrue((x[..., ind_array, :] == value).all().item()) # boolean mask, local arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5)