Skip to content

Commit

Permalink
[feat] not final script, cpu bound, idr93 only, auto devices now, pre…
Browse files Browse the repository at this point in the history
…cision16
  • Loading branch information
ctr26 committed Oct 1, 2024
1 parent d30ec91 commit 2266a5a
Showing 1 changed file with 51 additions and 14 deletions.
65 changes: 51 additions & 14 deletions scripts/idr/study.submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

NUM_GPUS_PER_NODE = 1
NUM_NODES = 1
CPUS_PER_TASK = 8
CPUS_PER_TASK = 48

params = {
"model": "resnet50_vqvae",
Expand Down Expand Up @@ -116,19 +116,20 @@ def read_image(fs,img_name):
return image


root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/'
# root_dir = '/nfs/research/uhlmann/ctr26/idr/idr0093-mueller-perturbation/'
fs = fsspec.filesystem('file')
# fs = fsspec.filesystem(
# 'ftp', host='ftp.ebi.ac.uk',
# cache_storage='/tmp/files/')
root_dir = '/pub/databases/IDR/idr0093-mueller-perturbation/'
root_dir = "/hps/nobackup/uhlmann/ctr26/idr/nfs/ftp/public/databases/IDR/"
root_dir += "idr0093-mueller-perturbation/"
# root_dir = '/pub/databases/IDR/idr0093-mueller-perturbation/'

# root_dir += "idr0093-mueller-perturbation/"
# /nfs/ftp/public/databases/IDR/idr0093-mueller-perturbation/'
# /nfs/ftp/public/databases/IDR/

def train(num_gpus_per_node=1,num_nodes=1):
os.environ["WANDB_DATA_DIR"] = f"{os.getcwd()}/wandb"

def train(name, num_gpus_per_node=1,num_nodes=1):
print("training")
input_dim = [3, 224, 224]

Expand All @@ -140,9 +141,22 @@ def train(num_gpus_per_node=1,num_nodes=1):

transform = instantiate(config.ATransform())
transform = A.Compose([A.ToRGB(),transform])
root_dir = "/hps/nobackup/uhlmann/ctr26/idr/nfs/ftp/public/databases/IDR/"
root_dir += "idr0093-mueller-perturbation/"

# root_dir += f"idr00032-"
# root_dir += "idr0093-mueller-perturbation/"
# root_dir += f"{name}*/"
dataset = GlobDataset(root_dir+'**/*.tif*',transform,fs=fs)
# dataset = RandomDataset(32, 64)
dataloader = config.DataLoader(dataset=dataset,num_workers=CPUS_PER_TASK-1,collate_fn=collate_fn,shuffle=True,batch_size=params["batch_size"])
effective_batch_size = params["batch_size"]*num_nodes*num_gpus_per_node
num_workers = CPUS_PER_TASK
dataloader = config.DataLoader(dataset=dataset,
num_workers=num_workers-1,
collate_fn=collate_fn,
shuffle=True,
batch_size=effective_batch_size,
)

# assert instantiate(dataloader,batch_size=1)
# assert dataset[0]
Expand All @@ -153,19 +167,31 @@ def train(num_gpus_per_node=1,num_nodes=1):
# _target_="bioimage_embed.lightning.torch.AutoEncoderSupervised",
model=model
)
wandb = pl_loggers.WandbLogger(project="idr", name="0093",log_model="all")

wandb = pl_loggers.WandbLogger(project="idr", name=name,log_model=True)
checkpoint_callback = ModelCheckpoint(
monitor='loss/val',
dirpath=f"lightning_logs/{name}",
# filename='best-checkpoint',
save_top_k=1,
save_last=True,
mode='min',
)

trainer = config.Trainer(
accelerator="auto",
devices=num_gpus_per_node,
# devices=num_gpus_per_node,
precision=16,
devices=-1,
num_nodes=num_nodes,
strategy="ddp",
strategy="dp",
enable_checkpointing=True,
# callbacks=[checkpoint_callback],
callbacks=None,
# default_root_dir=f"lightning_logs/{name}",
# plugin=[],

logger=[wandb],
accumulate_grad_batches=16,
max_epochs=params["max_epochs"],
)

cfg = config.Config(
Expand All @@ -183,11 +209,21 @@ def train(num_gpus_per_node=1,num_nodes=1):

bie.train()
wandb.finish()
# best_model = bie.model
# validation = best_model.validate(ckpt_path="best")
# # best_model = bie.model.load_from_checkpoint(
# # checkpoint_callback.best_model_path
# # )
# # best_model.model.push_to_hf_hub(
bie.push_to_hf_hub(
f"bioimagearchive/{params['model']}-{name}"
)

@click.command()
@click.option("--name", default="idr0093")
@click.option("--gpus", default=1)
@click.option("--nodes", default=1)
def main( gpus, nodes):
def main(name, gpus, nodes):
logdir = "lightning_slurm/"
os.makedirs(logdir, exist_ok=True)

Expand All @@ -202,8 +238,9 @@ def main( gpus, nodes):
cpus_per_task=CPUS_PER_TASK,
nodes=nodes,
slurm_constraint="a100",
slurm_additional_parameters={'export': 'ALL'}, # This ensures all environment variables are passed
)
job = executor.submit(train, gpus, nodes)
job = executor.submit(train, name, gpus, nodes)

if __name__ == "__main__":
main()

0 comments on commit 2266a5a

Please sign in to comment.