Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): correct shape for empty element in concat with dask #1843

Merged
merged 9 commits into from
Feb 13, 2025
31 changes: 21 additions & 10 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@
shape[axis] = len(self.new_idx)
return da.broadcast_to(fill_value, tuple(shape))

indexer = self.old_idx.get_indexer(self.new_idx)
indexer = self.get_new_idx_from_old_idx()
sub_el = _subset(el, make_slice(indexer, axis, len(shape)))

if any(indexer == -1):
Expand Down Expand Up @@ -607,7 +607,7 @@
shape[axis] = len(self.new_idx)
return np.broadcast_to(fill_value, tuple(shape))

indexer = self.old_idx.get_indexer(self.new_idx)
indexer = self.get_new_idx_from_old_idx()

# Indexes real fast, and does outer indexing
return pd.api.extensions.take(
Expand Down Expand Up @@ -705,7 +705,10 @@
else:
if len(self.new_idx) > len(self.old_idx):
el = ak.pad_none(el, 1, axis=axis) # axis == 0
return el[self.old_idx.get_indexer(self.new_idx)]
return el[self.get_new_idx_from_old_idx()]

def get_new_idx_from_old_idx(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, is that the correct name? wouldn’t the accurate name contain an into or do I understand this API wrong (pandas docs are down, so I can‘t check)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about into but perhaps just indexer and make this a property

return self.old_idx.get_indexer(self.new_idx)


def merge_indices(
Expand Down Expand Up @@ -940,16 +943,11 @@
els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray],
axis: Literal[0, 1] = 0,
fill_value: Any | None = None,
off_axis_size: int = 0,
) -> np.ndarray | DaskArray:
"""Generates value to use when there is a missing element."""
should_return_dask = any(isinstance(el, DaskArray) for el in els)
try:
non_missing_elem = next(el for el in els if not_missing(el))
except StopIteration: # pragma: no cover
msg = "All elements are missing when attempting to generate missing elements."
raise ValueError(msg)
# 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting.
off_axis_size = 0 if not should_return_dask else non_missing_elem.shape[axis - 1]
shape = (n, off_axis_size) if axis == 0 else (off_axis_size, n)
if should_return_dask:
import dask.array as da
Expand All @@ -973,13 +971,26 @@
else:
cur_reindexers = reindexers

# Dask needs to create a full array and can't do the size-0 trick
off_axis_size = 0
if any(isinstance(e, DaskArray) for e in els):
if not isinstance(cur_reindexers[0], Reindexer):
msg = "Cannot re-index a dask array without a Reindexer"
raise ValueError(msg)

Check warning on line 979 in src/anndata/_core/merge.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/merge.py#L978-L979

Added lines #L978 - L979 were not covered by tests
off_axis_size = cur_reindexers[0].get_new_idx_from_old_idx().shape[0]
# Handling of missing values here is hacky for dataframes
# We should probably just handle missing elements for all types
result[k] = concat_arrays(
[
el
if not_missing(el)
else missing_element(n, axis=axis, els=els, fill_value=fill_value)
else missing_element(
n,
axis=axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns)
],
cur_reindexers,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ def test_concat_missing_elem_dask_join(join_type):

import anndata as ad

ad1 = ad.AnnData(X=np.ones((5, 5)))
ad1 = ad.AnnData(X=np.ones((5, 10)))
ad2 = ad.AnnData(X=np.zeros((5, 5)), layers={"a": da.ones((5, 5))})
ad_in_memory_with_layers = ad2.to_memory()

Expand All @@ -1556,11 +1556,12 @@ def test_impute_dask(axis_name):

axis, _ = _resolve_axis(axis_name)
els = [da.ones((5, 5))]
missing = missing_element(6, els, axis=axis)
missing = missing_element(6, els, axis=axis, off_axis_size=17)
assert isinstance(missing, DaskArray)
in_memory = missing.compute()
assert np.all(np.isnan(in_memory))
assert in_memory.shape[axis] == 6
assert in_memory.shape[axis - 1] == 17


def test_outer_concat_with_missing_value_for_df():
Expand Down