From 19c7e2fe0aac9b8c960411c21295061752a4451d Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Tue, 13 Jun 2023 14:06:26 +0100 Subject: [PATCH] LandmarkMatcher class --- pymaid/fetch/landmarks.py | 240 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) diff --git a/pymaid/fetch/landmarks.py b/pymaid/fetch/landmarks.py index 25e3743..0df0b29 100644 --- a/pymaid/fetch/landmarks.py +++ b/pymaid/fetch/landmarks.py @@ -1,7 +1,11 @@ import typing as tp +from functools import cache +import numpy as np import pandas as pd +from pymaid.client import CatmaidInstance + from ..utils import _eval_remote_instance, DataFrameBuilder @@ -39,6 +43,7 @@ def get_landmarks( url = cm.make_url( cm.project_id, "landmarks", + # todo: don't download locations if we're not going to use them? with_locations="true", ) landmark_builder = DataFrameBuilder( @@ -172,3 +177,238 @@ def get_landmark_groups( locations = None return groups, locations, members + + +class LandmarkMatcher: + """Class for finding matching pairs of landmark locations between two groups.""" + def __init__( + self, + landmarks: pd.DataFrame, + landmark_locations: pd.DataFrame, + groups: pd.DataFrame, + group_locations: pd.DataFrame, + group_members: dict[int, tp.Iterable[int]], + ): + self.landmarks = landmarks + self.landmark_locations = landmark_locations + self.groups = groups + self.group_locations = group_locations + self.group_members: dict[int, set[int]] = { + k: set(v) for k, v in group_members.items() + } + + @classmethod + def from_catmaid(cls, remote_instance=None): + cm = _eval_remote_instance(remote_instance) + landmarks, landmark_locations = get_landmarks(True, cm) + groups, group_locations, members = get_landmark_groups(True, True, remote_instance=cm) + return cls(landmarks, landmark_locations, groups, group_locations, members) + + @cache + def _group_name_to_id(self) -> dict[str, int]: + return dict(zip(self.groups["name"], self.groups["group_id"])) + + def _locations_in_group(self, group_id: int): + idx = self.group_locations["group_id"] == group_id + return self.group_locations["location_id"][idx] + + def _locations_in_landmark(self, landmark_id: int): + idx = self.landmark_locations["landmark_id"] = landmark_id + return self.landmark_locations["location_id"][idx] + + def _unique_location(self, group_id: int, landmark_id: int) -> int: + if landmark_id not in self.group_members[group_id]: + return None + + landmark_ids = set(self._locations_in_group(group_id)).intersection(self._locations_in_landmark(landmark_id)) + if len(landmark_ids) == 1: + return landmark_ids.pop() + return None + + @cache + def _landmark_id_to_name(self) -> dict[int, str]: + return dict(zip(self.landmarks["landmark_id"], self.landmarks["name"])) + + @cache + def _locations(self) -> dict[int, tuple[int, int, int]]: + d = dict() + for loc_id, x, y, z, _ in self.group_locations.itertuples(index=False): + d[loc_id] = (x, y, z) + + for loc_id, x, y, z, _ in self.landmark_locations.itertuples(index=False): + d[loc_id] = (x, y, z) + + return d + + def _as_group_id(self, group: tp.Union[str, int]) -> int: + if isinstance(group, int): + return group + if isinstance(group, str): + return self._group_name_to_id()[group] + raise ValueError("Given group must be int ID or str name") + + def match(self, group1: tp.Union[str, int], group2: tp.Union[str, int]) -> pd.DataFrame: + """Get matching pairs of landmarks for two groups. + + Return the paired locations of landmarks + which are a members of both groups, + and have a single location in each group. + + Parameters + ---------- + group1 : tp.Union[str, int] + First group name (as str) or ID (as int) + group2 : tp.Union[str, int] + Second group name (as str) or ID (as int) + + Returns + ------- + pd.DataFrame + Columns landmark_name, landmark_id, location_id1, x1, y1, z1, location_id2, x2, y2, z2 + """ + group_to_id = self._group_name_to_id() + group1_id = group1 if isinstance(group1, int) else group_to_id[group1] + group2_id = group2 if isinstance(group2, int) else group_to_id[group2] + + group1_locs = set(self._locations_in_group(group1_id)) + group2_locs = set(self._locations_in_group(group2_id)) + + locs = self._locations() + + dtypes = {"landmark_name": str, "landmark_id": np.uint64} + for g in ["1", "2"]: + dtypes["location_id" + g] = np.uint64 + for d in "xyz": + dtypes[d + g] = np.float64 + + lm_id_to_name = self._landmark_id_to_name() + + rows = [] + for lm_id in set(self.group_members[group1_id]).intersection(self.group_members[group2_id]): + lm_locs = self._locations_in_landmark(lm_id) + g1_locs = group1_locs.intersection(lm_locs) + if not len(g1_locs) == 1: + continue + g2_locs = group2_locs.intersection(lm_locs) + if not len(g1_locs) == 1: + continue + + row = [lm_id, lm_id_to_name[lm_id]] + row.append(g1_locs.pop()) + row.extend(locs[row[-1]]) + row.append(g2_locs.pop()) + row.extend(locs[row[-1]]) + + rows.append(row) + + df = pd.DataFrame(rows, columns=list(dtypes), dtype=object) + return df.astype(dtypes) + + +class CrossInstanceLandmarkMatcher: + def __init__(self, this_lms: LandmarkMatcher, other_lms: LandmarkMatcher): + self.this_m: LandmarkMatcher = this_lms + self.other_m: LandmarkMatcher = other_lms + + @classmethod + def from_catmaid(cls, other_remote_instance: CatmaidInstance, this_remote_instance=None): + this_remote_instance = _eval_remote_instance(this_remote_instance) + return cls( + LandmarkMatcher.from_catmaid(this_remote_instance), + LandmarkMatcher.from_catmaid(other_remote_instance), + ) + + def _member_landmarks(self, group_id: int, this=True) -> dict[str, int]: + matcher = self.this_m if this else self.other_m + lmid_to_name = matcher._landmark_id_to_name() + out = dict() + for lmid in matcher.group_members[group_id]: + out[lmid_to_name[lmid]] = lmid + return out + + def match( + self, + this_group: tp.Union[str, int], + other_group: tp.Optional[tp.Union[str, int]] = None + ) -> pd.DataFrame: + """Match landmark locations between two instance of CATMAID. + + Looks through the members of the two groups to find landmarks with the same name. + If those landmarks have one location in each group, match those. + + Parameters + ---------- + this_group : tp.Union[str, int] + Group name (str) or ID (int) on this instance. + other_group : tp.Optional[tp.Union[str, int]], optional + Group name (str) or ID (int) on the other instance. + If None (default) and ``this_group`` is a str name, use that name. + + Returns + ------- + pd.DataFrame + Columns landmark_name, x1, y1, z1, x2, y2, z2 + where 1 is "this" and 2 is "other". + """ + this_group_id = self.this_m._as_group_id(this_group) + if other_group is None: + if isinstance(this_group, int): + raise ValueError("If other_group is None, this_group must be a str name") + other_group = this_group + other_group_id = self.other_m._as_group_id(other_group) + + this_locs = self.this_m._locations() + other_locs = self.other_m._locations() + + this_lms = self._member_landmarks(this_group_id) + other_lms = self._member_landmarks(other_group_id, False) + + dtypes = {"landmark_name": str} + for g in "12": + for d in "xyz": + dtypes[d + g] = np.float64 + + rows = [] + + for lm_name, this_lmid in this_lms.items(): + if lm_name not in other_lms: + continue + other_lmid = other_lms[lm_name] + + this_loc_id = self.this_m._unique_location(this_group_id, this_lmid) + if this_loc_id is None: + continue + + other_loc_id = self.other_m._unique_location(other_group_id, other_lmid) + if other_loc_id is None: + continue + + row = [lm_name] + row.extend(this_locs[this_loc_id]) + row.extend(other_locs[other_loc_id]) + + rows.append(row) + + df = pd.DataFrame(rows, columns=list(dtypes), dtype=object) + return df.astype(dtypes) + + def match_all(self) -> pd.DataFrame: + """Match all landmark locations between two instances of CATMAID. + + Looks for all groups which share a name, + then all landmark members of those groups which have a single location in each group. + + Returns + ------- + pd.DataFrame + Columns group_name, landmark_name, x1, y1, z1, x2, y2, z2 + """ + shared_groups = set( + self.this_m.groups["name"] + ).intersection(self.other_m.groups["name"]) + dfs = [] + for group_name in sorted(shared_groups): + df = self.match(group_name) + df.insert(0, "group_name", [group_name] * len(df)) + dfs.append(df) + return pd.concat(dfs)