From 04859bede6661c50e362b05b95d6961cd84c2095 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Thu, 13 Feb 2025 07:26:10 -0800 Subject: [PATCH] Backport PR #1843: (fix): correct shape for empty element in concat with dask (#1861) Co-authored-by: Ilan Gold --- src/anndata/_core/merge.py | 32 ++++++++++++++++++++++---------- tests/test_concatenate.py | 5 +++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 1b11ad674..fb1c4ef07 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -564,7 +564,7 @@ def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None): shape[axis] = len(self.new_idx) return da.broadcast_to(fill_value, tuple(shape)) - indexer = self.old_idx.get_indexer(self.new_idx) + indexer = self.idx sub_el = _subset(el, make_slice(indexer, axis, len(shape))) if any(indexer == -1): @@ -607,7 +607,7 @@ def _apply_to_array(self, el, *, axis, fill_value=None): shape[axis] = len(self.new_idx) return np.broadcast_to(fill_value, tuple(shape)) - indexer = self.old_idx.get_indexer(self.new_idx) + indexer = self.idx # Indexes real fast, and does outer indexing return pd.api.extensions.take( @@ -705,7 +705,11 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None): else: if len(self.new_idx) > len(self.old_idx): el = ak.pad_none(el, 1, axis=axis) # axis == 0 - return el[self.old_idx.get_indexer(self.new_idx)] + return el[self.idx] + + @property + def idx(self): + return self.old_idx.get_indexer(self.new_idx) def merge_indices( @@ -940,16 +944,11 @@ def missing_element( els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray], axis: Literal[0, 1] = 0, fill_value: Any | None = None, + off_axis_size: int = 0, ) -> np.ndarray | DaskArray: """Generates value to use when there is a missing element.""" should_return_dask = any(isinstance(el, DaskArray) for el in els) - try: - non_missing_elem = next(el for el in els if not_missing(el)) - except StopIteration: # pragma: no cover - msg = "All elements are missing when attempting to generate missing elements." - raise ValueError(msg) # 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting. - off_axis_size = 0 if not should_return_dask else non_missing_elem.shape[axis - 1] shape = (n, off_axis_size) if axis == 0 else (off_axis_size, n) if should_return_dask: import dask.array as da @@ -973,13 +972,26 @@ def outer_concat_aligned_mapping( else: cur_reindexers = reindexers + # Dask needs to create a full array and can't do the size-0 trick + off_axis_size = 0 + if any(isinstance(e, DaskArray) for e in els): + if not isinstance(cur_reindexers[0], Reindexer): # pragma: no cover + msg = "Cannot re-index a dask array without a Reindexer" + raise ValueError(msg) + off_axis_size = cur_reindexers[0].idx.shape[0] # Handling of missing values here is hacky for dataframes # We should probably just handle missing elements for all types result[k] = concat_arrays( [ el if not_missing(el) - else missing_element(n, axis=axis, els=els, fill_value=fill_value) + else missing_element( + n, + axis=axis, + els=els, + fill_value=fill_value, + off_axis_size=off_axis_size, + ) for el, n in zip(els, ns) ], cur_reindexers, diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 70409ca4e..13d7c7617 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1538,7 +1538,7 @@ def test_concat_missing_elem_dask_join(join_type): import anndata as ad - ad1 = ad.AnnData(X=np.ones((5, 5))) + ad1 = ad.AnnData(X=np.ones((5, 10))) ad2 = ad.AnnData(X=np.zeros((5, 5)), layers={"a": da.ones((5, 5))}) ad_in_memory_with_layers = ad2.to_memory() @@ -1554,11 +1554,12 @@ def test_impute_dask(axis_name): axis, _ = _resolve_axis(axis_name) els = [da.ones((5, 5))] - missing = missing_element(6, els, axis=axis) + missing = missing_element(6, els, axis=axis, off_axis_size=17) assert isinstance(missing, DaskArray) in_memory = missing.compute() assert np.all(np.isnan(in_memory)) assert in_memory.shape[axis] == 6 + assert in_memory.shape[axis - 1] == 17 def test_outer_concat_with_missing_value_for_df():