Skip to content

Commit

Permalink
Return trainer from train_embedding (#1571)
Browse files Browse the repository at this point in the history
* Return trainer object from `train_embedding`
  • Loading branch information
guarin authored Jul 4, 2024
1 parent bcf8ad1 commit e190906
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions lightly/embedding/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train_embedding(
trainer_config: DictConfig,
checkpoint_callback_config: DictConfig,
summary_callback_config: DictConfig,
) -> None:
) -> Trainer:
"""Train the model on the provided dataset.
Args:
Expand All @@ -98,7 +98,7 @@ def train_embedding(
summary_callback_config: ModelSummary callback arguments
Returns:
A trained encoder, ready for embedding datasets.
The PyTorch Lightning Trainer object used for training.
"""
trainer_callbacks: List[Callback] = []
Expand Down Expand Up @@ -128,6 +128,9 @@ def train_embedding(
if checkpoint_cb.best_model_path != "":
self.checkpoint = os.path.join(self.cwd, checkpoint_cb.best_model_path)

# Return trainer to check final training state.
return trainer

def embed(self, *args: Any, **kwargs: Any) -> Any:
"""Must be implemented by classes which inherit from BaseEmbedding."""
raise NotImplementedError()

0 comments on commit e190906

Please sign in to comment.