Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple aggregation redux #13

Merged
merged 17 commits into from
Oct 9, 2024
6 changes: 6 additions & 0 deletions cladecombiner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .aggregator import ArbitraryAggregator as ArbitraryAggregator
from .aggregator import (
BasicPhylogeneticAggregator as BasicPhylogeneticAggregator,
)
from .aggregator import SerialAggregator as SerialAggregator
from .nomenclature import PangoSc2Nomenclature as PangoSc2Nomenclature
from .taxon import Taxon as Taxon
from .taxon_utils import read_taxa as read_taxa
from .taxon_utils import sort_taxa as sort_taxa
from .taxonomy_scheme import (
PhylogeneticTaxonomyScheme as PhylogeneticTaxonomyScheme,
)
224 changes: 224 additions & 0 deletions cladecombiner/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from warnings import warn

from .taxon import Taxon
from .taxon_utils import sort_taxa
from .taxonomy_scheme import PhylogeneticTaxonomyScheme
from .utils import table


class Aggregation(dict[Taxon, Taxon]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need both concepts "aggregation" and "taxon map"?

A taxon map could be a dictionary, which guarantees unique keys (the second of the two checks).

It feels like the first check (that the input taxa match the keys in the taxon map) is functionality that should live in the aggregator. The aggregator gets handed input taxa, and so it should check that it's outputting a taxon map that has all those input taxa as keys.

Right now it seems like the aggregator has to hand the input taxa and the taxon map to the aggregation. If the aggregator does one of those checks, then the taxon map can be just a map, with no further functionality?

"""
An object for aggregations, basically just a dictionary.
"""

def _validate(
self, input_taxa: Iterable[Taxon], taxon_map: dict[Taxon, Taxon]
):
"""
Checks that all input taxa have been mapped exactly once.
"""
if set(taxon_map.keys()) != set(input_taxa):
raise RuntimeError(
"Mismatch between aggregated taxa and input taxa. Input taxa are: "
+ str(input_taxa)
+ " but aggregated taxa are "
+ str(taxon_map.keys())
)
tab = table(taxon_map)
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
if not all(v == 1 for v in tab.values()):
raise RuntimeError(
"Found following taxa mapped more than once: "
+ str([k for k, v in tab.items() if v > 1])
)

def __init__(
self, input_taxa: Iterable[Taxon], taxon_map: dict[Taxon, Taxon]
):
self._validate(input_taxa, taxon_map)
super().__init__(taxon_map)

def to_str(self):
"""
Get str : str map of taxa names
"""
return {k.name: v.name for k, v in self.items()}


class Aggregator(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per my comment above, I could see a static function that validates that an output taxon map has all the input taxa in it

Then you could count on code writers to know that their .aggregate() methods should call .validate_map() before returning the map.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want some minimal validation automatically. Every taxon should be mapped exactly once is the minimum correctness standard, though some Aggregators might know more about what to check.

The choices seemed to be what I did (make the map check) or have Aggregator.aggregate() be a non-abstract method calling some actual aggregation function Aggregator._aggregate() and then Aggregator.validate(). Neither is ideal, but I'm open to being told one is less awkward than the other.

I'm also going to dispute the above "then the taxon map can be just a map, with no further functionality." Exporting the Taxon : Taxon dictto a str : str dict is going to be one of the most common usage patterns for the existence of this class, so it might as well be a method. Taxon objects are our internal representations, but everywhere else a user is going to need to interact with taxa, they'll be strings.

"""
Aggregators return Aggregations, maps of input_taxon : aggregated_taxon
"""

@abstractmethod
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation:
pass


class ArbitraryAggregator(Aggregator):
"""
Aggregation via a user-provided dictionary.
"""

def __init__(
self,
map: dict[Taxon, Taxon],
):
"""
FixedAggregator constructor.

Parameters
----------
map : dict[Taxon, Taxon]
Dictionary mapping the input taxa to their aggregated taxa.
"""
self.map = map

def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation:
return Aggregation(
input_taxa, {taxon: self.map[taxon] for taxon in input_taxa}
)


class BasicPhylogeneticAggregator(Aggregator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would prefer just "PhylogeneticAggregator," because I can imagine some more "basic" forms, as per my comments about composability below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name can change, but we have to be able to clearly distinguish between this implicitly fixed-map "put these in those" phylogenetic aggregation and the (probably multiple different) dynamic "put these in something" aggregator(s) that we want to make. My instinct is to qualify all of them (this one being "basic" or "fixed" or something of the sort), but I could be convinced that only the more complex ones need more explicit names.

"""
An aggregator which maps a set of input taxa to a fixed set of aggregation targets using a tree.
"""

def __init__(
self,
targets: Iterable[Taxon],
taxonomy_scheme: PhylogeneticTaxonomyScheme,
sort_clades: bool = True,
off_target: str = "other",
warn: bool = True,
):
"""
BasicPhylogeneticAggregator constructor.

Parameters
----------
targets : Iterable[Taxon]
The taxa into which we wish to aggregate the input taxa.

taxonomy_scheme : PhylogeneticTaxonomyScheme
The tree which we use to do the mapping.

sort_clades : bool
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
If False, mapping is done using the taxa as ordered in `targets`.
If True, `targets` are taxonomically sorted so that so that larger
`targets` do not override smaller ones. For example, if BA.2 and
BA.2.86 are both aggregation targets, sort_clades = True would handle
BA.2.86 first, such that JN.1 would map to BA.2.86, while BG.1 would
map to BA.2. If BA.2 is processed first, both will map to it.

off_target : str
Specifies what to do with taxa which do not belong to any target.
Options are "other" for aggregating all such taxa into Taxon("other"),
and "self" for aggregating all such taxa into themselves.
"""
self.taxonomy_scheme = taxonomy_scheme
self._validate_inputs(targets)
self.targets = [taxon for taxon in targets]
off_target_options = ["self", "other"]
if off_target not in off_target_options:
raise RuntimeError(
f"Unrecognized value for `off_target`, options are:{off_target}"
)
self.off_target = off_target
self.warn = warn
if sort_clades:
self.targets = sort_taxa(self.targets, self.taxonomy_scheme)

def _validate_inputs(self, input_taxa: Iterable[Taxon]) -> None:
invalid_taxa = [
taxon
for taxon in input_taxa
if not self.taxonomy_scheme.is_valid_taxon(taxon)
]
if len(invalid_taxa) > 0:
raise ValueError(
"The following taxa are not valid taxa according to the provided taxonomy scheme: "
+ str(invalid_taxa)
)

def _check_missing(self, agg_map: dict[Taxon, Taxon]):
if self.warn:
used_targets = set(agg_map.values())
unused_targets = [
target for target in self.targets if target not in used_targets
]
if len(unused_targets) > 0:
warn(
f"The aggregation does not make use of the following input targets: {unused_targets}."
)

def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation:
self._validate_inputs(input_taxa)
agg_map: dict[Taxon, Taxon] = {}
stack = set(input_taxa)
for target in self.targets:
children = self.taxonomy_scheme.descendants(target, True)
sub_map = {taxon: target for taxon in stack if taxon in children}
agg_map = agg_map | sub_map
stack.difference_update(set(agg_map.keys()))

if len(stack) > 0:
if self.off_target == "other":
cleanup = HomogenousAggregator(
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
Taxon("other", False)
).aggregate(stack)
else:
cleanup = SelfAggregator().aggregate(stack)
agg_map = agg_map | cleanup

self._check_missing(agg_map)

return Aggregation(input_taxa, agg_map)


class HomogenousAggregator(Aggregator):
"""
Aggregation of every taxon to some catch-all taxon.
"""

def __init__(self, taxon: Taxon):
self.agg_taxon = taxon

def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation:
return Aggregation(
input_taxa, {taxon: self.agg_taxon for taxon in input_taxa}
)


class SelfAggregator(Aggregator):
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
"""
Aggregation of every taxon to itself
"""

def __init__(self):
pass

def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation:
return Aggregation(input_taxa, {taxon: taxon for taxon in input_taxa})


class SerialAggregator(Aggregator):
"""
A number of aggregators chained in serial.
"""

def __init__(self, aggregators: Iterable[Aggregator]):
self.aggregators = aggregators

def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation:
taxa = list(input_taxa)
comp_agg = SelfAggregator().aggregate(input_taxa)

for aggregator in self.aggregators:
agg = aggregator.aggregate(taxa)
taxa = set(agg.values())
comp_agg = {taxon: agg[comp_agg[taxon]] for taxon in input_taxa}

return Aggregation(input_taxa, comp_agg)
2 changes: 1 addition & 1 deletion cladecombiner/taxon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class Taxon:
"""
Representation of taxonomic units
Representation of taxonomic units.
"""

def __init__(self, name: str, is_tip: bool, data: Any = None):
Expand Down
33 changes: 28 additions & 5 deletions cladecombiner/taxon_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from functools import cmp_to_key
from os import path
from typing import Optional

from .nomenclature import Nomenclature
from .taxon import Taxon
from .taxonomy_scheme import TaxonomyScheme
from .taxonomy_scheme import TaxonomyScheme, TreelikeTaxonomyScheme


def read_taxa(
fp: str,
is_tip: bool | Sequence[bool],
nomenclature: Optional[Nomenclature],
taxonomy_scheme: Optional[TaxonomyScheme],
is_tip: bool | Sequence[bool] = True,
nomenclature: Optional[Nomenclature] = None,
taxonomy_scheme: Optional[TaxonomyScheme] = None,
) -> Sequence[Taxon]:
"""
Reads in taxa as a list of Taxon objects.
Expand All @@ -36,6 +37,11 @@ def read_taxa(
Sequence[Taxon]
Container of the taxa as Taxon objects.
"""
assert nomenclature is None or isinstance(nomenclature, Nomenclature)

assert taxonomy_scheme is None or isinstance(
taxonomy_scheme, TaxonomyScheme
)

ext = path.splitext(fp)[1]
taxa = []
Expand Down Expand Up @@ -91,3 +97,20 @@ def printable_taxon_list(taxa: Sequence[Taxon], sep: str = "\n") -> str:
for taxon in taxa:
print_str += str(taxon) + sep
return print_str


def sort_taxa(taxa: Iterable[Taxon], taxonomy_scheme: TreelikeTaxonomyScheme):
assert all(isinstance(taxon, Taxon) for taxon in taxa)
unknown = [
taxon for taxon in taxa if not taxonomy_scheme.is_valid_taxon(taxon)
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
]
if len(unknown) > 0:
raise ValueError(
f"Cannot sort the following taxa which are unknown to the taxonomy scheme: {unknown}"
)
return sorted(
taxa,
key=cmp_to_key(
lambda x, y: 1 if taxonomy_scheme.contains(x, y) else -1
),
)
4 changes: 3 additions & 1 deletion cladecombiner/taxonomy_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class TaxonomyScheme(ABC):
Allows hybridization-induced multiple ancestry.
"""

@abstractmethod
def ancestors(self, taxon: Taxon) -> Collection[Taxon]:
"""
All taxa which are between this taxon and the root (including the root).
Expand Down Expand Up @@ -312,6 +311,9 @@ def contains(self, focal: Taxon, target: Taxon) -> bool:
return True
node = node.parent_node

if node is node_x:
return True

return False

def descendants(self, taxon: Taxon, tip_only: bool) -> Collection[Taxon]:
Expand Down
Loading
Loading