Skip to content

Commit

Permalink
Merge pull request #11 from ctr26/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ctr26 authored Aug 26, 2023
2 parents 63906dd + 85b7b9d commit 87e76d2
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 164 deletions.
5 changes: 4 additions & 1 deletion bio_vae/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@

from functools import lru_cache

from albumentations import Compose
from typing import Callable


class DatasetGlob(Dataset):
def __init__(
self,
path_glob,
over_sampling=1,
transform=None,
transform: Callable = Compose([]),
samples=-1,
shuffle=True,
**kwargs,
Expand Down
40 changes: 25 additions & 15 deletions bio_vae/lightning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,35 @@ def __init__(
def get_dataset(self):
return self.dataset

def splitting(self, dataset, split=0.8, seed=42):
if len(dataset) < 4:
return dataset, dataset, dataset, dataset
spliting_shares = [
len(dataset) * split * split, # train
len(dataset) * split * (1 - split), # test
len(dataset) * split * (1 - split), # predict
len(dataset) * (1 - split) * (1 - split), # val
]

train, test, predict, val = random_split(

def splitting(self, dataset, split_train=0.8, split_val=0.1, seed=42):
if len(dataset) < 3:
return dataset, dataset, dataset

train_share = int(len(dataset) * split_train)
val_share = int(len(dataset) * split_val)
test_share = len(dataset) - train_share - val_share

# Ensure that the splits add up correctly
if train_share + val_share + test_share != len(dataset):
raise ValueError("The splitting ratios do not add up to the length of the dataset")

torch.manual_seed(seed) # for reproducibility

train, val, test = random_split(
dataset,
list(map(int, saferound(spliting_shares, places=0))),
[train_share, val_share, test_share]
)

return test, train, predict, val
return train, val, test

def setup(self, stage=None):
self.test, self.train, self.predict, self.val = self.splitting(self.dataset)
self.train, self.val, self.test = self.splitting(self.dataset)

# self.test = self.get_dataloader(test)
# self.predict = self.get_dataloader(predict)
# self.train = self.get_dataloader(train)
# self.val = self.get_dataloader(val)

def test_dataloader(self):
return DataLoader(self.test, **self.data_loader_settings)
Expand All @@ -72,7 +82,7 @@ def val_dataloader(self):
return DataLoader(self.val, **self.data_loader_settings)

def predict_dataloader(self):
return DataLoader(self.predict, **self.data_loader_settings)
return DataLoader(self.dataset, **self.data_loader_settings)

# def teardown(self, stage: Optional[str] = None):
# # Used to clean-up when the run is finished
Expand Down
2 changes: 0 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
channels:
- bioconda
- pytorch
- idr
- ome
- conda-forge
# - defaults
- torch
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ umap-learn = {extras = ["plot"], version = "^0.5.3"}
colorcet = "^3.0.1"
holoviews = "^1.15.2"
# idr-py = "^0.4.2"
llvmlite = "^0.39.1"
#llvmlite = "^0.39.1"
torchmetrics = "^0.11.0"
tensorboard = "^2.11.2"
albumentations = "^1.3.0"
Expand Down
Binary file added scripts/.train_ivy_gap_legacy.py.swp
Binary file not shown.
Loading

0 comments on commit 87e76d2

Please sign in to comment.