diff --git a/CHANGELOG.md b/CHANGELOG.md index a92a6561..1feddd64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72) * Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86) * Remove remaining errors from PDB dataset change +* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88) ### Models diff --git a/proteinworkshop/config/dataset/pdb.yaml b/proteinworkshop/config/dataset/pdb.yaml index 703a43af..991917d7 100644 --- a/proteinworkshop/config/dataset/pdb.yaml +++ b/proteinworkshop/config/dataset/pdb.yaml @@ -23,4 +23,7 @@ datamodule: remove_ligands: [] # Exclude specific ligands from any available protein-ligand complexes remove_non_standard_residues: True # Include only proteins containing standard amino acid residues remove_pdb_unavailable: True # Include only proteins that are available to download - split_sizes: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits + train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits + split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random" + split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random") + overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten diff --git a/proteinworkshop/datasets/pdb_dataset.py b/proteinworkshop/datasets/pdb_dataset.py index ff211847..1de4abb7 100644 --- a/proteinworkshop/datasets/pdb_dataset.py +++ b/proteinworkshop/datasets/pdb_dataset.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, List, Optional, Dict +from typing import Callable, Iterable, List, Optional, Dict, Literal import hydra import omegaconf @@ -29,7 +29,10 @@ def __init__( remove_ligands: List[str], remove_non_standard_residues: bool, remove_pdb_unavailable: bool, - split_sizes: List[float], + train_val_test: List[float], + split_type: Literal["sequence_similarity", "random"], + split_sequence_similiarity: int, + overwrite_sequence_clusters: bool ): self.fraction = fraction self.molecule_type = molecule_type @@ -44,11 +47,16 @@ def __init__( self.remove_pdb_unavailable = remove_pdb_unavailable self.min_length = min_length self.max_length = max_length - self.split_sizes = split_sizes + assert sum(train_val_test) == 1, f"train_val_test need to sum to 1, but sum to {sum(train_val_test)}" + self.train_val_test = train_val_test + self.split_type = split_type + self.split_sequence_similarity = split_sequence_similiarity + self.overwrite_sequence_clusters = overwrite_sequence_clusters + self.splits = ["train", "val", "test"] def create_dataset(self): log.info(f"Initializing PDBManager in {self.path}...") - pdb_manager = PDBManager(root_dir=self.path) + pdb_manager = PDBManager(root_dir=self.path, splits=self.splits, split_ratios=self.train_val_test) num_chains = len(pdb_manager.df) log.info(f"Starting with: {num_chains} chains") @@ -109,13 +117,21 @@ def create_dataset(self): pdb_manager.remove_unavailable_pdbs(update=True) log.info(f"{len(pdb_manager.df)} chains remaining") - log.info(f"Splitting dataset into {self.split_sizes}...") - split_names = ["train", "val", "test"] - splits = pdb_manager.split_df_proportionally( - df=pdb_manager.df, - splits=split_names, - split_ratios=self.split_sizes, - ) + if self.split_type == "random": + log.info(f"Splitting dataset via random split into {self.train_val_test}...") + splits = pdb_manager.split_df_proportionally( + df=pdb_manager.df, + splits=self.splits, + train_val_test=self.train_val_test, + ) + + elif self.split_type == "sequence_similarity": + log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...") + log.info(f"Using {self.split_sequence_similarity} sequence similarity for split") + pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True) + splits = pdb_manager.split_clusters( + pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters) + log.info(splits["train"]) return splits