Skip to content

Commit

Permalink
[Bug] Fix MLFlow import, and setup function that caused tests to fail (
Browse files Browse the repository at this point in the history
…#510)

Co-authored-by: Aleksander Wennersteen <[email protected]>
  • Loading branch information
awennersteen and Aleksander Wennersteen authored Jul 23, 2024
1 parent d06d132 commit d47e10e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
8 changes: 4 additions & 4 deletions qadence/ml_tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ class FeatureMapConfig:

num_repeats: int | dict[str, int] = 0
"""
Number of feature map layers repeated in the data reuploadig step.
Number of feature map layers repeated in the data reuploading step.
If all are to be repeated the same number of times, then can give a single
`int`. For different number of repeatitions for each feature, provide a dict
`int`. For different number of repetitions for each feature, provide a dict
of (str, int) where the key is the name of the variable and the value is the
number of repeatitions for that feature.
number of repetitions for that feature.
This amounts to the number of additional reuploads. So if `num_repeats` is N,
the data gets uploaded N+1 times. Defaults to no repeatition.
the data gets uploaded N+1 times. Defaults to no repetition.
"""

operation: Callable[[Parameter | Basic], AnalogBlock] | Type[RX] | None = None
Expand Down
16 changes: 12 additions & 4 deletions qadence/ml_tools/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Union

from matplotlib.figure import Figure
from mlflow.models import infer_signature
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -82,6 +81,7 @@ def plot_mlflow(
def log_model_mlflow(
writer: Any, model: Module, dataloader: DataLoader | DictDataLoader | None
) -> None:
signature = None
if dataloader is not None:
xs: InputData
xs, *_ = next(iter(dataloader))
Expand All @@ -94,9 +94,17 @@ def log_model_mlflow(
xs[key] = val.numpy()
for key, val in preds.items():
preds[key] = val.detach.numpy()
signature = infer_signature(xs, preds)
else:
signature = None

try:
from mlflow.models import infer_signature

signature = infer_signature(xs, preds)
except ImportError:
logger.warning(
"An MLFlow specific function has been called but MLFlow failed to import."
"Please install MLFlow or adjust your code."
)

writer.pytorch.log_model(model, artifact_path="model", signature=signature)


Expand Down
19 changes: 9 additions & 10 deletions tests/ml_tools/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def dataloader(batch_size: int = 25) -> DataLoader:
return to_dataloader(x, y, batch_size=batch_size, infinite=True)


def setup(model: Module) -> tuple[Callable, Optimizer]:
def setup_model(model: Module) -> tuple[Callable, Optimizer]:
cnt = count()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -83,7 +83,7 @@ def clean_artifacts(run: Run) -> None:
def test_hyperparams_logging_mlflow(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None:
model = BasicQuantumModel

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

hyperparams = {"max_iter": int(10), "lr": 0.1}

Expand Down Expand Up @@ -113,7 +113,7 @@ def test_hyperparams_logging_mlflow(BasicQuantumModel: QuantumModel, tmp_path: P
def test_hyperparams_logging_tensorboard(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None:
model = BasicQuantumModel

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

hyperparams = {"max_iter": int(10), "lr": 0.1}

Expand All @@ -131,8 +131,7 @@ def test_hyperparams_logging_tensorboard(BasicQuantumModel: QuantumModel, tmp_pa

def test_model_logging_mlflow_basicQM(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None:
model = BasicQuantumModel

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

config = TrainConfig(
folder=tmp_path,
Expand All @@ -154,7 +153,7 @@ def test_model_logging_mlflow_basicQNN(BasicQNN: QNN, tmp_path: Path) -> None:
data = dataloader()
model = BasicQNN

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

config = TrainConfig(
folder=tmp_path,
Expand All @@ -176,7 +175,7 @@ def test_model_logging_mlflow_basicAdjQNN(BasicAdjointQNN: QNN, tmp_path: Path)
data = dataloader()
model = BasicAdjointQNN

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

config = TrainConfig(
folder=tmp_path,
Expand All @@ -199,7 +198,7 @@ def test_model_logging_tensorboard(
) -> None:
model = BasicQuantumModel

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

config = TrainConfig(
folder=tmp_path,
Expand All @@ -219,7 +218,7 @@ def test_plotting_mlflow(BasicQNN: QNN, tmp_path: Path) -> None:
data = dataloader()
model = BasicQNN

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

def plot_model(model: QuantumModel, iteration: int) -> tuple[str, Figure]:
descr = f"model_prediction_epoch_{iteration}.png"
Expand Down Expand Up @@ -267,7 +266,7 @@ def test_plotting_tensorboard(BasicQNN: QNN, tmp_path: Path) -> None:
data = dataloader()
model = BasicQNN

loss_fn, optimizer = setup(model)
loss_fn, optimizer = setup_model(model)

def plot_model(model: QuantumModel, iteration: int) -> tuple[str, Figure]:
descr = f"model_prediction_epoch_{iteration}.png"
Expand Down

0 comments on commit d47e10e

Please sign in to comment.