Skip to content

Commit

Permalink
(fix): correct shape for empty element in concat with dask (#1843)
Browse files Browse the repository at this point in the history
* (fix): correct shape for empty element in concat with dask

* (chore): add test

* (fix): handling of off axis size when no-dask array

* (refactor): make index a property

* (fix): coverage skip

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilan-gold and pre-commit-ci[bot] authored Feb 13, 2025
1 parent 68c0966 commit 9838253
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
32 changes: 22 additions & 10 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):
shape[axis] = len(self.new_idx)
return da.broadcast_to(fill_value, tuple(shape))

indexer = self.old_idx.get_indexer(self.new_idx)
indexer = self.idx
sub_el = _subset(el, make_slice(indexer, axis, len(shape)))

if any(indexer == -1):
Expand Down Expand Up @@ -607,7 +607,7 @@ def _apply_to_array(self, el, *, axis, fill_value=None):
shape[axis] = len(self.new_idx)
return np.broadcast_to(fill_value, tuple(shape))

indexer = self.old_idx.get_indexer(self.new_idx)
indexer = self.idx

# Indexes real fast, and does outer indexing
return pd.api.extensions.take(
Expand Down Expand Up @@ -705,7 +705,11 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None):
else:
if len(self.new_idx) > len(self.old_idx):
el = ak.pad_none(el, 1, axis=axis) # axis == 0
return el[self.old_idx.get_indexer(self.new_idx)]
return el[self.idx]

@property
def idx(self):
return self.old_idx.get_indexer(self.new_idx)


def merge_indices(
Expand Down Expand Up @@ -940,16 +944,11 @@ def missing_element(
els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray],
axis: Literal[0, 1] = 0,
fill_value: Any | None = None,
off_axis_size: int = 0,
) -> np.ndarray | DaskArray:
"""Generates value to use when there is a missing element."""
should_return_dask = any(isinstance(el, DaskArray) for el in els)
try:
non_missing_elem = next(el for el in els if not_missing(el))
except StopIteration: # pragma: no cover
msg = "All elements are missing when attempting to generate missing elements."
raise ValueError(msg)
# 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting.
off_axis_size = 0 if not should_return_dask else non_missing_elem.shape[axis - 1]
shape = (n, off_axis_size) if axis == 0 else (off_axis_size, n)
if should_return_dask:
import dask.array as da
Expand All @@ -973,13 +972,26 @@ def outer_concat_aligned_mapping(
else:
cur_reindexers = reindexers

# Dask needs to create a full array and can't do the size-0 trick
off_axis_size = 0
if any(isinstance(e, DaskArray) for e in els):
if not isinstance(cur_reindexers[0], Reindexer): # pragma: no cover
msg = "Cannot re-index a dask array without a Reindexer"
raise ValueError(msg)
off_axis_size = cur_reindexers[0].idx.shape[0]
# Handling of missing values here is hacky for dataframes
# We should probably just handle missing elements for all types
result[k] = concat_arrays(
[
el
if not_missing(el)
else missing_element(n, axis=axis, els=els, fill_value=fill_value)
else missing_element(
n,
axis=axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns)
],
cur_reindexers,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ def test_concat_missing_elem_dask_join(join_type):

import anndata as ad

ad1 = ad.AnnData(X=np.ones((5, 5)))
ad1 = ad.AnnData(X=np.ones((5, 10)))
ad2 = ad.AnnData(X=np.zeros((5, 5)), layers={"a": da.ones((5, 5))})
ad_in_memory_with_layers = ad2.to_memory()

Expand All @@ -1556,11 +1556,12 @@ def test_impute_dask(axis_name):

axis, _ = _resolve_axis(axis_name)
els = [da.ones((5, 5))]
missing = missing_element(6, els, axis=axis)
missing = missing_element(6, els, axis=axis, off_axis_size=17)
assert isinstance(missing, DaskArray)
in_memory = missing.compute()
assert np.all(np.isnan(in_memory))
assert in_memory.shape[axis] == 6
assert in_memory.shape[axis - 1] == 17


def test_outer_concat_with_missing_value_for_df():
Expand Down

0 comments on commit 9838253

Please sign in to comment.