diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 07b4e48418..f6a6d1a698 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -10,7 +10,7 @@ from inspect import stack from mpi4py import MPI from pathlib import Path -from typing import List, Union, Tuple, TypeVar, Optional +from typing import List, Union, Tuple, TypeVar, Optional, Iterable warnings.simplefilter("always", ResourceWarning) @@ -632,8 +632,8 @@ def counts_displs(self) -> Tuple[Tuple[int], Tuple[int]]: counts = self.lshape_map[:, self.split] displs = [0] + torch.cumsum(counts, dim=0)[:-1].tolist() return tuple(counts.tolist()), tuple(displs) - else: - raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.") + + raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.") def cpu(self) -> DNDarray: """ @@ -825,6 +825,563 @@ def fill_diagonal(self, value: float) -> DNDarray: return self + def __process_key( + arr: DNDarray, + key: Union[Tuple[int, ...], List[int, ...]], + return_local_indices: Optional[bool] = False, + op: Optional[str] = None, + ) -> Tuple: + """ + Private method to process the key used for indexing a ``DNDarray`` so that it can be applied to the process-local data, i.e. `key` must be "torch-proof". + In a processed key: + - any ellipses or newaxis have been replaced with the appropriate number of slice objects + - ndarrays and DNDarrays have been converted to torch tensors + - the dimensionality is the same as the ``DNDarray`` it indexes + This function also manipulates `arr` if necessary, inserting and/or transposing dimensions as indicated by `key`. It calculates the output shape, split axis and balanced status of the indexed array. + + Parameters + ---------- + arr : DNDarray + The ``DNDarray`` to be indexed + key : int, Tuple[int, ...], List[int, ...] + The key used for indexing + return_local_indices : bool, optional + Whether to return the process-local indices of the key in the split dimension. This is only possible when the indexing key in the split dimension is ordered e.g. `split_key_is_ordered == 1`. Default: False + op : str, optional + The indexing operation that the key is being processed for. Get be "get" for `__getitem__` or "set" for `__setitem__`. Default: "get". + + Returns + ------- + arr : DNDarray + The ``DNDarray`` to be indexed. Its dimensions might have been modified if advanced, dimensional, broadcasted indexing is used. + key : Union(Tuple[Any, ...], DNDarray, np.ndarray, torch.Tensor, slice, int, List[int, ...]) + The processed key ready for indexing ``arr``. Its dimensions match the (potentially modified) dimensions of ``arr``. + Note: the key indices along the split axis are LOCAL indices, i.e. refer to the process-local data, if ordered indexing is used. Otherwise, they are GLOBAL indices, referring to the global memory-distributed DNDarray. Communication to extract the non-ordered elements of the input ``DNDarray`` is handled by the ``__getitem__`` function. + output_shape : Tuple[int, ...] + The shape of the output ``DNDarray`` + new_split : int + The new split axis + split_key_is_ordered : int + Whether the split key is sorted or ordered. Can be 1: ascending, 0: not ordered, -1: descending order. + out_is_balanced : bool + Whether the output ``DNDarray`` is balanced + root : int + The root process for the ``MPI.Bcast`` call when single-element indexing along the split axis is used + backwards_transpose_axes : Tuple[int, ...] + The axes to transpose the input ``DNDarray`` back to its original shape if it has been transposed for advanced indexing + """ + output_shape = list(arr.gshape) + split_bookkeeping = [None] * arr.ndim + new_split = arr.split + arr_is_distributed = False + if arr.split is not None: + split_bookkeeping[arr.split] = "split" + if arr.is_distributed(): + counts, displs = arr.counts_displs() + arr_is_distributed = True + + advanced_indexing = False + split_key_is_ordered = 1 + key_is_mask_like = False + out_is_balanced = False + root = None + backwards_transpose_axes = tuple(range(arr.ndim)) + + if isinstance(key, list): + try: + key = torch.tensor(key, device=arr.larray.device) + except RuntimeError: + raise IndexError("Invalid indices: expected a list of integers, got {}".format(key)) + if isinstance(key, (DNDarray, torch.Tensor, np.ndarray)): + if key.dtype in (bool, uint8, torch.bool, torch.uint8, np.bool_, np.uint8): + # boolean indexing: shape must be consistent with arr.shape + key_ndim = key.ndim + if not tuple(key.shape) == arr.shape[:key_ndim]: + raise IndexError( + "Boolean index of shape {} does not match indexed array of shape {}".format( + tuple(key.shape), arr.shape + ) + ) + # extract non-zero elements + try: + # key is torch tensor + key = key.nonzero(as_tuple=True) + except TypeError: + # key is np.ndarray or DNDarray + key = key.nonzero() + key_is_mask_like = True + else: + # advanced indexing on first dimension: first dim will expand to shape of key + output_shape = tuple(list(key.shape) + output_shape[1:]) + # adjust split axis accordingly + if arr_is_distributed: + if arr.split != 0: + # split axis is not affected + split_bookkeeping = [None] * key.ndim + split_bookkeeping[1:] + new_split = ( + split_bookkeeping.index("split") + if "split" in split_bookkeeping + else None + ) + out_is_balanced = arr.balanced + else: + # split axis is affected + if key.ndim > 1: + try: + key_numel = key.numel() + except AttributeError: + key_numel = key.size + if key_numel == arr.shape[0]: + new_split = tuple(key.shape).index(arr.shape[0]) + else: + new_split = key.ndim - 1 + try: + key_split = key[new_split].larray + sorted, _ = key_split.sort(stable=True) + except AttributeError: + key_split = key[new_split] + sorted = key_split.sort() + else: + new_split = 0 + # assess if key is sorted along split axis + try: + # DNDarray key + sorted, _ = torch.sort(key.larray, stable=True) + split_key_is_ordered = torch.tensor( + (key.larray == sorted).all(), + dtype=torch.uint8, + device=key.larray.device, + ) + if key.split is not None: + out_is_balanced = key.balanced + split_key_is_ordered = ( + factories.array( + [split_key_is_ordered], + is_split=0, + device=arr.device, + copy=False, + ) + .all() + .astype(types.canonical_heat_types.uint8) + .item() + ) + else: + split_key_is_ordered = split_key_is_ordered.item() + key = key.larray + except AttributeError: + # torch or ndarray key + try: + sorted, _ = torch.sort(key, stable=True) + except TypeError: + # ndarray key + sorted = torch.tensor(np.sort(key), device=arr.larray.device) + split_key_is_ordered = torch.tensor( + key == sorted, dtype=torch.uint8 + ).item() + if not split_key_is_ordered: + # prepare for distributed non-ordered indexing: distribute torch/numpy key + key = factories.array(key, split=0, device=arr.device).larray + out_is_balanced = True + if split_key_is_ordered: + # extract local key + cond1 = key >= displs[arr.comm.rank] + cond2 = key < displs[arr.comm.rank] + counts[arr.comm.rank] + key = key[cond1 & cond2] + if return_local_indices: + key -= displs[arr.comm.rank] + out_is_balanced = False + else: + try: + out_is_balanced = key.balanced + new_split = key.split + key = key.larray + except AttributeError: + # torch or numpy key, non-distributed indexed array + out_is_balanced = True + new_split = None + return ( + arr, + key, + output_shape, + new_split, + split_key_is_ordered, + key_is_mask_like, + out_is_balanced, + root, + backwards_transpose_axes, + ) + + key = list(key) if isinstance(key, Iterable) else [key] + + # check for ellipsis, newaxis. NB: (np.newaxis is None)==True + add_dims = sum(k is None for k in key) + ellipsis = sum(isinstance(k, type(...)) for k in key) + if ellipsis > 1: + raise ValueError("indexing key can only contain 1 Ellipsis (...)") + if ellipsis: + # key contains exactly 1 ellipsis + # replace with explicit `slice(None)` for affected dimensions + # output_shape, split_bookkeeping not affected + expand_key = [slice(None)] * (arr.ndim + add_dims) + ellipsis_index = key.index(...) + ellipsis_dims = arr.ndim - (len(key) - ellipsis - add_dims) + expand_key[:ellipsis_index] = key[:ellipsis_index] + expand_key[ellipsis_index + ellipsis_dims :] = key[ellipsis_index + 1 :] + key = expand_key + while add_dims > 0: + # expand array dims: output_shape, split_bookkeeping to reflect newaxis + # replace newaxis with slice(None) in key + for i, k in reversed(list(enumerate(key))): + if k is None: + key[i] = slice(None) + arr = arr.expand_dims(i - add_dims + 1) + output_shape = ( + output_shape[: i - add_dims + 1] + [1] + output_shape[i - add_dims + 1 :] + ) + split_bookkeeping = ( + split_bookkeeping[: i - add_dims + 1] + + [None] + + split_bookkeeping[i - add_dims + 1 :] + ) + add_dims -= 1 + + # recalculate new_split, transpose_axes after dimensions manipulation + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + transpose_axes, backwards_transpose_axes = tuple(range(arr.ndim)), tuple(range(arr.ndim)) + # check for advanced indexing and slices + advanced_indexing_dims = [] + advanced_indexing_shapes = [] + lose_dims = 0 + for i, k in enumerate(key): + if np.isscalar(k) or getattr(k, "ndim", 1) == 0: + # single-element indexing along axis i + try: + output_shape[i], split_bookkeeping[i] = None, None + except IndexError: + raise IndexError( + f"Too many indices for DNDarray: DNDarray is {arr.ndim}-dimensional, but {len(key)} dimensions were indexed" + ) + lose_dims += 1 + if i == arr.split: + key[i], root = arr.__process_scalar_key( + k, indexed_axis=i, return_local_indices=return_local_indices + ) + else: + key[i], _ = arr.__process_scalar_key( + k, indexed_axis=i, return_local_indices=False + ) + elif isinstance(k, Iterable) or isinstance(k, DNDarray): + advanced_indexing = True + advanced_indexing_dims.append(i) + # work with DNDarrays to assess distribution + # torch tensors will be extracted in the advanced indexing section below + k = factories.array(k, device=arr.device, comm=arr.comm, copy=None) + advanced_indexing_shapes.append(k.gshape) + if arr_is_distributed and i == arr.split: + if ( + not k.is_distributed() + and k.ndim == 1 + and (k.larray == torch.sort(k.larray, stable=True)[0]).all() + ): + split_key_is_ordered = 1 + out_is_balanced = None + else: + split_key_is_ordered = 0 + # redistribute key along last axis to match split axis of indexed array + k = k.resplit(-1) + out_is_balanced = True + key[i] = k + + elif isinstance(k, slice) and k != slice(None): + start, stop, step = k.start, k.stop, k.step + if start is None: + start = 0 + elif start < 0: + start += arr.gshape[i] + if stop is None: + stop = arr.gshape[i] + elif stop < 0: + stop += arr.gshape[i] + if step is None: + step = 1 + if step < 0 and start > stop: + # PyTorch doesn't support negative step as of 1.13 + # Lazy solution, potentially large memory footprint + # TODO: implement ht.fromiter (implemented in ASSET_ht) + key[i] = torch.tensor(list(range(start, stop, step)), device=arr.larray.device) + output_shape[i] = len(key[i]) + split_key_is_ordered = -1 + if arr_is_distributed and new_split == i: + if op == "set": + # setitem: flip key and keep process-local indices + key[i] = key[i].flip(0) + cond1 = key[i] >= displs[arr.comm.rank] + cond2 = key[i] < displs[arr.comm.rank] + counts[arr.comm.rank] + key[i] = key[i][cond1 & cond2] + if return_local_indices: + key[i] -= displs[arr.comm.rank] + else: + # getitem: distribute key and proceed with non-ordered indexing + key[i] = factories.array( + key[i], split=0, device=arr.device, copy=False + ).larray + out_is_balanced = True + elif step > 0 and start < stop: + output_shape[i] = int(torch.tensor((stop - start) / step).ceil().item()) + if arr_is_distributed and new_split == i: + split_key_is_ordered = 1 + out_is_balanced = False + local_arr_end = displs[arr.comm.rank] + counts[arr.comm.rank] + if stop > displs[arr.comm.rank] and start < local_arr_end: + index_in_cycle = (displs[arr.comm.rank] - start) % step + if start >= displs[arr.comm.rank]: + # slice begins on current rank + local_start = start - displs[arr.comm.rank] + else: + local_start = 0 if index_in_cycle == 0 else step - index_in_cycle + if stop <= local_arr_end: + # slice ends on current rank + local_stop = stop - displs[arr.comm.rank] + else: + local_stop = local_arr_end + key[i] = slice(local_start, local_stop, step) + else: + key[i] = slice(0, 0) + elif step == 0: + raise ValueError("Slice step cannot be zero") + else: + key[i] = slice(0, 0) + output_shape[i] = 0 + + if advanced_indexing: + # adv indexing key elements are DNDarrays: extract torch tensors + # options: 1. key is mask-like (covers boolean mask as well), 2. adv indexing along split axis, 3. everything else + # 1. define key as mask-like if each element of key is a DNDarray, and all elements of key are of the same shape, and the advanced-indexing dimensions are consecutive + key_is_mask_like = ( + all(isinstance(k, DNDarray) for k in key) + and len(set(k.shape for k in key)) == 1 + and torch.tensor(advanced_indexing_dims).diff().eq(1).all() + ) + # if split axis is affected by advanced indexing, keep track of non-split dimensions for later + if arr.is_distributed() and arr.split in advanced_indexing_dims: + non_split_dims = list(advanced_indexing_dims).copy() + if arr.split is not None: + non_split_dims.remove(arr.split) + # 1. key is mask-like + if key_is_mask_like: + key = list(key) + key_splits = [k.split for k in key] + if arr.split is not None: + if not key_splits.count(key_splits[arr.split]) == len(key_splits): + if ( + key_splits[arr.split] is not None + and key_splits.count(None) == len(key_splits) - 1 + ): + for i in non_split_dims: + key[i] = factories.array( + key[i], + split=key_splits[arr.split], + device=arr.device, + comm=arr.comm, + copy=None, + ) + else: + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." + ) + else: + # all key_splits must be the same, otherwise raise IndexError + if not key_splits.count(key_splits[0]) == len(key_splits): + raise IndexError( + f"Indexing arrays must be distributed along the same dimension, got splits {key_splits}." + ) + # all key elements are now DNDarrays of the same shape, same split axis + # 2. advanced indexing along split axis + if arr.is_distributed() and arr.split in advanced_indexing_dims: + if split_key_is_ordered == 1: + # extract torch tensors, keep process-local indices only + k = key[arr.split].larray + cond1 = k >= displs[arr.comm.rank] + cond2 = k < displs[arr.comm.rank] + counts[arr.comm.rank] + k = k[cond1 & cond2] + if return_local_indices: + k -= displs[arr.comm.rank] + key[arr.split] = k + for i in non_split_dims: + if key_is_mask_like: + # select the same elements along non-split dimensions + key[i] = key[i].larray[cond1 & cond2] + else: + key[i] = key[i].larray + elif split_key_is_ordered == 0: + # extract torch tensors, any other communication + mask-like case are handled in __getitem__ or __setitem__ + for i in advanced_indexing_dims: + key[i] = key[i].larray + # split_key_is_ordered == -1 not treated here as it is slicing, not advanced indexing + else: + # advanced indexing does not affect split axis, return torch tensors + for i in advanced_indexing_dims: + key[i] = key[i].larray + # all adv indexing keys are now torch tensors + + # shapes of adv indexing arrays must be broadcastable + try: + broadcasted_shape = torch.broadcast_shapes(*advanced_indexing_shapes) + except RuntimeError: + raise IndexError( + "Shape mismatch: indexing arrays could not be broadcast together with shapes: {}".format( + advanced_indexing_shapes + ) + ) + add_dims = len(broadcasted_shape) - len(advanced_indexing_dims) + if ( + len(advanced_indexing_dims) == 1 + or list(range(advanced_indexing_dims[0], advanced_indexing_dims[-1] + 1)) + == advanced_indexing_dims + ): + # dimensions affected by advanced indexing are consecutive: + output_shape[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = broadcasted_shape + if key_is_mask_like: + # advanced indexing dimensions will be collapsed into one dimension + if ( + "split" in split_bookkeeping + and split_bookkeeping.index("split") in advanced_indexing_dims + ): + split_bookkeeping[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = ["split"] + else: + split_bookkeeping[ + advanced_indexing_dims[0] : advanced_indexing_dims[0] + + len(advanced_indexing_dims) + ] = [None] + else: + split_bookkeeping = ( + split_bookkeeping[: advanced_indexing_dims[0]] + + [None] * add_dims + + split_bookkeeping[advanced_indexing_dims[0] :] + ) + else: + # advanced-indexing dimensions are not consecutive: + # transpose array to make the advanced-indexing dimensions consecutive as the first dimensions + non_adv_ind_dims = list( + i for i in range(arr.ndim) if i not in advanced_indexing_dims + ) + # keep track of transpose axes order, to be able to transpose back later + transpose_axes = tuple(advanced_indexing_dims + non_adv_ind_dims) + arr = arr.transpose(transpose_axes) + backwards_transpose_axes = tuple( + torch.tensor(transpose_axes, device=arr.larray.device) + .argsort(stable=True) + .tolist() + ) + # output shape and split bookkeeping + output_shape = list(output_shape[i] for i in transpose_axes) + output_shape[: len(advanced_indexing_dims)] = broadcasted_shape + split_bookkeeping = list(split_bookkeeping[i] for i in transpose_axes) + split_bookkeeping = [None] * add_dims + split_bookkeeping + # modify key to match the new dimension order + key = [key[i] for i in advanced_indexing_dims] + [key[i] for i in non_adv_ind_dims] + # update advanced-indexing dims + advanced_indexing_dims = list(range(len(advanced_indexing_dims))) + + # expand key to match the number of dimensions of the DNDarray + if arr.ndim > len(key): + key += [slice(None)] * (arr.ndim - len(key)) + + key = tuple(key) + for i in range(output_shape.count(None)): + lost_dim = output_shape.index(None) + output_shape.remove(None) + split_bookkeeping = split_bookkeeping[:lost_dim] + split_bookkeeping[lost_dim + 1 :] + output_shape = tuple(output_shape) + new_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + return ( + arr, + key, + output_shape, + new_split, + split_key_is_ordered, + key_is_mask_like, + out_is_balanced, + root, + backwards_transpose_axes, + ) + + def __process_scalar_key( + arr: DNDarray, + key: Union[int, DNDarray, torch.Tensor, np.ndarray], + indexed_axis: int, + return_local_indices: Optional[bool] = False, + ) -> Tuple(int, int): + """ + Private method to process a single-item scalar key used for indexing a ``DNDarray``. + + """ + device = arr.larray.device + try: + # is key an ndarray or DNDarray or torch.Tensor? + key = key.item() + except AttributeError: + # key is already an integer, do nothing + pass + if not arr.is_distributed(): + root = None + return key, root + if arr.split == indexed_axis: + # adjust negative key + if key < 0: + key += arr.shape[0] + # work out active process + _, displs = arr.counts_displs() + if key in displs: + root = displs.index(key) + else: + displs = torch.cat( + ( + torch.tensor(displs, device=device), + torch.tensor(key, device=device).reshape(-1), + ), + dim=0, + ) + _, sorted_indices = displs.unique(sorted=True, return_inverse=True) + root = sorted_indices[-1].item() - 1 + displs = displs.tolist() + # correct key for rank-specific displacement + if return_local_indices: + if arr.comm.rank == root: + key -= displs[root] + else: + root = None + return key, root + + def __get_local_slice(self, key: slice): + split = self.split + if split is None: + return key + key = stride_tricks.sanitize_slice(key, self.shape[split]) + start, stop, step = key.start, key.stop, key.step + if step < 0: # NOT supported by torch, should be filtered by torch_proxy + key = self.__get_local_slice(slice(stop + 1, start + 1, abs(step))) + if key is None: + return None + start, stop, step = key.start, key.stop, key.step + return slice(key.stop - 1, key.start - 1, -1 * key.step) + + _, offsets = self.counts_displs() + offset = offsets[self.comm.rank] + range_proxy = range(self.lshape[split]) + local_inds = range_proxy[start - offset : stop - offset] # only works if stop - offset > 0 + local_inds = local_inds[max(offset - start, 0) % step :: step] + if len(local_inds) and stop > offset: + # otherwise if (stop-offset) > -self.lshape[split] this can index into the local chunk despite ending before it + return slice(local_inds.start, local_inds.stop, local_inds.step) + return None + def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDarray: """ Global getter function for DNDarrays. @@ -855,231 +1412,276 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar (1/2) >>> tensor([0.]) (2/2) >>> tensor([0., 0.]) """ - key = getattr(key, "copy()", key) - l_dtype = self.dtype.torch_type() - advanced_ind = False - if isinstance(key, DNDarray) and key.ndim == self.ndim: - """if the key is a DNDarray and it has as many dimensions as self, then each of the - entries in the 0th dim refer to a single element. To handle this, the key is split - into the torch tensors for each dimension. This signals that advanced indexing is - to be used.""" - # NOTE: this gathers the entire key on every process!! - # TODO: remove this resplit!! - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) - - if key.ndim > 1: - key = list(key.larray.split(1, dim=1)) - # key is now a list of tensors with dimensions (key.ndim, 1) - # squeeze singleton dimension: - key = [key[i].squeeze_(1) for i in range(len(key))] - else: - key = [key] - advanced_ind = True - elif not isinstance(key, tuple): - """this loop handles all other cases. DNDarrays which make it to here refer to - advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * max(self.ndim, 1) - if isinstance(key, DNDarray): - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - h[0] = torch.nonzero(key.larray).flatten() # .tolist() - else: - h[0] = key.larray.tolist() - elif isinstance(key, torch.Tensor): - if key.dtype in [torch.bool, torch.uint8]: - # (coquelin77) i am not certain why this works without being a list. but it works...for now - h[0] = torch.nonzero(key).flatten() # .tolist() - else: - h[0] = key.tolist() - else: - h[0] = key - - key = list(h) + # key can be: int, tuple, list, slice, DNDarray, torch tensor, numpy array, or sequence thereof - if isinstance(key, (list, tuple)): - key = list(key) - for i, k in enumerate(key): - # this might be a good place to check if the dtype is there - try: - k = manipulations.resplit(k) - key[i] = k.larray - except AttributeError: - pass - - # ellipsis - key = list(key) - key_classes = [type(n) for n in key] - # if any(isinstance(n, ellipsis) for n in key): - n_elips = key_classes.count(type(...)) - if n_elips > 1: - raise ValueError("key can only contain 1 ellipsis") - elif n_elips == 1: - # get which item is the ellipsis - ell_ind = key_classes.index(type(...)) - kst = key[:ell_ind] - kend = key[ell_ind + 1 :] - slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - key = kst + slices + kend - else: - key = key + [slice(None)] * (self.ndim - len(key)) + if key is None: + return self.expand_dims(0) + if ( + key is ... or isinstance(key, slice) and key == slice(None) + ): # latter doesnt work with torch for 0-dim tensors + return self - self_proxy = self.__torch_proxy__() - for i in range(len(key)): - if self.__key_adds_dimension(key, i, self_proxy): - key[i] = slice(None) - return self.expand_dims(i)[tuple(key)] + original_split = self.split - key = tuple(key) - # assess final global shape - gout_full = list(self_proxy[key].shape) - - # calculate new split axis - new_split = self.split - # when slicing, squeezed singleton dimensions may affect new split axis - if self.split is not None and len(gout_full) < self.ndim: - if advanced_ind: - new_split = 0 + # Single-element indexing + scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 + if scalar: + # single-element indexing on axis 0 + if self.ndim == 0: + raise IndexError( + "Too many indices for DNDarray: DNDarray is 0-dimensional, but 1 were indexed" + ) + output_shape = self.gshape[1:] + if original_split is None or original_split == 0: + output_split = None else: - for i in range(len(key[: self.split + 1])): - if self.__key_is_singular(key, i, self_proxy): - new_split = None if i == self.split else new_split - 1 + output_split = original_split - 1 + split_key_is_ordered = 1 + out_is_balanced = True + backwards_transpose_axes = tuple(range(self.ndim)) + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) + if root is None: + # early out for single-element indexing not affecting split axis + indexed_arr = self.larray[key] + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + return indexed_arr + else: + # process multi-element key + ( + self, + key, + output_shape, + output_split, + split_key_is_ordered, + key_is_mask_like, + out_is_balanced, + root, + backwards_transpose_axes, + ) = self.__process_key(key, return_local_indices=True) - key = tuple(key) if not self.is_distributed(): - arr = self.__array[key].reshape(gout_full) + # key is torch-proof, index underlying torch tensor + indexed_arr = self.larray[key] + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) return DNDarray( - arr, tuple(gout_full), self.dtype, new_split, self.device, self.comm, self.balanced + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, ) - # else: (DNDarray is distributed) - arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device) - rank = self.comm.rank - counts, chunk_starts = self.counts_displs() - counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts) - chunk_ends = chunk_starts + counts - chunk_start = chunk_starts[rank] - chunk_end = chunk_ends[rank] + if split_key_is_ordered == 1: + if root is not None: + # single-element indexing along split axis + # prepare for Bcast: allocate buffer on all processes + if self.comm.rank == root: + indexed_arr = self.larray[key] + else: + indexed_arr = torch.zeros( + output_shape, dtype=self.larray.dtype, device=self.larray.device + ) + # broadcast result to all processes + self.comm.Bcast(indexed_arr, root=root) + indexed_arr = DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) + return indexed_arr + + # root is None, i.e. indexing does not affect split axis, apply as is + indexed_arr = self.larray[key] + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) + return DNDarray( + indexed_arr, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + balanced=out_is_balanced, + comm=self.comm, + ) - if len(key) == 0: # handle empty list - # this will return an array of shape (0, ...) - arr = self.__array[key] + # key along split axis is not ordered, indices are GLOBAL + # prepare for communication of indices and data + counts, displs = self.counts_displs() + rank, size = self.comm.rank, self.comm.size - """ At the end of the following if/elif/elif block the output array will be set. - each block handles the case where the element of the key along the split axis - is a different type and converts the key from global indices to local indices. """ - lout = gout_full.copy() + key_is_single_tensor = isinstance(key, torch.Tensor) + if key_is_single_tensor: + split_key = key + else: + split_key = key[self.split] + # split_key might be multi-dimensional, flatten it for communication + if split_key.ndim > 1: + original_split_key_shape = split_key.shape + communication_split = output_split - (split_key.ndim - 1) + split_key = split_key.flatten() + else: + communication_split = output_split - if ( - isinstance(key[self.split], (list, torch.Tensor, DNDarray, np.ndarray)) - and len(key[self.split]) > 1 - ): - # advanced indexing, elements in the split dimension are adjusted to the local indices - lkey = list(key) - if isinstance(key[self.split], DNDarray): - lkey[self.split] = key[self.split].larray - - if not isinstance(lkey[self.split], torch.Tensor): - inds = torch.tensor( - lkey[self.split], dtype=torch.long, device=self.device.torch_device - ) - elif lkey[self.split].dtype in [torch.bool, torch.uint8]: # or torch.byte? - # need to convert the bools to indices - inds = torch.nonzero(lkey[self.split]) - else: - inds = lkey[self.split] - # todo: remove where in favor of nonzero? might be a speed upgrade. testing required - loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end)) - # if there are no local indices on a process, then `arr` is empty - # if local indices exist: - if len(loc_inds[0]) != 0: - # select same local indices for other (non-split) dimensions if necessary - for i, k in enumerate(lkey): - if isinstance(k, (list, torch.Tensor, DNDarray)) and i != self.split: - lkey[i] = k[loc_inds] - # correct local indices for offset - inds = inds[loc_inds] - chunk_start - lkey[self.split] = inds - lout[new_split] = len(inds) - arr = self.__array[tuple(lkey)].reshape(tuple(lout)) - elif len(loc_inds[0]) == 0: - if new_split is not None: - lout[new_split] = len(loc_inds[0]) + # determine the number of elements to be received from each process + recv_counts = torch.zeros((size, 1), dtype=torch.int64, device=self.larray.device) + if key_is_mask_like: + recv_indices = torch.zeros( + (len(split_key), len(key)), dtype=split_key.dtype, device=self.larray.device + ) + else: + recv_indices = torch.zeros( + (split_key.shape), dtype=split_key.dtype, device=self.larray.device + ) + for p in range(size): + cond1 = split_key >= displs[p] + cond2 = split_key < displs[p] + counts[p] + indices_from_p = torch.nonzero(cond1 & cond2, as_tuple=False) + incoming_indices = split_key[indices_from_p].flatten() + recv_counts[p, 0] = incoming_indices.numel() + # store incoming indices in appropiate slice of recv_indices + start = recv_counts[:p].sum().item() + stop = start + recv_counts[p].item() + if incoming_indices.numel() > 0: + if key_is_mask_like: + # apply selection to all dimensions + for i in range(len(key)): + recv_indices[start:stop, i] = key[i][indices_from_p].flatten() + recv_indices[start:stop, self.split] -= displs[p] else: - lout = [0] * len(gout_full) - arr = torch.tensor([], dtype=self.larray.dtype, device=self.larray.device).reshape( - tuple(lout) - ) - - elif isinstance(key[self.split], slice): - # standard slicing along the split axis, - # adjust the slice start, stop, and step, then run it on the processes which have the requested data - key = list(key) - key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) - key_start, key_stop, key_step = ( - key[self.split].start, - key[self.split].stop, - key[self.split].step, + recv_indices[start:stop] = incoming_indices - displs[p] + # build communication matrix by sharing recv_counts with all processes + # comm_matrix rows contain the send_counts for each process, columns contain the recv_counts + comm_matrix = torch.zeros((size, size), dtype=torch.int64, device=self.larray.device) + self.comm.Allgather(recv_counts, comm_matrix) + send_counts = comm_matrix[:, rank] + + # active rank pairs: + active_rank_pairs = torch.nonzero(comm_matrix, as_tuple=False) + + # Communication build-up: + active_recv_indices_from = active_rank_pairs[torch.where(active_rank_pairs[:, 1] == rank)][ + :, 0 + ] + active_send_indices_to = active_rank_pairs[torch.where(active_rank_pairs[:, 0] == rank)][ + :, 1 + ] + rank_is_active = active_recv_indices_from.numel() > 0 or active_send_indices_to.numel() > 0 + + # allocate recv_buf for incoming data + recv_buf_shape = list(output_shape) + if communication_split != output_split: + # split key was flattened, flatten corresponding dims in recv_buf accordingly + recv_buf_shape = ( + recv_buf_shape[:communication_split] + + [recv_counts.sum().item()] + + recv_buf_shape[output_split + 1 :] ) - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) - if rank in actives: - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__xitem_get_key_start_stop( - rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - ) - key[self.split] = slice(key_start, key_stop, key_step) - lout[new_split] = ( - math.ceil((key_stop - key_start) / key_step) - if key_step is not None - else key_stop - key_start - ) - arr = self.__array[tuple(key)].reshape(lout) - else: - lout[new_split] = 0 - arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) - - elif self.__key_is_singular(key, self.split, self_proxy): - # getting one item along split axis: - key = list(key) - if isinstance(key[self.split], list): - key[self.split] = key[self.split].pop() - elif isinstance(key[self.split], (torch.Tensor, DNDarray, np.ndarray)): - key[self.split] = key[self.split].item() - # translate negative index - if key[self.split] < 0: - key[self.split] += self.gshape[self.split] - - active_rank = torch.where(key[self.split] >= chunk_starts)[0][-1].item() - # slice `self` on `active_rank`, allocate `arr` on all other ranks in preparation for Bcast - if rank == active_rank: - key[self.split] -= chunk_start.item() - arr = self.__array[tuple(key)].reshape(tuple(lout)) - else: - arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) - # broadcast result - # TODO: Replace with `self.comm.Bcast(arr, root=active_rank)` after fixing #784 - arr = self.comm.bcast(arr, root=active_rank) - if arr.device != self.larray.device: - # todo: remove when unnecessary (also after #784) - arr = arr.to(device=self.larray.device) - - return DNDarray( - arr.type(l_dtype), - gout_full if isinstance(gout_full, tuple) else tuple(gout_full), - self.dtype, - new_split, - self.device, - self.comm, - balanced=True if new_split is None else None, + else: + recv_buf_shape[communication_split] = recv_counts.sum().item() + recv_buf = torch.zeros( + tuple(recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device ) + if rank_is_active: + # non-blocking send indices to `active_send_indices_to` + send_requests = [] + for i in active_send_indices_to: + start = recv_counts[:i].sum().item() + stop = start + recv_counts[i].item() + outgoing_indices = recv_indices[start:stop] + send_requests.append(self.comm.Isend(outgoing_indices, dest=i)) + del outgoing_indices + del recv_indices + for i in active_recv_indices_from: + # receive indices from `active_recv_indices_from` + if key_is_mask_like: + incoming_indices = torch.zeros( + (send_counts[i].item(), len(key)), + dtype=torch.int64, + device=self.larray.device, + ) + else: + incoming_indices = torch.zeros( + send_counts[i].item(), dtype=torch.int64, device=self.larray.device + ) + self.comm.Recv(incoming_indices, source=i) + # prepare send_buf for outgoing data + if key_is_single_tensor: + send_buf = self.larray[incoming_indices] + else: + if key_is_mask_like: + send_key = tuple( + incoming_indices[:, i].reshape(-1) + for i in range(incoming_indices.shape[1]) + ) + send_buf = self.larray[send_key] + else: + send_key = list(key) + send_key[self.split] = incoming_indices + send_buf = self.larray[tuple(send_key)] + # non-blocking send requested data to i + send_requests.append(self.comm.Isend(send_buf, dest=i)) + del send_buf + # allocate temporary recv_buf to receive data from all active processes + tmp_recv_buf_shape = recv_buf_shape.copy() + tmp_recv_buf_shape[communication_split] = recv_counts.max().item() + tmp_recv_buf = torch.zeros( + tuple(tmp_recv_buf_shape), dtype=self.larray.dtype, device=self.larray.device + ) + for i in active_send_indices_to: + # receive data from i + tmp_recv_slice = [slice(None)] * tmp_recv_buf.ndim + tmp_recv_slice[communication_split] = slice(0, recv_counts[i].item()) + self.comm.Recv(tmp_recv_buf[tmp_recv_slice], source=i) + # write received data to appropriate portion of recv_buf + cond1 = split_key >= displs[i] + cond2 = split_key < displs[i] + counts[i] + recv_buf_indices = torch.nonzero(cond1 & cond2, as_tuple=False).flatten() + recv_buf_key = [slice(None)] * recv_buf.ndim + recv_buf_key[communication_split] = recv_buf_indices + recv_buf[recv_buf_key] = tmp_recv_buf[tmp_recv_slice] + del tmp_recv_buf + # wait for all non-blocking communication to finish + for req in send_requests: + req.Wait() + if communication_split != output_split: + # split_key has been flattened, bring back recv_buf to intended shape + original_local_shape = ( + output_shape[:communication_split] + + original_split_key_shape + + output_shape[output_split + 1 :] + ) + recv_buf = recv_buf.reshape(original_local_shape) + + # construct indexed array from recv_buf + indexed_arr = DNDarray( + recv_buf, + gshape=output_shape, + dtype=self.dtype, + split=output_split, + device=self.device, + comm=self.comm, + balanced=out_is_balanced, + ) + # transpose array back if needed + self = self.transpose(backwards_transpose_axes) + return indexed_arr if torch.cuda.device_count() > 0: @@ -1164,7 +1766,11 @@ def __len__(self) -> int: """ The length of the ``DNDarray``, i.e. the number of items in the first dimension. """ - return self.shape[0] + try: + len = self.shape[0] + return len + except IndexError: + raise TypeError("len() of unsized DNDarray") def numpy(self) -> np.array: """ @@ -1442,8 +2048,6 @@ def resplit_(self, axis: int = None): # early out for unchanged content if self.comm.size == 1: self.__split = axis - if axis is None: - self.__partitions_dict__ = None if axis == self.split: return self @@ -1569,260 +2173,458 @@ def __setitem__( (2/2) >>> tensor([[0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.]]) """ - key = getattr(key, "copy()", key) - try: - if value.split != self.split: - val_split = int(value.split) - sp = self.split - warnings.warn( - f"\nvalue.split {val_split} not equal to this DNDarray's split:" - f" {sp}. this may cause errors or unwanted behavior", - category=RuntimeWarning, - ) - except (AttributeError, TypeError): - pass - # NOTE: for whatever reason, there is an inplace op which interferes with the abstraction - # of this next block of code. this is shared with __getitem__. I attempted to abstract it - # in a standard way, but it was causing errors in the test suite. If someone else is - # motived to do this they are welcome to, but i have no time right now - # print(key) - if isinstance(key, DNDarray) and key.ndim == self.ndim: - """if the key is a DNDarray and it has as many dimensions as self, then each of the - entries in the 0th dim refer to a single element. To handle this, the key is split - into the torch tensors for each dimension. This signals that advanced indexing is - to be used.""" - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) - - if key.ndim > 1: - key = list(key.larray.split(1, dim=1)) - # key is now a list of tensors with dimensions (key.ndim, 1) - # squeeze singleton dimension: - key = [key[i].squeeze_(1) for i in range(len(key))] + def __broadcast_value( + arr: DNDarray, + key: Union[int, Tuple[int, ...], slice], + value: DNDarray, + **kwargs, + ): + """ + Broadcasts the assignment DNDarray `value` to the shape of the indexed array `arr[key]` if necessary. + """ + is_scalar = ( + np.isscalar(value) + or getattr(value, "ndim", 1) == 0 + or (value.shape == (1,) and value.split is None) + ) + if is_scalar: + # no need to broadcast + return value, is_scalar + # need information on indexed array + output_shape = kwargs.get("output_shape", None) + if output_shape is not None: + indexed_dims = len(output_shape) else: - key = [key] - elif not isinstance(key, tuple): - """this loop handles all other cases. DNDarrays which make it to here refer to - advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors - are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * self.ndim - if isinstance(key, DNDarray): - key = manipulations.resplit(key) - if key.larray.dtype in [torch.bool, torch.uint8]: - h[0] = torch.nonzero(key.larray).flatten() # .tolist() + if isinstance(key, (int, tuple)): + # direct indexing, output_shape has not been calculated + # use proxy to avoid MPI communication and limit memory usage + indexed_proxy = arr.__torch_proxy__()[key] + indexed_dims = indexed_proxy.ndim + output_shape = tuple(indexed_proxy.shape) else: - h[0] = key.larray.tolist() - elif isinstance(key, torch.Tensor): - if key.dtype in [torch.bool, torch.uint8]: - # (coquelin77) im not sure why this works without being a list...but it does...for now - h[0] = torch.nonzero(key).flatten() # .tolist() + raise RuntimeError( + "Not enough information to broadcast value to indexed array, please provide `output_shape`" + ) + value_shape = value.shape + # check if value needs to be broadcasted + if value_shape != output_shape: + # assess whether the shapes are compatible, starting from the trailing dimension + for i in range(1, min(len(value_shape), len(output_shape)) + 1): + if i == 1: + if value_shape[-i] != output_shape[-i] and not value_shape[-i] == 1: + # shapes are not compatible, raise error + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + else: + if value_shape[-i] != output_shape[-i] and (not value_shape[-i] == 1): + # shapes are not compatible, raise error + raise ValueError( + f"could not broadcast input from shape {value_shape} into shape {output_shape}" + ) + # value has more dimensions than indexed array + if value.ndim > indexed_dims: + # check if all dimensions except the indexed ones are singletons + all_singletons = value.shape[: value.ndim - indexed_dims] == (1,) * ( + value.ndim - indexed_dims + ) + if not all_singletons: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + # squeeze out singleton dimensions + value = value.squeeze(tuple(range(value.ndim - indexed_dims))) else: - h[0] = key.tolist() - else: - h[0] = key - key = list(h) - - # key must be torch-proof - if isinstance(key, (list, tuple)): - key = list(key) - for i, k in enumerate(key): - try: # extract torch tensor - k = manipulations.resplit(k) - key[i] = k.larray - except AttributeError: - pass - # remove bools from a torch tensor in favor of indexes - try: - if key[i].dtype in [torch.bool, torch.uint8]: - key[i] = torch.nonzero(key[i]).flatten() - except (AttributeError, TypeError): - pass - - key = list(key) - - # ellipsis stuff - key_classes = [type(n) for n in key] - # if any(isinstance(n, ellipsis) for n in key): - n_elips = key_classes.count(type(...)) - if n_elips > 1: - raise ValueError("key can only contain 1 ellipsis") - elif n_elips == 1: - # get which item is the ellipsis - ell_ind = key_classes.index(type(...)) - kst = key[:ell_ind] - kend = key[ell_ind + 1 :] - slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) - key = kst + slices + kend - # ---------- end ellipsis stuff ------------- - - for c, k in enumerate(key): - try: - key[c] = k.item() - except (AttributeError, ValueError, RuntimeError): - pass + while value.ndim < indexed_dims: + # broadcasting + # expand missing dimensions to align split axis + value = value.expand_dims(0) + try: + value_shape = tuple(torch.broadcast_shapes(value.shape, output_shape)) + except RuntimeError: + raise ValueError( + f"could not broadcast input array from shape {value_shape} into shape {output_shape}" + ) + return value, is_scalar + + def __set( + arr: DNDarray, + key: Union[int, Tuple[int, ...], List[int, ...]], + value: Union[DNDarray, torch.Tensor, np.ndarray, float, int, list, tuple], + ): + """ + Setter for not advanced indexing, i.e. when arr[key] is an in-place view of arr. + """ + # only assign values if key does not contain empty slices + process_is_inactive = arr.larray[key].numel() == 0 + if not process_is_inactive: + # make sure value is same datatype as arr + arr.larray[key] = value.larray.type(arr.dtype.torch_type()) + return - rank = self.comm.rank - if self.split is not None: - counts, chunk_starts = self.counts_displs() - else: - counts, chunk_starts = 0, [0] * self.comm.size - counts = torch.tensor(counts, device=self.device.torch_device) - chunk_starts = torch.tensor(chunk_starts, device=self.device.torch_device) - chunk_ends = chunk_starts + counts - chunk_start = chunk_starts[rank] - chunk_end = chunk_ends[rank] - # determine which elements are on the local process (if the key is a torch tensor) + # make sure `value` is a DNDarray try: - # if isinstance(key[self.split], torch.Tensor): - filter_key = torch.nonzero( - (chunk_start <= key[self.split]) & (key[self.split] < chunk_end) - ) - for k in range(len(key)): - try: - key[k] = key[k][filter_key].flatten() - except TypeError: - pass - except TypeError: # this will happen if the key doesnt have that many - pass - - key = tuple(key) - - if not self.is_distributed(): - return self.__setter(key, value) # returns None - - # raise RuntimeError("split axis of array and the target value are not equal") removed - # this will occur if the local shapes do not match - rank = self.comm.rank - ends = [] - for pr in range(self.comm.size): - _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) - ends.append(e[self.split].stop - e[self.split].start) - ends = torch.tensor(ends, device=self.device.torch_device) - chunk_ends = ends.cumsum(dim=0) - chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) - _, _, chunk_slice = self.comm.chunk(self.shape, self.split) - chunk_start = chunk_slice[self.split].start - chunk_end = chunk_slice[self.split].stop - - self_proxy = self.__torch_proxy__() - - # if the value is a DNDarray, the divisions need to be balanced: - # this means that we need to know how much data is where for both DNDarrays - # if the value data is not in the right place, then it will need to be moved - - if isinstance(key[self.split], slice): - key = list(key) - key_start = key[self.split].start if key[self.split].start is not None else 0 - key_stop = ( - key[self.split].stop - if key[self.split].stop is not None - else self.gshape[self.split] - ) - if key_stop < 0: - key_stop = self.gshape[self.split] + key[self.split].stop - key_step = key[self.split].step - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) + value = factories.array(value) + except TypeError: + raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.") + + # workaround for Heat issue #1292. TODO: remove when issue is fixed + if not isinstance(key, DNDarray): + if key is None or key is ... or key is slice(None): + # match dimensions + value, _ = __broadcast_value(self, key, value) + # make sure `self` and `value` distribution are aligned + value = sanitation.sanitize_distribution(value, target=self) + return __set(self, key, value) + + # single-element key + scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0 + if scalar: + key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True) + # match dimensions + value, _ = __broadcast_value(self, key, value) + # `root` will be None when the indexed axis is not the split axis, or when the + # indexed axis is the split axis but the indexed element is not local + if root is not None: + if self.comm.rank == root: + # verify that `self[key]` and `value` distribution are aligned + # do not index `self` with `key` directly here, as this would MPI-broadcast to all ranks + indexed_proxy = self.__torch_proxy__()[key] + if indexed_proxy.names.count("split") != 0: + # distribution map of indexed subarray is the same as the lshape_map of the original array after losing the first dimension + indexed_lshape_map = self.lshape_map[:, 1:] + if value.lshape_map != indexed_lshape_map: + try: + value.redistribute_(target_map=indexed_lshape_map) + except ValueError: + raise ValueError( + f"cannot assign value to indexed DNDarray because distribution schemes do not match: {value.lshape_map} vs. {indexed_lshape_map}" + ) + __set(self, key, value) + else: + # `root` is None, i.e. the indexed element is local on each process + # verify that `self[key]` and `value` distribution are aligned + value = sanitation.sanitize_distribution(value, target=self[key]) + __set(self, key, value) + return - if ( - isinstance(value, type(self)) - and value.split is not None - and value.shape[self.split] != self.shape[self.split] - ): - # setting elements in self with a DNDarray which is not the same size in the - # split dimension - local_keys = [] - # below is used if the target needs to be reshaped - target_reshape_map = torch.zeros( - (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device - ) - for r in range(self.comm.size): - if r not in actives: - loc_key = key.copy() - loc_key[self.split] = slice(0, 0, 0) - else: - key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] - key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] - key_start_l, key_stop_l = self.__xitem_get_key_start_stop( - r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start - ) - loc_key = key.copy() - loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) + # multi-element key, incl. slicing and striding, ordered and non-ordered advanced indexing + ( + self, + key, + output_shape, + output_split, + split_key_is_ordered, + key_is_mask_like, + _, + root, + backwards_transpose_axes, + ) = self.__process_key(key, return_local_indices=True, op="set") + + # match dimensions + value, value_is_scalar = __broadcast_value(self, key, value, output_shape=output_shape) + + # early out for non-distributed case + if not self.is_distributed() and not value.is_distributed(): + # no communication needed + __set(self, key, value) + self = self.transpose(backwards_transpose_axes) + return - gout_full = torch.tensor( - self_proxy[loc_key].shape, device=self.device.torch_device + # distributed case + if split_key_is_ordered == 1: + # key all local + if root is not None: + # single-element assignment along split axis, only one active process + if self.comm.rank == root: + self.larray[key] = value.larray.type(self.dtype.torch_type()) + else: + # indexed elements are process-local + if self.is_distributed() and not value_is_scalar: + if not value.is_distributed(): + # work with distributed `value` + value = factories.array( + value.larray, + dtype=value.dtype, + split=output_split, + device=self.device, + comm=self.comm, ) - target_reshape_map[r] = gout_full - local_keys.append(loc_key) - - key = local_keys[rank] - value = value.redistribute(target_map=target_reshape_map) - - if rank not in actives: - return # non-active ranks can exit here + else: + if value.split != output_split: + raise RuntimeError( + f"Cannot assign distributed `value` with split axis {value.split} to indexed DNDarray with split axis {output_split}." + ) + # verify that `self[key]` and `value` distribution are aligned + target_shape = torch.tensor( + tuple(self.larray[key].shape), device=self.device.torch_device + ) + target_map = torch.zeros( + (self.comm.size, len(target_shape)), + dtype=torch.int64, + device=self.device.torch_device, + ) + self.comm.Allgather(target_shape, target_map) + value.redistribute_(target_map=target_map) + __set(self, key, value) + self = self.transpose(backwards_transpose_axes) + return - chunk_starts_v = target_reshape_map[:, self.split] - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = (chunk_starts_v[rank] - og_key_start).item() + if split_key_is_ordered == -1: + # key along split axis is in descending order, i.e. slice with negative step + # N.B. PyTorch doesn't support negative-step slices. Key has been processed into torch tensor. - key_start = max(key_start, 0) - key_stop = key_start + key_stop - slice_loc = min(self.split, value.ndim - 1) - value_slice[slice_loc] = slice( - key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # flip value, match value distribution to key's + # NB: `value.ndim` can be smaller than `self.ndim`, hence `value.split` nominally different from `self.split` + flipped_value = manipulations.flip(value, axis=output_split) + split_key = factories.array( + key[self.split], is_split=0, device=self.device, comm=self.comm + ) + if not flipped_value.is_distributed(): + # work with distributed `flipped_value` + flipped_value = factories.array( + flipped_value.larray, + dtype=flipped_value.dtype, + split=output_split, + device=self.device, + comm=self.comm, ) + # match `value` distribution to `self[key]` distribution + target_map = flipped_value.lshape_map + target_map[:, output_split] = split_key.lshape_map[:, 0] + flipped_value.redistribute_(target_map=target_map) + __set(self, key, flipped_value) + self = self.transpose(backwards_transpose_axes) + return - self.__setter(tuple(key), value.larray) + if split_key_is_ordered == 0: + # key along split axis is unordered, communication needed + # key along the split axis is torch tensor, indices are GLOBAL + counts, displs = self.counts_displs() + rank, _ = self.comm.rank, self.comm.size + + # + key_is_single_tensor = isinstance(key, torch.Tensor) + if not value.is_distributed(): + if key_is_single_tensor: + # key is a single torch.Tensor + split_key = key + # find elements of `split_key` that are local to this process + local_indices = torch.nonzero( + (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) + ).flatten() + # keep local indexing key only and correct for displacements along the split axis + key = key[local_indices] - displs[rank] + # set local elements of `self` to corresponding elements of `value` + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + self = self.transpose(backwards_transpose_axes) + return + # key is a sequence of torch.Tensors + split_key = key[self.split] + split_key_dims = split_key.ndim + if split_key_dims > 1: + # flatten `split_key` + split_key = split_key.flatten() + # flatten split_key dimensions of `value`: + new_shape = list(value.shape) + new_shape = ( + new_shape[: output_split - (split_key_dims - 1)] + + [-1] + + new_shape[output_split + 1 :] + ) + value = value.reshape(new_shape) + output_split -= split_key_dims - 1 + # find elements of `split_key` that are local to this process + local_indices = torch.nonzero( + (split_key >= displs[rank]) & (split_key < displs[rank] + counts[rank]) + ).flatten() + key = list(key) + if key_is_mask_like: + # keep local indexing keys across all dimensions + # correct for displacements along the split axis + key = tuple( + [ + ( + key[i][local_indices] - displs[rank] + if i == self.split + else key[i][local_indices] + ) + for i in range(len(key)) + ] + ) + if not key[self.split].numel() == 0: + self.larray[key] = value.larray[local_indices].type(self.dtype.torch_type()) + else: + # keep local indexing key and correct for displacements along split dimension + key[self.split] = split_key[local_indices] - displs[rank] + key = tuple(key) + value_key = tuple( + [ + local_indices if i == output_split else slice(None) + for i in range(value.ndim) + ] + ) + # set local elements of `self` to corresponding elements of `value` + if not key[self.split].numel() == 0: + self.larray[key] = value.larray[value_key].type(self.dtype.torch_type()) + self = self.transpose(backwards_transpose_axes) return - # if rank in actives: - if rank not in actives: - return # non-active ranks can exit here - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__xitem_get_key_start_stop( - rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start - ) - key[self.split] = slice(key_start, key_stop, key_step) - - # todo: need to slice the values to be the right size... - if isinstance(value, (torch.Tensor, type(self))): - # if its a torch tensor, it is assumed to exist on all processes - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = (chunk_starts[rank] - og_key_start).item() - key_start = max(key_start, 0) - key_stop = key_start + key_stop - slice_loc = min(self.split, value.ndim - 1) - value_slice[slice_loc] = slice( - key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # both `self` and `value` are distributed + # distribution of `key` and `value` must be aligned + if key_is_mask_like: + # redistribute `value` to match distribution of `key` in one pass + split_key = key[self.split] + global_split_key = factories.array( + split_key, is_split=0, device=self.device, comm=self.comm, copy=False ) - self.__setter(tuple(key), value[tuple(value_slice)]) + target_map = value.lshape_map + target_map[:, value.split] = global_split_key.lshape_map[:, 0] + value.redistribute_(target_map=target_map) else: - self.__setter(tuple(key), value) - elif isinstance(key[self.split], (torch.Tensor, list)): - key = list(key) - key[self.split] -= chunk_start - if len(key[self.split]) != 0: - self.__setter(tuple(key), value) - - elif key[self.split] in range(chunk_start, chunk_end): - key = list(key) - key[self.split] = key[self.split] - chunk_start - self.__setter(tuple(key), value) + # redistribute split-axis `key` to match distribution of `value` in one pass + if key_is_single_tensor: + # key is a single torch.Tensor + split_key = key + else: + split_key = key[self.split] + global_split_key = factories.array( + split_key, is_split=0, device=self.device, comm=self.comm, copy=False + ) + target_map = global_split_key.lshape_map + target_map[:, 0] = value.lshape_map[:, value.split] + global_split_key.redistribute_(target_map=target_map) + split_key = global_split_key.larray + + # key and value are now aligned + + # prepare for `value` Alltoallv: + # work along axis 0, transpose if necessary + transpose_axes = list(range(value.ndim)) + transpose_axes[0], transpose_axes[value.split] = ( + transpose_axes[value.split], + transpose_axes[0], + ) + value = value.transpose(transpose_axes) + send_counts = torch.zeros( + self.comm.size, dtype=torch.int64, device=self.device.torch_device + ) + send_displs = torch.zeros_like(send_counts) + # allocate send buffer: add 1 column to store sent indices + send_buf_shape = list(value.lshape) + if value.ndim < 2: + send_buf_shape.append(1) + if key_is_mask_like: + send_buf_shape[-1] += len(key) + else: + send_buf_shape[-1] += 1 + send_buf = torch.zeros( + send_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + ) + for proc in range(self.comm.size): + # calculate what local elements of `value` belong on process `proc` + send_indices = torch.nonzero( + (split_key >= displs[proc]) & (split_key < displs[proc] + counts[proc]) + ).flatten() + # calculate outgoing counts and displacements for each process + send_counts[proc] = send_indices.numel() + send_displs[proc] = send_counts[:proc].sum() + # compose send buffer: stack local elements of `value` according to destination process + if send_indices.numel() > 0: + if value.ndim < 2: + # temporarily add a singleton dimension to value to accmodate column dimension for send_indices + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices].unsqueeze(1) + ) + else: + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], :-1] = ( + value.larray[send_indices] + ) + # store outgoing GLOBAL indices in the last column of send_buf + # TODO: if key_is_mask_like: apply send_indices to all dimensions of key + if key_is_mask_like: + for i in range(-len(key), 0): + send_buf[ + send_displs[proc] : send_displs[proc] + send_counts[proc], i + ] = key[i + len(key)][send_indices] + else: + send_indices = split_key[send_indices] + send_buf[send_displs[proc] : send_displs[proc] + send_counts[proc], -1] = ( + send_indices + ) - elif key[self.split] < 0: + # compose communication matrix: share `send_counts` information with all processes + comm_matrix = torch.zeros( + (self.comm.size, self.comm.size), + dtype=torch.int64, + device=self.device.torch_device, + ) + self.comm.Allgather(send_counts, comm_matrix) + # comm_matrix columns contain recv_counts for each process + recv_counts = comm_matrix[:, self.comm.rank].squeeze(0) + recv_displs = torch.zeros_like(recv_counts) + recv_displs[1:] = recv_counts.cumsum(0)[:-1] + # allocate receive buffer, with 1 extra column for incoming indices + recv_buf_shape = value.lshape_map[self.comm.rank] + recv_buf_shape[value.split] = recv_counts.sum() + recv_buf_shape = recv_buf_shape.tolist() + if value.ndim < 2: + recv_buf_shape.append(1) + if key_is_mask_like: + recv_buf_shape[-1] += len(key) + else: + recv_buf_shape[-1] += 1 + recv_buf_shape = tuple(recv_buf_shape) + recv_buf = torch.zeros( + recv_buf_shape, dtype=value.dtype.torch_type(), device=self.device.torch_device + ) + # perform Alltoallv along the 0 axis + send_counts, send_displs, recv_counts, recv_displs = ( + send_counts.tolist(), + send_displs.tolist(), + recv_counts.tolist(), + recv_displs.tolist(), + ) + self.comm.Alltoallv( + (send_buf, send_counts, send_displs), (recv_buf, recv_counts, recv_displs) + ) + del send_buf, comm_matrix key = list(key) - if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): - key[self.split] = key[self.split] + self.shape[self.split] - chunk_start - self.__setter(tuple(key), value) + if key_is_mask_like: + # extract incoming indices from recv_buf + recv_indices = recv_buf[..., -len(key) :] + # correct split-axis indices for rank offset + recv_indices[:, 0] -= displs[rank] + key = recv_indices.split(1, dim=1) + key = [key[i].squeeze_(1) for i in range(len(key))] + # remove indices from recv_buf + recv_buf = recv_buf[..., : -len(key)] + else: + # store incoming indices in int 1-D tensor and correct for rank offset + recv_indices = recv_buf[..., -1].type(torch.int64) - displs[rank] + # remove last column from recv_buf + recv_buf = recv_buf[..., :-1] + # replace split-axis key with incoming local indices + key = list(key) + key[self.split] = recv_indices + key = tuple(key) + # transpose back value and recv_buf if necessary, wrap recv_buf in DNDarray + value = value.transpose(transpose_axes) + if value.ndim < 2: + recv_buf.squeeze_(1) + recv_buf = DNDarray( + recv_buf.permute(*transpose_axes), + gshape=value.gshape, + dtype=value.dtype, + split=value.split, + device=value.device, + comm=value.comm, + balanced=value.balanced, + ) + # set local elements of `self` to corresponding elements of `value` + __set(self, key, recv_buf) + self = self.transpose(backwards_transpose_axes) def __setter( self, @@ -1889,11 +2691,16 @@ def tolist(self, keepsplit: bool = False) -> List: def __torch_proxy__(self) -> torch.Tensor: """ - Return a 1-element `torch.Tensor` strided as the global `self` shape. - Used internally for sanitation purposes. + Return a 1-element `torch.Tensor` strided as the global `self` shape. The split axis of the initial DNDarray is stored in the `names` attribute of the returned tensor. + Used internally to lower memory footprint of sanitation. """ - return torch.ones((1,), dtype=torch.int8, device=self.larray.device).as_strided( - self.gshape, [0] * self.ndim + names = [None] * self.ndim + if self.split is not None: + names[self.split] = "split" + return ( + torch.ones((1,), dtype=torch.int8, device=self.larray.device) + .as_strided(self.gshape, [0] * self.ndim) + .refine_names(*names) ) @staticmethod @@ -1937,4 +2744,5 @@ def __xitem_get_key_start_stop( from .devices import Device from .stride_tricks import sanitize_axis -from .types import datatype, canonical_heat_type +import types +from .types import datatype, canonical_heat_type, bool, uint8 diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 33d94c04d0..0a521608e3 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -9,16 +9,18 @@ from .dndarray import DNDarray from . import sanitation from . import types +from . import manipulations __all__ = ["nonzero", "where"] -def nonzero(x: DNDarray) -> DNDarray: +def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]: """ - Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero.. (using ``torch.nonzero``) - If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` + TODO: UPDATE DOCS! + Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, + containing the indices of the non-zero elements in that dimension. If ``x`` is split then + the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. - Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. The values in ``x`` are always tested and returned in row-major, C-style order. The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. @@ -32,10 +34,8 @@ def nonzero(x: DNDarray) -> DNDarray: >>> import heat as ht >>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0) >>> ht.nonzero(x) - DNDarray([[0, 0], - [1, 1], - [1, 2], - [2, 1]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None)) >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0) >>> y > 3 DNDarray([[False, False, False], @@ -48,40 +48,79 @@ def nonzero(x: DNDarray) -> DNDarray: [2, 0], [2, 1], [2, 2]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None)) >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ - sanitation.sanitize_in(x) - - if x.split is None: - # if there is no split then just return the values from torch - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - gout = list(lcl_nonzero.size()) - is_split = None + try: + local_x = x.larray + except AttributeError: + raise TypeError("Input must be a DNDarray, is {}".format(type(x))) + + if not x.is_distributed(): + # nonzero indices as tuple + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=True) + # bookkeeping for final DNDarray construct + nonzero_size = lcl_nonzero[0].shape[0] + output_split = None if x.split is None else 0 + output_balanced = True else: - # a is split - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - _, _, slices = x.comm.chunk(x.shape, x.split) - lcl_nonzero[..., x.split] += slices[x.split].start - gout = list(lcl_nonzero.size()) - gout[0] = x.comm.allreduce(gout[0], MPI.SUM) - is_split = 0 - - if x.ndim == 1: - lcl_nonzero = lcl_nonzero.squeeze(dim=1) - for g in range(len(gout) - 1, -1, -1): - if gout[g] == 1: - del gout[g] - - return DNDarray( - lcl_nonzero, - gshape=tuple(gout), - dtype=types.canonical_heat_type(lcl_nonzero.dtype), - split=is_split, - device=x.device, - comm=x.comm, - balanced=False, - ) + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) + nonzero_size = torch.tensor( + lcl_nonzero.shape[0], dtype=torch.int64, device=lcl_nonzero.device + ) + # global nonzero_size + x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) + # correct indices along split axis + _, displs = x.counts_displs() + lcl_nonzero[:, x.split] += displs[x.comm.rank] + + if x.split != 0: + # construct global 2D DNDarray of nz indices: + shape_2d = (nonzero_size.item(), x.ndim) + global_nonzero = DNDarray( + lcl_nonzero, + gshape=shape_2d, + dtype=types.int64, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # stabilize distributed result: vectorized sorting of nz indices along axis 0 + global_nonzero.balance_() + global_nonzero = manipulations.unique(global_nonzero, axis=0) + # return indices as tuple of columns + lcl_nonzero = global_nonzero.larray.split(1, dim=1) + output_balanced = True + else: + # return indices as tuple of columns + lcl_nonzero = lcl_nonzero.split(1, dim=1) + output_balanced = False + nonzero_size = nonzero_size.item() + + # return global_nonzero as tuple of DNDarrays + global_nonzero = list(lcl_nonzero) + output_shape = (nonzero_size,) + output_split = 0 + for i, nz_tensor in enumerate(global_nonzero): + if nz_tensor.ndim > 1: + # extra dimension in distributed case from usage of torch.split() + nz_tensor = nz_tensor.squeeze() + nz_array = DNDarray( + nz_tensor, + gshape=output_shape, + dtype=types.int64, + split=output_split, + device=x.device, + comm=x.comm, + balanced=output_balanced, + ) + global_nonzero[i] = nz_array + global_nonzero = tuple(global_nonzero) + + return tuple(global_nonzero) DNDarray.nonzero = lambda self: nonzero(self) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index a9e8f291a1..d05d110ca6 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -237,6 +237,8 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) + # distributed + # ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=0) a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 1ba1c45608..6d0776c3b5 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -8,15 +8,15 @@ class TestDNDarray(TestCase): - @classmethod - def setUpClass(cls): - super(TestDNDarray, cls).setUpClass() - N = ht.MPI_WORLD.size - cls.reference_tensor = ht.zeros((N, N + 1, 2 * N)) + # @classmethod + # def setUpClass(cls): + # super(TestDNDarray, cls).setUpClass() + # N = ht.MPI_WORLD.size + # cls.reference_tensor = ht.zeros((N, N + 1, 2 * N)) - for n in range(N): - for m in range(N + 1): - cls.reference_tensor[n, m, :] = ht.arange(0, 2 * N) + m * 10 + n * 100 + # for n in range(N): + # for m in range(N + 1): + # cls.reference_tensor[n, m, :] = ht.arange(0, 2 * N) + m * 10 + n * 100 def test_and(self): int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16) @@ -576,6 +576,255 @@ def test_float_cast(self): with self.assertRaises(TypeError): float(ht.full((ht.MPI_WORLD.size,), 2, split=0)) + def test_getitem(self): + # following https://numpy.org/doc/stable/user/basics.indexing.html + + # Single element indexing + # 1D, local + x = ht.arange(10) + self.assertTrue(x[2].item() == 2) + self.assertTrue(x[-2].item() == 8) + self.assertTrue(x[2].dtype == ht.int32) + # 1D, distributed + x = ht.arange(10, split=0, dtype=ht.float64) + self.assertTrue(x[2].item() == 2.0) + self.assertTrue(x[-2].item() == 8.0) + self.assertTrue(x[2].dtype == ht.float64) + self.assertTrue(x[2].split is None) + # 2D, local + x = ht.arange(10).reshape(2, 5) + self.assertTrue((x[0] == ht.arange(5)).all().item()) + self.assertTrue(x[0].dtype == ht.int32) + # 2D, distributed + x_split0 = ht.array(x, split=0) + self.assertTrue((x_split0[0] == ht.arange(5, split=None)).all().item()) + x_split1 = ht.array(x, split=1) + self.assertTrue((x_split1[-2] == ht.arange(5, split=0)).all().item()) + # 3D, local + x = ht.arange(27).reshape(3, 3, 3) + key = -2 + indexed = x[key] + self.assertTrue((indexed.larray == x.larray[key]).all()) + self.assertTrue(indexed.dtype == ht.int32) + self.assertTrue(indexed.split is None) + # 3D, distributed, split = 0 + x_split0 = ht.array(x, dtype=ht.float32, split=0) + indexed_split0 = x_split0[key] + self.assertTrue((indexed_split0.larray == x.larray[key]).all()) + self.assertTrue(indexed_split0.dtype == ht.float32) + self.assertTrue(indexed_split0.split is None) + # 3D, distributed split, != 0 + x_split2 = ht.array(x, dtype=ht.int64, split=2) + key = ht.array(2) + indexed_split2 = x_split2[key] + self.assertTrue((indexed_split2.numpy() == x.numpy()[key.item()]).all()) + self.assertTrue(indexed_split2.dtype == ht.int64) + self.assertTrue(indexed_split2.split == 1) + + # Slicing and striding + x = ht.arange(20, split=0) + x_sliced = x[1:11:3] + x_np = np.arange(20) + x_sliced_np = x_np[1:11:3] + self.assert_array_equal(x_sliced, x_sliced_np) + self.assertTrue(x_sliced.split == 0) + + # 1-element slice along split axis + x = ht.arange(20).reshape(4, 5) + x.resplit_(axis=1) + x_sliced = x[:, 2:3] + x_np = np.arange(20).reshape(4, 5) + x_sliced_np = x_np[:, 2:3] + self.assert_array_equal(x_sliced, x_sliced_np) + self.assertTrue(x_sliced.split == 1) + + # slicing with negative step along split axis 0 + shape = (20, 4, 3) + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) + x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[17:2:-2, :2, 1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 0) + + # slicing with negative step along split 1 + shape = (4, 20, 3) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=1) + key = (slice(None, 2), slice(17, 2, -2), 1) + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 17:2:-2, 1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 1) + + # slicing with negative step along split 2 and loss of axis < split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (slice(None, 2), 1, slice(17, 10, -2)) + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[:2, 1, 17:10:-2] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 1) + + # slicing with negative step along split 2 and loss of all axes but split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (0, 1, slice(17, 13, -1)) + x_3d_sliced = x_3d[key] + x_3d_sliced_np = np.arange(20 * 4 * 3).reshape(shape)[0, 1, 17:13:-1] + self.assert_array_equal(x_3d_sliced, x_3d_sliced_np) + self.assertTrue(x_3d_sliced.split == 0) + + # tests for bug 730: + a = ht.ones((10, 25, 30), split=1) + if a.comm.size > 1: + self.assertEqual(a[0].split, 0) + self.assertEqual(a[:, 0, :].split, None) + self.assertEqual(a[:, :, 0].split, 1) + + # DIMENSIONAL INDEXING + # ellipsis + x_np = np.array([[[1], [2], [3]], [[4], [5], [6]]]) + x_np_ellipsis = x_np[..., 0] + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + + # local + x_ellipsis = x[..., 0] + x_slice = x[:, :, 0] + self.assert_array_equal(x_ellipsis, x_np_ellipsis) + self.assert_array_equal(x_slice, x_np_ellipsis) + + # distributed + x.resplit_(axis=1) + x_ellipsis = x[..., 0] + x_slice = x[:, :, 0] + self.assert_array_equal(x_ellipsis, x_np_ellipsis) + self.assert_array_equal(x_slice, x_np_ellipsis) + self.assertTrue(x_ellipsis.split == 1) + + # newaxis: local + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + x_np_newaxis = x_np[:, np.newaxis, :2, :] + x_newaxis = x[:, np.newaxis, :2, :] + x_none = x[:, None, :2, :] + self.assert_array_equal(x_newaxis, x_np_newaxis) + self.assert_array_equal(x_none, x_np_newaxis) + + # newaxis: distributed + x.resplit_(axis=1) + x_newaxis = x[:, np.newaxis, :2, :] + x_none = x[:, None, :2, :] + self.assert_array_equal(x_newaxis, x_np_newaxis) + self.assert_array_equal(x_none, x_np_newaxis) + self.assertTrue(x_newaxis.split == 2) + self.assertTrue(x_none.split == 2) + + x = ht.arange(5, split=0) + x_np = np.arange(5) + y = x[:, np.newaxis] + x[np.newaxis, :] + y_np = x_np[:, np.newaxis] + x_np[np.newaxis, :] + self.assert_array_equal(y, y_np) + self.assertTrue(y.split == 0) + + # ADVANCED INDEXING + # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" + + x_np = np.arange(60).reshape(5, 3, 4) + indexed_x_np = x_np[(1, 2, 3)] + adv_indexed_x_np = x_np[(1, 2, 3),] + x = ht.array(x_np, split=0) + indexed_x = x[(1, 2, 3)] + self.assertTrue(indexed_x.item() == np.array(indexed_x_np)) + adv_indexed_x = x[(1, 2, 3),] + self.assert_array_equal(adv_indexed_x, adv_indexed_x_np) + + # 1d + x = ht.arange(10, 1, -1, split=0) + x_np = np.arange(10, 1, -1) + x_adv_ind = x[np.array([3, 3, 1, 8])] + x_np_adv_ind = x_np[np.array([3, 3, 1, 8])] + self.assert_array_equal(x_adv_ind, x_np_adv_ind) + + # 3d, split 0, non-unique, non-ordered key along split axis + x = ht.arange(60, split=0).reshape(5, 3, 4) + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + self.assert_array_equal( + x[ht.array(k1, split=0), ht.array(k2, split=0), ht.array(k3, split=0)], x_np[k1, k2, k3] + ) + # advanced indexing on non-consecutive dimensions + x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + x_copy = x.copy() + x_np = np.arange(60).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = 0 + k3 = np.array([1, 2, 3, 1]) + key = (k1, k2, k3) + self.assert_array_equal(x[key], x_np[key]) + # check that x is unchanged after internal manipulation + self.assertTrue(x.shape == x_copy.shape) + self.assertTrue(x.split == x_copy.split) + self.assertTrue(x.lshape == x_copy.lshape) + self.assertTrue((x == x_copy).all().item()) + + # broadcasting shapes + x.resplit_(axis=0) + self.assert_array_equal(x[ht.array(k1, split=0), ht.array(1), 2], x_np[k1, 1, 2]) + # test exception: broadcasting mismatching shapes + k2 = np.array([0, 2, 1]) + with self.assertRaises(IndexError): + x[k1, k2, k3] + + # more broadcasting + x_np = np.arange(12).reshape(4, 3) + rows = np.array([0, 3]) + cols = np.array([0, 2]) + x = ht.arange(12).reshape(4, 3) + x.resplit_(1) + x_np_indexed = x_np[rows[:, np.newaxis], cols] + x_indexed = x[ht.array(rows)[:, np.newaxis], cols] + self.assert_array_equal(x_indexed, x_np_indexed) + self.assertTrue(x_indexed.split == 1) + + # combining advanced and basic indexing + y_np = np.arange(35).reshape(5, 7) + y_np_indexed = y_np[np.array([0, 2, 4]), 1:3] + y = ht.array(y_np, split=1) + y_indexed = y[ht.array([0, 2, 4]), 1:3] + self.assert_array_equal(y_indexed, y_np_indexed) + self.assertTrue(y_indexed.split == 1) + + x_np = np.arange(10 * 20 * 30).reshape(10, 20, 30) + x = ht.array(x_np, split=1) + ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) + ind_array_np = ind_array.numpy() + x_np_indexed = x_np[..., ind_array_np, :] + x_indexed = x[..., ind_array, :] + self.assert_array_equal(x_indexed, x_np_indexed) + self.assertTrue(x_indexed.split == 3) + + # boolean mask, local + arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + np.random.seed(42) + mask = np.random.randint(0, 2, arr.shape, dtype=bool) + self.assertTrue((arr[mask].numpy() == arr.numpy()[mask]).all()) + + # boolean mask, distributed + arr_split0 = ht.array(arr, split=0) + mask_split0 = ht.array(mask, split=0) + self.assertTrue((arr_split0[mask_split0].numpy() == arr.numpy()[mask]).all()) + + arr_split1 = ht.array(arr, split=1) + mask_split1 = ht.array(mask, split=1) + self.assert_array_equal(arr_split1[mask_split1], arr.numpy()[mask]) + + arr_split2 = ht.array(arr, split=2) + mask_split2 = ht.array(mask, split=2) + self.assert_array_equal(arr_split2[mask_split2], arr.numpy()[mask]) + def test_int_cast(self): # simple scalar tensor a = ht.ones(1) @@ -1106,445 +1355,732 @@ def test_resplit(self): self.assertTrue(ht.all(t1_sub == res)) self.assertEqual(t1_sub.split, None) - def test_setitem_getitem(self): + def test_setitem(self): + # following https://numpy.org/doc/stable/user/basics.indexing.html + + # Single element indexing + # 1D, local + x = ht.zeros(10) + x[2] = 2 + x[-2] = 8 + self.assertTrue(x[2].item() == 2) + self.assertTrue(x[-2].item() == 8) + self.assertTrue(x[2].dtype == ht.float32) + # 1D, distributed + x = ht.zeros(10, split=0, dtype=ht.float64) + x[2] = 2 + x[-2] = 8 + self.assertTrue(x[2].item() == 2.0) + self.assertTrue(x[-2].item() == 8.0) + self.assertTrue(x[2].dtype == ht.float64) + self.assertTrue(x.split == 0) + # 2D, local + x = ht.zeros(10).reshape(2, 5) + x[0] = ht.arange(5) + self.assertTrue((x[0] == ht.arange(5)).all().item()) + self.assertTrue(x[0].dtype == ht.float32) + # 2D, distributed + x_split0 = ht.zeros(10, split=0).reshape(2, 5) + x_split0[0] = ht.arange(5) + self.assertTrue((x_split0[0] == ht.arange(5, split=None)).all().item()) + x_split1 = ht.zeros(10, split=0).reshape(2, 5, new_split=1) + x_split1[-2] = ht.arange(5) + self.assertTrue((x_split1[-2] == ht.arange(5, split=0)).all().item()) + # 3D, distributed, split = 0 + x_split0 = ht.zeros(27, split=0).reshape(3, 3, 3) + key = -2 + x_split0[key] = ht.arange(3) + self.assertTrue((x_split0[key].larray == torch.arange(3)).all()) + self.assertTrue(x_split0[key].dtype == ht.float32) + self.assertTrue(x_split0.split == 0) + # 3D, distributed split, != 0 + x_split2 = ht.zeros(27, dtype=ht.int64, split=0).reshape(3, 3, 3, new_split=2) + key = ht.array(2) + x_split2[key] = [6, 7, 8] + indexed_split2 = x_split2[key] + self.assertTrue((indexed_split2.numpy()[0] == np.array([6, 7, 8])).all()) + self.assertTrue(indexed_split2.dtype == ht.int64) + self.assertTrue(x_split2.split == 2) + + # Slicing and striding + x = ht.arange(20, split=0) + x[1:11:3] = ht.array([10, 40, 70, 100]) + x_np = np.arange(20) + x_np[1:11:3] = np.array([10, 40, 70, 100]) + self.assert_array_equal(x, x_np) + self.assertTrue(x.split == 0) + + # 1-element slice along split axis + x = ht.arange(20).reshape(4, 5) + x.resplit_(axis=1) + x[:, 2:3] = ht.array([10, 40, 70, 100]).reshape(4, 1) + x_np = np.arange(20).reshape(4, 5) + x_np[:, 2:3] = np.array([10, 40, 70, 100]).reshape(4, 1) + self.assert_array_equal(x, x_np) + self.assertTrue(x.split == 1) + with self.assertRaises(ValueError): + x[:, 2:3] = ht.array([10, 40, 70, 100]) + + # slicing with negative step along split axis 0 + # assign different dtype + shape = (20, 4, 3) + x_3d = ht.arange(20 * 4 * 3, split=0).reshape(shape) + value = ht.random.randn(8, 2) + x_3d[17:2:-2, :2, ht.array(1)] = value + x_3d_sliced = x_3d[17:2:-2, :2, ht.array(1)] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # slicing with negative step along split 1 + shape = (4, 20, 3) + x_3d = ht.arange(20 * 4 * 3, dtype=ht.float32).reshape(shape) + x_3d.resplit_(axis=1) + key = (slice(None, 2), slice(17, 2, -2), 1) + value = ht.random.randn(2, 8) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # slicing with negative step along split 2 and loss of axis < split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3, dtype=ht.float64).reshape(shape) + x_3d.resplit_(axis=2) + key = (slice(None, 2), 1, slice(17, 10, -2)) + value = ht.random.randn(2, 4) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # slicing with negative step along split 2 and loss of all axes but split + shape = (4, 3, 20) + x_3d = ht.arange(20 * 4 * 3).reshape(shape) + x_3d.resplit_(axis=2) + key = (0, 1, slice(17, 13, -1)) + value = ht.random.randint( + 0, + 5, + ( + 1, + 4, + ), + split=1, + ) + x_3d[key] = value + x_3d_sliced = x_3d[key] + self.assertTrue(ht.allclose(x_3d_sliced, value.squeeze(0).astype(x_3d.dtype))) + self.assertTrue(x_3d_sliced.dtype == x_3d.dtype) + + # DIMENSIONAL INDEXING + + # ellipsis + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + # local + value = x.squeeze() + 7 + x[..., 0] = value + self.assertTrue(ht.all(x[..., 0] == value).item()) + value -= 7 + x[:, :, 0] = value + self.assertTrue(ht.all(x[:, :, 0] == value).item()) + + # distributed + x.resplit_(axis=1) + value *= 2 + x[..., 0] = value + x_ellipsis = x[..., 0] + self.assertTrue(ht.all(x_ellipsis == value).item()) + value += 2 + x[:, :, 0] = value + self.assertTrue(ht.all(x[:, :, 0] == value).item()) + self.assertTrue(x_ellipsis.split == 1) + + # newaxis: local, w. broadcasting and different dtype + x = ht.array([[[1], [2], [3]], [[4], [5], [6]]]) + value = ht.array([10.0, 20.0]).reshape(2, 1) + x[:, None, :2, :] = value + x_newaxis = x[:, None, :2, :] + self.assertTrue(ht.all(x_newaxis == value).item()) + value += 2 + x[:, None, :2, :] = value + self.assertTrue(ht.all(x[:, None, :2, :] == value).item()) + self.assertTrue(x[:, None, :2, :].dtype == x.dtype) + + # newaxis: distributed w. broadcasting and different dtype + x.resplit_(axis=1) + value = ht.array([30.0, 40.0]).reshape(1, 2, 1) + x[:, np.newaxis, :2, :] = value + x_newaxis = x[:, np.newaxis, :2, :] + self.assertTrue(ht.all(x_newaxis == value).item()) + value += 2 + x[:, None, :2, :] = value + x_none = x[:, None, :2, :] + self.assertTrue(ht.all(x_none == value).item()) + self.assertTrue(x_none.dtype == x.dtype) + + # distributed value + x = ht.arange(6).reshape(1, 1, 2, 3) + x.resplit_(axis=-1) + value = ht.arange(3).reshape(1, 3) + value.resplit_(axis=1) + x[..., 0, :] = value + self.assertTrue(ht.all(x[..., 0, :] == value).item()) + + # ADVANCED INDEXING + # "x[(1, 2, 3),] is fundamentally different from x[(1, 2, 3)]" + + x = ht.arange(60, split=0).reshape(5, 3, 4) + value = 99.0 + x[(1, 2, 3)] = value + indexed_x = x[(1, 2, 3)] + self.assertTrue((indexed_x == value).item()) + self.assertTrue(indexed_x.dtype == x.dtype) + x[(1, 2, 3),] = value + adv_indexed_x = x[(1, 2, 3),] + self.assertTrue(ht.all(adv_indexed_x == value).item()) + self.assertTrue(adv_indexed_x.dtype == x.dtype) + + # 1d + x = ht.arange(10, 1, -1, split=0) + value = ht.arange(4) + x[ht.array([3, 2, 1, 8])] = value + x_adv_ind = x[np.array([3, 2, 1, 8])] + self.assertTrue(ht.all(x_adv_ind == value).item()) + self.assertTrue(x_adv_ind.dtype == x.dtype) + + # TODO: n-d value + + # 3d, split 0, non-unique, non-ordered key along split axis, key mask-like + x = ht.arange(60, split=0).reshape(5, 3, 4) + k1 = np.array([0, 4, 1, 0]) + k2 = np.array([0, 2, 1, 0]) + k3 = np.array([1, 2, 3, 1]) + value = ht.array([99, 98, 97, 96], split=0) + x[k1, k2, k3] = value + self.assertTrue((x[k1, k2, k3] == ht.array([96, 98, 97, 96], split=0)).all().item()) + + # advanced indexing on non-consecutive dimensions, split dimension will be lost + x = ht.arange(60, split=0).reshape(5, 3, 4, new_split=1) + x_copy = x.copy() + k1 = np.array([0, 4, 1, 2]) + k2 = 0 + k3 = np.array([1, 2, 3, 1]) + key = (k1, k2, k3) + value = ht.array([99, 98, 97, 96]) + x[key] = value + self.assertTrue((x[key] == ht.array([99, 98, 97, 96])).all().item()) + # check that x is unchanged after internal manipulation + self.assertTrue(x.shape == x_copy.shape) + self.assertTrue(x.split == x_copy.split) + self.assertTrue(x.lshape == x_copy.lshape) + + # broadcasting shapes + x.resplit_(axis=0) + key = (ht.array(k1, split=0), ht.array(1), 2) + value = ht.array([99, 98, 97, 96], split=0) + x[key] = value + self.assertTrue((x[key] == value).all().item()) + # test exception: broadcasting mismatching shapes + k2 = np.array([0, 2, 1]) + with self.assertRaises(IndexError): + x[k1, k2, k3] = value + + # more broadcasting + x = ht.arange(12).reshape(4, 3) + x.resplit_(1) + rows = np.array([0, 3]) + cols = np.array([0, 2]) + key = (ht.array(rows)[:, np.newaxis], cols) + value = ht.array([[99, 98], [97, 96]], split=1) + x[key] = value + self.assertTrue((x[key] == value).all().item()) + if x.comm.size > 1: + with self.assertRaises(RuntimeError): + value = ht.array([[99, 98], [97, 96]], split=0) + x[key] = value + + # combining advanced and basic indexing + + y = ht.arange(35).reshape(5, 7) + y.resplit_(1) + y_copy = y.copy() + # assign non-distributed value + value = ht.arange(6).reshape(3, 2) + y[ht.array([0, 2, 4]), 1:3] = value + self.assertTrue((y[ht.array([0, 2, 4]), 1:3] == value).all().item()) + # assign distributed value + value.resplit_(1) + y_copy[ht.array([0, 2, 4]), 1:3] = value + self.assertTrue((y_copy[ht.array([0, 2, 4]), 1:3] == value).all().item()) + + x = ht.arange(10 * 20 * 30).reshape(10, 20, 30) + x.resplit_(1) + ind_array = ht.random.randint(0, 20, (2, 3, 4), dtype=ht.int64) + value = ht.ones((1, 2, 3, 4, 1)) + x[..., ind_array, :] = value + self.assertTrue((x[..., ind_array, :] == value).all().item()) + + # boolean mask, local + arr = ht.arange(3 * 4 * 5).reshape(3, 4, 5) + np.random.seed(42) + mask = np.random.randint(0, 2, arr.shape, dtype=bool) + value = 99.0 + arr[mask] = value + self.assertTrue((arr[mask] == value).all().item()) + self.assertTrue(arr[mask].dtype == arr.dtype) + value = ht.ones_like(arr) + arr[mask] = value[mask] + self.assertTrue((arr[mask] == value[mask]).all().item()) + + # boolean mask, distributed, non-distributed `value` + arr_split0 = ht.array(arr, split=0) + mask_split0 = ht.array(mask, split=0) + arr_split0[mask_split0] = value[mask] + indexed_arr = arr_split0[mask_split0] + indexed_arr.balance_() + self.assertTrue((indexed_arr == value[mask]).all().item()) + arr_split1 = ht.array(arr, split=1) + mask_split1 = ht.array(mask, split=1) + arr_split1[mask_split1] = value[mask] + self.assertTrue((arr_split1[mask_split1] == value[mask]).all().item()) + arr_split2 = ht.array(arr, split=2) + mask_split2 = ht.array(mask, split=2) + arr_split2[mask_split2] = value[mask] + self.assertTrue((arr_split2[mask_split2] == value[mask]).all().item()) + + # TODO boolean mask, distributed, distributed `value` + # tests for bug #825 a = ht.ones((102, 102), split=0) setting = ht.zeros((100, 100), split=0) a[1:-1, 1:-1] = setting - self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0).item()) a = ht.ones((102, 102), split=1) setting = ht.zeros((30, 100), split=1) a[-30:, 1:-1] = setting - self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) + self.assertTrue(ht.all(a[-30:, 1:-1] == 0).item()) a = ht.ones((102, 102), split=1) setting = ht.zeros((100, 100), split=1) a[1:-1, 1:-1] = setting - self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0).item()) a = ht.ones((102, 102), split=1) setting = ht.zeros((100, 20), split=1) a[1:-1, :20] = setting - self.assertTrue(ht.all(a[1:-1, :20] == 0)) - - # tests for bug 730: - a = ht.ones((10, 25, 30), split=1) - if a.comm.size > 1: - self.assertEqual(a[0].split, 0) - self.assertEqual(a[:, 0, :].split, None) - self.assertEqual(a[:, :, 0].split, 1) - - # set and get single value - a = ht.zeros((13, 5), split=0) - # set value on one node - a[10, np.array(0)] = 1 - self.assertEqual(a[10, 0], 1) - self.assertEqual(a[10, 0].dtype, ht.float32) - - a = ht.zeros((13, 5), split=0) - a[10] = 1 - b = a[torch.tensor(10)] - self.assertTrue((b == 1).all()) - self.assertEqual(b.dtype, ht.float32) - self.assertEqual(b.gshape, (5,)) - - a = ht.zeros((13, 5), split=0) - a[-1] = 1 - b = a[-1] - self.assertTrue((b == 1).all()) - self.assertEqual(b.dtype, ht.float32) - self.assertEqual(b.gshape, (5,)) - - # slice in 1st dim only on 1 node - a = ht.zeros((13, 5), split=0) - a[1:4] = 1 - self.assertTrue((a[1:4] == 1).all()) - self.assertEqual(a[1:4].gshape, (3, 5)) - self.assertEqual(a[1:4].split, 0) - self.assertEqual(a[1:4].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4].lshape, (3, 5)) - else: - self.assertEqual(a[1:4].lshape, (0, 5)) - - a = ht.zeros((13, 5), split=0) - a[1:2] = 1 - self.assertTrue((a[1:2] == 1).all()) - self.assertEqual(a[1:2].gshape, (1, 5)) - self.assertEqual(a[1:2].split, 0) - self.assertEqual(a[1:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:2].lshape, (1, 5)) - else: - self.assertEqual(a[1:2].lshape, (0, 5)) - - # slice in 1st dim only on 1 node w/ singular second dim - a = ht.zeros((13, 5), split=0) - a[1:4, 1] = 1 - b = a[1:4, np.int64(1)] - self.assertTrue((b == 1).all()) - self.assertEqual(b.gshape, (3,)) - self.assertEqual(b.split, 0) - self.assertEqual(b.dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(b.lshape, (3,)) - else: - self.assertEqual(b.lshape, (0,)) - - # slice in 1st dim across both nodes (2 node case) w/ singular second dim - a = ht.zeros((13, 5), split=0) - a[1:11, 1] = 1 - self.assertTrue((a[1:11, 1] == 1).all()) - self.assertEqual(a[1:11, 1].gshape, (10,)) - self.assertEqual(a[1:11, torch.tensor(1)].split, 0) - self.assertEqual(a[1:11, 1].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[1:11, 1].lshape, (4,)) - if a.comm.rank == 0: - self.assertEqual(a[1:11, 1].lshape, (6,)) - - # slice in 1st dim across 1 node (2nd) w/ singular second dim - c = ht.zeros((13, 5), split=0) - c[8:12, ht.array(1)] = 1 - b = c[8:12, np.int64(1)] - self.assertTrue((b == 1).all()) - self.assertEqual(b.gshape, (4,)) - self.assertEqual(b.split, 0) - self.assertEqual(b.dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(b.lshape, (4,)) - if a.comm.rank == 0: - self.assertEqual(b.lshape, (0,)) - - # slice in both directions - a = ht.zeros((13, 5), split=0) - a[3:13, 2:5:2] = 1 - self.assertTrue((a[3:13, 2:5:2] == 1).all()) - self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) - self.assertEqual(a[3:13, 2:5:2].split, 0) - self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[3:13, 2:5:2].lshape, (6, 2)) - if a.comm.rank == 0: - self.assertEqual(a[3:13, 2:5:2].lshape, (4, 2)) - - # setting with heat tensor - a = ht.zeros((4, 5), split=0) - a[1, 0:4] = ht.arange(4) - # if a.comm.size == 2: - for c, i in enumerate(range(4)): - self.assertEqual(a[1, c], i) - - # setting with torch tensor - a = ht.zeros((4, 5), split=0) - a[1, 0:4] = torch.arange(4, device=self.device.torch_device) - # if a.comm.size == 2: - for c, i in enumerate(range(4)): - self.assertEqual(a[1, c], i) - - ################################################### - a = ht.zeros((13, 5), split=1) - # # set value on one node - a[10] = 1 - self.assertEqual(a[10].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10].lshape, (3,)) - if a.comm.rank == 1: - self.assertEqual(a[10].lshape, (2,)) - - a = ht.zeros((13, 5), split=1) - # # set value on one node - a[10, 0] = 1 - self.assertEqual(a[10, 0], 1) - self.assertEqual(a[10, 0].dtype, ht.float32) - - # slice in 1st dim only on 1 node - a = ht.zeros((13, 5), split=1) - a[1:4] = 1 - self.assertTrue((a[1:4] == 1).all()) - self.assertEqual(a[1:4].gshape, (3, 5)) - self.assertEqual(a[1:4].split, 1) - self.assertEqual(a[1:4].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4].lshape, (3, 3)) - if a.comm.rank == 1: - self.assertEqual(a[1:4].lshape, (3, 2)) - - # slice in 1st dim only on 1 node w/ singular second dim - a = ht.zeros((13, 5), split=1) - a[1:4, 1] = 1 - self.assertTrue((a[1:4, 1] == 1).all()) - self.assertEqual(a[1:4, 1].gshape, (3,)) - self.assertEqual(a[1:4, 1].split, None) - self.assertEqual(a[1:4, 1].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4, 1].lshape, (3,)) - if a.comm.rank == 1: - self.assertEqual(a[1:4, 1].lshape, (3,)) - - # slice in 2st dim across both nodes (2 node case) w/ singular fist dim - a = ht.zeros((13, 5), split=1) - a[11, 1:5] = 1 - self.assertTrue((a[11, 1:5] == 1).all()) - self.assertEqual(a[11, 1:5].gshape, (4,)) - self.assertEqual(a[11, 1:5].split, 0) - self.assertEqual(a[11, 1:5].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[11, 1:5].lshape, (2,)) - if a.comm.rank == 0: - self.assertEqual(a[11, 1:5].lshape, (2,)) - - # slice in 1st dim across 1 node (2nd) w/ singular second dim - a = ht.zeros((13, 5), split=1) - a[8:12, 1] = 1 - self.assertTrue((a[8:12, 1] == 1).all()) - self.assertEqual(a[8:12, 1].gshape, (4,)) - self.assertEqual(a[8:12, 1].split, None) - self.assertEqual(a[8:12, 1].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[8:12, 1].lshape, (4,)) - if a.comm.rank == 1: - self.assertEqual(a[8:12, 1].lshape, (4,)) - - # slice in both directions - a = ht.zeros((13, 5), split=1) - a[3:13, 2::2] = 1 - self.assertTrue((a[3:13, 2:5:2] == 1).all()) - self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) - self.assertEqual(a[3:13, 2:5:2].split, 1) - self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) - if a.comm.rank == 0: - self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) - - a = ht.zeros((13, 5), split=1) - a[..., 2::2] = 1 - self.assertTrue((a[:, 2:5:2] == 1).all()) - self.assertEqual(a[..., 2:5:2].gshape, (13, 2)) - self.assertEqual(a[..., 2:5:2].split, 1) - self.assertEqual(a[..., 2:5:2].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 1: - self.assertEqual(a[..., 2:5:2].lshape, (13, 1)) - if a.comm.rank == 0: - self.assertEqual(a[:, 2:5:2].lshape, (13, 1)) - - # setting with heat tensor - a = ht.zeros((4, 5), split=1) - a[1, 0:4] = ht.arange(4) - for c, i in enumerate(range(4)): - b = a[1, c] - if b.larray.numel() > 0: - self.assertEqual(b.item(), i) - - # setting with torch tensor - a = ht.zeros((4, 5), split=1) - a[1, 0:4] = torch.arange(4, device=self.device.torch_device) - for c, i in enumerate(range(4)): - self.assertEqual(a[1, c], i) - - #################################################### - a = ht.zeros((13, 5, 7), split=2) - # # set value on one node - a[10, :, :] = 1 - self.assertEqual(a[10, :, :].dtype, ht.float32) - self.assertEqual(a[10, :, :].gshape, (5, 7)) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10, :, :].lshape, (5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[10, :, :].lshape, (5, 3)) - - a = ht.zeros((13, 5, 7), split=2) - # # set value on one node - a[10, ...] = 1 - self.assertEqual(a[10, ...].dtype, ht.float32) - self.assertEqual(a[10, ...].gshape, (5, 7)) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[10, ...].lshape, (5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[10, ...].lshape, (5, 3)) - - a = ht.zeros((13, 5, 8), split=2) - # # set value on one node - a[10, 0, 0] = 1 - self.assertEqual(a[10, 0, 0], 1) - self.assertEqual(a[10, 0, 0].dtype, ht.float32) - - # # slice in 1st dim only on 1 node - a = ht.zeros((13, 5, 7), split=2) - a[1:4] = 1 - self.assertTrue((a[1:4] == 1).all()) - self.assertEqual(a[1:4].gshape, (3, 5, 7)) - self.assertEqual(a[1:4].split, 2) - self.assertEqual(a[1:4].dtype, ht.float32) - if a.comm.size == 2: - if a.comm.rank == 0: - self.assertEqual(a[1:4].lshape, (3, 5, 4)) - if a.comm.rank == 1: - self.assertEqual(a[1:4].lshape, (3, 5, 3)) - - # slice in 1st dim only on 1 node w/ singular second dim - a = ht.zeros((13, 5, 7), split=2) - a[1:4, 1, :] = 1 - self.assertTrue((a[1:4, 1, :] == 1).all()) - self.assertEqual(a[1:4, 1, :].gshape, (3, 7)) - if a.comm.size == 2: - self.assertEqual(a[1:4, 1, :].split, 1) - self.assertEqual(a[1:4, 1, :].dtype, ht.float32) - if a.comm.rank == 0: - self.assertEqual(a[1:4, 1, :].lshape, (3, 4)) - if a.comm.rank == 1: - self.assertEqual(a[1:4, 1, :].lshape, (3, 3)) - - # slice in both directions - a = ht.zeros((13, 5, 7), split=2) - a[3:13, 2:5:2, 1:7:3] = 1 - self.assertTrue((a[3:13, 2:5:2, 1:7:3] == 1).all()) - self.assertEqual(a[3:13, 2:5:2, 1:7:3].split, 2) - self.assertEqual(a[3:13, 2:5:2, 1:7:3].dtype, ht.float32) - self.assertEqual(a[3:13, 2:5:2, 1:7:3].gshape, (10, 2, 2)) - if a.comm.size == 2: - out = ht.ones((4, 5, 5), split=1) - self.assertEqual(out[0].gshape, (5, 5)) - if a.comm.rank == 1: - self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) - self.assertEqual(out[0].lshape, (2, 5)) - if a.comm.rank == 0: - self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) - self.assertEqual(out[0].lshape, (3, 5)) - - a = ht.ones((4, 5), split=0).tril() - a[0] = [6, 6, 6, 6, 6] - self.assertTrue((a[0] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = (6, 6, 6, 6, 6) - self.assertTrue((a[0] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = np.array([6, 6, 6, 6, 6]) - self.assertTrue((a[0] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = ht.array([6, 6, 6, 6, 6]) - self.assertTrue((a[ht.array((0,))] == 6).all()) - - a = ht.ones((4, 5), split=0).tril() - a[0] = ht.array([6, 6, 6, 6, 6]) - self.assertTrue((a[ht.array((0,))] == 6).all()) - - # ======================= indexing with bools ================================= - split = None - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = np_arr < 0.5 - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - # key -> tuple(ht.bool, int) - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - ht_key = ht.array(np_key, split=split) - arr[ht_key, 4] = 10.0 - np_arr[np_key, 4] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key, 4] == 10.0)) - - # key -> tuple(torch.bool, int) - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - t_key = torch.tensor(np_key, device=arr.larray.device) - arr[t_key, 4] = 10.0 - np_arr[np_key, 4] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[t_key, 4] == 10.0)) - - # key -> torch.bool - split = 0 - arr = ht.random.random((20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = (np_arr < 0.5)[0] - t_key = torch.tensor(np_key, device=arr.larray.device) - arr[t_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[t_key] == 10.0)) - - split = 1 - arr = ht.random.random((20, 20, 10)).resplit(split) - np_arr = arr.numpy() - np_key = np_arr < 0.5 - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - split = 2 - arr = ht.random.random((15, 20, 20)).resplit(split) - np_arr = arr.numpy() - np_key = np_arr < 0.5 - ht_key = ht.array(np_key, split=split) - arr[ht_key] = 10.0 - np_arr[np_key] = 10.0 - self.assertTrue(np.all(arr.numpy() == np_arr)) - self.assertTrue(ht.all(arr[ht_key] == 10.0)) - - with self.assertRaises(ValueError): - a[..., ...] - with self.assertRaises(ValueError): - a[..., ...] = 1 - if a.comm.size > 1: - with self.assertRaises(ValueError): - x = ht.ones((10, 10), split=0) - setting = ht.zeros((8, 8), split=1) - x[1:-1, 1:-1] = setting - - for split in [None, 0, 1, 2]: - for new_dim in [0, 1, 2]: - for add in [np.newaxis, None]: - arr = ht.ones((4, 3, 2), split=split, dtype=ht.int32) - check = torch.ones((4, 3, 2), dtype=torch.int32) - idx = [slice(None), slice(None), slice(None)] - idx[new_dim] = add - idx = tuple(idx) - arr = arr[idx] - check = check[idx] - self.assertTrue(arr.shape == check.shape) - self.assertTrue(arr.lshape[new_dim] == 1) + self.assertTrue(ht.all(a[1:-1, :20] == 0).item()) + + # # set and get single value + # a = ht.zeros((13, 5), split=0) + # # set value on one node + # a[10, np.array(0)] = 1 + # self.assertEqual(a[10, 0], 1) + # self.assertEqual(a[10, 0].dtype, ht.float32) + + # a = ht.zeros((13, 5), split=0) + # a[10] = 1 + # b = a[torch.tensor(10)] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.dtype, ht.float32) + # self.assertEqual(b.gshape, (5,)) + + # a = ht.zeros((13, 5), split=0) + # a[-1] = 1 + # b = a[-1] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.dtype, ht.float32) + # self.assertEqual(b.gshape, (5,)) + + # # slice in 1st dim only on 1 node + # a = ht.zeros((13, 5), split=0) + # a[1:4] = 1 + # self.assertTrue((a[1:4] == 1).all()) + # self.assertEqual(a[1:4].gshape, (3, 5)) + # self.assertEqual(a[1:4].split, 0) + # self.assertEqual(a[1:4].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4].lshape, (3, 5)) + # else: + # self.assertEqual(a[1:4].lshape, (0, 5)) + + # a = ht.zeros((13, 5), split=0) + # a[1:2] = 1 + # self.assertTrue((a[1:2] == 1).all()) + # self.assertEqual(a[1:2].gshape, (1, 5)) + # self.assertEqual(a[1:2].split, 0) + # self.assertEqual(a[1:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:2].lshape, (1, 5)) + # else: + # self.assertEqual(a[1:2].lshape, (0, 5)) + + # # slice in 1st dim only on 1 node w/ singular second dim + # a = ht.zeros((13, 5), split=0) + # a[1:4, 1] = 1 + # b = a[1:4, np.int64(1)] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.gshape, (3,)) + # self.assertEqual(b.split, 0) + # self.assertEqual(b.dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(b.lshape, (3,)) + # else: + # self.assertEqual(b.lshape, (0,)) + + # # slice in 1st dim across both nodes (2 node case) w/ singular second dim + # a = ht.zeros((13, 5), split=0) + # a[1:11, 1] = 1 + # self.assertTrue((a[1:11, 1] == 1).all()) + # self.assertEqual(a[1:11, 1].gshape, (10,)) + # self.assertEqual(a[1:11, torch.tensor(1)].split, 0) + # self.assertEqual(a[1:11, 1].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[1:11, 1].lshape, (4,)) + # if a.comm.rank == 0: + # self.assertEqual(a[1:11, 1].lshape, (6,)) + + # # slice in 1st dim across 1 node (2nd) w/ singular second dim + # c = ht.zeros((13, 5), split=0) + # c[8:12, ht.array(1)] = 1 + # b = c[8:12, np.int64(1)] + # self.assertTrue((b == 1).all()) + # self.assertEqual(b.gshape, (4,)) + # self.assertEqual(b.split, 0) + # self.assertEqual(b.dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(b.lshape, (4,)) + # if a.comm.rank == 0: + # self.assertEqual(b.lshape, (0,)) + + # # slice in both directions + # a = ht.zeros((13, 5), split=0) + # a[3:13, 2:5:2] = 1 + # self.assertTrue((a[3:13, 2:5:2] == 1).all()) + # self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) + # self.assertEqual(a[3:13, 2:5:2].split, 0) + # self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[3:13, 2:5:2].lshape, (6, 2)) + # if a.comm.rank == 0: + # self.assertEqual(a[3:13, 2:5:2].lshape, (4, 2)) + + # # setting with heat tensor + # a = ht.zeros((4, 5), split=0) + # a[1, 0:4] = ht.arange(4) + # # if a.comm.size == 2: + # for c, i in enumerate(range(4)): + # self.assertEqual(a[1, c], i) + + # # setting with torch tensor + # a = ht.zeros((4, 5), split=0) + # a[1, 0:4] = torch.arange(4, device=self.device.torch_device) + # # if a.comm.size == 2: + # for c, i in enumerate(range(4)): + # self.assertEqual(a[1, c], i) + + # ################################################### + # a = ht.zeros((13, 5), split=1) + # # # set value on one node + # a[10] = 1 + # self.assertEqual(a[10].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[10].lshape, (3,)) + # if a.comm.rank == 1: + # self.assertEqual(a[10].lshape, (2,)) + + # a = ht.zeros((13, 5), split=1) + # # # set value on one node + # a[10, 0] = 1 + # self.assertEqual(a[10, 0], 1) + # self.assertEqual(a[10, 0].dtype, ht.float32) + + # # slice in 1st dim only on 1 node + # a = ht.zeros((13, 5), split=1) + # a[1:4] = 1 + # self.assertTrue((a[1:4] == 1).all()) + # self.assertEqual(a[1:4].gshape, (3, 5)) + # self.assertEqual(a[1:4].split, 1) + # self.assertEqual(a[1:4].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4].lshape, (3, 3)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4].lshape, (3, 2)) + + # # slice in 1st dim only on 1 node w/ singular second dim + # a = ht.zeros((13, 5), split=1) + # a[1:4, 1] = 1 + # self.assertTrue((a[1:4, 1] == 1).all()) + # self.assertEqual(a[1:4, 1].gshape, (3,)) + # self.assertEqual(a[1:4, 1].split, None) + # self.assertEqual(a[1:4, 1].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4, 1].lshape, (3,)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4, 1].lshape, (3,)) + + # # slice in 2st dim across both nodes (2 node case) w/ singular fist dim + # a = ht.zeros((13, 5), split=1) + # a[11, 1:5] = 1 + # self.assertTrue((a[11, 1:5] == 1).all()) + # self.assertEqual(a[11, 1:5].gshape, (4,)) + # self.assertEqual(a[11, 1:5].split, 0) + # self.assertEqual(a[11, 1:5].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[11, 1:5].lshape, (2,)) + # if a.comm.rank == 0: + # self.assertEqual(a[11, 1:5].lshape, (2,)) + + # # slice in 1st dim across 1 node (2nd) w/ singular second dim + # a = ht.zeros((13, 5), split=1) + # a[8:12, 1] = 1 + # self.assertTrue((a[8:12, 1] == 1).all()) + # self.assertEqual(a[8:12, 1].gshape, (4,)) + # self.assertEqual(a[8:12, 1].split, None) + # self.assertEqual(a[8:12, 1].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[8:12, 1].lshape, (4,)) + # if a.comm.rank == 1: + # self.assertEqual(a[8:12, 1].lshape, (4,)) + + # # slice in both directions + # a = ht.zeros((13, 5), split=1) + # a[3:13, 2::2] = 1 + # self.assertTrue((a[3:13, 2:5:2] == 1).all()) + # self.assertEqual(a[3:13, 2:5:2].gshape, (10, 2)) + # self.assertEqual(a[3:13, 2:5:2].split, 1) + # self.assertEqual(a[3:13, 2:5:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) + # if a.comm.rank == 0: + # self.assertEqual(a[3:13, 2:5:2].lshape, (10, 1)) + + # a = ht.zeros((13, 5), split=1) + # a[..., 2::2] = 1 + # self.assertTrue((a[:, 2:5:2] == 1).all()) + # self.assertEqual(a[..., 2:5:2].gshape, (13, 2)) + # self.assertEqual(a[..., 2:5:2].split, 1) + # self.assertEqual(a[..., 2:5:2].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 1: + # self.assertEqual(a[..., 2:5:2].lshape, (13, 1)) + # if a.comm.rank == 0: + # self.assertEqual(a[:, 2:5:2].lshape, (13, 1)) + + # # setting with heat tensor + # a = ht.zeros((4, 5), split=1) + # a[1, 0:4] = ht.arange(4) + # for c, i in enumerate(range(4)): + # b = a[1, c] + # if b.larray.numel() > 0: + # self.assertEqual(b.item(), i) + + # # setting with torch tensor + # a = ht.zeros((4, 5), split=1) + # a[1, 0:4] = torch.arange(4, device=self.device.torch_device) + # for c, i in enumerate(range(4)): + # self.assertEqual(a[1, c], i) + + # #################################################### + # a = ht.zeros((13, 5, 7), split=2) + # # # set value on one node + # a[10, :, :] = 1 + # self.assertEqual(a[10, :, :].dtype, ht.float32) + # self.assertEqual(a[10, :, :].gshape, (5, 7)) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[10, :, :].lshape, (5, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[10, :, :].lshape, (5, 3)) + + # a = ht.zeros((13, 5, 7), split=2) + # # # set value on one node + # a[10, ...] = 1 + # self.assertEqual(a[10, ...].dtype, ht.float32) + # self.assertEqual(a[10, ...].gshape, (5, 7)) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[10, ...].lshape, (5, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[10, ...].lshape, (5, 3)) + + # a = ht.zeros((13, 5, 8), split=2) + # # # set value on one node + # a[10, 0, 0] = 1 + # self.assertEqual(a[10, 0, 0], 1) + # self.assertEqual(a[10, 0, 0].dtype, ht.float32) + + # # # slice in 1st dim only on 1 node + # a = ht.zeros((13, 5, 7), split=2) + # a[1:4] = 1 + # self.assertTrue((a[1:4] == 1).all()) + # self.assertEqual(a[1:4].gshape, (3, 5, 7)) + # self.assertEqual(a[1:4].split, 2) + # self.assertEqual(a[1:4].dtype, ht.float32) + # if a.comm.size == 2: + # if a.comm.rank == 0: + # self.assertEqual(a[1:4].lshape, (3, 5, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4].lshape, (3, 5, 3)) + + # # slice in 1st dim only on 1 node w/ singular second dim + # a = ht.zeros((13, 5, 7), split=2) + # a[1:4, 1, :] = 1 + # self.assertTrue((a[1:4, 1, :] == 1).all()) + # self.assertEqual(a[1:4, 1, :].gshape, (3, 7)) + # if a.comm.size == 2: + # self.assertEqual(a[1:4, 1, :].split, 1) + # self.assertEqual(a[1:4, 1, :].dtype, ht.float32) + # if a.comm.rank == 0: + # self.assertEqual(a[1:4, 1, :].lshape, (3, 4)) + # if a.comm.rank == 1: + # self.assertEqual(a[1:4, 1, :].lshape, (3, 3)) + + # # slice in both directions + # a = ht.zeros((13, 5, 7), split=2) + # a[3:13, 2:5:2, 1:7:3] = 1 + # self.assertTrue((a[3:13, 2:5:2, 1:7:3] == 1).all()) + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].split, 2) + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].dtype, ht.float32) + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].gshape, (10, 2, 2)) + # if a.comm.size == 2: + # out = ht.ones((4, 5, 5), split=1) + # self.assertEqual(out[0].gshape, (5, 5)) + # if a.comm.rank == 1: + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) + # self.assertEqual(out[0].lshape, (2, 5)) + # if a.comm.rank == 0: + # self.assertEqual(a[3:13, 2:5:2, 1:7:3].lshape, (10, 2, 1)) + # self.assertEqual(out[0].lshape, (3, 5)) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = [6, 6, 6, 6, 6] + # self.assertTrue((a[0] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = (6, 6, 6, 6, 6) + # self.assertTrue((a[0] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = np.array([6, 6, 6, 6, 6]) + # self.assertTrue((a[0] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = ht.array([6, 6, 6, 6, 6]) + # self.assertTrue((a[ht.array((0,))] == 6).all()) + + # a = ht.ones((4, 5), split=0).tril() + # a[0] = ht.array([6, 6, 6, 6, 6]) + # self.assertTrue((a[ht.array((0,))] == 6).all()) + + # # ======================= indexing with bools ================================= + # split = None + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = np_arr < 0.5 + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # # key -> tuple(ht.bool, int) + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # ht_key = ht.array(np_key, split=split) + # arr[ht_key, 4] = 10.0 + # np_arr[np_key, 4] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key, 4] == 10.0)) + + # # key -> tuple(torch.bool, int) + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # t_key = torch.tensor(np_key, device=arr.larray.device) + # arr[t_key, 4] = 10.0 + # np_arr[np_key, 4] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[t_key, 4] == 10.0)) + + # # key -> torch.bool + # split = 0 + # arr = ht.random.random((20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = (np_arr < 0.5)[0] + # t_key = torch.tensor(np_key, device=arr.larray.device) + # arr[t_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[t_key] == 10.0)) + + # split = 1 + # arr = ht.random.random((20, 20, 10)).resplit(split) + # np_arr = arr.numpy() + # np_key = np_arr < 0.5 + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # split = 2 + # arr = ht.random.random((15, 20, 20)).resplit(split) + # np_arr = arr.numpy() + # np_key = np_arr < 0.5 + # ht_key = ht.array(np_key, split=split) + # arr[ht_key] = 10.0 + # np_arr[np_key] = 10.0 + # self.assertTrue(np.all(arr.numpy() == np_arr)) + # self.assertTrue(ht.all(arr[ht_key] == 10.0)) + + # with self.assertRaises(ValueError): + # a[..., ...] + # with self.assertRaises(ValueError): + # a[..., ...] = 1 + # if a.comm.size > 1: + # with self.assertRaises(ValueError): + # x = ht.ones((10, 10), split=0) + # setting = ht.zeros((8, 8), split=1) + # x[1:-1, 1:-1] = setting + + # for split in [None, 0, 1, 2]: + # for new_dim in [0, 1, 2]: + # for add in [np.newaxis, None]: + # arr = ht.ones((4, 3, 2), split=split, dtype=ht.int32) + # check = torch.ones((4, 3, 2), dtype=torch.int32) + # idx = [slice(None), slice(None), slice(None)] + # idx[new_dim] = add + # idx = tuple(idx) + # arr = arr[idx] + # check = check[idx] + # self.assertTrue(arr.shape == check.shape) + # self.assertTrue(arr.lshape[new_dim] == 1) def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None) @@ -1737,6 +2273,7 @@ def test_torch_proxy(self): dndarray_proxy.storage().size() * dndarray_proxy.storage().element_size() ) self.assertTrue(dndarray_proxy_nbytes == 1) + self.assertTrue(dndarray_proxy.names.index("split") == 1) def test_xor(self): int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16) diff --git a/heat/core/tests/test_indexing.py b/heat/core/tests/test_indexing.py index 4707aa28ab..58c7410456 100644 --- a/heat/core/tests/test_indexing.py +++ b/heat/core/tests/test_indexing.py @@ -9,18 +9,18 @@ def test_nonzero(self): a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) cond = a > 3 nz = ht.nonzero(cond) - self.assertEqual(nz.gshape, (5, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, None) + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 5) + self.assertEqual(nz[0].dtype, ht.int64) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 nz = cond.nonzero() - self.assertEqual(nz.gshape, (6, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, 0) - a[nz] = 10.0 + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 6) + self.assertEqual(nz[0].dtype, ht.int64) + a[nz] = 10 self.assertEqual(ht.all(a[nz] == 10), 1) def test_where(self):