Skip to content

Commit

Permalink
Merge pull request #158 from YeoLab/spatial-multimap
Browse files Browse the repository at this point in the history
multimapping shapes
  • Loading branch information
ckmah authored Dec 17, 2024
2 parents b0ddc03 + 4bccd91 commit 1046e01
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 20 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
Expand Down
7 changes: 6 additions & 1 deletion bento/geometry/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def overlay(
name: str,
how: str = "intersection",
make_valid: bool = True,
instance_map_type: str = "1to1",
):
"""Overlay two shape elements in a SpatialData object and store the result as a new shape element.
Expand All @@ -37,6 +38,8 @@ def overlay(
If True, correct invalid geometries with GeoPandas, by default True
instance_key : str
Name of the shape element to use as the instance for indexing, by default "cell_boundaries". If None, no indexing is performed.
instance_map_type : str, optional
Type of instance mapping to use. Options are "1to1", "1tomany", by default "1to1".
Returns
-------
Expand All @@ -46,7 +49,8 @@ def overlay(
shape1 = sdata[s1]
shape2 = sdata[s2]

new_shape = shape1.overlay(shape2, how=how, make_valid=make_valid)
new_shape = shape1.overlay(shape2, how=how, make_valid=make_valid)[["geometry"]]
new_shape.index = new_shape.index.astype(str)
new_shape.attrs = {}

transform = shape1.attrs
Expand All @@ -58,6 +62,7 @@ def overlay(
shape_keys=[name],
instance_key=get_instance_key(sdata),
feature_key=get_feature_key(sdata),
instance_map_type=instance_map_type,
)


Expand Down
46 changes: 38 additions & 8 deletions bento/io/_index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List
from typing import List, Union

import pandas as pd
import geopandas as gpd
from spatialdata._core.spatialdata import SpatialData

from spatialdata.models import ShapesModel
from .._utils import (
get_points,
set_points_metadata,
Expand Down Expand Up @@ -68,7 +68,12 @@ def _sjoin_points(
return sdata


def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]):
def _sjoin_shapes(
sdata: SpatialData,
instance_key: str,
shape_keys: List[str],
instance_map_type: Union[str, dict],
):
"""Adds polygon indexes to sdata.shapes[instance_key][shape_key] for point feature analysis.
Parameters
Expand Down Expand Up @@ -102,17 +107,36 @@ def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]):

# sjoin shapes to instance_key shape
for shape_key in shape_keys:
child_shape = sdata.shapes[shape_key]
child_shape = sdata.shapes[shape_key].copy()
child_attrs = child_shape.attrs
# Hack for polygons that are 99% contained in parent shape or have shared boundaries
child_shape = gpd.GeoDataFrame(geometry=child_shape.buffer(-10e-6))

# Map child shape index to parent shape and process the result

if instance_map_type == "1tomany":
child_shape = (
child_shape.sjoin(
parent_shape.reset_index(drop=True),
how="left",
predicate="covered_by",
)
.dissolve(by="index_right", observed=True, dropna=False)
.reset_index(drop=True)[["geometry"]]
)
child_shape.index = child_shape.index.astype(str)
child_shape = ShapesModel.parse(child_shape)
child_shape.attrs = child_attrs
sdata.shapes[shape_key] = child_shape

parent_shape = (
parent_shape.sjoin(child_shape, how="left", predicate="covers")
.reset_index()
.drop_duplicates(subset="index", keep="last")
.reset_index() # ignore any user defined index name
.drop_duplicates(
subset="index", keep="last"
) # Remove multiple child shapes mapped to same parent shape
.set_index("index")
.assign(
.assign( # can this just be fillna on index_right?
index_right=lambda df: df.loc[
~df["index_right"].duplicated(keep="first"), "index_right"
]
Expand All @@ -121,14 +145,20 @@ def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]):
)
.rename(columns={"index_right": shape_key})
)

if (
parent_shape[shape_key].dtype == "category"
and "" not in parent_shape[shape_key].cat.categories
):
parent_shape[shape_key] = parent_shape[shape_key].cat.add_categories([""])
parent_shape[shape_key] = parent_shape[shape_key].fillna("")

