From 179eda9f6a2bd71c2953e8cddbb036b93429ae92 Mon Sep 17 00:00:00 2001 From: mbernstein Date: Mon, 21 Aug 2023 21:22:59 +0000 Subject: [PATCH] Further enhancements --- pyproject.toml | 1 + src/monkeybread/calc/_cell_density.py | 156 ++++++++++++++------ src/monkeybread/calc/_ligand_receptor.py | 1 + src/monkeybread/plot/_embedding_other.py | 6 + src/monkeybread/plot/_shortest_distances.py | 13 +- src/monkeybread/stat/_ligand_receptor.py | 11 +- src/monkeybread/util/__init__.py | 1 + src/monkeybread/util/_load_merscope.py | 27 ++-- src/monkeybread/util/_omnipath.py | 45 ++++++ 9 files changed, 183 insertions(+), 78 deletions(-) create mode 100644 src/monkeybread/util/_omnipath.py diff --git a/pyproject.toml b/pyproject.toml index 5bad4d7..c15b9c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "tqdm", "scikit-learn>=0.22", "omnipath>=1.0.7", + "tqdm>=4.64.1", # for debug logging (referenced from the issue template) "session-info", ] diff --git a/src/monkeybread/calc/_cell_density.py b/src/monkeybread/calc/_cell_density.py index dd85fa4..18801ad 100644 --- a/src/monkeybread/calc/_cell_density.py +++ b/src/monkeybread/calc/_cell_density.py @@ -5,10 +5,31 @@ import itertools from collections import Counter, defaultdict from typing import Dict, List, Optional, Union - +from tqdm import tqdm import numpy as np from anndata import AnnData from sklearn.metrics import pairwise_distances +from scipy.spatial import cKDTree +from scipy.sparse import dok_matrix + +def _sparse_distance_matrix(data, distance_threshold): + N = data.shape[0] + + # Build a KD-Tree for fast nearest neighbor search + kdtree = cKDTree(data) + + # Initialize a sparse matrix to store the distance matrix + distance_matrix = dok_matrix((N, N), dtype=float) + + # Query the KD-Tree for neighbors within the distance threshold + # and set the corresponding entries in the distance matrix + + for i in tqdm(range(N)): + neighbors = kdtree.query_ball_point(data[i], distance_threshold) + distances = np.linalg.norm(data[i] - data[neighbors], axis=1) + distance_matrix[i, neighbors] = distances + distance_matrix[neighbors, i] = distances + return distance_matrix def cell_density( @@ -18,17 +39,25 @@ def cell_density( groupname: Optional[str] = None, basis: Optional[str] = "X_spatial", bandwidth: Optional[float] = 1.0, - resolution: Optional[float] = 1.0 + approx: Optional[bool] = True, + resolution: Optional[float] = 1.0, + radius_threshold: Optional[float] = 250 ) -> Union[str, Dict[str, str]]: - """Calculates the spatial distribution of cells of a given cell type using a heuristic algorithm - based on kernel density estimation. + """Calculates the spatial distribution of cells of a given cell type using kernel density + estimation. + + This can be computed via two approximations: the first approximation is closest to the true + kernel density estimate, however, the kernel between pairs of cells are zeroed out if they + are greater than `radius_threshold` from eachother. This enables the kernel estimation to run + on sparse matrices. - Specifically, a grid is overlayed onto the spatial coordinates, and each cell is assigned to its - nearest gridpoint. Each gridpoint is then used in a kernel density estimation calculation. The - density value for each gridpoint is then assigned back to each cell. Fineness of the grid can be - modified using the `resolution` parameter where a high resolution leads to a more accurate - calculation of density for each cell. Final density values are scaled between 0 and 1 - corresponding to the minimum and maximum density, respectively. + The second approximation is faster and can be run if `approx` is set to true. A grid is + overlayed onto the spatial coordinates, and each cell is assigned to its nearest gridpoint. + Each gridpoint is then used in a kernel density estimation calculation. The density value for + each gridpoint is then assigned back to each cell. Fineness of the grid can be modified using + the `resolution` parameter where a high resolution leads to a more accurate calculation of + density for each cell. Final density values are scaled between 0 and 1 corresponding to the + minimum and maximum density, respectively. Parameters ---------- @@ -45,8 +74,14 @@ def cell_density( Coordinates in `adata.obsm[{basis}]` to use. Defaults to `spatial` bandwidth Bandwidth for kernel density estimation + approx + If true, approximate the kernel density estimation by grouping cells into bins resolution - How small each gridpoint is. Number of points per row/column = resolution * 10 + How small each gridpoint is. Number of points per row/column = resolution * 10. + Used only if `approx` is `True` + radius_threshold + Sparsifies the pairwise distance matrix to only consider cells if they are within + this threshold distance from eachother. Only used if `approx` is False. Returns ------- @@ -61,45 +96,70 @@ def cell_density( if groupname is None: groupname = "_".join(groups) - # Multiply base resolution by 10 to get number of bins per row/column - res = int(10 * resolution) - - # Calculate bounds of spatial coordinates, x and y increments for each bin, and number of bins - [x_coords, y_coords] = adata.obsm[basis].transpose() - (x_min, x_max) = (min(x_coords) - 0.01, max(x_coords) + 0.01) - (y_min, y_max) = (min(y_coords) - 0.01, max(y_coords) + 0.01) - (x_inc, y_inc) = ((x_max - x_min) / res, (y_max - y_min) / res) - num_bins = int(res**2) - - # Calculate distance from the center of each bin to the center of every other bin - x_bin_centers = [x_min + x_inc * (i + 1) / 2 for i in range(res)] - y_bin_centers = [y_min + y_inc * (j + 1) / 2 for j in range(res)] - bin_centers = list(itertools.product(x_bin_centers, y_bin_centers)) - bin_distances = pairwise_distances(bin_centers) - - # Calculate kernel based on distances and bandwidth - kernel = 1 / np.exp(np.square(bin_distances / bandwidth)) - - # Determine the bin each cell belongs to based on its location - # Bin number corresponds to index in bin_distancesx_c - location_to_bin = lambda x, y: int((x - x_min) / x_inc) * res + int((y - y_min) / y_inc) - cell_to_bin = [location_to_bin(x, y) for (x, y) in adata.obsm[basis]] - - bin_counts = np.zeros(num_bins, dtype=int) - if groupby == "all": - for bin_id, count in Counter(cell_to_bin).items(): - bin_counts[bin_id] = count + if approx: + # Multiply base resolution by 10 to get number of bins per row/column + res = int(10 * resolution) + + # Calculate bounds of spatial coordinates, x and y increments for each bin, and number of bins + [x_coords, y_coords] = adata.obsm[basis].transpose() + (x_min, x_max) = (min(x_coords) - 0.01, max(x_coords) + 0.01) + (y_min, y_max) = (min(y_coords) - 0.01, max(y_coords) + 0.01) + (x_inc, y_inc) = ((x_max - x_min) / res, (y_max - y_min) / res) + num_bins = int(res**2) + + # Calculate distance from the center of each bin to the center of every other bin + x_bin_centers = [x_min + x_inc * (i + 1) / 2 for i in range(res)] + y_bin_centers = [y_min + y_inc * (j + 1) / 2 for j in range(res)] + bin_centers = list(itertools.product(x_bin_centers, y_bin_centers)) + bin_distances = pairwise_distances(bin_centers) + + # Calculate kernel based on distances and bandwidth + kernel = 1 / np.exp(np.square(bin_distances / bandwidth)) + + # Determine the bin each cell belongs to based on its location + # Bin number corresponds to index in bin_distancesx_c + location_to_bin = lambda x, y: int((x - x_min) / x_inc) * res + int((y - y_min) / y_inc) + cell_to_bin = [location_to_bin(x, y) for (x, y) in adata.obsm[basis]] + + bin_counts = np.zeros(num_bins, dtype=int) + if groupby == "all": + for bin_id, count in Counter(cell_to_bin).items(): + bin_counts[bin_id] = count + else: + # Iterate over each cell, counting the number in groups in each bin + for (cell_group, bin_id) in zip(adata.obs[groupby], cell_to_bin): + bin_counts[bin_id] += 1 if groupby == "all" or cell_group in groups else 0 + + # Compute bin densities + bin_densities = np.matmul(kernel, bin_counts) + + # Scale between zero and one + min_density, max_density = min(bin_densities), max(bin_densities) + if min_density != max_density: + bin_densities = [ + (d - min_density) / (max_density - min_density) + for d in bin_densities + ] + + # Map kernel densities back to cells and assign to a column in obs + cell_densities = [bin_densities[bin_id] for bin_id in cell_to_bin] else: - # Iterate over each cell, counting the number in groups in each bin - for (cell_group, bin_id) in zip(adata.obs[groupby], cell_to_bin): - bin_counts[bin_id] += 1 if groupby == "all" or cell_group in groups else 0 + # Compute kernel matrix + distances = _sparse_distance_matrix(adata.obsm[basis], radius_threshold) + kernel = np.expm1(-(distances/bandwidth).power(2)) + kernel[kernel.nonzero()] += 1 + + # Get cell type binary vector + is_group = np.array(adata.obs[groupby].isin(groups)).astype(int) - bin_densities = np.matmul(kernel, bin_counts) - min_density, max_density = min(bin_densities), max(bin_densities) - if min_density != max_density: - bin_densities = [(d - min_density) / (max_density - min_density) for d in bin_densities] + # Compute densities + densities = kernel.dot(is_group) + + # Scale between zero and one + cell_densities = (densities - densities.min()) / (densities.max() - densities.min()) - # Map kernel densities back to cells and assign to a column in obs - cell_densities = [bin_densities[bin_id] for bin_id in cell_to_bin] adata.obs[f"{groupby}_density_{groupname}"] = cell_densities return f"{groupby}_density_{groupname}" + + + diff --git a/src/monkeybread/calc/_ligand_receptor.py b/src/monkeybread/calc/_ligand_receptor.py index 96c6810..4c9d501 100644 --- a/src/monkeybread/calc/_ligand_receptor.py +++ b/src/monkeybread/calc/_ligand_receptor.py @@ -81,3 +81,4 @@ def ligand_receptor_score( lr_pair_to_score = {(ligand, receptor): calculate_score(ligand, receptor) for ligand, receptor in lr_pairs} return lr_pair_to_score + diff --git a/src/monkeybread/plot/_embedding_other.py b/src/monkeybread/plot/_embedding_other.py index 797a07a..bb4101d 100644 --- a/src/monkeybread/plot/_embedding_other.py +++ b/src/monkeybread/plot/_embedding_other.py @@ -158,6 +158,12 @@ def embedding_zoom( if not all([left_pct, top_pct, width_pct, height_pct]): raise ValueError("Must provide left_pct, top_pct, width_pct, height_pct") + # Place in units of fractions rather than percents + left_pct /= 100 + top_pct /= 100 + width_pct /= 100 + height_pct /= 100 + # Add dot size to key-word arguments if unzoom_s: unzoom_kwargs = dict(kwargs, s=unzoom_s) diff --git a/src/monkeybread/plot/_shortest_distances.py b/src/monkeybread/plot/_shortest_distances.py index 87d26b6..29766d5 100644 --- a/src/monkeybread/plot/_shortest_distances.py +++ b/src/monkeybread/plot/_shortest_distances.py @@ -122,6 +122,7 @@ def shortest_distances_pairwise( fmt: Optional[str] = '.2f', order_x: Optional[List[str]] = None, order_y: Optional[List[str]] = None, + figsize: Optional[Tuple[float, float]] = None, show: Optional[bool] = True, ax: Optional[plt.Axes] = None, fig: Optional[plt.Figure] = None, @@ -149,6 +150,8 @@ def shortest_distances_pairwise( order_y Order of labels along the y-axis (keys of the outer-dictionary within the `g1_to_g2_to_pval` argument. + figsize + Dimensions of figure show If True, show the plot and don't return the `plt.Axes` object. ax @@ -188,13 +191,15 @@ def shortest_distances_pairwise( ) # Create heatmap + if not figsize: + figsize=( + 1.17*len(cell_types_x), + 2*len(cell_types_y) + ) if ax is None: fig, ax = plt.subplots( 1, 1, - figsize=( - 1.17*len(cell_types_x), - 2*len(cell_types_y) - ) + figsize=figsize ) sns.heatmap( df_plot, diff --git a/src/monkeybread/stat/_ligand_receptor.py b/src/monkeybread/stat/_ligand_receptor.py index 93928ab..072fbf8 100644 --- a/src/monkeybread/stat/_ligand_receptor.py +++ b/src/monkeybread/stat/_ligand_receptor.py @@ -7,7 +7,7 @@ import itertools from typing import Dict, Optional, Set, Tuple - +from tqdm import tqdm import numpy as np from anndata import AnnData @@ -18,8 +18,7 @@ def ligand_receptor_score( adata: AnnData, contacts: Dict[str, Set[str]], actual_scores: Dict[Tuple[str, str], float], - n_perms: Optional[int] = 100, - verbose: Optional[bool] = False + n_perms: Optional[int] = 100 ) -> Dict[Tuple[str, str], Tuple[np.ndarray, float]]: """Calculates statistical significance of the co-expression of ligand-receptor pairs between neighboring cells. @@ -57,11 +56,7 @@ def ligand_receptor_score( ligand_index_starts = np.array([0] + list(itertools.accumulate(len(v) for v in contacts.values()))) # Iterate over permutations - for i in range(n_perms): - if verbose: - if i % 10 == 0: - print(f"Generating permutation {i+1}...") - + for i in tqdm(range(n_perms)): np.random.shuffle(receptor_cells) # Randomize linkages between ligand cells and receptor cells, pulling out the appropriate # number of linkages for each ligand and receptor diff --git a/src/monkeybread/util/__init__.py b/src/monkeybread/util/__init__.py index fe95e8d..1e6e83b 100644 --- a/src/monkeybread/util/__init__.py +++ b/src/monkeybread/util/__init__.py @@ -2,3 +2,4 @@ from ._neighbor_count import neighbor_count from ._load_merscope import load_merscope from ._subset_cells import subset_cells +from ._omnipath import load_ligand_receptor_pairs_omnipath diff --git a/src/monkeybread/util/_load_merscope.py b/src/monkeybread/util/_load_merscope.py index bb57ba2..f6c6528 100644 --- a/src/monkeybread/util/_load_merscope.py +++ b/src/monkeybread/util/_load_merscope.py @@ -17,7 +17,6 @@ def load_merscope( folder: Optional[str] = ".", - use_cache: Optional[str] = None, transcript_locations: Optional[bool] = None, paths: Optional[Dict[str, str]] = None, ) -> ad.AnnData: @@ -46,23 +45,15 @@ def load_merscope( data: ad.AnnData - if use_cache is not None: - if use_cache == "all": - return ad.read(f"{folder}/{paths['cache']}") - elif use_cache == "spatial": - data = ad.read(f"{folder}/{paths['cache']}") - else: - raise ValueError("use_cache must be None, 'all', or 'spatial'") - else: - # If cached data doesn't exist, read gene expression and basic spatial data - counts = sc.read(f"{folder}/{paths['counts']}", first_column_names=True, cache=True) - coordinates = pd.read_csv(f"{folder}/{paths['coordinates']}") - coordinates = coordinates.rename({"Unnamed: 0": "cell_id"}, axis=1) - data = counts[coordinates.cell_id] # Slice data by coordinate index - data.obsm["X_spatial"] = coordinates[["center_x", "center_y"]].to_numpy() - data.obs["width"] = coordinates["max_x"].to_numpy() - coordinates["min_x"].to_numpy() - data.obs["height"] = coordinates["max_y"].to_numpy() - coordinates["min_y"].to_numpy() - data.obs["fov"] = coordinates["fov"].to_numpy() + counts = sc.read(f"{folder}/{paths['counts']}", first_column_names=True, cache=True) + coordinates = pd.read_csv(f"{folder}/{paths['coordinates']}") + coordinates = coordinates.rename({"Unnamed: 0": "cell_id"}, axis=1) + data = counts[coordinates.cell_id] # Slice data by coordinate index + data.obsm["X_spatial"] = coordinates[["center_x", "center_y"]].to_numpy() + data.obs["width"] = coordinates["max_x"].to_numpy() - coordinates["min_x"].to_numpy() + data.obs["height"] = coordinates["max_y"].to_numpy() - coordinates["min_y"].to_numpy() + data.obs["fov"] = coordinates["fov"].to_numpy() + data.obs.index = [str(x) for x in data.obs.index] # Ensure cell indices are strings # Read transcripts if transcript_locations or (transcript_locations is None and os.path.exists(f"{folder}/{paths['transcripts']}")): diff --git a/src/monkeybread/util/_omnipath.py b/src/monkeybread/util/_omnipath.py new file mode 100644 index 0000000..6bbbe01 --- /dev/null +++ b/src/monkeybread/util/_omnipath.py @@ -0,0 +1,45 @@ +from typing import Optional +from anndata import AnnData +from omnipath.interactions import import_intercell_network + +def load_ligand_receptor_pairs_omnipath( + adata: AnnData, + require_gene: Optional[str]=None + ): + """ + Load ligand-receptor pairs from the Omnipath database that are also + in the provided dataset. + + Parameters + ---------- + adata + An AnnData object + require_gene + Only return ligand-receptor pairs where either the ligand or + receptor is this argument + + Returns + ------- + List of ligand-receptor pair tuples + """ + interactions = import_intercell_network( + transmitter_params={"categories": "ligand"}, + receiver_params={"categories": "receptor"} + ) + genes_in_data = set(adata.var_names) + lr_pairs = [ + (s, t) + for s, t in zip( + interactions['genesymbol_intercell_source'], + interactions['genesymbol_intercell_target'] + ) + if s in genes_in_data and t in genes_in_data + ] + + if require_gene: + lr_pairs = [ + (s, t) + for s, t in lr_pairs + if s == require_gene or t == require_gene + ] + return lr_pairs