Skip to content

Commit

Permalink
More shell fixes
Browse files Browse the repository at this point in the history
- Add proper documentation strings.
- Fix the printed script name in help output.
- Add a `force-static` train parameter that forces the
  creation of a static predictor from the expected
  model location.
  • Loading branch information
Alexander Alexandrov authored and aalexandrov committed Jul 10, 2019
1 parent f8572ab commit 624b872
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/core/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DEBUG = os.environ.get('DEBUG', 'false').lower() == 'true'
logging.basicConfig(level=logging.DEBUG if DEBUG else logging.INFO)

logger = logging.getLogger('SWIST')
logger = logging.getLogger('gluonts')


def metric(metric: str, value: Any) -> None:
Expand Down
53 changes: 45 additions & 8 deletions src/gluonts/shell/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from gluonts.model.predictor import Predictor

# Relative imports
from .sagemaker import SageMakerEnv
from .sagemaker import TrainEnv, ServeEnv

Forecaster = Type[Union[Estimator, Predictor]]

Expand Down Expand Up @@ -77,14 +77,39 @@ def cli() -> None:
type=click.Path(exists=True),
envvar="SAGEMAKER_DATA_PATH",
default='/opt/ml',
help='The root path of all folders mounted by the SageMaker runtime.',
)
@click.option("--forecaster", metavar="NAME", envvar="GLUONTS_FORECASTER")
def serve_command(data_path: str, forecaster: Optional[str]) -> None:
@click.option(
"--forecaster",
metavar='NAME',
envvar="GLUONTS_FORECASTER",
help=(
'An alias or a fully qualified name of a Predictor to use. '
'If this value is defined, the inference server runs in the '
'so-called dynamic mode, where the predictor is initialized for '
'each request using parameters provided in the "configuration" field '
'of the JSON request. Otherwise, the server runs in static mode, '
'where the predictor is initialized upfront from a serialized model '
'located in the {data-path}/model folder.'
),
)
@click.option(
"--force-static/--no-force-static",
envvar="GLUONTS_FORCE_STATIC",
default=False,
help=(
'Forces execution in static mode, even in situations where the '
'"forecaster" option is present.'
),
)
def serve_command(
data_path: str, forecaster: Optional[str], force_static: bool
) -> None:
from gluonts.shell import serve

env = SageMakerEnv(Path(data_path))
env = ServeEnv(Path(data_path))

if forecaster is not None:
if not force_static and forecaster is not None:
serve.run_inference_server(env, forecaster_type_by_name(forecaster))
else:
serve.run_inference_server(env, None)
Expand All @@ -96,12 +121,24 @@ def serve_command(data_path: str, forecaster: Optional[str]) -> None:
type=click.Path(exists=True),
envvar="SAGEMAKER_DATA_PATH",
default='/opt/ml',
help='The root path of all folders mounted by the SageMaker runtime.',
)
@click.option(
"--forecaster",
type=str,
envvar="GLUONTS_FORECASTER",
help=(
'An alias or a fully qualified name of a Predictor or Estimator to '
'use. If this value is not defined, the command will try to read it'
'from the hyperparameters dictionary under the "forecaster_name" key. '
'If the value denotes a Predictor, training will be skipped and the '
'command will only do an evaluation on the provided test dataset.'
),
)
@click.option("--forecaster", type=str, envvar="GLUONTS_FORECASTER")
def train_command(data_path: str, forecaster: Optional[str]) -> None:
from gluonts.shell import train

env = SageMakerEnv(Path(data_path))
env = TrainEnv(Path(data_path))

if forecaster is None:
try:
Expand All @@ -119,4 +156,4 @@ def train_command(data_path: str, forecaster: Optional[str]) -> None:


if __name__ == "__main__":
cli()
cli(prog_name=__package__)
19 changes: 9 additions & 10 deletions src/gluonts/shell/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from gluonts.dataset.common import FileDataset, MetaData
from .params import parse_sagemaker_parameters
from .path import MLPath
from .path import ServePaths, TrainPaths


class DataConfig(BaseModel):
Expand All @@ -30,22 +30,21 @@ class DataConfig(BaseModel):
DATASET_NAMES = 'train', 'test'


