Skip to content

Commit

Permalink
Further enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
mbernstein committed Aug 21, 2023
1 parent 99c291f commit 179eda9
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 78 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"tqdm",
"scikit-learn>=0.22",
"omnipath>=1.0.7",
"tqdm>=4.64.1",
# for debug logging (referenced from the issue template)
"session-info",
]
Expand Down
156 changes: 108 additions & 48 deletions src/monkeybread/calc/_cell_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,31 @@
import itertools
from collections import Counter, defaultdict
from typing import Dict, List, Optional, Union

from tqdm import tqdm
import numpy as np
from anndata import AnnData
from sklearn.metrics import pairwise_distances
from scipy.spatial import cKDTree
from scipy.sparse import dok_matrix

def _sparse_distance_matrix(data, distance_threshold):
N = data.shape[0]

# Build a KD-Tree for fast nearest neighbor search
kdtree = cKDTree(data)

# Initialize a sparse matrix to store the distance matrix
distance_matrix = dok_matrix((N, N), dtype=float)

# Query the KD-Tree for neighbors within the distance threshold
# and set the corresponding entries in the distance matrix

for i in tqdm(range(N)):
neighbors = kdtree.query_ball_point(data[i], distance_threshold)
distances = np.linalg.norm(data[i] - data[neighbors], axis=1)
distance_matrix[i, neighbors] = distances
distance_matrix[neighbors, i] = distances
return distance_matrix