# Save shape index as column in instance_key shape
set_shape_metadata(
sdata, shape_key=instance_key, metadata=parent_shape[shape_key]
)

# Add instance_key shape index to shape
# Add instance_key shape index to child shape
instance_index = (
parent_shape.drop_duplicates(subset=shape_key)
.reset_index()
Expand Down
48 changes: 39 additions & 9 deletions bento/io/_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings
from typing import List
from typing import List, Union

import emoji
import spatialdata as sd
from anndata.utils import make_index_unique
from spatialdata import SpatialData
from spatialdata.models import TableModel
Expand All @@ -19,6 +20,7 @@ def prep(
feature_key: str = "feature_name",
instance_key: str = "cell_boundaries",
shape_keys: List[str] = ["cell_boundaries", "nucleus_boundaries"],
instance_map_type: Union[dict, str] = "1to1",
) -> SpatialData:
"""Computes spatial indices for elements in SpatialData to enable usage of bento-tools.
Expand All @@ -36,6 +38,11 @@ def prep(
Key for the shape that will be used as the instance for all indexing. Usually the cell shape.
shape_keys : str, list
List of shape names to index points to
instance_map_type : str, dict
Type of mapping to use for the instance shape. If "1to1", each instance shape will be mapped to a single shape at most.
If "1tomany", each instance shape will be mapped to one or more shapes;
multiple shapes mapped to the same instance shape will be merged into a single MultiPolygon.
Use a dict to specify different mapping types for each shape.
Returns
-------
Expand All @@ -51,6 +58,24 @@ def prep(
shape_gdf[shape_key] = shape_gdf["geometry"]
shape_gdf.index = make_index_unique(shape_gdf.index.astype(str))

transform = {
"global": sd.transformations.get_transformation(sdata.points[points_key])
}
if "global" in sdata.points[points_key].attrs["transform"]:
# Force points to 2D for Xenium data
if isinstance(transform["global"], sd.transformations.Scale):
transform = {
"global": sd.transformations.Scale(
scale=transform.to_scale_vector(["x", "y"]), axes=["x", "y"]
)
}
sdata.points[points_key] = sd.models.PointsModel.parse(
sdata.points[points_key].compute().reset_index(drop=True),
coordinates={"x": "x", "y": "y"},
feature_key=feature_key,
transformations=transform,
)

# sindex points and sjoin shapes if they have not been indexed or joined
point_sjoin = []
shape_sjoin = []
Expand All @@ -72,6 +97,19 @@ def prep(
sdata.points[points_key].attrs["spatialdata_attrs"]["instance_key"] = instance_key

pbar = tqdm(total=3)
if len(shape_sjoin) > 0:
pbar.set_description(
"Mapping shapes"
) # Map shapes must happen first; manyto1 mapping resets shape index
sdata = _sjoin_shapes(
sdata=sdata,
instance_key=instance_key,
shape_keys=shape_sjoin,
instance_map_type=instance_map_type,
)

pbar.update()

if len(point_sjoin) > 0:
pbar.set_description("Mapping points")
sdata = _sjoin_points(
Expand All @@ -82,14 +120,6 @@ def prep(

pbar.update()

if len(shape_sjoin) > 0:
pbar.set_description("Mapping shapes")
sdata = _sjoin_shapes(
sdata=sdata, instance_key=instance_key, shape_keys=shape_sjoin
)

pbar.update()

# Only keep points within instance_key shape
_sync_points(sdata, points_key)

Expand Down
1 change: 1 addition & 0 deletions bento/tools/_shape_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def density(sdata: SpatialData, shape_key: str, recompute: bool = False):
.query(f"{shape_key} != 'None'")[shape_key]
.value_counts()
.compute()
.reindex_like(sdata.shapes[shape_key])
)
area(sdata, shape_key)

Expand Down
Binary file added bento/tools/gene_sets/boyle2023.zip
Binary file not shown.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"upsetplot>=0.9.0",
"xgboost>=2.0.3",
"statsmodels>=0.14.1",
"scikit-learn>=1.4.2",
"scikit-learn<1.6.0",
"ipywidgets>=8.1.5",
]
license = "BSD-2-Clause"
Expand Down

0 comments on commit 1046e01

Please sign in to comment.