Skip to content

Commit

Permalink
Merge pull request #66 from ctr26/dev
Browse files Browse the repository at this point in the history
@DataClass(config=dict(extra="allow"))
  • Loading branch information
ctr26 authored Sep 27, 2024
2 parents 981d2e7 + f3977f5 commit 2b97491
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 39 deletions.
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ include .env
include secrets.env
export

@PHONY: all download test

download.data:
kaggle competitions download -c data-science-bowl-2018
@PHONY: all test

test:
poetry run pytest -v --tb=no
16 changes: 1 addition & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,8 @@ OR
bie --help
```

### 3. Fetching Data:

This utility makes it simple to fetch the necessary datasets:

```bash
make download.data
```
If you don't have a Kaggle account you must create one and then follow the next steps:
1. Install the Kaggle API package so you can download the data from the Makefile you have all the information in their [Github repository](https://github.com/Kaggle/kaggle-api).
2. To use the Kaggle API you need also to create an API token.
You can found how to do it in their [documentation](https://github.com/Kaggle/kaggle-api#api-credentials)
4. After that you will need to add your user and key in a file called `kaggle.json` in this location in your home directory `chmod 600 ~/.kaggle/kaggle.json`
5. Don't forget to accept the conditions for the "2018 Data Science Bowl" on the Kaggle website.
Otherwise you would not be able to pull this data from the command line.

### 4. Developer Installation:
### 3. Developer Installation:

For those intending to contribute or looking for a deeper dive into the codebase, we use `poetry` to manage our dependencies and virtual environments:

Expand Down
44 changes: 24 additions & 20 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@
"""

# TODO need a way to copy signatures from original classes for validation
from omegaconf import OmegaConf
from bioimage_embed import augmentations as augs
import os
from dataclasses import field
from pydantic.dataclasses import dataclass

# from dataclasses import dataclass
from typing import List, Optional, Dict, Any

from pydantic import Field
from omegaconf import II
from . import utils


@dataclass
@dataclass(config=dict(extra="allow"))
class Recipe:
_target_: str = "types.SimpleNamespace"
model: str = "resnet18_vae"
Expand Down Expand Up @@ -58,7 +61,7 @@ class Recipe:

# Use the ALbumentations .to_dict() method to get the dictionary
# that pydantic can use
@dataclass
@dataclass(config=dict(extra="allow"))
class ATransform:
_target_: str = "albumentations.from_dict"
_convert_: str = "object"
Expand All @@ -71,7 +74,7 @@ class ATransform:
# VisionWrapper is a helper class for applying albumentations pipelines for image augmentations in autoencoding


@dataclass
@dataclass(config=dict(extra="allow"))
class Transform:
_target_: str = "bioimage_embed.augmentations.VisionWrapper"
_convert_: str = "object"
Expand All @@ -81,7 +84,7 @@ class Transform:
)


@dataclass
@dataclass(config=dict(extra="allow"))
class Dataset:
_target_: str = "torch.utils.data.Dataset"
transform: Any = Field(default_factory=Transform)
Expand All @@ -93,24 +96,24 @@ class Dataset:
# return self


@dataclass
@dataclass(config=dict(extra="allow"))
class FakeDataset(Dataset):
_target_: str = "torchvision.datasets.FakeData"


@dataclass
@dataclass(config=dict(extra="allow"))
class ImageFolderDataset(Dataset):
_target_: str = "torchvision.datasets.ImageFolder"
# transform: Transform = Field(default_factory=Transform)
root: str = II("recipe.data")


@dataclass
@dataclass(config=dict(extra="allow"))
class NdDataset(ImageFolderDataset):
transform: Transform = Field(default_factory=Transform)


@dataclass
@dataclass(config=dict(extra="allow"))
class TiffDataset(NdDataset):
_target_: str = "bioimage_embed.datasets.TiffDataset"

Expand All @@ -119,15 +122,15 @@ class NgffDataset(NdDataset):
_target_: str = "bioimage_embed.datasets.NgffDataset"


@dataclass
@dataclass(config=dict(extra="allow"))
class DataLoader:
_target_: str = "bioimage_embed.lightning.dataloader.DataModule"
dataset: Any = Field(default_factory=FakeDataset)
num_workers: int = 1
batch_size: int = II("recipe.batch_size")


@dataclass
@dataclass(config=dict(extra="allow"))
class Model:
_target_: str = "bioimage_embed.models.create_model"
model: str = II("recipe.model")
Expand All @@ -136,20 +139,20 @@ class Model:
pretrained: bool = True


@dataclass
@dataclass(config=dict(extra="allow"))
class Callback:
pass


@dataclass
@dataclass(config=dict(extra="allow"))
class EarlyStopping(Callback):
_target_: str = "pytorch_lightning.callbacks.EarlyStopping"
monitor: str = "loss/val"
mode: str = "min"
patience: int = 3


@dataclass
@dataclass(config=dict(extra="allow"))
class ModelCheckpoint(Callback):
_target_: str = "pytorch_lightning.callbacks.ModelCheckpoint"
save_last = True
Expand All @@ -160,7 +163,7 @@ class ModelCheckpoint(Callback):
dirpath: str = f"{II('paths.model')}/{II('uuid')}"


@dataclass
@dataclass(config=dict(extra="allow"))
class LightningModel:
_target_: str = "bioimage_embed.lightning.torch.AEUnsupervised"
# This should be pythae base autoencoder?
Expand All @@ -172,20 +175,20 @@ class LightningModelSupervised(LightningModel):
_target_: str = "bioimage_embed.lightning.torch.AESupervised"


@dataclass
@dataclass(config=dict(extra="allow"))
class Callbacks:
# _target_: str = "collections.OrderedDict"
model_checkpoint: ModelCheckpoint = Field(default_factory=ModelCheckpoint)
early_stopping: EarlyStopping = Field(default_factory=EarlyStopping)


@dataclass
@dataclass(config=dict(extra="allow"))
class Trainer:
_target_: str = "pytorch_lightning.Trainer"
# logger: Optional[any]
gradient_clip_val: float = 0.5
enable_checkpointing: bool = True
devices: str = "auto"
devices: Any = "auto"
accelerator: str = "auto"
accumulate_grad_batches: int = 16
min_epochs: int = 1
Expand All @@ -195,12 +198,13 @@ class Trainer:
callbacks: List[Any] = Field(
default_factory=lambda: list(vars(Callbacks()).values()), frozen=True
)
# TODO idea here would be to use pydantic to validate omegaconf


# TODO add argument caching for checkpointing


@dataclass
@dataclass(config=dict(extra="allow"))
class Paths:
model: str = "models"
logs: str = "logs"
Expand All @@ -212,7 +216,7 @@ def __post_init__(self):
os.makedirs(path, exist_ok=True)


@dataclass
@dataclass(config=dict(extra="allow"))
class Config:
# This has to be dataclass.field instead of pydantic Field for somereason
paths: Any = field(default_factory=Paths)
Expand All @@ -224,7 +228,7 @@ class Config:
uuid: str = field(default_factory=lambda: utils.hashing_fn(Recipe()))


@dataclass
@dataclass(config=dict(extra="allow"))
class SupervisedConfig(Config):
lit_model: LightningModel = field(default_factory=LightningModel)

Expand Down

0 comments on commit 2b97491

Please sign in to comment.