def cell_density(
Expand All @@ -18,17 +39,25 @@ def cell_density(
groupname: Optional[str] = None,
basis: Optional[str] = "X_spatial",
bandwidth: Optional[float] = 1.0,
resolution: Optional[float] = 1.0
approx: Optional[bool] = True,
resolution: Optional[float] = 1.0,
radius_threshold: Optional[float] = 250
) -> Union[str, Dict[str, str]]:
"""Calculates the spatial distribution of cells of a given cell type using a heuristic algorithm
based on kernel density estimation.
"""Calculates the spatial distribution of cells of a given cell type using kernel density
estimation.
This can be computed via two approximations: the first approximation is closest to the true
kernel density estimate, however, the kernel between pairs of cells are zeroed out if they
are greater than `radius_threshold` from eachother. This enables the kernel estimation to run
on sparse matrices.
Specifically, a grid is overlayed onto the spatial coordinates, and each cell is assigned to its
nearest gridpoint. Each gridpoint is then used in a kernel density estimation calculation. The
density value for each gridpoint is then assigned back to each cell. Fineness of the grid can be
modified using the `resolution` parameter where a high resolution leads to a more accurate
calculation of density for each cell. Final density values are scaled between 0 and 1
corresponding to the minimum and maximum density, respectively.
The second approximation is faster and can be run if `approx` is set to true. A grid is
overlayed onto the spatial coordinates, and each cell is assigned to its nearest gridpoint.
Each gridpoint is then used in a kernel density estimation calculation. The density value for
each gridpoint is then assigned back to each cell. Fineness of the grid can be modified using
the `resolution` parameter where a high resolution leads to a more accurate calculation of
density for each cell. Final density values are scaled between 0 and 1 corresponding to the
minimum and maximum density, respectively.
Parameters
----------
Expand All @@ -45,8 +74,14 @@ def cell_density(
Coordinates in `adata.obsm[{basis}]` to use. Defaults to `spatial`
bandwidth
Bandwidth for kernel density estimation
approx
If true, approximate the kernel density estimation by grouping cells into bins
resolution
How small each gridpoint is. Number of points per row/column = resolution * 10
How small each gridpoint is. Number of points per row/column = resolution * 10.
Used only if `approx` is `True`
radius_threshold
Sparsifies the pairwise distance matrix to only consider cells if they are within
this threshold distance from eachother. Only used if `approx` is False.
Returns
-------
Expand All @@ -61,45 +96,70 @@ def cell_density(
if groupname is None:
groupname = "_".join(groups)

# Multiply base resolution by 10 to get number of bins per row/column
res = int(10 * resolution)

# Calculate bounds of spatial coordinates, x and y increments for each bin, and number of bins
[x_coords, y_coords] = adata.obsm[basis].transpose()
(x_min, x_max) = (min(x_coords) - 0.01, max(x_coords) + 0.01)
(y_min, y_max) = (min(y_coords) - 0.01, max(y_coords) + 0.01)
(x_inc, y_inc) = ((x_max - x_min) / res, (y_max - y_min) / res)
num_bins = int(res**2)

# Calculate distance from the center of each bin to the center of every other bin
x_bin_centers = [x_min + x_inc * (i + 1) / 2 for i in range(res)]
y_bin_centers = [y_min + y_inc * (j + 1) / 2 for j in range(res)]
bin_centers = list(itertools.product(x_bin_centers, y_bin_centers))
bin_distances = pairwise_distances(bin_centers)

# Calculate kernel based on distances and bandwidth
kernel = 1 / np.exp(np.square(bin_distances / bandwidth))

# Determine the bin each cell belongs to based on its location
# Bin number corresponds to index in bin_distancesx_c
location_to_bin = lambda x, y: int((x - x_min) / x_inc) * res + int((y - y_min) / y_inc)
cell_to_bin = [location_to_bin(x, y) for (x, y) in adata.obsm[basis]]

bin_counts = np.zeros(num_bins, dtype=int)
if groupby == "all":
for bin_id, count in Counter(cell_to_bin).items():
bin_counts[bin_id] = count
if approx:
# Multiply base resolution by 10 to get number of bins per row/column
res = int(10 * resolution)

# Calculate bounds of spatial coordinates, x and y increments for each bin, and number of bins
[x_coords, y_coords] = adata.obsm[basis].transpose()
(x_min, x_max) = (min(x_coords) - 0.01, max(x_coords) + 0.01)
(y_min, y_max) = (min(y_coords) - 0.01, max(y_coords) + 0.01)
(x_inc, y_inc) = ((x_max - x_min) / res, (y_max - y_min) / res)
num_bins = int(res**2)

# Calculate distance from the center of each bin to the center of every other bin
x_bin_centers = [x_min + x_inc * (i + 1) / 2 for i in range(res)]
y_bin_centers = [y_min + y_inc * (j + 1) / 2 for j in range(res)]
bin_centers = list(itertools.product(x_bin_centers, y_bin_centers))
bin_distances = pairwise_distances(bin_centers)

# Calculate kernel based on distances and bandwidth
kernel = 1 / np.exp(np.square(bin_distances / bandwidth))

# Determine the bin each cell belongs to based on its location
# Bin number corresponds to index in bin_distancesx_c
location_to_bin = lambda x, y: int((x - x_min) / x_inc) * res + int((y - y_min) / y_inc)
cell_to_bin = [location_to_bin(x, y) for (x, y) in adata.obsm[basis]]

bin_counts = np.zeros(num_bins, dtype=int)
if groupby == "all":
for bin_id, count in Counter(cell_to_bin).items():
bin_counts[bin_id] = count
else:
# Iterate over each cell, counting the number in groups in each bin
for (cell_group, bin_id) in zip(adata.obs[groupby], cell_to_bin):
bin_counts[bin_id] += 1 if groupby == "all" or cell_group in groups else 0

# Compute bin densities
bin_densities = np.matmul(kernel, bin_counts)

# Scale between zero and one
min_density, max_density = min(bin_densities), max(bin_densities)
if min_density != max_density:
bin_densities = [
(d - min_density) / (max_density - min_density)
for d in bin_densities
]

# Map kernel densities back to cells and assign to a column in obs
cell_densities = [bin_densities[bin_id] for bin_id in cell_to_bin]
else:
# Iterate over each cell, counting the number in groups in each bin
for (cell_group, bin_id) in zip(adata.obs[groupby], cell_to_bin):
bin_counts[bin_id] += 1 if groupby == "all" or cell_group in groups else 0
# Compute kernel matrix
distances = _sparse_distance_matrix(adata.obsm[basis], radius_threshold)
kernel = np.expm1(-(distances/bandwidth).power(2))
kernel[kernel.nonzero()] += 1

# Get cell type binary vector
is_group = np.array(adata.obs[groupby].isin(groups)).astype(int)

bin_densities = np.matmul(kernel, bin_counts)
min_density, max_density = min(bin_densities), max(bin_densities)
if min_density != max_density:
bin_densities = [(d - min_density) / (max_density - min_density) for d in bin_densities]
# Compute densities
densities = kernel.dot(is_group)

# Scale between zero and one
cell_densities = (densities - densities.min()) / (densities.max() - densities.min())

# Map kernel densities back to cells and assign to a column in obs
cell_densities = [bin_densities[bin_id] for bin_id in cell_to_bin]
adata.obs[f"{groupby}_density_{groupname}"] = cell_densities
return f"{groupby}_density_{groupname}"



1 change: 1 addition & 0 deletions src/monkeybread/calc/_ligand_receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,4 @@ def ligand_receptor_score(
lr_pair_to_score = {(ligand, receptor): calculate_score(ligand, receptor) for ligand, receptor in lr_pairs}

return lr_pair_to_score

6 changes: 6 additions & 0 deletions src/monkeybread/plot/_embedding_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def embedding_zoom(
if not all([left_pct, top_pct, width_pct, height_pct]):
raise ValueError("Must provide left_pct, top_pct, width_pct, height_pct")

# Place in units of fractions rather than percents
left_pct /= 100
top_pct /= 100
width_pct /= 100
height_pct /= 100

# Add dot size to key-word arguments
if unzoom_s:
unzoom_kwargs = dict(kwargs, s=unzoom_s)
Expand Down
13 changes: 9 additions & 4 deletions src/monkeybread/plot/_shortest_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def shortest_distances_pairwise(
fmt: Optional[str] = '.2f',
order_x: Optional[List[str]] = None,
order_y: Optional[List[str]] = None,
figsize: Optional[Tuple[float, float]] = None,
show: Optional[bool] = True,
ax: Optional[plt.Axes] = None,
fig: Optional[plt.Figure] = None,
Expand Down Expand Up @@ -149,6 +150,8 @@ def shortest_distances_pairwise(
order_y
Order of labels along the y-axis (keys of the outer-dictionary within the `g1_to_g2_to_pval`
argument.
figsize
Dimensions of figure
show
If True, show the plot and don't return the `plt.Axes` object.
ax
Expand Down Expand Up @@ -188,13 +191,15 @@ def shortest_distances_pairwise(
)

# Create heatmap
if not figsize:
figsize=(
1.17*len(cell_types_x),
2*len(cell_types_y)
)
if ax is None:
fig, ax = plt.subplots(
1, 1,
figsize=(
1.17*len(cell_types_x),
2*len(cell_types_y)
)
figsize=figsize
)
sns.heatmap(
df_plot,
Expand Down
11 changes: 3 additions & 8 deletions src/monkeybread/stat/_ligand_receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import itertools
from typing import Dict, Optional, Set, Tuple

from tqdm import tqdm
import numpy as np
from anndata import AnnData

Expand All @@ -18,8 +18,7 @@ def ligand_receptor_score(
adata: AnnData,
contacts: Dict[str, Set[str]],
actual_scores: Dict[Tuple[str, str], float],
n_perms: Optional[int] = 100,
verbose: Optional[bool] = False
n_perms: Optional[int] = 100
) -> Dict[Tuple[str, str], Tuple[np.ndarray, float]]:
"""Calculates statistical significance of the co-expression of
ligand-receptor pairs between neighboring cells.
Expand Down Expand Up @@ -57,11 +56,7 @@ def ligand_receptor_score(
ligand_index_starts = np.array([0] + list(itertools.accumulate(len(v) for v in contacts.values())))

# Iterate over permutations
for i in range(n_perms):
if verbose:
if i % 10 == 0:
print(f"Generating permutation {i+1}...")

for i in tqdm(range(n_perms)):
np.random.shuffle(receptor_cells)
# Randomize linkages between ligand cells and receptor cells, pulling out the appropriate
# number of linkages for each ligand and receptor
Expand Down
1 change: 1 addition & 0 deletions src/monkeybread/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from ._neighbor_count import neighbor_count
from ._load_merscope import load_merscope
from ._subset_cells import subset_cells
from ._omnipath import load_ligand_receptor_pairs_omnipath
27 changes: 9 additions & 18 deletions src/monkeybread/util/_load_merscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

def load_merscope(
folder: Optional[str] = ".",
use_cache: Optional[str] = None,
transcript_locations: Optional[bool] = None,
paths: Optional[Dict[str, str]] = None,
) -> ad.AnnData:
Expand Down Expand Up @@ -46,23 +45,15 @@ def load_merscope(

data: ad.AnnData

if use_cache is not None:
if use_cache == "all":
return ad.read(f"{folder}/{paths['cache']}")
elif use_cache == "spatial":
data = ad.read(f"{folder}/{paths['cache']}")
else:
raise ValueError("use_cache must be None, 'all', or 'spatial'")
else:
# If cached data doesn't exist, read gene expression and basic spatial data
counts = sc.read(f"{folder}/{paths['counts']}", first_column_names=True, cache=True)
coordinates = pd.read_csv(f"{folder}/{paths['coordinates']}")
coordinates = coordinates.rename({"Unnamed: 0": "cell_id"}, axis=1)
data = counts[coordinates.cell_id] # Slice data by coordinate index
data.obsm["X_spatial"] = coordinates[["center_x", "center_y"]].to_numpy()
data.obs["width"] = coordinates["max_x"].to_numpy() - coordinates["min_x"].to_numpy()
data.obs["height"] = coordinates["max_y"].to_numpy() - coordinates["min_y"].to_numpy()
data.obs["fov"] = coordinates["fov"].to_numpy()
counts = sc.read(f"{folder}/{paths['counts']}", first_column_names=True, cache=True)
coordinates = pd.read_csv(f"{folder}/{paths['coordinates']}")
coordinates = coordinates.rename({"Unnamed: 0": "cell_id"}, axis=1)
data = counts[coordinates.cell_id] # Slice data by coordinate index
data.obsm["X_spatial"] = coordinates[["center_x", "center_y"]].to_numpy()
data.obs["width"] = coordinates["max_x"].to_numpy() - coordinates["min_x"].to_numpy()
data.obs["height"] = coordinates["max_y"].to_numpy() - coordinates["min_y"].to_numpy()
data.obs["fov"] = coordinates["fov"].to_numpy()
data.obs.index = [str(x) for x in data.obs.index] # Ensure cell indices are strings

# Read transcripts
if transcript_locations or (transcript_locations is None and os.path.exists(f"{folder}/{paths['transcripts']}")):
Expand Down
45 changes: 45 additions & 0 deletions src/monkeybread/util/_omnipath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional
from anndata import AnnData
from omnipath.interactions import import_intercell_network

def load_ligand_receptor_pairs_omnipath(
adata: AnnData,
require_gene: Optional[str]=None
):
"""
Load ligand-receptor pairs from the Omnipath database that are also
in the provided dataset.
Parameters
----------
adata
An AnnData object
require_gene
Only return ligand-receptor pairs where either the ligand or
receptor is this argument
Returns
-------
List of ligand-receptor pair tuples
"""
interactions = import_intercell_network(
transmitter_params={"categories": "ligand"},
receiver_params={"categories": "receptor"}
)
genes_in_data = set(adata.var_names)
lr_pairs = [
(s, t)
for s, t in zip(
interactions['genesymbol_intercell_source'],
interactions['genesymbol_intercell_target']
)
if s in genes_in_data and t in genes_in_data
]

if require_gene:
lr_pairs = [
(s, t)
for s, t in lr_pairs
if s == require_gene or t == require_gene
]
return lr_pairs

0 comments on commit 179eda9

Please sign in to comment.