From 80c16e883e276f4d9bfcc1c28d34d66ddb08b92b Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 8 Jul 2024 18:19:02 +0100 Subject: [PATCH 01/16] [bug] imports for bie are needed --- bioimage_embed/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimage_embed/__init__.py b/bioimage_embed/__init__.py index e135e24a..0e1a0d5a 100644 --- a/bioimage_embed/__init__.py +++ b/bioimage_embed/__init__.py @@ -10,8 +10,8 @@ from .config import Config from . import augmentations -# import logging -# logging.captureWarnings(True) +import logging +logging.captureWarnings(True) __all__ = [ "AESupervised", From 35f30087cd4073325fc28a0121fbd2e8b7f7d89d Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 8 Jul 2024 18:20:41 +0100 Subject: [PATCH 02/16] [bug] midssing batch_size --- bioimage_embed/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bioimage_embed/config.py b/bioimage_embed/config.py index a610567c..a7935be8 100644 --- a/bioimage_embed/config.py +++ b/bioimage_embed/config.py @@ -34,6 +34,7 @@ class Recipe: batch_size: int = 16 data: str = "data" opt: str = "adamw" + batch_size: int = 16 max_epochs: int = 125 weight_decay: float = 0.001 momentum: float = 0.9 From 46a670c9790dd43145f2727f9dcab192e6b7b06b Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 8 Jul 2024 18:22:49 +0100 Subject: [PATCH 03/16] [bug] white space dude is annoying --- .pre-commit-config.yaml | 2 +- scripts/idr/study.py | 141 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 scripts/idr/study.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a9191db..ed291b20 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: check-yaml - id: end-of-file-fixer - - id: trailing-whitespace + # - id: trailing-whitespace # - repo: https://github.com/psf/black # rev: 22.10.0 # hooks: diff --git a/scripts/idr/study.py b/scripts/idr/study.py new file mode 100644 index 00000000..c0d53ec8 --- /dev/null +++ b/scripts/idr/study.py @@ -0,0 +1,141 @@ +import bioimage_embed +import bioimage_embed.config as config +# from ray.tune.integration.pytorch_lightning import ( +# TuneReportCallback, +# TuneReportCheckpointCallback, + +# ) +from ray import tune +import numpy as np +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig +from hydra.utils import instantiate +import ray +from ray.train.lightning import ( + RayDDPStrategy, + RayLightningEnvironment, + RayTrainReportCallback, + prepare_trainer, +) +from pytorch_lightning import loggers as pl_loggers + +if __name__ == "__main__": + + ray.init() + input_dim = [3, 224, 224] + # trainer = instantiate(cfg.trainer) + params_space = { + "model": tune.choice( + [ + "resnet50_vqvae", + "resnet110_vqvae_legacy", + "resnet152_vqvae_legacy", + ] + ), + # "data": "data", + "opt": tune.choice(["adamw","LAMB"]), + "max_epochs": 1000, + "max_steps": -1, + "weight_decay": tune.uniform(0.0001, 0.01), + "momentum": tune.uniform(0.8, 0.99), + # "sched": "cosine", + "epochs": 1000, + "lr": tune.loguniform(1e-6, 1e-2), + "batch_size": tune.choice([2 **x for x in range(4,12)]) + # tune.qlograndint(4, 4096,q=1,base=2), + # "min_lr": 1e-6, + # "t_initial": 10, + # "t_mul": 2, + # "decay_rate": 0.1, + # "warmup_lr": 1e-6, + # "warmup_lr_init": 1e-6, + # "warmup_epochs": 5, + # "cycle_limit": None, + # "t_in_epochs": False, + # "noisy": False, + # "noise_std": 0.1, + # "noise_pct": 0.67, + # "cooldown_epochs": 5, + # "warmup_t": 0, + # "seed": 42 + } + + # root = "/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation" + + # mock_dataset = config.ImageFolderDataset( + # image_size=input_dim, + # root="/nfs/", + # ) + + mock_dataset = config.ImageFolderDataset( + _target_="bioimage_embed.datasets.FakeImageFolder", + image_size=input_dim, + num_classes=1, + ) + + dataloader = config.DataLoader(dataset=mock_dataset) + # breakpoint() + model = config.Model(input_dim=input_dim) + + + lit_model = config.LightningModel( + _target_="bioimage_embed.lightning.torch.AutoEncoderSupervisedNChannels", + model=model, + ) + + trainer = config.Trainer( + devices="auto", + accelerator="auto", + strategy=RayDDPStrategy(), + # callbacks=[RayTrainReportCallback()], + plugins=[RayLightningEnvironment()], + ) + + def task(): + cfg = config.Config(dataloader=dataloader, model=model,trainer=trainer) + bie = bioimage_embed.BioImageEmbed(cfg) + # bie.icfg.trainer = prepare_trainer(bie.icfg.trainer) + bie.check() + return True + + assert task() + task = ray.remote(task) + gen = task.remote() + + def train(params): + + cfg = config.Config(dataloader=dataloader, + model=model, + trainer=trainer, + recipe=config.Recipe(**params)) + + + bie = bioimage_embed.BioImageEmbed(cfg) + wandb = pl_loggers.WandbLogger(project="bioimage-embed", name="shapes") + # bie.icfg.trainer = prepare_trainer(bie.icfg.trainer) + wandb.watch(bie.icfg.lit_model, log="all") + bie.train() + wandb.finish() + return bie + + + analysis = tune.run( + tune.with_parameters(train), + # resources_per_trial={"cpu": 32, "gpu": 1}, + config=params_space, + # metric="loss", + # mode="min", + num_samples=1, + scheduler=tune.schedulers.ASHAScheduler( + metric="val/loss", + mode="min", + max_t=10, + grace_period=1, + reduction_factor=2, + ), + ) + # results = tuner.fit() + print("Best hyperparameters found were: ", analysis.best_config) + + + # bie.export("model") From 69e2aaaae514b2d6a091e676854f9804e599940c Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 8 Jul 2024 18:23:15 +0100 Subject: [PATCH 04/16] [ref] formatting --- scripts/idr/study.py | 92 +++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/scripts/idr/study.py b/scripts/idr/study.py index c0d53ec8..6e9a7366 100644 --- a/scripts/idr/study.py +++ b/scripts/idr/study.py @@ -3,7 +3,7 @@ # from ray.tune.integration.pytorch_lightning import ( # TuneReportCallback, # TuneReportCheckpointCallback, - + # ) from ray import tune import numpy as np @@ -20,46 +20,45 @@ from pytorch_lightning import loggers as pl_loggers if __name__ == "__main__": - ray.init() input_dim = [3, 224, 224] # trainer = instantiate(cfg.trainer) params_space = { - "model": tune.choice( - [ - "resnet50_vqvae", - "resnet110_vqvae_legacy", - "resnet152_vqvae_legacy", - ] - ), - # "data": "data", - "opt": tune.choice(["adamw","LAMB"]), - "max_epochs": 1000, - "max_steps": -1, - "weight_decay": tune.uniform(0.0001, 0.01), - "momentum": tune.uniform(0.8, 0.99), - # "sched": "cosine", - "epochs": 1000, - "lr": tune.loguniform(1e-6, 1e-2), - "batch_size": tune.choice([2 **x for x in range(4,12)]) - # tune.qlograndint(4, 4096,q=1,base=2), - # "min_lr": 1e-6, - # "t_initial": 10, - # "t_mul": 2, - # "decay_rate": 0.1, - # "warmup_lr": 1e-6, - # "warmup_lr_init": 1e-6, - # "warmup_epochs": 5, - # "cycle_limit": None, - # "t_in_epochs": False, - # "noisy": False, - # "noise_std": 0.1, - # "noise_pct": 0.67, - # "cooldown_epochs": 5, - # "warmup_t": 0, - # "seed": 42 + "model": tune.choice( + [ + "resnet50_vqvae", + "resnet110_vqvae_legacy", + "resnet152_vqvae_legacy", + ] + ), + # "data": "data", + "opt": tune.choice(["adamw", "LAMB"]), + "max_epochs": 1000, + "max_steps": -1, + "weight_decay": tune.uniform(0.0001, 0.01), + "momentum": tune.uniform(0.8, 0.99), + # "sched": "cosine", + "epochs": 1000, + "lr": tune.loguniform(1e-6, 1e-2), + "batch_size": tune.choice([2**x for x in range(4, 12)]), + # tune.qlograndint(4, 4096,q=1,base=2), + # "min_lr": 1e-6, + # "t_initial": 10, + # "t_mul": 2, + # "decay_rate": 0.1, + # "warmup_lr": 1e-6, + # "warmup_lr_init": 1e-6, + # "warmup_epochs": 5, + # "cycle_limit": None, + # "t_in_epochs": False, + # "noisy": False, + # "noise_std": 0.1, + # "noise_pct": 0.67, + # "cooldown_epochs": 5, + # "warmup_t": 0, + # "seed": 42 } - + # root = "/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation" # mock_dataset = config.ImageFolderDataset( @@ -76,7 +75,6 @@ dataloader = config.DataLoader(dataset=mock_dataset) # breakpoint() model = config.Model(input_dim=input_dim) - lit_model = config.LightningModel( _target_="bioimage_embed.lightning.torch.AutoEncoderSupervisedNChannels", @@ -90,9 +88,9 @@ # callbacks=[RayTrainReportCallback()], plugins=[RayLightningEnvironment()], ) - + def task(): - cfg = config.Config(dataloader=dataloader, model=model,trainer=trainer) + cfg = config.Config(dataloader=dataloader, model=model, trainer=trainer) bie = bioimage_embed.BioImageEmbed(cfg) # bie.icfg.trainer = prepare_trainer(bie.icfg.trainer) bie.check() @@ -103,13 +101,13 @@ def task(): gen = task.remote() def train(params): - - cfg = config.Config(dataloader=dataloader, - model=model, - trainer=trainer, - recipe=config.Recipe(**params)) - - + cfg = config.Config( + dataloader=dataloader, + model=model, + trainer=trainer, + recipe=config.Recipe(**params), + ) + bie = bioimage_embed.BioImageEmbed(cfg) wandb = pl_loggers.WandbLogger(project="bioimage-embed", name="shapes") # bie.icfg.trainer = prepare_trainer(bie.icfg.trainer) @@ -118,7 +116,6 @@ def train(params): wandb.finish() return bie - analysis = tune.run( tune.with_parameters(train), # resources_per_trial={"cpu": 32, "gpu": 1}, @@ -137,5 +134,4 @@ def train(params): # results = tuner.fit() print("Best hyperparameters found were: ", analysis.best_config) - # bie.export("model") From d78b37b20a7b7adebb435745a76afb5629e86b97 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 08:59:11 +0100 Subject: [PATCH 05/16] [feat] fixing single study run --- scripts/idr/study.py | 179 ++++++++++++++++++++----------------------- 1 file changed, 84 insertions(+), 95 deletions(-) diff --git a/scripts/idr/study.py b/scripts/idr/study.py index 6e9a7366..b7b5cf5c 100644 --- a/scripts/idr/study.py +++ b/scripts/idr/study.py @@ -5,6 +5,8 @@ # TuneReportCheckpointCallback, # ) +import albumentations as A +from types import SimpleNamespace from ray import tune import numpy as np from ray.train.torch import TorchTrainer @@ -17,121 +19,108 @@ RayTrainReportCallback, prepare_trainer, ) +import os +import glob +from PIL import Image +from typing import List +from torch.utils.data import Dataset +import torch +from joblib import Memory +from pydantic.dataclasses import dataclass from pytorch_lightning import loggers as pl_loggers - -if __name__ == "__main__": - ray.init() - input_dim = [3, 224, 224] - # trainer = instantiate(cfg.trainer) - params_space = { - "model": tune.choice( - [ - "resnet50_vqvae", - "resnet110_vqvae_legacy", - "resnet152_vqvae_legacy", - ] - ), +params = { + "model": "resnet50_vqvae", # "data": "data", - "opt": tune.choice(["adamw", "LAMB"]), + "opt": "adamw", "max_epochs": 1000, "max_steps": -1, - "weight_decay": tune.uniform(0.0001, 0.01), - "momentum": tune.uniform(0.8, 0.99), + "weight_decay":0.0001, + "momentum": 0.9, # "sched": "cosine", "epochs": 1000, - "lr": tune.loguniform(1e-6, 1e-2), - "batch_size": tune.choice([2**x for x in range(4, 12)]), - # tune.qlograndint(4, 4096,q=1,base=2), - # "min_lr": 1e-6, - # "t_initial": 10, - # "t_mul": 2, - # "decay_rate": 0.1, - # "warmup_lr": 1e-6, - # "warmup_lr_init": 1e-6, - # "warmup_epochs": 5, - # "cycle_limit": None, - # "t_in_epochs": False, - # "noisy": False, - # "noise_std": 0.1, - # "noise_pct": 0.67, - # "cooldown_epochs": 5, - # "warmup_t": 0, - # "seed": 42 + "lr": 1e-3, + "batch_size": 16, } +memory = Memory(location='.', verbose=0) + +@memory.cache +def get_file_list(glob_str): + return glob.glob(os.path.join(glob_str), recursive=True) + + +class GlobDataset(Dataset): + def __init__(self, glob_str,transform=None): + self.file_list = get_file_list(glob_str) + + def __len__(self): + return len(self.file_list) - # root = "/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation" + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + img_name = self.file_list[idx] + image = Image.open(img_name) + # breakpoint() + image = np.array(image) + if transform: + # t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)]) + t = A.Compose([A.ToRGB(),transform]) + image = t(image=image) + + # breakpoint() + # sample = {'image': image, 'path': img_name} + + return image["image"], 0 + +root_dir = '/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/' +root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/' + +if __name__ == "__main__": + print("training") + input_dim = [3, 224, 224] + # mock_dataset = config.ImageFolderDataset( + # _target_="bioimage_embed.datasets.FakeImageFolder", # image_size=input_dim, - # root="/nfs/", + # num_classes=1, # ) + # breakpoint() + transform = instantiate(config.ATransform()) + dataset = GlobDataset(root_dir+'**/*.tif*',transform) + dataloader = config.DataLoader(dataset=dataset,num_workers=32) - mock_dataset = config.ImageFolderDataset( - _target_="bioimage_embed.datasets.FakeImageFolder", - image_size=input_dim, - num_classes=1, - ) + assert instantiate(dataloader,batch_size=1) + assert dataset[0] - dataloader = config.DataLoader(dataset=mock_dataset) - # breakpoint() model = config.Model(input_dim=input_dim) lit_model = config.LightningModel( - _target_="bioimage_embed.lightning.torch.AutoEncoderSupervisedNChannels", - model=model, + _target_="bioimage_embed.lightning.torch.AutoEncoderSupervised", + model=model ) + wandb = pl_loggers.WandbLogger(project="idr", name="0093") trainer = config.Trainer( - devices="auto", accelerator="auto", - strategy=RayDDPStrategy(), - # callbacks=[RayTrainReportCallback()], - plugins=[RayLightningEnvironment()], - ) - - def task(): - cfg = config.Config(dataloader=dataloader, model=model, trainer=trainer) - bie = bioimage_embed.BioImageEmbed(cfg) - # bie.icfg.trainer = prepare_trainer(bie.icfg.trainer) - bie.check() - return True - - assert task() - task = ray.remote(task) - gen = task.remote() - - def train(params): - cfg = config.Config( - dataloader=dataloader, - model=model, - trainer=trainer, - recipe=config.Recipe(**params), + devices=1, + num_nodes=1, + # strategy="ddp", + callbacks=[], + plugin=[], + logger=[wandb], ) - - bie = bioimage_embed.BioImageEmbed(cfg) - wandb = pl_loggers.WandbLogger(project="bioimage-embed", name="shapes") - # bie.icfg.trainer = prepare_trainer(bie.icfg.trainer) - wandb.watch(bie.icfg.lit_model, log="all") - bie.train() - wandb.finish() - return bie - - analysis = tune.run( - tune.with_parameters(train), - # resources_per_trial={"cpu": 32, "gpu": 1}, - config=params_space, - # metric="loss", - # mode="min", - num_samples=1, - scheduler=tune.schedulers.ASHAScheduler( - metric="val/loss", - mode="min", - max_t=10, - grace_period=1, - reduction_factor=2, - ), + + cfg = config.Config( + dataloader=dataloader, + lit_model=lit_model, + trainer=trainer, + recipe=config.Recipe(**params), ) - # results = tuner.fit() - print("Best hyperparameters found were: ", analysis.best_config) - - # bie.export("model") + # breakpoint() + + bie = bioimage_embed.BioImageEmbed(cfg) + wandb.watch(bie.icfg.lit_model, log="all") + + bie.train() + wandb.finish() From 1c3dbb700a304a063cee9ed09ab5152ce59e8bc6 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 08:59:29 +0100 Subject: [PATCH 06/16] [init] Adding sbatch for running study.py --- scripts/idr/lightning.study.sh | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 scripts/idr/lightning.study.sh diff --git a/scripts/idr/lightning.study.sh b/scripts/idr/lightning.study.sh new file mode 100644 index 00000000..e66d1762 --- /dev/null +++ b/scripts/idr/lightning.study.sh @@ -0,0 +1,41 @@ +#!/bin/bash -l +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=0-02:00:00 +#SBATCH --job-name=lightning +#SBATCH --constraint=a100 # Ensure the job is scheduled on nodes with A100 GPUs +#SBATCH --mem-per-cpu=2GB +#SBATCH --cpus-per-task=32 +#SBATCH --output=lightning_%j.out +set -x + +source ./env/bin/activate + +# debugging flags (optional) +export NCCL_DEBUG=INFO +export PYTHONFAULTHANDLER=1 + +# on your cluster you might need these: +# set the network interface +export NCCL_SOCKET_IFNAME=^docker0,lo + +# might need the latest CUDA +# module load NCCL/2.4.7-1-cuda.10.0 + +# run script from above +echo "Starting Lightning training script" +srun python3 -u scripts/idr/study.py + + +# # shellcheck disable=SC2206 +# # SBATCH --job-name=lightning +# # SBATCH --nodes=4 # This needs to match Trainer(num_nodes=...) +# # SBATCH --ntasks-per-node=1 # This needs to match Trainer(devices=...) +# # SBATCH --cpus-per-task=16 +# # SBATCH --mem-per-cpu=2GB +# # SBATCH --tasks-per-node=1 +# # SBATCH --gpus-per-task=1 +# # SBATCH --constraint=a100 # Ensure the job is scheduled on nodes with A100 GPUs +# # SBATCH --output=lightning_%j.out +# # SBATCH --time=24:00:00 \ No newline at end of file From 500756fd68ab0d8fd600a1d1a1e49430cbcd06db Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 09:00:09 +0100 Subject: [PATCH 07/16] [bug] pydantic types too rigid --- bioimage_embed/config.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/bioimage_embed/config.py b/bioimage_embed/config.py index a7935be8..17722f7c 100644 --- a/bioimage_embed/config.py +++ b/bioimage_embed/config.py @@ -64,7 +64,7 @@ class Recipe: # that pydantic can use @dataclass(config=dict(extra="allow")) class ATransform: - _target_: str = "albumentations.from_dict" + _target_: Any = "albumentations.from_dict" _convert_: str = "object" # _convert_: str = "all" transform_dict: Dict = Field( @@ -77,7 +77,7 @@ class ATransform: @dataclass(config=dict(extra="allow")) class Transform: - _target_: str = "bioimage_embed.augmentations.VisionWrapper" + _target_: Any = "bioimage_embed.augmentations.VisionWrapper" _convert_: str = "object" # transform: ATransform = field(default_factory=ATransform) transform_dict: Dict = Field( @@ -133,7 +133,7 @@ class DataLoader: @dataclass(config=dict(extra="allow")) class Model: - _target_: str = "bioimage_embed.models.create_model" + _target_: Any = "bioimage_embed.models.create_model" model: str = II("recipe.model") input_dim: List[int] = Field(default_factory=lambda: [3, 224, 224]) latent_dim: int = 64 @@ -147,7 +147,7 @@ class Callback: @dataclass(config=dict(extra="allow")) class EarlyStopping(Callback): - _target_: str = "pytorch_lightning.callbacks.EarlyStopping" + _target_: Any = "pytorch_lightning.callbacks.EarlyStopping" monitor: str = "loss/val" mode: str = "min" patience: int = 3 @@ -155,7 +155,7 @@ class EarlyStopping(Callback): @dataclass(config=dict(extra="allow")) class ModelCheckpoint(Callback): - _target_: str = "pytorch_lightning.callbacks.ModelCheckpoint" + _target_: Any = "pytorch_lightning.callbacks.ModelCheckpoint" save_last = True save_top_k = 1 monitor = "loss/val" @@ -168,8 +168,8 @@ class ModelCheckpoint(Callback): class LightningModel: _target_: str = "bioimage_embed.lightning.torch.AEUnsupervised" # This should be pythae base autoencoder? - model: Model = Field(default_factory=Model) - args: Recipe = Field(default_factory=lambda: II("recipe")) + model: Any = Field(default_factory=Model) + args: Any = Field(default_factory=lambda: II("recipe")) class LightningModelSupervised(LightningModel): @@ -179,14 +179,15 @@ class LightningModelSupervised(LightningModel): @dataclass(config=dict(extra="allow")) class Callbacks: # _target_: str = "collections.OrderedDict" - model_checkpoint: ModelCheckpoint = Field(default_factory=ModelCheckpoint) - early_stopping: EarlyStopping = Field(default_factory=EarlyStopping) + model_checkpoint: Any = Field(default_factory=ModelCheckpoint) + early_stopping: Any = Field(default_factory=EarlyStopping) + @dataclass(config=dict(extra="allow")) class Trainer: - _target_: str = "pytorch_lightning.Trainer" - # logger: Optional[any] + _target_: Any = "pytorch_lightning.Trainer" + logger: Optional[List[Any]] = Field(default_factory=List) gradient_clip_val: float = 0.5 enable_checkpointing: bool = True devices: Any = "auto" @@ -196,7 +197,7 @@ class Trainer: max_epochs: int = II("recipe.max_epochs") log_every_n_steps: int = 1 # This is not a clean implementation but I am not sure how to do it better - callbacks: List[Any] = Field( + callbacks: Any = Field( default_factory=lambda: list(vars(Callbacks()).values()), frozen=True ) # TODO idea here would be to use pydantic to validate omegaconf From 2a2b2d28b7eafae4ca5e0c15d2605d68b7d9ded1 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 09:00:39 +0100 Subject: [PATCH 08/16] [bug] old augmentations were not ideal --- bioimage_embed/augmentations.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/bioimage_embed/augmentations.py b/bioimage_embed/augmentations.py index c83b2562..18bec875 100644 --- a/bioimage_embed/augmentations.py +++ b/bioimage_embed/augmentations.py @@ -3,7 +3,6 @@ from albumentations.pytorch import ToTensorV2 DEFAULT_AUGMENTATION_LIST = [ - # Flip the images horizontally or vertically with a 50% chance A.OneOf( [ A.HorizontalFlip(p=0.5), @@ -11,26 +10,20 @@ ], p=0.5, ), - # Rotate the images by a random angle within a specified range + A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), A.Rotate(limit=45, p=0.5), - # Randomly scale the image intensity to adjust brightness and contrast - A.RandomGamma(gamma_limit=(80, 120), p=0.5), - # Apply random elastic transformations to the images + # A.RandomGamma(gamma_limit=(80, 120), p=0.5), A.ElasticTransform( alpha=1, sigma=50, alpha_affine=50, p=0.5, ), - # Shift the image channels along the intensity axis - A.ChannelShuffle(p=0.5), - # Add a small amount of noise to the images - A.GaussNoise(var_limit=(10.0, 50.0), p=0.5), - # Crop a random part of the image and resize it back to the original size + # A.ChannelShuffle(p=0.5), A.RandomResizedCrop( - height=224, width=224, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.5 + height=224, width=224, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=1.0 ), - # Adjust image intensity with a specified range for individual channels + A.GaussNoise(var_limit=(10.0, 50.0), p=0.5), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), A.ToFloat(), ToTensorV2(), From 3a51bd9f64bc0de0df2007af920b68e3ff58e53d Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 11:11:33 +0100 Subject: [PATCH 09/16] [bug] Everything needs to be Any (for now) --- bioimage_embed/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bioimage_embed/config.py b/bioimage_embed/config.py index 17722f7c..5ea859bb 100644 --- a/bioimage_embed/config.py +++ b/bioimage_embed/config.py @@ -34,6 +34,7 @@ class Recipe: batch_size: int = 16 data: str = "data" opt: str = "adamw" + latent_dim: int = 64 batch_size: int = 16 max_epochs: int = 125 weight_decay: float = 0.001 @@ -136,7 +137,7 @@ class Model: _target_: Any = "bioimage_embed.models.create_model" model: str = II("recipe.model") input_dim: List[int] = Field(default_factory=lambda: [3, 224, 224]) - latent_dim: int = 64 + latent_dim: int = II("recipe.latent_dim") pretrained: bool = True @@ -180,14 +181,15 @@ class LightningModelSupervised(LightningModel): class Callbacks: # _target_: str = "collections.OrderedDict" model_checkpoint: Any = Field(default_factory=ModelCheckpoint) - early_stopping: Any = Field(default_factory=EarlyStopping) + # early_stopping: Any = Field(default_factory=EarlyStopping) @dataclass(config=dict(extra="allow")) class Trainer: +# class Trainer(pytorch_lightning.Trainer): _target_: Any = "pytorch_lightning.Trainer" - logger: Optional[List[Any]] = Field(default_factory=List) + logger: Any = None gradient_clip_val: float = 0.5 enable_checkpointing: bool = True devices: Any = "auto" @@ -195,6 +197,7 @@ class Trainer: accumulate_grad_batches: int = 16 min_epochs: int = 1 max_epochs: int = II("recipe.max_epochs") + num_nodes: int = 1, log_every_n_steps: int = 1 # This is not a clean implementation but I am not sure how to do it better callbacks: Any = Field( @@ -202,7 +205,6 @@ class Trainer: ) # TODO idea here would be to use pydantic to validate omegaconf - # TODO add argument caching for checkpointing From ef580bba4d820f934ddb50b07a953b989b14e9a8 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 11:12:25 +0100 Subject: [PATCH 10/16] [bug] I think the images need normalising at the end --- bioimage_embed/augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimage_embed/augmentations.py b/bioimage_embed/augmentations.py index 18bec875..1f5ee279 100644 --- a/bioimage_embed/augmentations.py +++ b/bioimage_embed/augmentations.py @@ -10,7 +10,7 @@ ], p=0.5, ), - A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + A.Rotate(limit=45, p=0.5), # A.RandomGamma(gamma_limit=(80, 120), p=0.5), A.ElasticTransform( From a8ce27e41ce2b3da6b50a74439fa6f1ac163be1a Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 10 Jul 2024 11:13:00 +0100 Subject: [PATCH 11/16] [bug] faulty transform logic in script --- scripts/idr/study.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/idr/study.py b/scripts/idr/study.py index b7b5cf5c..1790b1cb 100644 --- a/scripts/idr/study.py +++ b/scripts/idr/study.py @@ -51,6 +51,7 @@ def get_file_list(glob_str): class GlobDataset(Dataset): def __init__(self, glob_str,transform=None): self.file_list = get_file_list(glob_str) + self.transform = transform def __len__(self): return len(self.file_list) @@ -63,9 +64,9 @@ def __getitem__(self, idx): image = Image.open(img_name) # breakpoint() image = np.array(image) - if transform: + if self.transform: # t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)]) - t = A.Compose([A.ToRGB(),transform]) + t = A.Compose([A.ToRGB(),self.transform]) image = t(image=image) # breakpoint() From 7d393922be98f61f19c58e893f95c5fef1f403eb Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 24 Jul 2024 11:45:14 +0100 Subject: [PATCH 12/16] [feat] example of multinode training using lightning --- scripts/idr/lightning.study.sh | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/scripts/idr/lightning.study.sh b/scripts/idr/lightning.study.sh index e66d1762..dac53e82 100644 --- a/scripts/idr/lightning.study.sh +++ b/scripts/idr/lightning.study.sh @@ -1,8 +1,8 @@ #!/bin/bash -l -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 +#SBATCH --nodes=3 +#SBATCH --gres=gpu:4 #SBATCH --ntasks-per-node=1 -#SBATCH --time=0-02:00:00 +#SBATCH --time=0-24:00:00 #SBATCH --job-name=lightning #SBATCH --constraint=a100 # Ensure the job is scheduled on nodes with A100 GPUs #SBATCH --mem-per-cpu=2GB @@ -10,32 +10,21 @@ #SBATCH --output=lightning_%j.out set -x -source ./env/bin/activate +source activate $1 # debugging flags (optional) export NCCL_DEBUG=INFO export PYTHONFAULTHANDLER=1 +# export NCCL_P2P_DISABLE=1 +# unset LOCAL_RANK # on your cluster you might need these: # set the network interface -export NCCL_SOCKET_IFNAME=^docker0,lo +# export NCCL_SOCKET_IFNAME=^docker0,lo # might need the latest CUDA # module load NCCL/2.4.7-1-cuda.10.0 # run script from above echo "Starting Lightning training script" -srun python3 -u scripts/idr/study.py - - -# # shellcheck disable=SC2206 -# # SBATCH --job-name=lightning -# # SBATCH --nodes=4 # This needs to match Trainer(num_nodes=...) -# # SBATCH --ntasks-per-node=1 # This needs to match Trainer(devices=...) -# # SBATCH --cpus-per-task=16 -# # SBATCH --mem-per-cpu=2GB -# # SBATCH --tasks-per-node=1 -# # SBATCH --gpus-per-task=1 -# # SBATCH --constraint=a100 # Ensure the job is scheduled on nodes with A100 GPUs -# # SBATCH --output=lightning_%j.out -# # SBATCH --time=24:00:00 \ No newline at end of file +srun python3 scripts/idr/study.py \ No newline at end of file From ba641ad26fa091cadc1b4e132baac9962f065769 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 24 Jul 2024 11:45:29 +0100 Subject: [PATCH 13/16] [feat] example using submitit --- scripts/idr/study.submitit.py | 158 ++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 scripts/idr/study.submitit.py diff --git a/scripts/idr/study.submitit.py b/scripts/idr/study.submitit.py new file mode 100644 index 00000000..ac8c38c3 --- /dev/null +++ b/scripts/idr/study.submitit.py @@ -0,0 +1,158 @@ +import bioimage_embed +import bioimage_embed.config as config +import wandb +from pytorch_lightning import LightningModule, Trainer +import albumentations as A +from types import SimpleNamespace +from ray import tune +import numpy as np +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig +from hydra.utils import instantiate +import os +import glob +from PIL import Image +from typing import List +from torch.utils.data import Dataset +import torch +from joblib import Memory +from pydantic.dataclasses import dataclass +from pytorch_lightning import loggers as pl_loggers +import submitit +import os + +NUM_GPUS_PER_NODE = 1 +NUM_NODES = 1 + + +params = { + "model": "resnet50_vqvae", + # "data": "data", + "opt": "lamb", + "latent_dim": 224**2//4, + "max_epochs": 1000, + "max_steps": -1, + "weight_decay": 0.0001, + "momentum": 0.9, + # "sched": "cosine", + "epochs": 1000, + "lr": 1e-3, + "batch_size": 16, + "sched": "cosine", + } +memory = Memory(location='.', verbose=0) + +@memory.cache +def get_file_list(glob_str): + return glob.glob(os.path.join(glob_str), recursive=True) + +def collate_fn(batch): + # Filter out None values + batch = list(filter(lambda x: x[0] is not None, batch)) + return torch.utils.data.dataloader.default_collate(batch) + +class GlobDataset(Dataset): + def __init__(self, glob_str,transform=None): + self.file_list = get_file_list(glob_str) + self.transform = transform + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + img_name = self.file_list[idx] + try: + image = Image.open(img_name) + except: + return None,None + # breakpoint() + image = np.array(image) + if self.transform: + # t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)]) + t = A.Compose([A.ToRGB(),self.transform]) + image = t(image=image) + + # breakpoint() + + return image["image"], 0 + +root_dir = '/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/' +root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/' + + + +def train(num_gpus_per_node=1,num_nodes=1): + + print("training") + input_dim = [3, 224, 224] + + # mock_dataset = config.ImageFolderDataset( + # _target_="bioimage_embed.datasets.FakeImageFolder", + # image_size=input_dim, + # num_classes=1, + # ) + + transform = instantiate(config.ATransform()) + dataset = GlobDataset(root_dir+'**/*.tif*',transform) + # dataset = RandomDataset(32, 64) + dataloader = config.DataLoader(dataset=dataset,num_workers=os.cpu_count(),collate_fn=collate_fn) + + assert instantiate(dataloader,batch_size=1) + assert dataset[0] + + model = config.Model(input_dim=input_dim) + + lit_model = config.LightningModel( + # _target_="bioimage_embed.lightning.torch.AutoEncoderSupervised", + model=model + ) + wandb = pl_loggers.WandbLogger(project="idr", name="0093",log_model="all") + trainer = config.Trainer( + accelerator="auto", + devices=num_gpus_per_node, + num_nodes=num_nodes, + strategy="ddp", + callbacks=[], + # plugin=[], + logger=[wandb], + ) + + cfg = config.Config( + dataloader=dataloader, + lit_model=lit_model, + trainer=trainer, + recipe=config.Recipe(**params), + ) + # breakpoint() + + bie = bioimage_embed.BioImageEmbed(cfg) + # wandb.watch(bie.icfg.lit_model, log="all") + # wandb.run.define_metric("mse/val", summary="best") + # wandb.run.define_metric("loss/val.loss", summary="best") + + bie.train() + wandb.finish() + +def main(): + logdir = "lightning_slurm/" + os.makedirs(logdir, exist_ok=True) + + # executor is the submission interface (logs are dumped in the folder) + executor = submitit.AutoExecutor(folder=logdir) + executor.update_parameters( + mem_gb=2 * 32 * 4, # 2GB per CPU, 32 CPUs per task, 4 tasks per node + timeout_min=1440*2, # 48 hours + # slurm_partition="your_partition_name", # Replace with your partition name + gpus_per_node=NUM_GPUS_PER_NODE, + tasks_per_node=1, + cpus_per_task=8, + nodes=NUM_NODES, + slurm_constraint="a100", + ) + job = executor.submit(train, NUM_GPUS_PER_NODE, NUM_NODES) + +if __name__ == "__main__": + main() \ No newline at end of file From 98769b85e9db262fcc27559106d1416ab819a8af Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 11 Sep 2024 14:45:43 +0100 Subject: [PATCH 14/16] [fix] script cleanup --- scripts/idr/study.submitit.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/scripts/idr/study.submitit.py b/scripts/idr/study.submitit.py index ac8c38c3..22ac7d93 100644 --- a/scripts/idr/study.submitit.py +++ b/scripts/idr/study.submitit.py @@ -20,7 +20,7 @@ from pytorch_lightning import loggers as pl_loggers import submitit import os - +import fsspec NUM_GPUS_PER_NODE = 1 NUM_NODES = 1 @@ -43,8 +43,11 @@ memory = Memory(location='.', verbose=0) @memory.cache -def get_file_list(glob_str): - return glob.glob(os.path.join(glob_str), recursive=True) +def get_file_list(glob_str,fs): + return fs.glob(glob_str) + # return fs.open(glob_str,filecache={'cache_storage':'tmp/idr'}) + # return fsspec.open_files(glob_str, recursive=True) + # return glob.glob(os.path.join(glob_str), recursive=True) def collate_fn(batch): # Filter out None values @@ -52,8 +55,10 @@ def collate_fn(batch): return torch.utils.data.dataloader.default_collate(batch) class GlobDataset(Dataset): - def __init__(self, glob_str,transform=None): - self.file_list = get_file_list(glob_str) + def __init__(self, glob_str,transform=None,fs=fsspec.filesystem('file')): + print("Getting file list, this may take a while") + self.file_list = get_file_list(glob_str,fs) + print("Done getting file list") self.transform = transform def __len__(self): @@ -64,8 +69,11 @@ def __getitem__(self, idx): idx = idx.tolist() img_name = self.file_list[idx] + obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'}) try: - image = Image.open(img_name) + with obj as f: + image = Image.open(f) + # image = Image.open(img_name) except: return None,None # breakpoint() @@ -79,10 +87,15 @@ def __getitem__(self, idx): return image["image"], 0 -root_dir = '/nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/' root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/' +fs = fsspec.filesystem('file') +fs = fsspec.filesystem( + 'ftp', host='ftp.ebi.ac.uk', + cache_storage='/tmp/files/') +root_dir = '/pub/databases/IDR/idr0093-mueller-perturbation/' - +# /nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/' +# /nfs/ftp/public/databases/IDR/ def train(num_gpus_per_node=1,num_nodes=1): @@ -96,7 +109,7 @@ def train(num_gpus_per_node=1,num_nodes=1): # ) transform = instantiate(config.ATransform()) - dataset = GlobDataset(root_dir+'**/*.tif*',transform) + dataset = GlobDataset(root_dir+'**/*.tif*',transform,fs=fs) # dataset = RandomDataset(32, 64) dataloader = config.DataLoader(dataset=dataset,num_workers=os.cpu_count(),collate_fn=collate_fn) @@ -155,4 +168,4 @@ def main(): job = executor.submit(train, NUM_GPUS_PER_NODE, NUM_NODES) if __name__ == "__main__": - main() \ No newline at end of file + train() \ No newline at end of file From 5b4cbc966f81748a49ea1bed3d7ea78aaa6e578d Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Fri, 20 Sep 2024 10:14:39 +0100 Subject: [PATCH 15/16] [good] uploading training script that works? --- bioimage_embed/inference.py | 0 bioimage_embed/models/legacy/__init__.py | 0 bioimage_embed/models/tests/legacy.py | 0 bioimage_embed/models/tests/mae.py | 0 bioimage_embed/models/vit/tests/mae.py | 0 bioimage_embed/tests/test_training.py | 0 scripts/idr/study.submitit.py | 106 +++++++++++++++-------- 7 files changed, 72 insertions(+), 34 deletions(-) delete mode 100644 bioimage_embed/inference.py delete mode 100644 bioimage_embed/models/legacy/__init__.py delete mode 100644 bioimage_embed/models/tests/legacy.py delete mode 100644 bioimage_embed/models/tests/mae.py delete mode 100644 bioimage_embed/models/vit/tests/mae.py delete mode 100644 bioimage_embed/tests/test_training.py diff --git a/bioimage_embed/inference.py b/bioimage_embed/inference.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimage_embed/models/legacy/__init__.py b/bioimage_embed/models/legacy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimage_embed/models/tests/legacy.py b/bioimage_embed/models/tests/legacy.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimage_embed/models/tests/mae.py b/bioimage_embed/models/tests/mae.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimage_embed/models/vit/tests/mae.py b/bioimage_embed/models/vit/tests/mae.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimage_embed/tests/test_training.py b/bioimage_embed/tests/test_training.py deleted file mode 100644 index e69de29b..00000000 diff --git a/scripts/idr/study.submitit.py b/scripts/idr/study.submitit.py index 22ac7d93..af383f05 100644 --- a/scripts/idr/study.submitit.py +++ b/scripts/idr/study.submitit.py @@ -21,9 +21,19 @@ import submitit import os import fsspec +import logging +import click +from pytorch_lightning.callbacks import ModelCheckpoint # Added import +import random +from tqdm import tqdm + + +torch.manual_seed(42) +np.random.seed(42) + NUM_GPUS_PER_NODE = 1 NUM_NODES = 1 - +CPUS_PER_TASK = 8 params = { "model": "resnet50_vqvae", @@ -45,20 +55,29 @@ @memory.cache def get_file_list(glob_str,fs): return fs.glob(glob_str) - # return fs.open(glob_str,filecache={'cache_storage':'tmp/idr'}) - # return fsspec.open_files(glob_str, recursive=True) - # return glob.glob(os.path.join(glob_str), recursive=True) + +@memory.cache +def get_clean_file_list(glob_str, fs): + filelist = get_file_list(glob_str, fs) + # Use filter with tqdm + valid_files = list(filter(lambda x: check_image(fs,x), tqdm(filelist, desc="Validating images"))) + return valid_files + def collate_fn(batch): # Filter out None values batch = list(filter(lambda x: x[0] is not None, batch)) + if len(batch) == 0: + logging.warning("Batch is empty") + return None return torch.utils.data.dataloader.default_collate(batch) class GlobDataset(Dataset): def __init__(self, glob_str,transform=None,fs=fsspec.filesystem('file')): print("Getting file list, this may take a while") - self.file_list = get_file_list(glob_str,fs) - print("Done getting file list") + self.file_list = np.random.permutation(get_clean_file_list(glob_str, fs)).tolist() + + print(f"Done getting file list: {len(self.file_list)}") self.transform = transform def __len__(self): @@ -67,33 +86,44 @@ def __len__(self): def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() - img_name = self.file_list[idx] - obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'}) try: - with obj as f: - image = Image.open(f) - # image = Image.open(img_name) + image = read_image(fs,img_name) + if self.transform: + image = self.transform(image=image)["image"] + return image,0 except: - return None,None - # breakpoint() - image = np.array(image) - if self.transform: - # t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)]) - t = A.Compose([A.ToRGB(),self.transform]) - image = t(image=image) + logging.info(f"Could not open {img_name}") + breakpoint() + return None, 0 - # breakpoint() - return image["image"], 0 + +def check_image(fs,img_name): + obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'}) + with obj as f: + try: + image = Image.open(f).verify() + return True + except: + return False + +def read_image(fs,img_name): + obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'}) + with obj as f: + image = Image.open(f) + image = np.array(image) + return image + root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/' fs = fsspec.filesystem('file') -fs = fsspec.filesystem( - 'ftp', host='ftp.ebi.ac.uk', - cache_storage='/tmp/files/') +# fs = fsspec.filesystem( +# 'ftp', host='ftp.ebi.ac.uk', +# cache_storage='/tmp/files/') root_dir = '/pub/databases/IDR/idr0093-mueller-perturbation/' - +root_dir = "/hps/nobackup/uhlmann/ctr26/idr/nfs/ftp/public/databases/IDR/" +root_dir += "idr0093-mueller-perturbation/" # /nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/' # /nfs/ftp/public/databases/IDR/ @@ -109,12 +139,13 @@ def train(num_gpus_per_node=1,num_nodes=1): # ) transform = instantiate(config.ATransform()) + transform = A.Compose([A.ToRGB(),transform]) dataset = GlobDataset(root_dir+'**/*.tif*',transform,fs=fs) # dataset = RandomDataset(32, 64) - dataloader = config.DataLoader(dataset=dataset,num_workers=os.cpu_count(),collate_fn=collate_fn) + dataloader = config.DataLoader(dataset=dataset,num_workers=CPUS_PER_TASK-1,collate_fn=collate_fn,shuffle=True,batch_size=params["batch_size"]) - assert instantiate(dataloader,batch_size=1) - assert dataset[0] + # assert instantiate(dataloader,batch_size=1) + # assert dataset[0] model = config.Model(input_dim=input_dim) @@ -123,13 +154,17 @@ def train(num_gpus_per_node=1,num_nodes=1): model=model ) wandb = pl_loggers.WandbLogger(project="idr", name="0093",log_model="all") + + trainer = config.Trainer( accelerator="auto", devices=num_gpus_per_node, num_nodes=num_nodes, strategy="ddp", - callbacks=[], + enable_checkpointing=True, + callbacks=None, # plugin=[], + logger=[wandb], ) @@ -149,7 +184,10 @@ def train(num_gpus_per_node=1,num_nodes=1): bie.train() wandb.finish() -def main(): +@click.command() +@click.option("--gpus", default=1) +@click.option("--nodes", default=1) +def main( gpus, nodes): logdir = "lightning_slurm/" os.makedirs(logdir, exist_ok=True) @@ -159,13 +197,13 @@ def main(): mem_gb=2 * 32 * 4, # 2GB per CPU, 32 CPUs per task, 4 tasks per node timeout_min=1440*2, # 48 hours # slurm_partition="your_partition_name", # Replace with your partition name - gpus_per_node=NUM_GPUS_PER_NODE, + gpus_per_node=gpus, tasks_per_node=1, - cpus_per_task=8, - nodes=NUM_NODES, + cpus_per_task=CPUS_PER_TASK, + nodes=nodes, slurm_constraint="a100", ) - job = executor.submit(train, NUM_GPUS_PER_NODE, NUM_NODES) + job = executor.submit(train, gpus, nodes) if __name__ == "__main__": - train() \ No newline at end of file + main() From f270dcec96c637b81a74484b3ff22cb15cecfa6d Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Fri, 27 Sep 2024 10:41:06 +0100 Subject: [PATCH 16/16] [fix] stray , broke everything --- bioimage_embed/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimage_embed/config.py b/bioimage_embed/config.py index 5ea859bb..e78f4977 100644 --- a/bioimage_embed/config.py +++ b/bioimage_embed/config.py @@ -197,7 +197,7 @@ class Trainer: accumulate_grad_batches: int = 16 min_epochs: int = 1 max_epochs: int = II("recipe.max_epochs") - num_nodes: int = 1, + num_nodes: int = 1 log_every_n_steps: int = 1 # This is not a clean implementation but I am not sure how to do it better callbacks: Any = Field(