Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Idr script + rebase for testing against ci? #68

Merged
merged 16 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
# - id: trailing-whitespace
# - repo: https://github.com/psf/black
# rev: 22.10.0
# hooks:
Expand Down
4 changes: 2 additions & 2 deletions bioimage_embed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from .config import Config
from . import augmentations

# import logging
# logging.captureWarnings(True)
import logging
logging.captureWarnings(True)

__all__ = [
"AESupervised",
Expand Down
17 changes: 5 additions & 12 deletions bioimage_embed/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,27 @@
from albumentations.pytorch import ToTensorV2

DEFAULT_AUGMENTATION_LIST = [
# Flip the images horizontally or vertically with a 50% chance
A.OneOf(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
],
p=0.5,
),
# Rotate the images by a random angle within a specified range

A.Rotate(limit=45, p=0.5),
# Randomly scale the image intensity to adjust brightness and contrast
A.RandomGamma(gamma_limit=(80, 120), p=0.5),
# Apply random elastic transformations to the images
# A.RandomGamma(gamma_limit=(80, 120), p=0.5),
A.ElasticTransform(
alpha=1,
sigma=50,
alpha_affine=50,
p=0.5,
),
# Shift the image channels along the intensity axis
A.ChannelShuffle(p=0.5),
# Add a small amount of noise to the images
A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
# Crop a random part of the image and resize it back to the original size
# A.ChannelShuffle(p=0.5),
A.RandomResizedCrop(
height=224, width=224, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.5
height=224, width=224, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=1.0
),
# Adjust image intensity with a specified range for individual channels
A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.ToFloat(),
ToTensorV2(),
Expand Down
32 changes: 18 additions & 14 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Recipe:
batch_size: int = 16
data: str = "data"
opt: str = "adamw"
latent_dim: int = 64
batch_size: int = 16
max_epochs: int = 125
weight_decay: float = 0.001
momentum: float = 0.9
Expand Down Expand Up @@ -63,7 +65,7 @@ class Recipe:
# that pydantic can use
@dataclass(config=dict(extra="allow"))
class ATransform:
_target_: str = "albumentations.from_dict"
_target_: Any = "albumentations.from_dict"
_convert_: str = "object"
# _convert_: str = "all"
transform_dict: Dict = Field(
Expand All @@ -76,7 +78,7 @@ class ATransform:

@dataclass(config=dict(extra="allow"))
class Transform:
_target_: str = "bioimage_embed.augmentations.VisionWrapper"
_target_: Any = "bioimage_embed.augmentations.VisionWrapper"
_convert_: str = "object"
# transform: ATransform = field(default_factory=ATransform)
transform_dict: Dict = Field(
Expand Down Expand Up @@ -132,10 +134,10 @@ class DataLoader:

@dataclass(config=dict(extra="allow"))
class Model:
_target_: str = "bioimage_embed.models.create_model"
_target_: Any = "bioimage_embed.models.create_model"
model: str = II("recipe.model")
input_dim: List[int] = Field(default_factory=lambda: [3, 224, 224])
latent_dim: int = 64
latent_dim: int = II("recipe.latent_dim")
pretrained: bool = True


Expand All @@ -146,15 +148,15 @@ class Callback:

@dataclass(config=dict(extra="allow"))
class EarlyStopping(Callback):
_target_: str = "pytorch_lightning.callbacks.EarlyStopping"
_target_: Any = "pytorch_lightning.callbacks.EarlyStopping"
monitor: str = "loss/val"
mode: str = "min"
patience: int = 3


@dataclass(config=dict(extra="allow"))
class ModelCheckpoint(Callback):
_target_: str = "pytorch_lightning.callbacks.ModelCheckpoint"
_target_: Any = "pytorch_lightning.callbacks.ModelCheckpoint"
save_last = True
save_top_k = 1
monitor = "loss/val"
Expand All @@ -167,8 +169,8 @@ class ModelCheckpoint(Callback):
class LightningModel:
_target_: str = "bioimage_embed.lightning.torch.AEUnsupervised"
# This should be pythae base autoencoder?
model: Model = Field(default_factory=Model)
args: Recipe = Field(default_factory=lambda: II("recipe"))
model: Any = Field(default_factory=Model)
args: Any = Field(default_factory=lambda: II("recipe"))


class LightningModelSupervised(LightningModel):
Expand All @@ -178,29 +180,31 @@ class LightningModelSupervised(LightningModel):
@dataclass(config=dict(extra="allow"))
class Callbacks:
# _target_: str = "collections.OrderedDict"
model_checkpoint: ModelCheckpoint = Field(default_factory=ModelCheckpoint)
early_stopping: EarlyStopping = Field(default_factory=EarlyStopping)
model_checkpoint: Any = Field(default_factory=ModelCheckpoint)
# early_stopping: Any = Field(default_factory=EarlyStopping)



@dataclass(config=dict(extra="allow"))
class Trainer:
_target_: str = "pytorch_lightning.Trainer"
# logger: Optional[any]
# class Trainer(pytorch_lightning.Trainer):
_target_: Any = "pytorch_lightning.Trainer"
logger: Any = None
gradient_clip_val: float = 0.5
enable_checkpointing: bool = True
devices: Any = "auto"
accelerator: str = "auto"
accumulate_grad_batches: int = 16
min_epochs: int = 1
max_epochs: int = II("recipe.max_epochs")
num_nodes: int = 1
log_every_n_steps: int = 1
# This is not a clean implementation but I am not sure how to do it better
callbacks: List[Any] = Field(
callbacks: Any = Field(
default_factory=lambda: list(vars(Callbacks()).values()), frozen=True
)
# TODO idea here would be to use pydantic to validate omegaconf


# TODO add argument caching for checkpointing


Expand Down
Empty file removed bioimage_embed/inference.py
Empty file.
Empty file.
Empty file.
Empty file removed bioimage_embed/models/tests/mae.py
Empty file.
Empty file.
Empty file.
30 changes: 30 additions & 0 deletions scripts/idr/lightning.study.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash -l
#SBATCH --nodes=3
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=1
#SBATCH --time=0-24:00:00
#SBATCH --job-name=lightning
#SBATCH --constraint=a100 # Ensure the job is scheduled on nodes with A100 GPUs
#SBATCH --mem-per-cpu=2GB
#SBATCH --cpus-per-task=32
#SBATCH --output=lightning_%j.out
set -x

source activate $1

# debugging flags (optional)
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# export NCCL_P2P_DISABLE=1
# unset LOCAL_RANK

# on your cluster you might need these:
# set the network interface
# export NCCL_SOCKET_IFNAME=^docker0,lo

# might need the latest CUDA
# module load NCCL/2.4.7-1-cuda.10.0

# run script from above
echo "Starting Lightning training script"
srun python3 scripts/idr/study.py
127 changes: 127 additions & 0 deletions scripts/idr/study.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import bioimage_embed
import bioimage_embed.config as config
# from ray.tune.integration.pytorch_lightning import (
# TuneReportCallback,
# TuneReportCheckpointCallback,

# )
import albumentations as A
from types import SimpleNamespace
from ray import tune
import numpy as np
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from hydra.utils import instantiate
import ray
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer,
)
import os
import glob
from PIL import Image
from typing import List
from torch.utils.data import Dataset
import torch
from joblib import Memory
from pydantic.dataclasses import dataclass
from pytorch_lightning import loggers as pl_loggers
params = {
"model": "resnet50_vqvae",
# "data": "data",
"opt": "adamw",
"max_epochs": 1000,
"max_steps": -1,
"weight_decay":0.0001,
"momentum": 0.9,
# "sched": "cosine",
"epochs": 1000,
"lr": 1e-3,
"batch_size": 16,
}
memory = Memory(location='.', verbose=0)

@memory.cache
def get_file_list(glob_str):
return glob.glob(os.path.join(glob_str), recursive=True)


class GlobDataset(Dataset):
def __init__(self, glob_str,transform=None):
self.file_list = get_file_list(glob_str)
self.transform = transform

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = self.file_list[idx]
image = Image.open(img_name)
# breakpoint()
image = np.array(image)
if self.transform:
# t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)])
t = A.Compose([A.ToRGB(),self.transform])
image = t(image=image)

# breakpoint()
# sample = {'image': image, 'path': img_name}

return image["image"], 0

root_dir = '/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/'
root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/'

if __name__ == "__main__":
print("training")
input_dim = [3, 224, 224]

# mock_dataset = config.ImageFolderDataset(
# _target_="bioimage_embed.datasets.FakeImageFolder",
# image_size=input_dim,
# num_classes=1,
# )
# breakpoint()
transform = instantiate(config.ATransform())
dataset = GlobDataset(root_dir+'**/*.tif*',transform)
dataloader = config.DataLoader(dataset=dataset,num_workers=32)

assert instantiate(dataloader,batch_size=1)
assert dataset[0]

model = config.Model(input_dim=input_dim)

lit_model = config.LightningModel(
_target_="bioimage_embed.lightning.torch.AutoEncoderSupervised",
model=model
)

wandb = pl_loggers.WandbLogger(project="idr", name="0093")
trainer = config.Trainer(
accelerator="auto",
devices=1,
num_nodes=1,
# strategy="ddp",
callbacks=[],
plugin=[],
logger=[wandb],
)

cfg = config.Config(
dataloader=dataloader,
lit_model=lit_model,
trainer=trainer,
recipe=config.Recipe(**params),
)
# breakpoint()

bie = bioimage_embed.BioImageEmbed(cfg)
wandb.watch(bie.icfg.lit_model, log="all")

bie.train()
wandb.finish()
Loading
Loading