Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpoint artifact path prefix to MLflow logger #20538

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source-pytorch/visualize/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,37 @@ Track and Visualize Experiments

</div>
</div>

.. _mlflow_logger:

MLflow Logger
-------------

The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.

Example usage:

.. code-block:: python

import lightning as L
from lightning.pytorch.loggers import MLFlowLogger

mlf_logger = MLFlowLogger(
experiment_name="lightning_logs",
tracking_uri="file:./ml-runs",
checkpoint_artifact_path_prefix="my_prefix"
)
trainer = L.Trainer(logger=mlf_logger)

# Your LightningModule definition
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# example
self.logger.experiment.whatever_ml_flow_supports(...)

def any_lightning_module_function_or_hook(self):
self.logger.experiment.whatever_ml_flow_supports(...)

# Train your model
model = LitModel()
trainer.fit(model)
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a new `checkpoint_artifact_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored.
benglewis marked this conversation as resolved.
Show resolved Hide resolved

## [2.5.0] - 2024-12-19

### Added
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
* if ``log_model == False`` (default), no checkpoint is logged.

checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
prefix: A string to put at the beginning of metric keys.
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
default.
Expand All @@ -121,6 +121,7 @@ def __init__(
tags: Optional[dict[str, Any]] = None,
save_dir: Optional[str] = "./mlruns",
log_model: Literal[True, False, "all"] = False,
checkpoint_path_prefix: str = "",
prefix: str = "",
artifact_location: Optional[str] = None,
run_id: Optional[str] = None,
Expand All @@ -147,6 +148,7 @@ def __init__(
self._artifact_location = artifact_location
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
self._initialized = False
self._checkpoint_path_prefix = checkpoint_path_prefix

from mlflow.tracking import MlflowClient

Expand Down Expand Up @@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]

# Artifact path on mlflow
artifact_path = Path(p).stem
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem

# Log the checkpoint
self.experiment.log_artifact(self._run_id, p, artifact_path)
Expand Down
30 changes: 30 additions & 0 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,33 @@ def test_set_tracking_uri(mlflow_mock):
mlflow_mock.set_tracking_uri.assert_not_called()
_ = logger.experiment
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")


@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path):
"""Test that the logger creates the folders and files in the right place with a prefix."""
client = mlflow_mock.tracking.MlflowClient

# Get model, logger, trainer and train
model = BoringModel()
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix")
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")

trainer = Trainer(
default_root_dir=tmp_path,
logger=logger,
max_epochs=2,
limit_train_batches=3,
limit_val_batches=3,
)
trainer.fit(model)

# Checkpoint log
assert client.return_value.log_artifact.call_count == 2
# Metadata and aliases log
assert client.return_value.log_artifacts.call_count == 2

# Check that the prefix is used in the artifact path
for call in client.return_value.log_artifact.call_args_list:
args, _ = call
assert str(args[2]).startswith("my_prefix")
Loading