From 624b872edab880b44b5db6fe5d991b1deae08730 Mon Sep 17 00:00:00 2001 From: Alexander Alexandrov Date: Tue, 9 Jul 2019 18:26:17 +0200 Subject: [PATCH] More shell fixes - Add proper documentation strings. - Fix the printed script name in help output. - Add a `force-static` train parameter that forces the creation of a static predictor from the expected model location. --- src/gluonts/core/log.py | 2 +- src/gluonts/shell/__main__.py | 53 +++++++++++++++++++++---- src/gluonts/shell/sagemaker/__init__.py | 19 +++++---- src/gluonts/shell/sagemaker/path.py | 18 ++++++--- src/gluonts/shell/serve/__init__.py | 5 +-- src/gluonts/shell/train.py | 4 +- 6 files changed, 72 insertions(+), 29 deletions(-) diff --git a/src/gluonts/core/log.py b/src/gluonts/core/log.py index e4a12f0fb7..8363b25fe6 100644 --- a/src/gluonts/core/log.py +++ b/src/gluonts/core/log.py @@ -21,7 +21,7 @@ DEBUG = os.environ.get('DEBUG', 'false').lower() == 'true' logging.basicConfig(level=logging.DEBUG if DEBUG else logging.INFO) -logger = logging.getLogger('SWIST') +logger = logging.getLogger('gluonts') def metric(metric: str, value: Any) -> None: diff --git a/src/gluonts/shell/__main__.py b/src/gluonts/shell/__main__.py index ee45b4943d..979fb42738 100644 --- a/src/gluonts/shell/__main__.py +++ b/src/gluonts/shell/__main__.py @@ -26,7 +26,7 @@ from gluonts.model.predictor import Predictor # Relative imports -from .sagemaker import SageMakerEnv +from .sagemaker import TrainEnv, ServeEnv Forecaster = Type[Union[Estimator, Predictor]] @@ -77,14 +77,39 @@ def cli() -> None: type=click.Path(exists=True), envvar="SAGEMAKER_DATA_PATH", default='/opt/ml', + help='The root path of all folders mounted by the SageMaker runtime.', ) -@click.option("--forecaster", metavar="NAME", envvar="GLUONTS_FORECASTER") -def serve_command(data_path: str, forecaster: Optional[str]) -> None: +@click.option( + "--forecaster", + metavar='NAME', + envvar="GLUONTS_FORECASTER", + help=( + 'An alias or a fully qualified name of a Predictor to use. ' + 'If this value is defined, the inference server runs in the ' + 'so-called dynamic mode, where the predictor is initialized for ' + 'each request using parameters provided in the "configuration" field ' + 'of the JSON request. Otherwise, the server runs in static mode, ' + 'where the predictor is initialized upfront from a serialized model ' + 'located in the {data-path}/model folder.' + ), +) +@click.option( + "--force-static/--no-force-static", + envvar="GLUONTS_FORCE_STATIC", + default=False, + help=( + 'Forces execution in static mode, even in situations where the ' + '"forecaster" option is present.' + ), +) +def serve_command( + data_path: str, forecaster: Optional[str], force_static: bool +) -> None: from gluonts.shell import serve - env = SageMakerEnv(Path(data_path)) + env = ServeEnv(Path(data_path)) - if forecaster is not None: + if not force_static and forecaster is not None: serve.run_inference_server(env, forecaster_type_by_name(forecaster)) else: serve.run_inference_server(env, None) @@ -96,12 +121,24 @@ def serve_command(data_path: str, forecaster: Optional[str]) -> None: type=click.Path(exists=True), envvar="SAGEMAKER_DATA_PATH", default='/opt/ml', + help='The root path of all folders mounted by the SageMaker runtime.', +) +@click.option( + "--forecaster", + type=str, + envvar="GLUONTS_FORECASTER", + help=( + 'An alias or a fully qualified name of a Predictor or Estimator to ' + 'use. If this value is not defined, the command will try to read it' + 'from the hyperparameters dictionary under the "forecaster_name" key. ' + 'If the value denotes a Predictor, training will be skipped and the ' + 'command will only do an evaluation on the provided test dataset.' + ), ) -@click.option("--forecaster", type=str, envvar="GLUONTS_FORECASTER") def train_command(data_path: str, forecaster: Optional[str]) -> None: from gluonts.shell import train - env = SageMakerEnv(Path(data_path)) + env = TrainEnv(Path(data_path)) if forecaster is None: try: @@ -119,4 +156,4 @@ def train_command(data_path: str, forecaster: Optional[str]) -> None: if __name__ == "__main__": - cli() + cli(prog_name=__package__) diff --git a/src/gluonts/shell/sagemaker/__init__.py b/src/gluonts/shell/sagemaker/__init__.py index d31d86a232..b80648d472 100644 --- a/src/gluonts/shell/sagemaker/__init__.py +++ b/src/gluonts/shell/sagemaker/__init__.py @@ -19,7 +19,7 @@ from gluonts.dataset.common import FileDataset, MetaData from .params import parse_sagemaker_parameters -from .path import MLPath +from .path import ServePaths, TrainPaths class DataConfig(BaseModel): @@ -30,22 +30,21 @@ class DataConfig(BaseModel): DATASET_NAMES = 'train', 'test' -class SageMakerEnv: +class TrainEnv: def __init__(self, path: Path = Path("/opt/ml")) -> None: - self.path = _load_path(path) + self.path = TrainPaths(path) self.inputdataconfig = _load_inputdataconfig(self.path) self.channels = _load_channels(self.path, self.inputdataconfig) self.hyperparameters = _load_hyperparameters(self.path, self.channels) self.datasets = _load_datasets(self.hyperparameters, self.channels) -def _load_path(path: Path) -> MLPath: - ml_path = MLPath(path) - ml_path.makedirs() - return ml_path +class ServeEnv: + def __init__(self, path: Path = Path("/opt/ml")) -> None: + self.path = ServePaths(path) -def _load_inputdataconfig(path: MLPath) -> Optional[Dict[str, DataConfig]]: +def _load_inputdataconfig(path: TrainPaths) -> Optional[Dict[str, DataConfig]]: if path.inputdataconfig.exists(): with path.inputdataconfig.open() as json_file: return { @@ -57,7 +56,7 @@ def _load_inputdataconfig(path: MLPath) -> Optional[Dict[str, DataConfig]]: def _load_channels( - path: MLPath, inputdataconfig: Optional[Dict[str, DataConfig]] + path: TrainPaths, inputdataconfig: Optional[Dict[str, DataConfig]] ) -> Dict[str, Path]: """Lists the available channels in `/opt/ml/input/data`. @@ -80,7 +79,7 @@ def _load_channels( return {channel.name: channel for channel in path.data.iterdir()} -def _load_hyperparameters(path: MLPath, channels) -> dict: +def _load_hyperparameters(path: TrainPaths, channels) -> dict: with path.hyperparameters.open() as json_file: hyperparameters = parse_sagemaker_parameters(json.load(json_file)) diff --git a/src/gluonts/shell/sagemaker/path.py b/src/gluonts/shell/sagemaker/path.py index ef09528956..740eee60aa 100644 --- a/src/gluonts/shell/sagemaker/path.py +++ b/src/gluonts/shell/sagemaker/path.py @@ -16,20 +16,28 @@ from pathlib import Path -class MLPath: +class TrainPaths: def __init__(self, base="/opt/ml") -> None: self.base: Path = Path(base).expanduser().resolve() - self.config: Path = self.base / "input/config" - self.data: Path = self.base / "input/data" + self.config: Path = self.base / "input" / "config" + self.data: Path = self.base / "input" / "data" self.model: Path = self.base / "model" self.output: Path = self.base / "output" self.hyperparameters: Path = self.config / "hyperparameters.json" self.inputdataconfig: Path = self.config / "inputdataconfig.json" - def makedirs(self) -> None: self.config.mkdir(parents=True, exist_ok=True) self.data.mkdir(parents=True, exist_ok=True) self.model.mkdir(parents=True, exist_ok=True) self.output.mkdir(parents=True, exist_ok=True) - # (self.output / 'data').mkdir(parents=True, exist_ok=True) + + +class ServePaths: + def __init__(self, base="/opt/ml") -> None: + self.base: Path = Path(base).expanduser().resolve() + self.model: Path = self.base / "model" + self.output: Path = self.base / "output" + + self.model.mkdir(parents=True, exist_ok=True) + self.output.mkdir(parents=True, exist_ok=True) diff --git a/src/gluonts/shell/serve/__init__.py b/src/gluonts/shell/serve/__init__.py index 13ef3e3e58..0a23be79b0 100644 --- a/src/gluonts/shell/serve/__init__.py +++ b/src/gluonts/shell/serve/__init__.py @@ -22,7 +22,7 @@ # First-party imports from gluonts.model.estimator import Estimator from gluonts.model.predictor import Predictor -from gluonts.shell.sagemaker import SageMakerEnv +from gluonts.shell.sagemaker import ServeEnv import logging import multiprocessing @@ -98,8 +98,7 @@ def stop(self, *args, **kwargs): def run_inference_server( - env: SageMakerEnv, - forecaster_type: Optional[Type[Union[Estimator, Predictor]]], + env: ServeEnv, forecaster_type: Optional[Type[Union[Estimator, Predictor]]] ) -> None: if forecaster_type is not None: ctor = forecaster_type.from_hyperparameters diff --git a/src/gluonts/shell/train.py b/src/gluonts/shell/train.py index a03fdf7b38..5da85be3e8 100644 --- a/src/gluonts/shell/train.py +++ b/src/gluonts/shell/train.py @@ -23,11 +23,11 @@ from gluonts.transform import Dataset, FilterTransformation, TransformedDataset # Relative imports -from .sagemaker import SageMakerEnv +from .sagemaker import TrainEnv def run_train_and_test( - env: SageMakerEnv, forecaster_type: Type[Union[Estimator, Predictor]] + env: TrainEnv, forecaster_type: Type[Union[Estimator, Predictor]] ) -> None: check_gpu_support()