-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple aggregation redux #13
Changes from all commits
d3c94f6
a3bd5d1
20927fc
429da16
aa2819d
b893100
2ec9b2c
e46d425
b15dc38
b831a03
2d8623b
d3347b5
c5b827f
ab89f5d
6fe9677
9613df6
93e7ceb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,12 @@ | ||
from .aggregator import ArbitraryAggregator as ArbitraryAggregator | ||
from .aggregator import ( | ||
BasicPhylogeneticAggregator as BasicPhylogeneticAggregator, | ||
) | ||
from .aggregator import SerialAggregator as SerialAggregator | ||
from .nomenclature import PangoSc2Nomenclature as PangoSc2Nomenclature | ||
from .taxon import Taxon as Taxon | ||
from .taxon_utils import read_taxa as read_taxa | ||
from .taxon_utils import sort_taxa as sort_taxa | ||
from .taxonomy_scheme import ( | ||
PhylogeneticTaxonomyScheme as PhylogeneticTaxonomyScheme, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from abc import ABC, abstractmethod | ||
from collections.abc import Iterable | ||
from warnings import warn | ||
|
||
from .taxon import Taxon | ||
from .taxon_utils import sort_taxa | ||
from .taxonomy_scheme import PhylogeneticTaxonomyScheme | ||
from .utils import table | ||
|
||
|
||
class Aggregation(dict[Taxon, Taxon]): | ||
""" | ||
An object for aggregations, basically just a dictionary. | ||
""" | ||
|
||
def _validate( | ||
self, input_taxa: Iterable[Taxon], taxon_map: dict[Taxon, Taxon] | ||
): | ||
""" | ||
Checks that all input taxa have been mapped exactly once. | ||
""" | ||
if set(taxon_map.keys()) != set(input_taxa): | ||
raise RuntimeError( | ||
"Mismatch between aggregated taxa and input taxa. Input taxa are: " | ||
+ str(input_taxa) | ||
+ " but aggregated taxa are " | ||
+ str(taxon_map.keys()) | ||
) | ||
tab = table(taxon_map) | ||
afmagee42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not all(v == 1 for v in tab.values()): | ||
raise RuntimeError( | ||
"Found following taxa mapped more than once: " | ||
+ str([k for k, v in tab.items() if v > 1]) | ||
) | ||
|
||
def __init__( | ||
self, input_taxa: Iterable[Taxon], taxon_map: dict[Taxon, Taxon] | ||
): | ||
self._validate(input_taxa, taxon_map) | ||
super().__init__(taxon_map) | ||
|
||
def to_str(self): | ||
""" | ||
Get str : str map of taxa names | ||
""" | ||
return {k.name: v.name for k, v in self.items()} | ||
|
||
|
||
class Aggregator(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per my comment above, I could see a static function that validates that an output taxon map has all the input taxa in it Then you could count on code writers to know that their There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want some minimal validation automatically. Every taxon should be mapped exactly once is the minimum correctness standard, though some The choices seemed to be what I did (make the map check) or have I'm also going to dispute the above "then the taxon map can be just a map, with no further functionality." Exporting the |
||
""" | ||
Aggregators return Aggregations, maps of input_taxon : aggregated_taxon | ||
""" | ||
|
||
@abstractmethod | ||
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation: | ||
pass | ||
|
||
|
||
class ArbitraryAggregator(Aggregator): | ||
""" | ||
Aggregation via a user-provided dictionary. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
map: dict[Taxon, Taxon], | ||
): | ||
""" | ||
FixedAggregator constructor. | ||
|
||
Parameters | ||
---------- | ||
map : dict[Taxon, Taxon] | ||
Dictionary mapping the input taxa to their aggregated taxa. | ||
""" | ||
self.map = map | ||
|
||
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation: | ||
return Aggregation( | ||
input_taxa, {taxon: self.map[taxon] for taxon in input_taxa} | ||
) | ||
|
||
|
||
class BasicPhylogeneticAggregator(Aggregator): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally I would prefer just "PhylogeneticAggregator," because I can imagine some more "basic" forms, as per my comments about composability below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name can change, but we have to be able to clearly distinguish between this implicitly fixed-map "put these in those" phylogenetic aggregation and the (probably multiple different) dynamic "put these in something" aggregator(s) that we want to make. My instinct is to qualify all of them (this one being "basic" or "fixed" or something of the sort), but I could be convinced that only the more complex ones need more explicit names. |
||
""" | ||
An aggregator which maps a set of input taxa to a fixed set of aggregation targets using a tree. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
targets: Iterable[Taxon], | ||
taxonomy_scheme: PhylogeneticTaxonomyScheme, | ||
sort_clades: bool = True, | ||
off_target: str = "other", | ||
warn: bool = True, | ||
): | ||
""" | ||
BasicPhylogeneticAggregator constructor. | ||
|
||
Parameters | ||
---------- | ||
targets : Iterable[Taxon] | ||
The taxa into which we wish to aggregate the input taxa. | ||
|
||
taxonomy_scheme : PhylogeneticTaxonomyScheme | ||
The tree which we use to do the mapping. | ||
|
||
sort_clades : bool | ||
afmagee42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
If False, mapping is done using the taxa as ordered in `targets`. | ||
If True, `targets` are taxonomically sorted so that so that larger | ||
`targets` do not override smaller ones. For example, if BA.2 and | ||
BA.2.86 are both aggregation targets, sort_clades = True would handle | ||
BA.2.86 first, such that JN.1 would map to BA.2.86, while BG.1 would | ||
map to BA.2. If BA.2 is processed first, both will map to it. | ||
|
||
off_target : str | ||
Specifies what to do with taxa which do not belong to any target. | ||
Options are "other" for aggregating all such taxa into Taxon("other"), | ||
and "self" for aggregating all such taxa into themselves. | ||
""" | ||
self.taxonomy_scheme = taxonomy_scheme | ||
self._validate_inputs(targets) | ||
self.targets = [taxon for taxon in targets] | ||
off_target_options = ["self", "other"] | ||
if off_target not in off_target_options: | ||
raise RuntimeError( | ||
f"Unrecognized value for `off_target`, options are:{off_target}" | ||
) | ||
self.off_target = off_target | ||
self.warn = warn | ||
if sort_clades: | ||
self.targets = sort_taxa(self.targets, self.taxonomy_scheme) | ||
|
||
def _validate_inputs(self, input_taxa: Iterable[Taxon]) -> None: | ||
invalid_taxa = [ | ||
taxon | ||
for taxon in input_taxa | ||
if not self.taxonomy_scheme.is_valid_taxon(taxon) | ||
] | ||
if len(invalid_taxa) > 0: | ||
raise ValueError( | ||
"The following taxa are not valid taxa according to the provided taxonomy scheme: " | ||
+ str(invalid_taxa) | ||
) | ||
|
||
def _check_missing(self, agg_map: dict[Taxon, Taxon]): | ||
if self.warn: | ||
used_targets = set(agg_map.values()) | ||
unused_targets = [ | ||
target for target in self.targets if target not in used_targets | ||
] | ||
if len(unused_targets) > 0: | ||
warn( | ||
f"The aggregation does not make use of the following input targets: {unused_targets}." | ||
) | ||
|
||
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation: | ||
self._validate_inputs(input_taxa) | ||
agg_map: dict[Taxon, Taxon] = {} | ||
stack = set(input_taxa) | ||
for target in self.targets: | ||
children = self.taxonomy_scheme.descendants(target, True) | ||
sub_map = {taxon: target for taxon in stack if taxon in children} | ||
agg_map = agg_map | sub_map | ||
stack.difference_update(set(agg_map.keys())) | ||
|
||
if len(stack) > 0: | ||
if self.off_target == "other": | ||
cleanup = HomogenousAggregator( | ||
afmagee42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Taxon("other", False) | ||
).aggregate(stack) | ||
else: | ||
cleanup = SelfAggregator().aggregate(stack) | ||
agg_map = agg_map | cleanup | ||
|
||
self._check_missing(agg_map) | ||
|
||
return Aggregation(input_taxa, agg_map) | ||
|
||
|
||
class HomogenousAggregator(Aggregator): | ||
""" | ||
Aggregation of every taxon to some catch-all taxon. | ||
""" | ||
|
||
def __init__(self, taxon: Taxon): | ||
self.agg_taxon = taxon | ||
|
||
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation: | ||
return Aggregation( | ||
input_taxa, {taxon: self.agg_taxon for taxon in input_taxa} | ||
) | ||
|
||
|
||
class SelfAggregator(Aggregator): | ||
afmagee42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Aggregation of every taxon to itself | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation: | ||
return Aggregation(input_taxa, {taxon: taxon for taxon in input_taxa}) | ||
|
||
|
||
class SerialAggregator(Aggregator): | ||
""" | ||
A number of aggregators chained in serial. | ||
""" | ||
|
||
def __init__(self, aggregators: Iterable[Aggregator]): | ||
self.aggregators = aggregators | ||
|
||
def aggregate(self, input_taxa: Iterable[Taxon]) -> Aggregation: | ||
taxa = list(input_taxa) | ||
comp_agg = SelfAggregator().aggregate(input_taxa) | ||
|
||
for aggregator in self.aggregators: | ||
agg = aggregator.aggregate(taxa) | ||
taxa = set(agg.values()) | ||
comp_agg = {taxon: agg[comp_agg[taxon]] for taxon in input_taxa} | ||
|
||
return Aggregation(input_taxa, comp_agg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused why we need both concepts "aggregation" and "taxon map"?
A taxon map could be a dictionary, which guarantees unique keys (the second of the two checks).
It feels like the first check (that the input taxa match the keys in the taxon map) is functionality that should live in the aggregator. The aggregator gets handed input taxa, and so it should check that it's outputting a taxon map that has all those input taxa as keys.
Right now it seems like the aggregator has to hand the input taxa and the taxon map to the aggregation. If the aggregator does one of those checks, then the taxon map can be just a map, with no further functionality?