Skip to content

Commit

Permalink
[good] uploading training script that works?
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 20, 2024
1 parent d0ff883 commit 5458c62
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 51 deletions.
Empty file removed .env
Empty file.
34 changes: 17 additions & 17 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ATransform:
# VisionWrapper is a helper class for applying albumentations pipelines for image augmentations in autoencoding


@dataclass
@dataclass(config=dict(extra="allow"))
class Transform:
_target_: Any = "bioimage_embed.augmentations.VisionWrapper"
_convert_: str = "object"
Expand All @@ -81,25 +81,25 @@ class Transform:
)


@dataclass
@dataclass(config=dict(extra="allow"))
class Dataset:
# _target_: str = "torch.utils.data.Dataset"
transform: Transform = Field(default_factory=Transform)


@dataclass
@dataclass(config=dict(extra="allow"))
class ImageFolderDataset(Dataset):
_target_: str = "torchvision.datasets.ImageFolder"
# transform: Transform = Field(default_factory=Transform)
root: str = II("recipe.data")


@dataclass
@dataclass(config=dict(extra="allow"))
class NdDataset(ImageFolderDataset):
transform: Transform = Field(default_factory=Transform)


@dataclass
@dataclass(config=dict(extra="allow"))
class TiffDataset(NdDataset):
_target_: str = "bioimage_embed.datasets.TiffDataset"

Expand All @@ -108,15 +108,15 @@ class NgffDataset(NdDataset):
_target_: str = "bioimage_embed.datasets.NgffDataset"


@dataclass
@dataclass(config=dict(extra="allow"))
class DataLoader:
_target_: Any = "bioimage_embed.lightning.dataloader.DataModule"
dataset: Any = Field(default_factory=ImageFolderDataset)
num_workers: int = 1
batch_size: int = II("recipe.batch_size")


@dataclass
@dataclass(config=dict(extra="allow"))
class Model:
_target_: Any = "bioimage_embed.models.create_model"
model: str = II("recipe.model")
Expand All @@ -125,20 +125,20 @@ class Model:
pretrained: bool = True


@dataclass
@dataclass(config=dict(extra="allow"))
class Callback:
pass


@dataclass
@dataclass(config=dict(extra="allow"))
class EarlyStopping(Callback):
_target_: Any = "pytorch_lightning.callbacks.EarlyStopping"
monitor: str = "loss/val"
mode: str = "min"
patience: int = 3


@dataclass
@dataclass(config=dict(extra="allow"))
class ModelCheckpoint(Callback):
_target_: Any = "pytorch_lightning.callbacks.ModelCheckpoint"
save_last = True
Expand All @@ -149,37 +149,37 @@ class ModelCheckpoint(Callback):
dirpath: str = f"{II('paths.model')}/{II('uuid')}"


@dataclass
@dataclass(config=dict(extra="allow"))
class LightningModel:
_target_: Any = "bioimage_embed.lightning.torch.AutoEncoderSupervised"
# This should be pythae base autoencoder?
model: Any = Field(default_factory=Model)
args: Any = Field(default_factory=lambda: II("recipe"))


@dataclass
@dataclass(config=dict(extra="allow"))
class Callbacks:
# _target_: str = "collections.OrderedDict"
model_checkpoint: Any = Field(default_factory=ModelCheckpoint)
# early_stopping: Any = Field(default_factory=EarlyStopping)



@dataclass
@dataclass(config=dict(extra="allow"))
class WandbLogger:
_target_: Any = "pytorch_lightning.loggers.WandbLogger"
project: str = ""
name: str = ""

@dataclass
@dataclass(config=dict(extra="allow"))
class TensorboardLogger:
_target_: str = "pytorch_lightning.loggers.TensorboardLogger"

@dataclass
@dataclass(config=dict(extra="allow"))
class Loggers:
tensorboard: Any = Field(default_factory=TensorboardLogger)

@dataclass
@dataclass(config=dict(extra="allow"))
class Trainer:
# class Trainer(pytorch_lightning.Trainer):
_target_: Any = "pytorch_lightning.Trainer"
Expand All @@ -203,7 +203,7 @@ class Trainer:
# TODO add argument caching for checkpointing


@dataclass
@dataclass(config=dict(extra="allow"))
class Paths:
model: str = "models"
logs: str = "logs"
Expand Down
Empty file removed bioimage_embed/inference.py
Empty file.
Empty file.
Empty file.
Empty file.
Empty file removed bioimage_embed/models/tests/mae.py
Empty file.
Empty file.
Empty file removed bioimage_embed/tests/__init__.py
Empty file.
Empty file.
106 changes: 72 additions & 34 deletions scripts/idr/study.submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@
import submitit
import os
import fsspec
import logging
import click
from pytorch_lightning.callbacks import ModelCheckpoint # Added import
import random
from tqdm import tqdm


torch.manual_seed(42)
np.random.seed(42)

NUM_GPUS_PER_NODE = 1
NUM_NODES = 1

CPUS_PER_TASK = 8

