Skip to content

Commit

Permalink
LandmarkMatcher class
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Jul 5, 2023
1 parent 2a452c3 commit 19c7e2f
Showing 1 changed file with 240 additions and 0 deletions.
240 changes: 240 additions & 0 deletions pymaid/fetch/landmarks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import typing as tp
from functools import cache

import numpy as np
import pandas as pd

from pymaid.client import CatmaidInstance

from ..utils import _eval_remote_instance, DataFrameBuilder


Expand Down Expand Up @@ -39,6 +43,7 @@ def get_landmarks(
url = cm.make_url(
cm.project_id,
"landmarks",
# todo: don't download locations if we're not going to use them?
with_locations="true",
)
landmark_builder = DataFrameBuilder(
Expand Down Expand Up @@ -172,3 +177,238 @@ def get_landmark_groups(
locations = None

return groups, locations, members


class LandmarkMatcher:
"""Class for finding matching pairs of landmark locations between two groups."""
def __init__(
self,
landmarks: pd.DataFrame,
landmark_locations: pd.DataFrame,
groups: pd.DataFrame,
group_locations: pd.DataFrame,
group_members: dict[int, tp.Iterable[int]],
):
self.landmarks = landmarks
self.landmark_locations = landmark_locations
self.groups = groups
self.group_locations = group_locations
self.group_members: dict[int, set[int]] = {
k: set(v) for k, v in group_members.items()
}

@classmethod
def from_catmaid(cls, remote_instance=None):
cm = _eval_remote_instance(remote_instance)
landmarks, landmark_locations = get_landmarks(True, cm)
groups, group_locations, members = get_landmark_groups(True, True, remote_instance=cm)
return cls(landmarks, landmark_locations, groups, group_locations, members)

@cache
def _group_name_to_id(self) -> dict[str, int]:
return dict(zip(self.groups["name"], self.groups["group_id"]))

def _locations_in_group(self, group_id: int):
idx = self.group_locations["group_id"] == group_id
return self.group_locations["location_id"][idx]

def _locations_in_landmark(self, landmark_id: int):
idx = self.landmark_locations["landmark_id"] = landmark_id
return self.landmark_locations["location_id"][idx]

def _unique_location(self, group_id: int, landmark_id: int) -> int:
if landmark_id not in self.group_members[group_id]:
return None

landmark_ids = set(self._locations_in_group(group_id)).intersection(self._locations_in_landmark(landmark_id))
if len(landmark_ids) == 1:
return landmark_ids.pop()
return None

@cache
def _landmark_id_to_name(self) -> dict[int, str]:
return dict(zip(self.landmarks["landmark_id"], self.landmarks["name"]))

@cache
def _locations(self) -> dict[int, tuple[int, int, int]]:
d = dict()
for loc_id, x, y, z, _ in self.group_locations.itertuples(index=False):
d[loc_id] = (x, y, z)

for loc_id, x, y, z, _ in self.landmark_locations.itertuples(index=False):
d[loc_id] = (x, y, z)

return d

def _as_group_id(self, group: tp.Union[str, int]) -> int:
if isinstance(group, int):
return group
if isinstance(group, str):
return self._group_name_to_id()[group]
raise ValueError("Given group must be int ID or str name")

def match(self, group1: tp.Union[str, int], group2: tp.Union[str, int]) -> pd.DataFrame:
"""Get matching pairs of landmarks for two groups.
Return the paired locations of landmarks
which are a members of both groups,
and have a single location in each group.
Parameters
----------
group1 : tp.Union[str, int]
First group name (as str) or ID (as int)
group2 : tp.Union[str, int]
Second group name (as str) or ID (as int)
Returns
-------
pd.DataFrame
Columns landmark_name, landmark_id, location_id1, x1, y1, z1, location_id2, x2, y2, z2
"""
group_to_id = self._group_name_to_id()
group1_id = group1 if isinstance(group1, int) else group_to_id[group1]
group2_id = group2 if isinstance(group2, int) else group_to_id[group2]

group1_locs = set(self._locations_in_group(group1_id))
group2_locs = set(self._locations_in_group(group2_id))

locs = self._locations()

dtypes = {"landmark_name": str, "landmark_id": np.uint64}
for g in ["1", "2"]:
dtypes["location_id" + g] = np.uint64
for d in "xyz":
dtypes[d + g] = np.float64

lm_id_to_name = self._landmark_id_to_name()

rows = []
for lm_id in set(self.group_members[group1_id]).intersection(self.group_members[group2_id]):
lm_locs = self._locations_in_landmark(lm_id)
g1_locs = group1_locs.intersection(lm_locs)
if not len(g1_locs) == 1:
continue
g2_locs = group2_locs.intersection(lm_locs)
if not len(g1_locs) == 1:
continue

row = [lm_id, lm_id_to_name[lm_id]]
row.append(g1_locs.pop())
row.extend(locs[row[-1]])
row.append(g2_locs.pop())
row.extend(locs[row[-1]])

rows.append(row)

df = pd.DataFrame(rows, columns=list(dtypes), dtype=object)
return df.astype(dtypes)


class CrossInstanceLandmarkMatcher:
def __init__(self, this_lms: LandmarkMatcher, other_lms: LandmarkMatcher):
self.this_m: LandmarkMatcher = this_lms
self.other_m: LandmarkMatcher = other_lms

@classmethod
def from_catmaid(cls, other_remote_instance: CatmaidInstance, this_remote_instance=None):
this_remote_instance = _eval_remote_instance(this_remote_instance)
return cls(
LandmarkMatcher.from_catmaid(this_remote_instance),
LandmarkMatcher.from_catmaid(other_remote_instance),
)

def _member_landmarks(self, group_id: int, this=True) -> dict[str, int]:
matcher = self.this_m if this else self.other_m
lmid_to_name = matcher._landmark_id_to_name()
out = dict()
for lmid in matcher.group_members[group_id]:
out[lmid_to_name[lmid]] = lmid
return out

def match(
self,
this_group: tp.Union[str, int],
other_group: tp.Optional[tp.Union[str, int]] = None
) -> pd.DataFrame:
"""Match landmark locations between two instance of CATMAID.
Looks through the members of the two groups to find landmarks with the same name.
If those landmarks have one location in each group, match those.
Parameters
----------
this_group : tp.Union[str, int]
Group name (str) or ID (int) on this instance.
other_group : tp.Optional[tp.Union[str, int]], optional
Group name (str) or ID (int) on the other instance.
If None (default) and ``this_group`` is a str name, use that name.
Returns
-------
pd.DataFrame
Columns landmark_name, x1, y1, z1, x2, y2, z2
where 1 is "this" and 2 is "other".
"""
this_group_id = self.this_m._as_group_id(this_group)
if other_group is None:
if isinstance(this_group, int):
raise ValueError("If other_group is None, this_group must be a str name")
other_group = this_group
other_group_id = self.other_m._as_group_id(other_group)

this_locs = self.this_m._locations()
other_locs = self.other_m._locations()

this_lms = self._member_landmarks(this_group_id)
other_lms = self._member_landmarks(other_group_id, False)

dtypes = {"landmark_name": str}
for g in "12":
for d in "xyz":
dtypes[d + g] = np.float64

rows = []

for lm_name, this_lmid in this_lms.items():
if lm_name not in other_lms:
continue
other_lmid = other_lms[lm_name]

this_loc_id = self.this_m._unique_location(this_group_id, this_lmid)
if this_loc_id is None:
continue

other_loc_id = self.other_m._unique_location(other_group_id, other_lmid)
if other_loc_id is None:
continue

row = [lm_name]
row.extend(this_locs[this_loc_id])
row.extend(other_locs[other_loc_id])

rows.append(row)

df = pd.DataFrame(rows, columns=list(dtypes), dtype=object)
return df.astype(dtypes)

def match_all(self) -> pd.DataFrame:
"""Match all landmark locations between two instances of CATMAID.
Looks for all groups which share a name,
then all landmark members of those groups which have a single location in each group.
Returns
-------
pd.DataFrame
Columns group_name, landmark_name, x1, y1, z1, x2, y2, z2
"""
shared_groups = set(
self.this_m.groups["name"]
).intersection(self.other_m.groups["name"])
dfs = []
for group_name in sorted(shared_groups):
df = self.match(group_name)
df.insert(0, "group_name", [group_name] * len(df))
dfs.append(df)
return pd.concat(dfs)

0 comments on commit 19c7e2f

Please sign in to comment.