Skip to content

Commit

Permalink
[feat] pytorch lightning 2.0.0? with auto resume?
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 30, 2024
1 parent 95b22ea commit c8898c5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
7 changes: 6 additions & 1 deletion bioimage_embed/bie.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def train(self, resume: bool = True):
# best_checkpoint_path = chkpt_callbacks.best_model_path

# TODO add tests for checkpointing (properply)
return self._train()
try:
logging.info("Attempting to resume training")
return self._train("last")
except Exception:
logging.info("Forced to start from scratch")
return self._train(None)

def train_resume(self):
return self.train("last")
Expand Down
1 change: 0 additions & 1 deletion bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ class Trainer:
min_epochs: int = 1
max_epochs: int = II("recipe.max_epochs")
log_every_n_steps: int = 1
ckpt_path: str = "last"
# This is not a clean implementation but I am not sure how to do it better
callbacks: List[Any] = Field(
default_factory=lambda: list(vars(Callbacks()).values()), frozen=True
Expand Down
2 changes: 1 addition & 1 deletion bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def configure_optimizers(self):
optimizer, lr_scheduler = self.timm_optimizers(self.model)
return self.timm_to_lightning(optimizer, lr_scheduler)

def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(epoch=self.current_epoch, metric=metric)

def log_wandb(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ python = "^3.9,<3.11"
umap-learn = { extras = ["plot"], version = "^0.5.3" }
jsonargparse = { extras = ["signatures"], version = "^4.21.0" }
llvmlite = "0.40.1"
pytorch-lightning = "1.*.*"
pytorch-lightning = "2.0.0"
torchinfo = "^1.8.0"
matplotlib = "^3.7.2"
pyro-ppl = "^1.8.6"
Expand Down

0 comments on commit c8898c5

Please sign in to comment.