From d1e7a4fec2ccf240c61b8cc22127f3a31ce33bb6 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Fri, 15 Nov 2024 11:24:44 -0700 Subject: [PATCH 01/11] Defined __call__ in BaseSplit so it can be overridden by splitters that don't sort --- src/obnb/label/split/base.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/obnb/label/split/base.py b/src/obnb/label/split/base.py index 6bb3259c..e52adab2 100644 --- a/src/obnb/label/split/base.py +++ b/src/obnb/label/split/base.py @@ -36,6 +36,32 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({', '.join(attrs)})" + def __call__( + self, + ids: List[str], + y: np.ndarray, + ) -> Iterator[Tuple[np.ndarray, ...]]: + """ + Split the input ids into multiple splits, e.g. a test, train, validation + split. The means by which this splitting occurs should be defined by + classes that inherit from this base class. + + Note: + Inheriting classes should yield the value instead of returning it, + to make it compatible with the sklearn split methods. See the + implementation of the BaseSortedSplit class for an example. + + Args: + ids: List of entity IDs to put in each split. + y: Labels for each entity(?) + + Yields: + Iterator of splits. Each split is a tuple of numpy arrays, where + each array contains the IDs of the entities in the split. + """ + + raise NotImplementedError + class BaseSortedSplit(BaseSplit): """BaseSortedSplit object for splitting dataset based on sorting.""" From 5d9b4611e76e2ebd39354cea4a8b1751c364bed7 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Fri, 15 Nov 2024 11:34:50 -0700 Subject: [PATCH 02/11] Adds ByTermSplit, a term-level gene splitter --- src/obnb/label/split/__init__.py | 4 ++ src/obnb/label/split/explicit.py | 98 ++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 src/obnb/label/split/explicit.py diff --git a/src/obnb/label/split/__init__.py b/src/obnb/label/split/__init__.py index 2657e2ea..7c2cf61f 100644 --- a/src/obnb/label/split/__init__.py +++ b/src/obnb/label/split/__init__.py @@ -10,6 +10,9 @@ RatioPartition, ThresholdPartition, ) +from obnb.label.split.explicit import ( + ByTermSplit +) __all__ = classes = [ "AllHoldout", @@ -19,4 +22,5 @@ "RandomRatioPartition", "RatioPartition", "ThresholdPartition", + "ByTermSplit", ] diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py new file mode 100644 index 00000000..b3b3df83 --- /dev/null +++ b/src/obnb/label/split/explicit.py @@ -0,0 +1,98 @@ +from typing import Any, Iterator, List, Tuple + +import numpy + +from obnb.label.collection import LabelsetCollection +from obnb.label.split.base import BaseSplit +from numpy import ndarray + + +class ByTermSplit(BaseSplit): + """ + Produces splits based on an explicit list of terms. Genes + which match each term will be placed in the split corresponding + to that term. + + A split with a single term '*' will act as a catch-all for any + genes that weren't matched by any of the other splits. This would + allow you to, e.g., only retain a specific set of genes in the + training set, and place all others in the test set. + + Note that if the '*' split is not provided, any genes that don't + match any of the other splits will not be present in the returned + splits at all. + """ + + def __init__(self, labelset:LabelsetCollection, split_terms: Tuple[set[str]]) -> None: + """ + Initialize ByTermSplit object with reference labels and terms into + which to create splits. + + Args: + labelset: LabelsetCollection object containing terms for each + gene ID. + split_terms: Tuple of sets of terms. Each set of terms will + correspond to a split + """ + self.labelset = labelset + self.split_terms = [set(x) for x in split_terms] + + # convert labelset into a dataframe where one can search for + # the terms associated with each gene ID like so: + # self.long_df[self.long_df["Value"] == str(gene_id)]["Name"] + df = self.labelset.to_df() + self.long_df = df.melt( + id_vars=["Name"], + value_vars=df.columns.difference(["Info", "Size", "Name"]), + value_name="Value" + ).dropna(subset=["Value"]) + + # group by the integer value and aggregate names into a set, making + # it possible to retrieve all the terms for a given gene ID + self.gene_id_to_terms = ( + self.long_df.groupby("Value")["Name"] + .apply(set) + .reset_index() + .rename(columns={"Value": "GeneID", "Name": "Terms"}) + ) + + super().__init__() + + def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: + """ + For each gene ID, look up the term it's associated with + in the labelset, and place it in the corresponding split. + + Returns as many splits as there are elements in the split_terms + tuple. + """ + + # alias field to shorten the code below + gdf = self.gene_id_to_terms + + # for each split, filter to the gene IDs that have at least one + # term in the split + result = [ + ( + numpy.asarray([ + id for id in ids + if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms + ]) if terms != {"*"} else None + ) + for terms in self.split_terms + ] + + # if one of the resulting splits ended up as 'None', we need to + # fill in that split with any gene that wasn't matched by any of + # the other splits + for idx, x in enumerate(result): + if x is None: + result[idx] = numpy.asarray([ + id for id in ids + if not any( + gdf[gdf["GeneID"] == str(id)]["Terms"].isin(terms).any() + for terms in self.split_terms + ) + ]) + + yield tuple(result) From 930c360bf693d3db7504bef4c2aadc95697f2550 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Fri, 15 Nov 2024 11:49:06 -0700 Subject: [PATCH 03/11] Added a check to ensure there's at most one catch-all split --- src/obnb/label/split/explicit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index b3b3df83..a7f45824 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -37,6 +37,10 @@ def __init__(self, labelset:LabelsetCollection, split_terms: Tuple[set[str]]) -> self.labelset = labelset self.split_terms = [set(x) for x in split_terms] + # verify that there's only one catch-all split + if sum(1 for x in self.split_terms if x == {"*"}) > 1: + raise ValueError("Only one catch-all '*' split is allowed") + # convert labelset into a dataframe where one can search for # the terms associated with each gene ID like so: # self.long_df[self.long_df["Value"] == str(gene_id)]["Name"] From 4dfd3a48128daf76b7d73038cfeb5ceff5a5cd4a Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Fri, 15 Nov 2024 14:25:35 -0700 Subject: [PATCH 04/11] Adds 'exclusive' argument to ByTermSplit, tweaks docstring --- src/obnb/label/split/explicit.py | 41 ++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index a7f45824..8fc64f57 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -1,4 +1,4 @@ -from typing import Any, Iterator, List, Tuple +from typing import Iterable, Iterator, List, Tuple import numpy @@ -23,7 +23,11 @@ class ByTermSplit(BaseSplit): splits at all. """ - def __init__(self, labelset:LabelsetCollection, split_terms: Tuple[set[str]]) -> None: + def __init__( + self, labelset: LabelsetCollection, + split_terms: Iterable[Iterable[str]], + exclusive: bool = False + ) -> None: """ Initialize ByTermSplit object with reference labels and terms into which to create splits. @@ -31,11 +35,16 @@ def __init__(self, labelset:LabelsetCollection, split_terms: Tuple[set[str]]) -> Args: labelset: LabelsetCollection object containing terms for each gene ID. - split_terms: Tuple of sets of terms. Each set of terms will - correspond to a split + split_terms: a nested collection. The first level of nesting + indicates the splits; the second level within each identifies + terms that should be matched to place a gene in that split. + exclusive: if True, a gene can occur only once across all the + splits; it will belong to the first split in which it occurs. """ + self.labelset = labelset self.split_terms = [set(x) for x in split_terms] + self.exclusive = exclusive # verify that there's only one catch-all split if sum(1 for x in self.split_terms if x == {"*"}) > 1: @@ -78,7 +87,7 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # term in the split result = [ ( - numpy.asarray([ + set([ id for id in ids if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms ]) if terms != {"*"} else None @@ -89,14 +98,26 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # if one of the resulting splits ended up as 'None', we need to # fill in that split with any gene that wasn't matched by any of # the other splits - for idx, x in enumerate(result): - if x is None: - result[idx] = numpy.asarray([ + for idx in range(len(result)): + if result[idx] is None: + result[idx] = set([ id for id in ids if not any( - gdf[gdf["GeneID"] == str(id)]["Terms"].isin(terms).any() + gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms for terms in self.split_terms ) ]) - yield tuple(result) + if self.exclusive: + # if exclusive, remove genes in the current split that occurred + # in any previous split + # (we skip the first split since there's nothing with which to + # compare it) + for idx in range(1, len(result)): + result[idx] = result[idx] - set.union(*result[:idx]) + + # yield it in the format returned by other splitters, e.g. a tuple of + # numpy arrays. we cast to list because leaving it as a set would cause + # numpy.asarray() to create an array with a single element, the set, + # rather than an array with the elements of the list + yield tuple([ numpy.asarray(list(x)) for x in result ]) From 5749f97ed69fcd1f2326025fb26163d8404e2488 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Fri, 15 Nov 2024 14:28:34 -0700 Subject: [PATCH 05/11] More small comment tweaks --- src/obnb/label/split/explicit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index 8fc64f57..1c7e249a 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -60,8 +60,8 @@ def __init__( value_name="Value" ).dropna(subset=["Value"]) - # group by the integer value and aggregate names into a set, making - # it possible to retrieve all the terms for a given gene ID + # group gene id and aggregate terms into a set, which makes + # it faster to retrieve all the terms for a given gene ID self.gene_id_to_terms = ( self.long_df.groupby("Value")["Name"] .apply(set) From 81c37ec09e95009a3bb311977611f3351c8c4e94 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Fri, 15 Nov 2024 14:58:17 -0700 Subject: [PATCH 06/11] Applied pre-commit, committed style changes --- src/obnb/label/split/__init__.py | 4 +- src/obnb/label/split/base.py | 10 ++--- src/obnb/label/split/explicit.py | 70 ++++++++++++++++---------------- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/obnb/label/split/__init__.py b/src/obnb/label/split/__init__.py index 7c2cf61f..cce2b6d9 100644 --- a/src/obnb/label/split/__init__.py +++ b/src/obnb/label/split/__init__.py @@ -1,4 +1,5 @@ """Genearting data splits from the labelset collection.""" +from obnb.label.split.explicit import ByTermSplit from obnb.label.split.holdout import ( AllHoldout, RandomRatioHoldout, @@ -10,9 +11,6 @@ RatioPartition, ThresholdPartition, ) -from obnb.label.split.explicit import ( - ByTermSplit -) __all__ = classes = [ "AllHoldout", diff --git a/src/obnb/label/split/base.py b/src/obnb/label/split/base.py index e52adab2..db8a45a8 100644 --- a/src/obnb/label/split/base.py +++ b/src/obnb/label/split/base.py @@ -41,10 +41,9 @@ def __call__( ids: List[str], y: np.ndarray, ) -> Iterator[Tuple[np.ndarray, ...]]: - """ - Split the input ids into multiple splits, e.g. a test, train, validation - split. The means by which this splitting occurs should be defined by - classes that inherit from this base class. + """Split the input ids into multiple splits, e.g. a test, train, validation + split. The means by which this splitting occurs should be defined by classes + that inherit from this base class. Note: Inheriting classes should yield the value instead of returning it, @@ -58,8 +57,9 @@ def __call__( Yields: Iterator of splits. Each split is a tuple of numpy arrays, where each array contains the IDs of the entities in the split. + """ - + raise NotImplementedError diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index 1c7e249a..1122e4d2 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -1,36 +1,34 @@ from typing import Iterable, Iterator, List, Tuple import numpy +from numpy import ndarray from obnb.label.collection import LabelsetCollection from obnb.label.split.base import BaseSplit -from numpy import ndarray class ByTermSplit(BaseSplit): - """ - Produces splits based on an explicit list of terms. Genes - which match each term will be placed in the split corresponding - to that term. - - A split with a single term '*' will act as a catch-all for any - genes that weren't matched by any of the other splits. This would - allow you to, e.g., only retain a specific set of genes in the - training set, and place all others in the test set. - - Note that if the '*' split is not provided, any genes that don't - match any of the other splits will not be present in the returned - splits at all. + """Produces splits based on an explicit list of terms. Genes which match each term + will be placed in the split corresponding to that term. + + A split with a single term '*' will act as a catch-all for any genes that + weren't matched by any of the other splits. This would allow you to, e.g., + only retain a specific set of genes in the training set, and place all + others in the test set. + + Note that if the '*' split is not provided, any genes that don't match any + of the other splits will not be present in the returned splits at all. + """ def __init__( - self, labelset: LabelsetCollection, - split_terms: Iterable[Iterable[str]], - exclusive: bool = False - ) -> None: - """ - Initialize ByTermSplit object with reference labels and terms into - which to create splits. + self, + labelset: LabelsetCollection, + split_terms: Iterable[Iterable[str]], + exclusive: bool = False, + ) -> None: + """Initialize ByTermSplit object with reference labels and terms into which to + create splits. Args: labelset: LabelsetCollection object containing terms for each @@ -40,6 +38,7 @@ def __init__( terms that should be matched to place a gene in that split. exclusive: if True, a gene can occur only once across all the splits; it will belong to the first split in which it occurs. + """ self.labelset = labelset @@ -57,7 +56,7 @@ def __init__( self.long_df = df.melt( id_vars=["Name"], value_vars=df.columns.difference(["Info", "Size", "Name"]), - value_name="Value" + value_name="Value", ).dropna(subset=["Value"]) # group gene id and aggregate terms into a set, which makes @@ -72,12 +71,11 @@ def __init__( super().__init__() def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: - """ - For each gene ID, look up the term it's associated with - in the labelset, and place it in the corresponding split. + """For each gene ID, look up the term it's associated with in the labelset, and + place it in the corresponding split. + + Returns as many splits as there are elements in the split_terms tuple. - Returns as many splits as there are elements in the split_terms - tuple. """ # alias field to shorten the code below @@ -87,10 +85,13 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # term in the split result = [ ( - set([ - id for id in ids + { + id + for id in ids if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms - ]) if terms != {"*"} else None + } + if terms != {"*"} + else None ) for terms in self.split_terms ] @@ -100,13 +101,14 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # the other splits for idx in range(len(result)): if result[idx] is None: - result[idx] = set([ - id for id in ids + result[idx] = { + id + for id in ids if not any( gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms for terms in self.split_terms ) - ]) + } if self.exclusive: # if exclusive, remove genes in the current split that occurred @@ -120,4 +122,4 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # numpy arrays. we cast to list because leaving it as a set would cause # numpy.asarray() to create an array with a single element, the set, # rather than an array with the elements of the list - yield tuple([ numpy.asarray(list(x)) for x in result ]) + yield tuple([numpy.asarray(list(x)) for x in result]) From 0502513ff01866916bfddd29020c609627ef86dc Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 4 Dec 2024 09:23:53 -0700 Subject: [PATCH 07/11] Replaces gene IDs in ByTermSplit splits with indices into the ids array --- src/obnb/label/split/explicit.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index 1122e4d2..c28f4635 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -119,7 +119,10 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: result[idx] = result[idx] - set.union(*result[:idx]) # yield it in the format returned by other splitters, e.g. a tuple of - # numpy arrays. we cast to list because leaving it as a set would cause - # numpy.asarray() to create an array with a single element, the set, - # rather than an array with the elements of the list - yield tuple([numpy.asarray(list(x)) for x in result]) + # numpy arrays, each of which contain indices into the 'ids' array + # passed into the splitter. + yield tuple([ + numpy.asarray([ + ids.index(v) for v in x + ]) for x in result + ]) From f1e5dd7a522525aa4d3cfc60f5000ed33f51cd36 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 4 Dec 2024 10:06:09 -0700 Subject: [PATCH 08/11] Fixes flake8 errors raised by now-fixed test workflow --- src/obnb/label/split/explicit.py | 39 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index c28f4635..e3ac435b 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -8,17 +8,17 @@ class ByTermSplit(BaseSplit): - """Produces splits based on an explicit list of terms. Genes which match each term - will be placed in the split corresponding to that term. + """ + Produces splits based on an explicit list of terms. - A split with a single term '*' will act as a catch-all for any genes that - weren't matched by any of the other splits. This would allow you to, e.g., - only retain a specific set of genes in the training set, and place all - others in the test set. + Genes which match each term will be placed in the split corresponding to + that term. A split with a single term '*' will act as a catch-all for any + genes that weren't matched by any of the other splits. This would allow you + to, e.g., only retain a specific set of genes in the training set, and place + all others in the test set. Note that if the '*' split is not provided, any genes that don't match any of the other splits will not be present in the returned splits at all. - """ def __init__( @@ -27,8 +27,8 @@ def __init__( split_terms: Iterable[Iterable[str]], exclusive: bool = False, ) -> None: - """Initialize ByTermSplit object with reference labels and terms into which to - create splits. + """ + Initialize ByTermSplit object with reference labels and terms for splits. Args: labelset: LabelsetCollection object containing terms for each @@ -38,9 +38,7 @@ def __init__( terms that should be matched to place a gene in that split. exclusive: if True, a gene can occur only once across all the splits; it will belong to the first split in which it occurs. - """ - self.labelset = labelset self.split_terms = [set(x) for x in split_terms] self.exclusive = exclusive @@ -71,13 +69,14 @@ def __init__( super().__init__() def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: - """For each gene ID, look up the term it's associated with in the labelset, and - place it in the corresponding split. + """ + Produce splits based on the terms associated with each gene ID. - Returns as many splits as there are elements in the split_terms tuple. + For each gene ID, look up the term it's associated with in the labelset, + and place it in the corresponding split. + Returns as many splits as there are elements in the split_terms tuple. """ - # alias field to shorten the code below gdf = self.gene_id_to_terms @@ -86,8 +85,8 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: result = [ ( { - id - for id in ids + gene_id + for gene_id in ids if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms } if terms != {"*"} @@ -102,10 +101,10 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: for idx in range(len(result)): if result[idx] is None: result[idx] = { - id - for id in ids + gene_id + for gene_id in ids if not any( - gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms + gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms for terms in self.split_terms ) } From 0118675b8c4d36f4d90b27167757894c32a4db02 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:17:44 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/obnb/label/split/__init__.py | 1 + src/obnb/label/split/explicit.py | 18 +++++++----------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/obnb/label/split/__init__.py b/src/obnb/label/split/__init__.py index 0dbae090..2039f703 100644 --- a/src/obnb/label/split/__init__.py +++ b/src/obnb/label/split/__init__.py @@ -1,4 +1,5 @@ """Generating data splits from the labelset collection.""" + from obnb.label.split.explicit import ByTermSplit from obnb.label.split.holdout import ( AllHoldout, diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index e3ac435b..99b44b8d 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -8,8 +8,7 @@ class ByTermSplit(BaseSplit): - """ - Produces splits based on an explicit list of terms. + """Produces splits based on an explicit list of terms. Genes which match each term will be placed in the split corresponding to that term. A split with a single term '*' will act as a catch-all for any @@ -19,6 +18,7 @@ class ByTermSplit(BaseSplit): Note that if the '*' split is not provided, any genes that don't match any of the other splits will not be present in the returned splits at all. + """ def __init__( @@ -27,8 +27,7 @@ def __init__( split_terms: Iterable[Iterable[str]], exclusive: bool = False, ) -> None: - """ - Initialize ByTermSplit object with reference labels and terms for splits. + """Initialize ByTermSplit object with reference labels and terms for splits. Args: labelset: LabelsetCollection object containing terms for each @@ -38,6 +37,7 @@ def __init__( terms that should be matched to place a gene in that split. exclusive: if True, a gene can occur only once across all the splits; it will belong to the first split in which it occurs. + """ self.labelset = labelset self.split_terms = [set(x) for x in split_terms] @@ -69,13 +69,13 @@ def __init__( super().__init__() def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: - """ - Produce splits based on the terms associated with each gene ID. + """Produce splits based on the terms associated with each gene ID. For each gene ID, look up the term it's associated with in the labelset, and place it in the corresponding split. Returns as many splits as there are elements in the split_terms tuple. + """ # alias field to shorten the code below gdf = self.gene_id_to_terms @@ -120,8 +120,4 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # yield it in the format returned by other splitters, e.g. a tuple of # numpy arrays, each of which contain indices into the 'ids' array # passed into the splitter. - yield tuple([ - numpy.asarray([ - ids.index(v) for v in x - ]) for x in result - ]) + yield tuple([numpy.asarray([ids.index(v) for v in x]) for x in result]) From e2b28f0d0cb2fc0bae122ac9691ddd2e25c3f03f Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 4 Dec 2024 11:28:30 -0700 Subject: [PATCH 10/11] Changed 'id' ref that hadn't been updated in the id->gene_id change --- src/obnb/label/split/explicit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index 99b44b8d..929cb799 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -87,7 +87,7 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: { gene_id for gene_id in ids - if gdf[gdf["GeneID"] == str(id)]["Terms"].values[0] & terms + if gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms } if terms != {"*"} else None From be64301bb38b0ccb025d783479b018349daf7b7d Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 4 Dec 2024 13:08:03 -0700 Subject: [PATCH 11/11] Small tweaks to get mypy to pass --- src/obnb/label/split/explicit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py index 929cb799..d5699756 100644 --- a/src/obnb/label/split/explicit.py +++ b/src/obnb/label/split/explicit.py @@ -90,16 +90,16 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: if gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms } if terms != {"*"} - else None + else {"*"} ) for terms in self.split_terms ] - # if one of the resulting splits ended up as 'None', we need to + # if one of the resulting splits ended up as the wildcard, we need to # fill in that split with any gene that wasn't matched by any of # the other splits for idx in range(len(result)): - if result[idx] is None: + if result[idx] == {"*"}: result[idx] = { gene_id for gene_id in ids @@ -115,7 +115,7 @@ def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: # (we skip the first split since there's nothing with which to # compare it) for idx in range(1, len(result)): - result[idx] = result[idx] - set.union(*result[:idx]) + result[idx] = result[idx] - set().union(*result[:idx]) # yield it in the format returned by other splitters, e.g. a tuple of # numpy arrays, each of which contain indices into the 'ids' array