class SageMakerEnv:
class TrainEnv:
def __init__(self, path: Path = Path("/opt/ml")) -> None:
self.path = _load_path(path)
self.path = TrainPaths(path)
self.inputdataconfig = _load_inputdataconfig(self.path)
self.channels = _load_channels(self.path, self.inputdataconfig)
self.hyperparameters = _load_hyperparameters(self.path, self.channels)
self.datasets = _load_datasets(self.hyperparameters, self.channels)


def _load_path(path: Path) -> MLPath:
ml_path = MLPath(path)
ml_path.makedirs()
return ml_path
class ServeEnv:
def __init__(self, path: Path = Path("/opt/ml")) -> None:
self.path = ServePaths(path)


def _load_inputdataconfig(path: MLPath) -> Optional[Dict[str, DataConfig]]:
def _load_inputdataconfig(path: TrainPaths) -> Optional[Dict[str, DataConfig]]:
if path.inputdataconfig.exists():
with path.inputdataconfig.open() as json_file:
return {
Expand All @@ -57,7 +56,7 @@ def _load_inputdataconfig(path: MLPath) -> Optional[Dict[str, DataConfig]]:


def _load_channels(
path: MLPath, inputdataconfig: Optional[Dict[str, DataConfig]]
path: TrainPaths, inputdataconfig: Optional[Dict[str, DataConfig]]
) -> Dict[str, Path]:
"""Lists the available channels in `/opt/ml/input/data`.
Expand All @@ -80,7 +79,7 @@ def _load_channels(
return {channel.name: channel for channel in path.data.iterdir()}


def _load_hyperparameters(path: MLPath, channels) -> dict:
def _load_hyperparameters(path: TrainPaths, channels) -> dict:
with path.hyperparameters.open() as json_file:
hyperparameters = parse_sagemaker_parameters(json.load(json_file))

Expand Down
18 changes: 13 additions & 5 deletions src/gluonts/shell/sagemaker/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,28 @@
from pathlib import Path


class MLPath:
class TrainPaths:
def __init__(self, base="/opt/ml") -> None:
self.base: Path = Path(base).expanduser().resolve()
self.config: Path = self.base / "input/config"
self.data: Path = self.base / "input/data"
self.config: Path = self.base / "input" / "config"
self.data: Path = self.base / "input" / "data"
self.model: Path = self.base / "model"
self.output: Path = self.base / "output"

self.hyperparameters: Path = self.config / "hyperparameters.json"
self.inputdataconfig: Path = self.config / "inputdataconfig.json"

def makedirs(self) -> None:
self.config.mkdir(parents=True, exist_ok=True)
self.data.mkdir(parents=True, exist_ok=True)
self.model.mkdir(parents=True, exist_ok=True)
self.output.mkdir(parents=True, exist_ok=True)
# (self.output / 'data').mkdir(parents=True, exist_ok=True)


class ServePaths:
def __init__(self, base="/opt/ml") -> None:
self.base: Path = Path(base).expanduser().resolve()
self.model: Path = self.base / "model"
self.output: Path = self.base / "output"

self.model.mkdir(parents=True, exist_ok=True)
self.output.mkdir(parents=True, exist_ok=True)
5 changes: 2 additions & 3 deletions src/gluonts/shell/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# First-party imports
from gluonts.model.estimator import Estimator
from gluonts.model.predictor import Predictor
from gluonts.shell.sagemaker import SageMakerEnv
from gluonts.shell.sagemaker import ServeEnv

import logging
import multiprocessing
Expand Down Expand Up @@ -98,8 +98,7 @@ def stop(self, *args, **kwargs):


def run_inference_server(
env: SageMakerEnv,
forecaster_type: Optional[Type[Union[Estimator, Predictor]]],
env: ServeEnv, forecaster_type: Optional[Type[Union[Estimator, Predictor]]]
) -> None:
if forecaster_type is not None:
ctor = forecaster_type.from_hyperparameters
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/shell/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from gluonts.transform import Dataset, FilterTransformation, TransformedDataset

# Relative imports
from .sagemaker import SageMakerEnv
from .sagemaker import TrainEnv


def run_train_and_test(
env: SageMakerEnv, forecaster_type: Type[Union[Estimator, Predictor]]
env: TrainEnv, forecaster_type: Type[Union[Estimator, Predictor]]
) -> None:
check_gpu_support()

Expand Down

0 comments on commit 624b872

Please sign in to comment.