params = {
"model": "resnet50_vqvae",
Expand All @@ -45,20 +55,29 @@
@memory.cache
def get_file_list(glob_str,fs):
return fs.glob(glob_str)
# return fs.open(glob_str,filecache={'cache_storage':'tmp/idr'})
# return fsspec.open_files(glob_str, recursive=True)
# return glob.glob(os.path.join(glob_str), recursive=True)

@memory.cache
def get_clean_file_list(glob_str, fs):
filelist = get_file_list(glob_str, fs)
# Use filter with tqdm
valid_files = list(filter(lambda x: check_image(fs,x), tqdm(filelist, desc="Validating images")))
return valid_files


def collate_fn(batch):
# Filter out None values
batch = list(filter(lambda x: x[0] is not None, batch))
if len(batch) == 0:
logging.warning("Batch is empty")
return None
return torch.utils.data.dataloader.default_collate(batch)

class GlobDataset(Dataset):
def __init__(self, glob_str,transform=None,fs=fsspec.filesystem('file')):
print("Getting file list, this may take a while")
self.file_list = get_file_list(glob_str,fs)
print("Done getting file list")
self.file_list = np.random.permutation(get_clean_file_list(glob_str, fs)).tolist()

print(f"Done getting file list: {len(self.file_list)}")
self.transform = transform

def __len__(self):
Expand All @@ -67,33 +86,44 @@ def __len__(self):
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = self.file_list[idx]
obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'})
try:
with obj as f:
image = Image.open(f)
# image = Image.open(img_name)
image = read_image(fs,img_name)
if self.transform:
image = self.transform(image=image)["image"]
return image,0
except:
return None,None
# breakpoint()
image = np.array(image)
if self.transform:
# t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)])
t = A.Compose([A.ToRGB(),self.transform])
image = t(image=image)
logging.info(f"Could not open {img_name}")
breakpoint()
return None, 0

# breakpoint()

return image["image"], 0

def check_image(fs,img_name):
obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'})
with obj as f:
try:
image = Image.open(f).verify()
return True
except:
return False

def read_image(fs,img_name):
obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'})
with obj as f:
image = Image.open(f)
image = np.array(image)
return image


root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/'
fs = fsspec.filesystem('file')
fs = fsspec.filesystem(
'ftp', host='ftp.ebi.ac.uk',
cache_storage='/tmp/files/')
# fs = fsspec.filesystem(
# 'ftp', host='ftp.ebi.ac.uk',
# cache_storage='/tmp/files/')
root_dir = '/pub/databases/IDR/idr0093-mueller-perturbation/'

root_dir = "/hps/nobackup/uhlmann/ctr26/idr/nfs/ftp/public/databases/IDR/"
root_dir += "idr0093-mueller-perturbation/"
# /nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/'
# /nfs/ftp/public/databases/IDR/

Expand All @@ -109,12 +139,13 @@ def train(num_gpus_per_node=1,num_nodes=1):
# )

transform = instantiate(config.ATransform())
transform = A.Compose([A.ToRGB(),transform])
dataset = GlobDataset(root_dir+'**/*.tif*',transform,fs=fs)
# dataset = RandomDataset(32, 64)
dataloader = config.DataLoader(dataset=dataset,num_workers=os.cpu_count(),collate_fn=collate_fn)
dataloader = config.DataLoader(dataset=dataset,num_workers=CPUS_PER_TASK-1,collate_fn=collate_fn,shuffle=True,batch_size=params["batch_size"])

assert instantiate(dataloader,batch_size=1)
assert dataset[0]
# assert instantiate(dataloader,batch_size=1)
# assert dataset[0]

model = config.Model(input_dim=input_dim)

Expand All @@ -123,13 +154,17 @@ def train(num_gpus_per_node=1,num_nodes=1):
model=model
)
wandb = pl_loggers.WandbLogger(project="idr", name="0093",log_model="all")


trainer = config.Trainer(
accelerator="auto",
devices=num_gpus_per_node,
num_nodes=num_nodes,
strategy="ddp",
callbacks=[],
enable_checkpointing=True,
callbacks=None,
# plugin=[],

logger=[wandb],
)

Expand All @@ -149,7 +184,10 @@ def train(num_gpus_per_node=1,num_nodes=1):
bie.train()
wandb.finish()

def main():
@click.command()
@click.option("--gpus", default=1)
@click.option("--nodes", default=1)
def main( gpus, nodes):
logdir = "lightning_slurm/"
os.makedirs(logdir, exist_ok=True)

Expand All @@ -159,13 +197,13 @@ def main():
mem_gb=2 * 32 * 4, # 2GB per CPU, 32 CPUs per task, 4 tasks per node
timeout_min=1440*2, # 48 hours
# slurm_partition="your_partition_name", # Replace with your partition name
gpus_per_node=NUM_GPUS_PER_NODE,
gpus_per_node=gpus,
tasks_per_node=1,
cpus_per_task=8,
nodes=NUM_NODES,
cpus_per_task=CPUS_PER_TASK,
nodes=nodes,
slurm_constraint="a100",
)
job = executor.submit(train, NUM_GPUS_PER_NODE, NUM_NODES)
job = executor.submit(train, gpus, nodes)

if __name__ == "__main__":
train()
main()

0 comments on commit 5458c62

Please sign in to comment.