From 3070d6a16b5d10a48ce599d8709296ef6fd73b83 Mon Sep 17 00:00:00 2001 From: Mike Lin Date: Mon, 9 Sep 2024 09:23:26 -1000 Subject: [PATCH] [python] similarity search API: optimize predict_obs_metadata (#1257) * squash for PR * use DEFAULT_TILEDB_CONFIGURATION * workaround * workaround * fix * resolve indexes through JSONs * lint * API refactoring * Update api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py Co-authored-by: Isaac Virshup * fixups * Update api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py Co-authored-by: Isaac Virshup --------- Co-authored-by: Isaac Virshup --- .../experimental/_embedding_search.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py index 8d095a08c..de09e2060 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py @@ -10,6 +10,7 @@ import pandas as pd import tiledb.vector_search as vs import tiledbsoma as soma +from scipy import sparse from .._experiment import _get_experiment_name from .._open import DEFAULT_TILEDB_CONFIGURATION, open_soma @@ -157,12 +158,28 @@ def predict_obs_metadata( # step through query cells to generate prediction for each column as the plurality value # found among its neighbors, with a confidence score based on the simple fraction (for now) # TODO: something more intelligent for numeric columns! also use distances, etc. - out: dict[str, list[Any]] = {} - for i in range(neighbors.neighbor_ids.shape[0]): - neighbors_i = neighbor_obs.loc[neighbors.neighbor_ids[i]] - for col in column_names: - col_value_counts = neighbors_i[col].value_counts(normalize=True) - out.setdefault(col, []).append(col_value_counts.idxmax()) - out.setdefault(col + "_confidence", []).append(col_value_counts.max()) + max_joinid = neighbor_obs.index.max() + out: dict[str, pd.Series[Any]] = {} + indices = np.broadcast_to(np.arange(neighbors.neighbor_ids.shape[0]), (10, neighbors.neighbor_ids.shape[0])).T + g = sparse.csr_matrix( + ( + np.broadcast_to(1, neighbors.neighbor_ids.shape[0] * 10), + ( + indices.flatten(), + neighbors.neighbor_ids.astype(np.int64).flatten(), + ), + ), + shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1), + ) + for col in column_names: + col_categorical = neighbor_obs[col].astype("category") + joinid2category = sparse.coo_matrix( + (np.broadcast_to(1, len(neighbor_obs)), (neighbor_obs.index, col_categorical.cat.codes)), + shape=(max_joinid + 1, len(col_categorical.cat.categories)), + ) + counts = g @ joinid2category + rel_counts = counts / counts.sum(axis=1) + out[col] = col_categorical.cat.categories[rel_counts.argmax(axis=1).A.flatten()].astype(object) + out[f"{col}_confidence"] = rel_counts.max(axis=1).toarray().flatten() return pd.DataFrame.from_dict(out)