Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: save model as onnx option #137

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion configs/ops/train/callbacks/callbacks_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ callback_dict:
save_model:
# Stores the model after each epoch
_target_: niceml.dlframeworks.keras.callbacks.callback_factories.ModelCallbackFactory
model_subfolder: models/model-id_{short_id}-ep{epoch:03d}.hdf5
# model_subfolder may include model name or declare them separately
model_subfolder: models/model-id_{short_id}-ep{epoch:03d}
# model_subfolder: models
# model_filename: model-id_{short_id}-ep{epoch:03d}
60 changes: 54 additions & 6 deletions niceml/dlframeworks/keras/callbacks/callback_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from abc import ABC, abstractmethod
from os.path import join
from pathlib import Path
from typing import Any, List
from typing import Any, List, Optional

from niceml.dlframeworks.keras.callbacks.csvlogger import CSVLogger
from niceml.dlframeworks.keras.callbacks.modelcheckpoint import ModelCheckpoint
from niceml.dlframeworks.keras.callbacks.modelcheckpoint import (
ModelCheckpoint,
)
from niceml.experiments.experimentcontext import ExperimentContext
from niceml.utilities.factoryutils import subs_path_and_create_folder
from niceml.utilities.fsspec.locationutils import join_location_w_path
Expand All @@ -25,27 +27,58 @@ class InitCallbackFactory(CallbackFactory):
any experiment specific parameters"""

def __init__(self, callback: Any):
"""Initializes the InitCallbackFactory object with given callback"""
self.callback = callback

def create_callback(self, exp_context: ExperimentContext):
"""Returns the callback of the factory"""
return self.callback


# pylint: disable=too-few-public-methods
class ModelCallbackFactory(CallbackFactory):
"""Creates the model checkpoint callback"""

def __init__(self, model_subfolder: str, **kwargs):
def __init__(
self, model_subfolder: str, model_filename: Optional[str] = None, **kwargs
):
"""
Initializes the ModelCallbackFactory object, which creates
ModelCheckpoint callbacks. If model_filename is not given, it will
be inferred from the model_subfolder. Fileextensions will be ignored.

Args:
model_subfolder: name of the subfolder to save the model in
model_filename: filename of the model file without the file extension. If
model_filename is not given, it will be inferred from the model_subfolder
**kwargs: additional keyword arguments for ModelCheckpoint initialization
"""
self.kwargs = kwargs
self.model_subfolder = model_subfolder
self.model_subfolder = (
model_subfolder if model_filename else str(Path(model_subfolder).parent)
)
self.model_filename = model_filename or str(Path(model_subfolder).stem)

def create_callback(self, exp_context: ExperimentContext):
def create_callback(self, exp_context: ExperimentContext) -> ModelCheckpoint:
"""
Creates the model checkpoint callback based on the given experiment context.
The ModelCheckpoint callback saves the model of the experiment.

Args:
exp_context: experiment to create the model callback for

Returns:
ModelCheckpoint callback
"""
target_model_fs = join_location_w_path(
exp_context.fs_config, self.model_subfolder
)
file_formats = {"run_id": exp_context.run_id, "short_id": exp_context.short_id}
return ModelCheckpoint(
target_model_fs, file_formats=file_formats, **self.kwargs
target_model_fs,
file_formats=file_formats,
model_filename=self.model_filename,
**self.kwargs
)


Expand All @@ -54,16 +87,31 @@ class LoggingOutputCallbackFactory(CallbackFactory):
"""Creates a callback that logs the metrics to a csv file"""

def __init__(self, filename: str = "train_logs.csv"):
"""
Initializes a LoggingOutputCallbackFactory.

Args:
filename: name of the file to save the logging output to
"""
self.filename = filename

def create_callback(self, exp_context: ExperimentContext):
"""
Creates the CSVLogger callback for the given experiment context.
Args:
exp_context: experiment context to create the callback for

Returns:
CSVLogger callback
"""
return CSVLogger(experiment_context=exp_context, filename=self.filename)


class CamCallbackFactory(CallbackFactory): # pylint: disable=too-few-public-methods
"""Callback factory for a cam callback"""

def __init__(self, images: List[str]):
"""Initializes a CamCallbackFactory"""
self.images = images

def create_callback(self, exp_context: ExperimentContext):
Expand Down
51 changes: 46 additions & 5 deletions niceml/dlframeworks/keras/callbacks/modelcheckpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
from tempfile import TemporaryDirectory
from typing import Optional, Union

import onnx
import tf2onnx
from fsspec import AbstractFileSystem
from keras.callbacks import ModelCheckpoint as ModelCheckpointKeras
from keras.src.utils import io_utils, tf_utils

from niceml.utilities.fsspec.locationutils import LocationConfig, open_location
from niceml.utilities.fsspec.locationutils import (
LocationConfig,
open_location,
)


class ModelCheckpoint(ModelCheckpointKeras):
Expand All @@ -19,12 +24,26 @@ class ModelCheckpoint(ModelCheckpointKeras):
def __init__(
self,
output_location: Union[dict, LocationConfig],
model_filename: Optional[str] = None,
file_formats: Optional[dict] = None,
save_as_onnx: bool = False,
**kwargs,
):
"""
Initializes a ModelCheckpoint to save the model files

Args:
output_location: location to save the model to
model_filename: filename of the model file without the file extension
file_formats: dictionary of parameters to complete the model file name
save_as_onnx: whether to save the model also in onnx file format
**kwargs: additional keyword arguments for initialization
"""
super().__init__("", **kwargs)
self.output_location = output_location
self.model_filename = model_filename or "model"
self.file_formats = file_formats or {}
self.save_as_onnx = save_as_onnx

def _should_save(self, epoch, logs) -> bool:
"""Determines whether the model should be saved."""
Expand Down Expand Up @@ -76,12 +95,13 @@ def _save_model(self, epoch, batch, logs):
if self._should_save(epoch, logs):
target_fs: AbstractFileSystem
with open_location(self.output_location) as (target_fs, target_path):
target_fs.makedirs(dirname(target_path), exist_ok=True)
target_path = target_path.format(
target_fs.makedirs(target_path, exist_ok=True)
model_path = join(target_path, f"{self.model_filename}.hdf5")
model_path = model_path.format(
epoch=epoch + 1, **self.file_formats, **logs
)
with target_fs.open(
target_path, "wb"
model_path, "wb"
) as model_file, TemporaryDirectory() as temp_dir:
tmp_path = join(temp_dir, "model.h5")
if self.save_weights_only:
Expand All @@ -100,4 +120,25 @@ def _save_model(self, epoch, batch, logs):
)
with open(tmp_path, "rb") as tmp_model_file:
model_file.write(tmp_model_file.read())
logging.info("Saved model to %s", target_path)
logging.info("Saved model to %s", model_path)
if self.save_as_onnx:
self._save_onnx_model(epoch, logs)

def _save_onnx_model(self, epoch, logs):
"""Saves the model in onnx format.

Args:
epoch: the epoch this iteration is in.
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
"""
logs = logs or tf_utils.sync_to_numpy_or_python_type(logs)

target_fs: AbstractFileSystem
with open_location(self.output_location) as (target_fs, target_path):
target_fs.makedirs(dirname(target_path), exist_ok=True)
model_path = join(target_path, f"{self.model_filename}.onnx")
model_path = model_path.format(epoch=epoch + 1, **self.file_formats, **logs)
with target_fs.open(model_path, "wb"):
onnx_model, _ = tf2onnx.convert.from_keras(self.model)
onnx.save(onnx_model, model_path)
logging.info("Saved model to %s", model_path)
Loading
Loading