Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ByTermSplit splitter #499

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/obnb/label/split/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Genearting data splits from the labelset collection."""
"""Generating data splits from the labelset collection."""

from obnb.label.split.explicit import ByTermSplit
from obnb.label.split.holdout import (
AllHoldout,
RandomRatioHoldout,
Expand All @@ -20,4 +21,5 @@
"RandomRatioPartition",
"RatioPartition",
"ThresholdPartition",
"ByTermSplit",
]
26 changes: 26 additions & 0 deletions src/obnb/label/split/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ def __repr__(self) -> str:

return f"{self.__class__.__name__}({', '.join(attrs)})"

def __call__(
self,
ids: List[str],
y: np.ndarray,
) -> Iterator[Tuple[np.ndarray, ...]]:
"""Split the input ids into multiple splits, e.g. a test, train, validation
split. The means by which this splitting occurs should be defined by classes
that inherit from this base class.

Note:
Inheriting classes should yield the value instead of returning it,
to make it compatible with the sklearn split methods. See the
implementation of the BaseSortedSplit class for an example.

Args:
ids: List of entity IDs to put in each split.
y: Labels for each entity(?)

Yields:
Iterator of splits. Each split is a tuple of numpy arrays, where
each array contains the IDs of the entities in the split.

"""

raise NotImplementedError


class BaseSortedSplit(BaseSplit):
"""BaseSortedSplit object for splitting dataset based on sorting."""
Expand Down
123 changes: 123 additions & 0 deletions src/obnb/label/split/explicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Iterable, Iterator, List, Tuple

import numpy
from numpy import ndarray

from obnb.label.collection import LabelsetCollection
from obnb.label.split.base import BaseSplit


class ByTermSplit(BaseSplit):
"""Produces splits based on an explicit list of terms.

Genes which match each term will be placed in the split corresponding to
that term. A split with a single term '*' will act as a catch-all for any
genes that weren't matched by any of the other splits. This would allow you
to, e.g., only retain a specific set of genes in the training set, and place
all others in the test set.

Note that if the '*' split is not provided, any genes that don't match any
of the other splits will not be present in the returned splits at all.

"""

def __init__(
self,
labelset: LabelsetCollection,
split_terms: Iterable[Iterable[str]],
exclusive: bool = False,
) -> None:
"""Initialize ByTermSplit object with reference labels and terms for splits.

Args:
labelset: LabelsetCollection object containing terms for each
gene ID.
split_terms: a nested collection. The first level of nesting
indicates the splits; the second level within each identifies
terms that should be matched to place a gene in that split.
exclusive: if True, a gene can occur only once across all the
splits; it will belong to the first split in which it occurs.

"""
self.labelset = labelset
self.split_terms = [set(x) for x in split_terms]
self.exclusive = exclusive

# verify that there's only one catch-all split
if sum(1 for x in self.split_terms if x == {"*"}) > 1:
raise ValueError("Only one catch-all '*' split is allowed")

# convert labelset into a dataframe where one can search for
# the terms associated with each gene ID like so:
# self.long_df[self.long_df["Value"] == str(gene_id)]["Name"]
df = self.labelset.to_df()
self.long_df = df.melt(
id_vars=["Name"],
value_vars=df.columns.difference(["Info", "Size", "Name"]),
value_name="Value",
).dropna(subset=["Value"])

# group gene id and aggregate terms into a set, which makes
# it faster to retrieve all the terms for a given gene ID
self.gene_id_to_terms = (
self.long_df.groupby("Value")["Name"]
.apply(set)
.reset_index()
.rename(columns={"Value": "GeneID", "Name": "Terms"})
)

super().__init__()

def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]:
"""Produce splits based on the terms associated with each gene ID.

For each gene ID, look up the term it's associated with in the labelset,
and place it in the corresponding split.

Returns as many splits as there are elements in the split_terms tuple.

"""
# alias field to shorten the code below
gdf = self.gene_id_to_terms

# for each split, filter to the gene IDs that have at least one
# term in the split
result = [
(
{
gene_id
for gene_id in ids
if gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms
}
if terms != {"*"}
else {"*"}
)
for terms in self.split_terms
]

# if one of the resulting splits ended up as the wildcard, we need to
# fill in that split with any gene that wasn't matched by any of
# the other splits
for idx in range(len(result)):
if result[idx] == {"*"}:
result[idx] = {
gene_id
for gene_id in ids
if not any(
gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms
for terms in self.split_terms
)
}

if self.exclusive:
# if exclusive, remove genes in the current split that occurred
# in any previous split
# (we skip the first split since there's nothing with which to
# compare it)
for idx in range(1, len(result)):
result[idx] = result[idx] - set().union(*result[:idx])

# yield it in the format returned by other splitters, e.g. a tuple of
# numpy arrays, each of which contain indices into the 'ids' array
# passed into the splitter.
yield tuple([numpy.asarray([ids.index(v) for v in x]) for x in result])
Loading