Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): correct shape for empty element in concat with dask #1843

Merged
merged 9 commits into from
Feb 13, 2025
32 changes: 22 additions & 10 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@
shape[axis] = len(self.new_idx)
return da.broadcast_to(fill_value, tuple(shape))

indexer = self.old_idx.get_indexer(self.new_idx)
indexer = self.idx
sub_el = _subset(el, make_slice(indexer, axis, len(shape)))

if any(indexer == -1):
Expand Down Expand Up @@ -607,7 +607,7 @@
shape[axis] = len(self.new_idx)
return np.broadcast_to(fill_value, tuple(shape))

indexer = self.old_idx.get_indexer(self.new_idx)
indexer = self.idx

# Indexes real fast, and does outer indexing
return pd.api.extensions.take(
Expand Down Expand Up @@ -705,7 +705,11 @@
else:
if len(self.new_idx) > len(self.old_idx):
el = ak.pad_none(el, 1, axis=axis) # axis == 0
return el[self.old_idx.get_indexer(self.new_idx)]
return el[self.idx]

@property
def idx(self):
return self.old_idx.get_indexer(self.new_idx)


def merge_indices(
Expand Down Expand Up @@ -940,16 +944,11 @@
els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray],
axis: Literal[0, 1] = 0,
fill_value: Any | None = None,
off_axis_size: int = 0,
) -> np.ndarray | DaskArray:
"""Generates value to use when there is a missing element."""
should_return_dask = any(isinstance(el, DaskArray) for el in els)
try:
non_missing_elem = next(el for el in els if not_missing(el))
except StopIteration: # pragma: no cover
msg = "All elements are missing when attempting to generate missing elements."
raise ValueError(msg)
# 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting.
off_axis_size = 0 if not should_return_dask else non_missing_elem.shape[axis - 1]
shape = (n, off_axis_size) if axis == 0 else (off_axis_size, n)
if should_return_dask:
import dask.array as da
Expand All @@ -973,13 +972,26 @@
else:
cur_reindexers = reindexers

# Dask needs to create a full array and can't do the size-0 trick
off_axis_size = 0
if any(isinstance(e, DaskArray) for e in els):
if not isinstance(cur_reindexers[0], Reindexer):
msg = "Cannot re-index a dask array without a Reindexer"
raise ValueError(msg)

Check warning on line 980 in src/anndata/_core/merge.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/merge.py#L979-L980

Added lines #L979 - L980 were not covered by tests
off_axis_size = cur_reindexers[0].idx.shape[0]
# Handling of missing values here is hacky for dataframes
# We should probably just handle missing elements for all types
result[k] = concat_arrays(
[
el
if not_missing(el)
else missing_element(n, axis=axis, els=els, fill_value=fill_value)
else missing_element(
n,
axis=axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns)
],
cur_reindexers,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ def test_concat_missing_elem_dask_join(join_type):

import anndata as ad

ad1 = ad.AnnData(X=np.ones((5, 5)))
ad1 = ad.AnnData(X=np.ones((5, 10)))
ad2 = ad.AnnData(X=np.zeros((5, 5)), layers={"a": da.ones((5, 5))})
ad_in_memory_with_layers = ad2.to_memory()

Expand All @@ -1556,11 +1556,12 @@ def test_impute_dask(axis_name):

axis, _ = _resolve_axis(axis_name)
els = [da.ones((5, 5))]
missing = missing_element(6, els, axis=axis)
missing = missing_element(6, els, axis=axis, off_axis_size=17)
assert isinstance(missing, DaskArray)
in_memory = missing.compute()
assert np.all(np.isnan(in_memory))
assert in_memory.shape[axis] == 6
assert in_memory.shape[axis - 1] == 17


def test_outer_concat_with_missing_value_for_df():
Expand Down