Skip to content

Commit

Permalink
[good] uploading training script that works?
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 27, 2024
1 parent 98769b8 commit 3b3b473
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 34 deletions.
Empty file removed bioimage_embed/inference.py
Empty file.
Empty file.
Empty file.
Empty file.
Empty file removed bioimage_embed/models/tests/mae.py
Empty file.
Empty file.
Empty file removed bioimage_embed/tests/__init__.py
Empty file.
Empty file.
106 changes: 72 additions & 34 deletions scripts/idr/study.submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@
import submitit
import os
import fsspec
import logging
import click
from pytorch_lightning.callbacks import ModelCheckpoint # Added import
import random
from tqdm import tqdm


torch.manual_seed(42)
np.random.seed(42)

NUM_GPUS_PER_NODE = 1
NUM_NODES = 1

CPUS_PER_TASK = 8

params = {
"model": "resnet50_vqvae",
Expand All @@ -45,20 +55,29 @@
@memory.cache
def get_file_list(glob_str,fs):
return fs.glob(glob_str)
# return fs.open(glob_str,filecache={'cache_storage':'tmp/idr'})
# return fsspec.open_files(glob_str, recursive=True)
# return glob.glob(os.path.join(glob_str), recursive=True)

@memory.cache
def get_clean_file_list(glob_str, fs):
filelist = get_file_list(glob_str, fs)
# Use filter with tqdm
valid_files = list(filter(lambda x: check_image(fs,x), tqdm(filelist, desc="Validating images")))
return valid_files


def collate_fn(batch):
# Filter out None values
batch = list(filter(lambda x: x[0] is not None, batch))
if len(batch) == 0:
logging.warning("Batch is empty")
return None
return torch.utils.data.dataloader.default_collate(batch)

class GlobDataset(Dataset):
def __init__(self, glob_str,transform=None,fs=fsspec.filesystem('file')):
print("Getting file list, this may take a while")
self.file_list = get_file_list(glob_str,fs)
print("Done getting file list")
self.file_list = np.random.permutation(get_clean_file_list(glob_str, fs)).tolist()

print(f"Done getting file list: {len(self.file_list)}")
self.transform = transform

def __len__(self):
Expand All @@ -67,33 +86,44 @@ def __len__(self):
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = self.file_list[idx]
obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'})
try:
with obj as f:
image = Image.open(f)
# image = Image.open(img_name)
image = read_image(fs,img_name)
if self.transform:
image = self.transform(image=image)["image"]
return image,0
except:
return None,None
# breakpoint()
image = np.array(image)
if self.transform:
# t = A.Compose([A.ToRGB(),transform, A.RandomCrop(224,224)])
t = A.Compose([A.ToRGB(),self.transform])
image = t(image=image)
logging.info(f"Could not open {img_name}")
breakpoint()
return None, 0

# breakpoint()

return image["image"], 0

def check_image(fs,img_name):
obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'})
with obj as f:
try:
image = Image.open(f).verify()
return True
except:
return False

def read_image(fs,img_name):
obj = fs.open(img_name,filecache={'cache_storage':'tmp/idr'})
with obj as f:
image = Image.open(f)
image = np.array(image)
return image


root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/'
fs = fsspec.filesystem('file')
fs = fsspec.filesystem(
'ftp', host='ftp.ebi.ac.uk',
cache_storage='/tmp/files/')
# fs = fsspec.filesystem(
# 'ftp', host='ftp.ebi.ac.uk',
# cache_storage='/tmp/files/')
root_dir = '/pub/databases/IDR/idr0093-mueller-perturbation/'

root_dir = "/hps/nobackup/uhlmann/ctr26/idr/nfs/ftp/public/databases/IDR/"
root_dir += "idr0093-mueller-perturbation/"
# /nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/'
# /nfs/ftp/public/databases/IDR/

Expand All @@ -109,12 +139,13 @@ def train(num_gpus_per_node=1,num_nodes=1):
# )

transform = instantiate(config.ATransform())
transform = A.Compose([A.ToRGB(),transform])
dataset = GlobDataset(root_dir+'**/*.tif*',transform,fs=fs)
# dataset = RandomDataset(32, 64)
dataloader = config.DataLoader(dataset=dataset,num_workers=os.cpu_count(),collate_fn=collate_fn)
dataloader = config.DataLoader(dataset=dataset,num_workers=CPUS_PER_TASK-1,collate_fn=collate_fn,shuffle=True,batch_size=params["batch_size"])

assert instantiate(dataloader,batch_size=1)
assert dataset[0]
# assert instantiate(dataloader,batch_size=1)
# assert dataset[0]

model = config.Model(input_dim=input_dim)

Expand All @@ -123,13 +154,17 @@ def train(num_gpus_per_node=1,num_nodes=1):
model=model
)
wandb = pl_loggers.WandbLogger(project="idr", name="0093",log_model="all")


trainer = config.Trainer(
accelerator="auto",
devices=num_gpus_per_node,
num_nodes=num_nodes,
strategy="ddp",
callbacks=[],
enable_checkpointing=True,
callbacks=None,
# plugin=[],

logger=[wandb],
)

Expand All @@ -149,7 +184,10 @@ def train(num_gpus_per_node=1,num_nodes=1):
bie.train()
wandb.finish()

def main():
@click.command()
@click.option("--gpus", default=1)
@click.option("--nodes", default=1)
def main( gpus, nodes):
logdir = "lightning_slurm/"
os.makedirs(logdir, exist_ok=True)

Expand All @@ -159,13 +197,13 @@ def main():
mem_gb=2 * 32 * 4, # 2GB per CPU, 32 CPUs per task, 4 tasks per node
timeout_min=1440*2, # 48 hours
# slurm_partition="your_partition_name", # Replace with your partition name
gpus_per_node=NUM_GPUS_PER_NODE,
gpus_per_node=gpus,
tasks_per_node=1,
cpus_per_task=8,
nodes=NUM_NODES,
cpus_per_task=CPUS_PER_TASK,
nodes=nodes,
slurm_constraint="a100",
)
job = executor.submit(train, NUM_GPUS_PER_NODE, NUM_NODES)
job = executor.submit(train, gpus, nodes)

if __name__ == "__main__":
train()
main()

0 comments on commit 3b3b473

Please sign in to comment.