diff --git a/pyproject.toml b/pyproject.toml index 1becd6a..dba338a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "graphpro" -version = "0.10.1" +version = "0.11.0" authors = [ { name="Pegerto Fernandez", email="pegerto@gmail.com" }, ] diff --git a/src/graphpro/collection.py b/src/graphpro/collection.py index bef7d15..8fe3b5a 100644 --- a/src/graphpro/collection.py +++ b/src/graphpro/collection.py @@ -1,9 +1,10 @@ """ This collection module allow a set of utilities to manage a collection of graphs. """ import pickle +import random -from torch_geometric.data import InMemoryDataset from typing import Callable, Optional +from torch_geometric.data import InMemoryDataset from .graph import Graph from .model import Target @@ -51,6 +52,25 @@ def to_dataset(self, root: str, node_encoders = [], target: Target = None) -> I """ return GraphProDataset(root, self, node_encoders, target) + + def split(self, + test_size: float = 0.8, + seed: int = None): + """ Split the graph collection into trainning and validation sets. + """ + random.seed(seed) + #avoid use random.choice rather decide the splits base on size + test = [] + val = [] + for graph in self._graphs: + if random.random() < test_size: + test.append(graph) + else: + val.append(graph) + + return GraphCollection(test), GraphCollection(val) + + @staticmethod def load(filename: str): """ Loads a collection from a stored file, restoring the collection @@ -72,4 +92,4 @@ def __init__( pre_filter: Optional[Callable] = None, ): super().__init__(root, transform, pre_transform, pre_filter) - self.data, self.slices = self.collate([g.to_data(node_encoders, target) for g in collection]) + self.data, self.slices = self.collate([g.to_data(node_encoders, target) for g in collection]) \ No newline at end of file diff --git a/test/graphpro/collection_test.py b/test/graphpro/collection_test.py index 9cb89b8..05396d7 100644 --- a/test/graphpro/collection_test.py +++ b/test/graphpro/collection_test.py @@ -38,7 +38,14 @@ def test_graphs_are_iterable(): assert graph is not None +def test_graph_collection_split(): + collection = GraphCollection([SIMPLE_G] * 100) + train, test = collection.split(seed=42) + assert len(train) == 80 + assert len(test) == 20 + def test_graphs_dataset(): col = GraphCollection([SIMPLE_G, SIMPLE_G]) ds = col.to_dataset('.') assert torch.all(ds[0].edge_index.eq(SIMPLE_G.to_data().edge_index)) + \ No newline at end of file