From 2058def78db2d8d86b08793d27bda2066ce9b0ea Mon Sep 17 00:00:00 2001 From: James Mathews Date: Mon, 25 Sep 2023 16:49:27 -0400 Subject: [PATCH] Fix cell index handling in proximity calculation. --- .../ondemand/providers/proximity_provider.py | 11 +++++++++++ spatialprofilingtoolbox/workflow/common/proximity.py | 7 ++++--- .../workflow/phenotype_proximity/core.py | 1 + 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/spatialprofilingtoolbox/ondemand/providers/proximity_provider.py b/spatialprofilingtoolbox/ondemand/providers/proximity_provider.py index 1a7b2e6d2..3362e7c03 100644 --- a/spatialprofilingtoolbox/ondemand/providers/proximity_provider.py +++ b/spatialprofilingtoolbox/ondemand/providers/proximity_provider.py @@ -47,6 +47,13 @@ def _create_ball_trees(self) -> None: } for study_name, _data_arrays in self.data_arrays.items() } + self.index_lookups = { + study_name: { + sample_identifier: tuple(df.index.to_list()) + for sample_identifier, df in _data_arrays.items() + } + for study_name, _data_arrays in self.data_arrays.items() + } @classmethod def get_or_create_feature_specification( @@ -136,6 +143,7 @@ def have_feature_computed(self, feature_specification: str) -> None: radius, self.get_cells(sample_identifier, study_name), self._get_tree(sample_identifier, study_name), + self._get_index_lookup(sample_identifier, study_name), ) message = 'Computed one feature value of %s: %s, %s' logger.debug(message, feature_specification, sample_identifier, value) @@ -148,3 +156,6 @@ def have_feature_computed(self, feature_specification: str) -> None: def _get_tree(self, sample_identifier: str, study_name: str) -> BallTree: return self.trees[study_name][sample_identifier] + + def _get_index_lookup(self, sample_identifier: str, study_name: str) -> tuple[int, ...]: + return self.index_lookups[study_name][sample_identifier] diff --git a/spatialprofilingtoolbox/workflow/common/proximity.py b/spatialprofilingtoolbox/workflow/common/proximity.py index 42ea892b4..f253942f4 100644 --- a/spatialprofilingtoolbox/workflow/common/proximity.py +++ b/spatialprofilingtoolbox/workflow/common/proximity.py @@ -17,7 +17,8 @@ def compute_proximity_metric_for_signature_pair( signature2: PhenotypeCriteria, radius: float, cells: DataFrame, - tree: BallTree + tree: BallTree, + index_lookup: tuple[int, ...], ) -> float | None: cells = cells.rename({ column: (column[2:] if (column.startswith('C ') or column.startswith('P ')) else column) @@ -38,8 +39,8 @@ def compute_proximity_metric_for_signature_pair( return_distance=False, ) counts = [ - sum(mask2[index] for index in list(indices)) - for indices in within_radius_indices_list + sum(mask2[index_lookup[integer_index]] for integer_index in list(integer_indices)) + for integer_indices in within_radius_indices_list ] count = sum(counts) - sum(logical_and(mask1, mask2)) return count / source_count diff --git a/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py b/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py index 78e3e84ff..bccab9014 100644 --- a/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py +++ b/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py @@ -99,6 +99,7 @@ def calculate_proximity(self): r, cells, self.tree, + tuple(cells.index.to_list()), ) for f1, f2, r in self.get_cases(all_features) } self.write_table(proximity_metrics)