Skip to content

Commit

Permalink
refactor multi test set datasets; add seq id test splits to GO (#72)
Browse files Browse the repository at this point in the history
* refactor multi test set datasets; add seq id test splits to GO

* remove commented code

* remove commented code

* add reference

* Add .bib to readme

* fix commas

* update changelog

---------

Co-authored-by: Jamasb <[email protected]>
  • Loading branch information
a-r-j and Jamasb authored Feb 9, 2024
1 parent 0177167 commit 99696a6
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 64 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### 0.2.6 (UNRELEASED)

### Datasets
* Add stage-based conditions to `setup` in `ProteinDataModule` [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)

### Models

* Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74)
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,18 @@ To build a local version of the project's Sphinx documentation web pages:
pip install -r docs/.docs.requirements # one-time only
rm -rf docs/build/ && sphinx-build docs/source/ docs/build/ # NOTE: errors can safely be ignored
```
## Citing `ProteinWorkshop`
Please consider citing `proteinworkshop` if it proves useful in your work.
```bibtex
@inproceedings{
jamasb2024evaluating,
title={Evaluating Representation Learning on the Protein Structure Universe},
author={Arian R. Jamasb, Alex Morehead, Zuobai Zhang, Chaitanya K. Joshi, Kieran Didi, Simon V. Mathis, Charles Harris, Jian Tang, Jianlin Cheng, Pietro Lio, Tom L. Blundell},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}
```
7 changes: 7 additions & 0 deletions citation.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@inproceedings{
jamasb2024evaluating,
title={Evaluating Representation Learning on the Protein Structure Universe},
author={Arian R. Jamasb, Alex Morehead, Zuobai Zhang, Chaitanya K. Joshi, Kieran Didi, Simon V. Mathis, Charles Harris, Jian Tang, Jianlin Cheng, Pietro Lio, Tom L. Blundell},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}
26 changes: 19 additions & 7 deletions proteinworkshop/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base classes for protein structure datamodules and datasets."""
import copy
import os
import pathlib
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -81,12 +82,23 @@ def download(self):

def setup(self, stage: Optional[str] = None):
self.download()
logger.info("Preprocessing training data")
self.train_ds = self.train_dataset()
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()
logger.info("Preprocessing test data")
self.test_ds = self.test_dataset()

if stage == "fit" or stage is None:
logger.info("Preprocessing training data")
self.train_ds = self.train_dataset()
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()
elif stage == "test":
logger.info("Preprocessing test data")
if hasattr(self, "test_dataset_names"):
for split in self.test_dataset_names:
setattr(self, f"{split}_ds", self.test_dataset(split))
else:
self.test_ds = self.test_dataset()
elif stage == "lazy_init":
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()

# self.class_weights = self.get_class_weights()

@property
Expand Down Expand Up @@ -518,7 +530,7 @@ def get(self, idx: int) -> Data:
:return: PyTorch Geometric Data object.
"""
if self.in_memory:
return self._batch_format(self.data[idx])
return self._batch_format(copy.deepcopy(self.data[idx]))

if self.out_names is not None:
fname = f"{self.out_names[idx]}.pt"
Expand Down
44 changes: 20 additions & 24 deletions proteinworkshop/datasets/fold_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pathlib
import tarfile
from typing import Callable, Dict, Iterable, Optional
from typing import Callable, Dict, Iterable, List, Literal, Optional

import omegaconf
import pandas as pd
Expand Down Expand Up @@ -72,6 +72,11 @@ def __init__(
else:
self.transform = None

@property
def test_dataset_names(self) -> List[str]:
"""Provides a list of test set split names."""
return ["fold", "family", "superfamily"]

def download(self):
self.download_data_files()
self.download_structures()
Expand Down Expand Up @@ -152,16 +157,12 @@ def parse_class_map(self) -> Dict[str, str]:
)
return dict(class_map.values)

def setup(self, stage: Optional[str] = None):
self.download_data_files()
self.download_structures()
self.train_ds = self.train_dataset()
self.val_ds = self.val_dataset()
self.test_ds = self.test_dataset()

def _get_dataset(self, split: str) -> ProteinDataset:
if hasattr(self, f"{split}_ds"):
return getattr(self, f"{split}_ds")

df = self.parse_dataset(split)
return ProteinDataset(
ds = ProteinDataset(
root=str(self.data_dir),
pdb_dir=str(self.structure_dir),
pdb_codes=list(df.id),
Expand All @@ -171,15 +172,19 @@ def _get_dataset(self, split: str) -> ProteinDataset:
transform=self.transform,
in_memory=self.in_memory,
)
setattr(self, f"{split}_ds", ds)
return ds

def train_dataset(self) -> ProteinDataset:
return self._get_dataset("training")

def val_dataset(self) -> ProteinDataset:
return self._get_dataset("validation")

def test_dataset(self) -> ProteinDataset:
return self._get_dataset(f"test_{self.split}")
def test_dataset(
self, split: Literal["fold", "family", "superfamily"]
) -> ProteinDataset:
return self._get_dataset(f"test_{split}")

def train_dataloader(self) -> ProteinDataLoader:
self.train_ds = self.train_dataset()
Expand All @@ -201,8 +206,10 @@ def val_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def test_dataloader(self) -> ProteinDataLoader:
self.test_ds = self.test_dataset()
def test_dataloader(
self, split: Literal["fold", "family", "superfamily"]
) -> ProteinDataLoader:
self.test_ds = self.test_dataset(split)
return ProteinDataLoader(
self.test_ds,
batch_size=self.batch_size,
Expand All @@ -211,17 +218,6 @@ def test_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def get_test_loader(self, split: str) -> ProteinDataLoader:
log.info(f"Getting test loader: {split}")
test_ds = self._get_dataset(f"test_{split}")
return ProteinDataLoader(
test_ds,
batch_size=self.batch_size,
shuffle=False,
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)

def parse_dataset(self, split: str) -> pd.DataFrame:
"""
Parses the raw dataset files to Pandas DataFrames.
Expand Down
99 changes: 75 additions & 24 deletions proteinworkshop/datasets/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Callable, Dict, Iterable, Literal, Optional
from typing import Callable, Dict, Iterable, List, Literal, Optional

import omegaconf
import pandas as pd
Expand Down Expand Up @@ -70,6 +70,14 @@ def __init__(

self.shuffle_labels = shuffle_labels

self.test_seq_similarity_cutoffs: List[float] = [
0.3,
0.4,
0.5,
0.7,
0.95,
]

if transforms is not None:
self.transform = self.compose_transforms(
omegaconf.OmegaConf.to_container(transforms, resolve=True)
Expand All @@ -79,14 +87,19 @@ def __init__(

self.train_fname = self.data_dir / "nrPDB-GO_train.txt"
self.val_fname = self.data_dir / "nrPDB-GO_valid.txt"
self.test_fname = self.data_dir / "nrPDB-GO_test.txt"
self.test_fname = self.data_dir / "nrPDB-GO_test.csv"
self.label_fname = self.data_dir / "nrPDB-GO_annot.tsv"
self.url = "https://zenodo.org/record/6622158/files/GeneOntology.zip"

log.info(
f"Setting up Gene Ontology dataset. Fraction {self.dataset_fraction}"
)

@property
def test_dataset_names(self) -> List[str]:
"""Provides a list of test set split names."""
return ["test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"]

@lru_cache
def parse_labels(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -130,11 +143,23 @@ def parse_labels(self) -> Dict[str, torch.Tensor]:
return labels

def _get_dataset(
self, split: Literal["training", "validation", "testing"]
self,
split: Literal[
"training",
"validation",
"test_0.3",
"test_0.4",
"test_0.5",
"test_0.7",
"test_0.95",
],
) -> ProteinDataset:
if hasattr(self, f"{split}_ds"):
return getattr(self, f"{split}_ds")

df = self.parse_dataset(split)
log.info("Initialising Graphein dataset...")
return ProteinDataset(
ds = ProteinDataset(
root=str(self.data_dir),
pdb_dir=str(self.pdb_dir),
pdb_codes=list(df.pdb),
Expand All @@ -147,15 +172,22 @@ def _get_dataset(
format=self.format,
in_memory=self.in_memory,
)
setattr(self, f"{split}_ds", ds)
return ds

def train_dataset(self) -> ProteinDataset:
return self._get_dataset("training")

def val_dataset(self) -> ProteinDataset:
return self._get_dataset("validation")

def test_dataset(self) -> ProteinDataset:
return self._get_dataset("testing")
def test_dataset(
self,
split: Literal[
"test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"
],
) -> ProteinDataset:
return self._get_dataset(split)

def train_dataloader(self) -> ProteinDataLoader:
return ProteinDataLoader(
Expand All @@ -175,9 +207,14 @@ def val_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def test_dataloader(self) -> ProteinDataLoader:
def test_dataloader(
self,
split: Literal[
"test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"
],
) -> ProteinDataLoader:
return ProteinDataLoader(
self.test_dataset(),
self.test_dataset(split),
batch_size=self.batch_size,
shuffle=False,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -205,7 +242,16 @@ def exclude_pdbs(self):
pass

def parse_dataset(
self, split: Literal["training", "validation", "testing"]
self,
split: Literal[
"training",
"validation",
"test_0.3",
"test_0.4",
"test_0.5",
"test_0.7",
"test_0.95",
],
) -> pd.DataFrame:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches, switch
"""
Expand All @@ -221,8 +267,11 @@ def parse_dataset(
data = data.sample(frac=self.dataset_fraction)
elif split == "validation":
data = pd.read_csv(self.val_fname, sep="\t", header=None)
elif split == "testing":
data = pd.read_csv(self.test_fname, sep="\t", header=None)
elif split.startswith("test_"):
cutoff = int(float(split.split("_")[1]) * 100)
data = pd.read_csv(self.test_fname, sep=",")
data = data.loc[data[f"<{cutoff}%"] == 1]
data = pd.DataFrame(data["PDB-chain"].values)
else:
raise ValueError(f"Unknown split: {split}")

Expand Down Expand Up @@ -304,16 +353,18 @@ def __call__(self, data: Protein) -> Protein:
cfg.datamodule.transforms = []
log.info("Loaded config")

ds = hydra.utils.instantiate(cfg)
print(ds)
# labels = ds["datamodule"].parse_labels()
ds.datamodule.setup()
dl = ds["datamodule"].train_dataloader()
for batch in dl:
print(batch)
dl = ds["datamodule"].val_dataloader()
for batch in dl:
print(batch)
dl = ds["datamodule"].test_dataloader()
for batch in dl:
print(batch)
ds = hydra.utils.instantiate(cfg)["datamodule"]
ds.parse_dataset("test_0.3")
ds.parse_dataset("test_0.95")
# print(ds)
## labels = ds["datamodule"].parse_labels()
# ds.datamodule.setup()
# dl = ds["datamodule"].train_dataloader()
# for batch in dl:
# print(batch)
# dl = ds["datamodule"].val_dataloader()
# for batch in dl:
# print(batch)
# dl = ds["datamodule"].test_dataloader()
# for batch in dl:
# print(batch)
15 changes: 6 additions & 9 deletions proteinworkshop/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def train_model(

log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
datamodule.setup(stage="lazy_init") # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
Expand Down Expand Up @@ -185,16 +185,13 @@ def train_model(

if cfg.get("test"):
log.info("Starting testing!")
# Run test on all splits if using fold_classification dataset
if (
cfg.dataset.datamodule._target_
== "proteinworkshop.datasets.fold_classification.FoldClassificationDataModule"
):
splits = ["fold", "family", "superfamily"]
if hasattr(datamodule, "test_dataset_names"):
splits = datamodule.test_dataset_names
wandb_logger = copy.deepcopy(trainer.logger)
for split in splits:
dataloader = datamodule.get_test_loader(split)
for i, split in enumerate(splits):
dataloader = datamodule.test_dataloader(split)
trainer.logger = False
log.info(f"Testing on {split} ({i+1} / {len(splits)})...")
results = trainer.test(
model=model, dataloaders=dataloader, ckpt_path="best"
)[0]
Expand Down

0 comments on commit 99696a6

Please sign in to comment.