Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lobster #20

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/beignet.datasets.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# beignet.datasets

## Protein sequences

::: beignet.datasets.ATOM3DPPIDataset
::: beignet.datasets.CaLMDataset
::: beignet.datasets.DataFrameDataset
::: beignet.datasets.FASTADataset
::: beignet.datasets.LMDBDataset
::: beignet.datasets.SequenceDataset
::: beignet.datasets.SizedSequenceDataset
::: beignet.datasets.UniRef50Dataset
::: beignet.datasets.UniRef90Dataset
::: beignet.datasets.UniRef100Dataset
4 changes: 4 additions & 0 deletions docs/beignet.lightning.datamodules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# beignet.lightning.datamodules

::: beignet.lightning.datamodules.CaLMLightningDataModule

File renamed without changes.
Empty file added src/beignet/data/__init__.py
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"cls_token": "<cls>",
"eos_token": "<eos>",
"mask_token": "<mask>",
"pad_token": "<pad>",
"unk_token": "<unk>"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"clean_up_tokenization_spaces": true,
"do_lower_case": false,
"model_max_length": 1024,
"tokenizer_class": "ProteinMLMTokenizer"
}
33 changes: 33 additions & 0 deletions src/beignet/data/tokenizers/protein_mlm_tokenizer/vocab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<cls>
<pad>
<eos>
<unk>
L
A
G
V
S
E
R
T
I
D
P
K
Q
N
F
Y
M
H
W
C
X
B
U
Z
O
.
-
<null_1>
<mask>
44 changes: 44 additions & 0 deletions src/beignet/datasets/__atom3d_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from os import PathLike
from pathlib import Path
from typing import Callable

import pooch

from beignet.transforms import Transform

from ._lmdb_dataset import LMDBDataset


class ATOM3DDataset(LMDBDataset):
def __init__(
self,
root: str | PathLike,
path: str | PathLike,
resource: str,
name: str,
*,
checksum: str | None = None,
download: bool = False,
transform: Callable | Transform | None = None,
):
if root is None:
root = pooch.os_cache("beignet")

if isinstance(root, str):
root = Path(root)

self.root = root.resolve()

self.transform = transform

if download:
pooch.retrieve(
resource,
self.root / f"ATOM3D{name}",
checksum=checksum,
)

super().__init__(
self.root / f"ATOM3D{name}" / path,
transform=transform,
)
9 changes: 8 additions & 1 deletion src/beignet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from .__uni_ref_dataset import _UniRefDataset
from ._atom3d_ppi_dataset import ATOM3DPPIDataset
from ._calm_dataset import CaLMDataset
from ._dataframe_dataset import DataFrameDataset
from ._fasta_dataset import FASTADataset
from ._lmdb_dataset import LMDBDataset
from ._sequence_dataset import SequenceDataset
from ._sized_sequence_dataset import SizedSequenceDataset
from ._uniref50_dataset import UniRef50Dataset
from ._uniref90_dataset import UniRef90Dataset
from ._uniref100_dataset import UniRef100Dataset

