Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pdb dataset fix #86

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Datasets
* Add stage-based conditions to `setup` in `ProteinDataModule` [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86)

### Models

Expand Down
55 changes: 44 additions & 11 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Dict

import hydra
import omegaconf
import os
import pandas as pd
import pathlib
from graphein.ml.datasets import PDBManager
from loguru import logger as log
from torch_geometric.data import Dataset
Expand All @@ -11,7 +13,6 @@
from proteinworkshop.datasets.base import ProteinDataModule, ProteinDataset
from proteinworkshop.datasets.utils import download_pdb_mmtf


class PDBData:
def __init__(
self,
Expand Down Expand Up @@ -130,21 +131,28 @@ class PDBDataModule(ProteinDataModule):
def __init__(
self,
path: Optional[str] = None,
structure_dir: Optional[str] = None,
pdb_dataset: Optional[PDBData] = None,
transforms: Optional[Iterable[Callable]] = None,
in_memory: bool = False,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False,
structure_format: str = "mmtf.gz",
overwrite: bool = False,
a-r-j marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
self.root = path
self.dataset = pdb_dataset
self.dataset.path = path
self.format = "mmtf.gz"
self.format = structure_format
self.overwrite = overwrite

if structure_dir is not None:
self.structure_dir = pathlib.Path(structure_dir)
else:
self.structure_dir = pathlib.Path(self.root) / "raw"
a-r-j marked this conversation as resolved.
Show resolved Hide resolved

self.in_memory = in_memory

if transforms is not None:
Expand All @@ -157,19 +165,45 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.batch_size = batch_size

def parse_dataset(self) -> pd.DataFrame:
return self.dataset.create_dataset()


def parse_dataset(self) -> Dict[str, pd.DataFrame]:
if hasattr(self, "splits"):
return getattr(self, "splits")

splits = self.dataset.create_dataset()
ids_to_exclude = self.exclude_pdbs()

if ids_to_exclude is not None:
for k, v in splits.items():
log.info(f"Split {k} has {len(v)} chains before excluding failing PDB")
v["id"] = v["pdb"] + "_" + v["chain"].str.join("")
log.info(v)
splits[k] = v.loc[v.id.isin(ids_to_exclude) == False]
log.info(
f"Split {k} has {len(splits[k])} chains after excluding failing PDB"
)
self.splits = splits
breakpoint()
return splits
# def parse_dataset(self) -> pd.DataFrame:
# return self.dataset.create_dataset()
a-r-j marked this conversation as resolved.
Show resolved Hide resolved

def exclude_pdbs(self):
pass

def download(self):
pdbs = self.parse_dataset()

for k, v in pdbs:
log.info(f"Downloading {k} PDBs")
download_pdb_mmtf(pathlib.Path(self.root) / "raw", v.pdb.tolist())
for k, v in pdbs.items():
log.info(f"Downloading {k} PDBs to {self.structure_dir}")
pdblist = v.pdb.tolist()
pdblist = [
pdb
for pdb in pdblist
if not os.path.exists(self.structure_dir / f"{pdb}.{self.format}")
]
download_pdb_mmtf(self.structure_dir, pdblist)

def parse_labels(self):
raise NotImplementedError
Expand Down Expand Up @@ -223,7 +257,6 @@ def test_dataset(self) -> Dataset:


if __name__ == "__main__":
import pathlib

from proteinworkshop import constants

Expand All @@ -234,4 +267,4 @@ def test_dataset(self) -> Dataset:
print(cfg)
ds = hydra.utils.instantiate(cfg)["datamodule"]
print(ds)
ds.val_dataset()
ds.val_dataset()
Loading