__all__ = [
"ATOM3DPPIDataset",
"CaLMDataset",
"DataFrameDataset",
"FASTADataset",
"LMDBDataset",
"SequenceDataset",
"SizedSequenceDataset",
"UniRef100Dataset",
Expand Down
52 changes: 52 additions & 0 deletions src/beignet/datasets/_atom3d_ppi_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from os import PathLike
from typing import Callable

from pandas import DataFrame

from beignet.transforms import Transform

from .__atom3d_dataset import ATOM3DDataset


class ATOM3DPPIDataset(ATOM3DDataset):
def __init__(
self,
root: str | PathLike,
*,
download: bool = False,
joint_transform: Callable | Transform | None = None,
target_transform: Callable | Transform | None = None,
transform: Callable | Transform | None = None,
):
super().__init__(
root,
"raw/DIPS/data",
"https://zenodo.org/record/4911102/files/PPI-raw.tar.gz",
"PPI",
checksum="621977d132b39957e3480a24a30a7358",
download=download,
)

self.joint_transform = joint_transform

self.target_transform = target_transform

self.transform = transform

def __getitem__(self, index: int) -> (DataFrame, DataFrame):
item = super().__getitem__(index)

input = DataFrame(**item["atoms_pairs"])

target = DataFrame(**item["atoms_neighbors"])

if self.joint_transform is not None:
input, target = self.joint_transform(input, target)

if self.transform is not None:
input = self.transform(input)

if self.target_transform is not None:
target = self.target_transform(target)

return input, target
83 changes: 83 additions & 0 deletions src/beignet/datasets/_calm_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import shutil
from os import PathLike
from pathlib import Path
from typing import Callable

import pooch
from pooch import Decompress, Untar

from beignet.transforms import Transform

from ._fasta_dataset import FASTADataset


class CaLMDataset(FASTADataset):
def __init__(
self,
root: str | PathLike | None = None,
*,
train: bool = True,
transform: Callable | Transform | None = None,
target_transform: Callable | Transform | None = None,
):
"""
Parameters
----------
root : str | PathLike, optional
Root directory where the dataset subdirectory exists or, if
`download` is `True`, the directory where the dataset subdirectory
will be created and the dataset downloaded.

transform : Callable | Transform, optional
A `Callable` or `Transform` that that maps a sequence to a
transformed sequence (default: `None`).

target_transform : Callable | Transform, optional
A `Callable` or `Transform` that maps a target (a cluster
identifier) to a transformed target (default: `None`).
"""
if root is None:
root = pooch.os_cache("beignet")

if isinstance(root, str):
root = Path(root)

self._root = root.resolve()

name = self.__class__.__name__.replace("Dataset", "")

if train:
path = pooch.retrieve(
"https://opig.stats.ox.ac.uk/data/downloads/training_data.tar.gz",
"sha256:22673dc6db6fa0dfa9fb6d2b1e94fe2a94f01e0a726e597f03d9b31d1b503f0e",
f"{name}-train.fasta.gz",
root / name,
processor=Decompress(
name=f"{name}-train.fasta",
),
progressbar=True,
)
else:
pooch.retrieve(
"https://opig.stats.ox.ac.uk/data/downloads/heldout.tar.gz",
"sha256:6433a04c75aa16b555bb5fe2e0501315e5e98811d19447f6f8bc05939e8cb23d",
f"{name}-test.fasta.gz",
root / name,
processor=Untar(
extract_dir=f"{name}-test",
),
progressbar=True,
)

shutil.move(
f"{root}/{name}/{name}-test/hsapiens.fasta",
f"{root}/{name}/{name}-test.fasta",
)

path = root / name / f"{name}-test.fasta"

super().__init__(path)

self.transform = transform

self.target_transform = target_transform
68 changes: 68 additions & 0 deletions src/beignet/datasets/_dataframe_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Callable, Sequence, TypeVar

import torch
from pandas import DataFrame
from torch import Tensor
from torch.utils.data import Dataset

from beignet.transforms import Transform

T = TypeVar("T")


class DataFrameDataset(Dataset):
data: DataFrame

def __init__(
self,
data: DataFrame,
*,
transform: Callable | Transform | None = None,
target_transform: Callable | Transform | None = None,
columns: Sequence[str] | None = None,
target_columns: Sequence[str] | None = None,
):
self.data = data

self.transform = transform

self.target_transform = target_transform

self.columns = columns

self.target_columns = target_columns

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index: int) -> T | (T, T):
item = self.data.iloc[index]

if len(self.columns) > 1:
input = tuple(item[column] for column in self.columns)
else:
input = item[self.columns[0]]

if self.transform is not None:
input = self.transform(input)

if self.target_columns is None:
return input

if len(self.target_columns) > 1:
target = tuple(item[column] for column in self.target_columns)
else:
target = item[self.target_columns[0]]

if self.target_transform is not None:
target = self.target_transform(target)

if len(self.target_columns) > 1:
if not all(isinstance(y, Tensor) for y in target):
target = tuple(torch.as_tensor(y) for y in target)
elif not isinstance(target, Tensor):
target = torch.as_tensor(target)
elif not isinstance(target, Tensor):
target = torch.as_tensor(target)

return input, target
Loading
Loading