Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify inference with ONNX models #2060

Merged
merged 7 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING
from os.path import join
import uuid

Expand All @@ -7,12 +7,10 @@
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_cc
from rastervision.pytorch_learner.dataset import (
ClassificationSlidingWindowGeoDataset)
from rastervision.core.data import ChipClassificationLabels
from rastervision.pytorch_learner.utils import predict_scene_cc

if TYPE_CHECKING:
import numpy as np
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import ChipOptions, PredictOptions
from rastervision.pytorch_learner import ClassificationGeoDataConfig
Expand Down Expand Up @@ -60,32 +58,9 @@ def chip_dataset(self,

def predict_scene(self, scene: 'Scene', predict_options: 'PredictOptions'
) -> 'ChipClassificationLabels':

if self.learner is None:
self.load_model()

chip_sz = predict_options.chip_sz
stride = predict_options.stride
batch_sz = predict_options.batch_sz

# Important to use self.learner.cfg.data instead of
# self.learner_cfg.data because of the updates
# Learner.from_model_bundle() makes to the custom transforms.
base_tf, _ = self.learner.cfg.data.get_data_transforms()
ds = ClassificationSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)

predictions: Iterator['np.array'] = self.learner.predict_dataset(
ds,
raw_out=True,
numpy_out=True,
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))

labels = ChipClassificationLabels.from_predictions(
ds.windows, predictions)

labels = predict_scene_cc(self.learner, scene, predict_options)
return labels

def _make_chip_data_config(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import TYPE_CHECKING, Dict, Iterator
from typing import TYPE_CHECKING
from os.path import join, basename
import uuid

import numpy as np

from rastervision.pipeline.file_system import json_to_file
from rastervision.core.data_sample import DataSample
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_od
from rastervision.pytorch_learner.dataset import (
ObjectDetectionSlidingWindowGeoDataset)
from rastervision.pytorch_learner.utils import predict_scene_od

if TYPE_CHECKING:
from rastervision.core.data import DatasetConfig, Scene
Expand Down Expand Up @@ -117,39 +114,9 @@ def chip_dataset(self,
def predict_scene(self, scene: 'Scene',
predict_options: 'ObjectDetectionPredictOptions'
) -> ObjectDetectionLabels:

chip_sz = predict_options.chip_sz
stride = predict_options.stride
batch_sz = predict_options.batch_sz

if self.learner is None:
self.load_model()

# Important to use self.learner.cfg.data instead of
# self.learner_cfg.data because of the updates
# Learner.from_model_bundle() makes to the custom transforms.
base_tf, _ = self.learner.cfg.data.get_data_transforms()
ds = ObjectDetectionSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)

predictions: Iterator[Dict[str, 'np.ndarray']] = (
self.learner.predict_dataset(
ds,
raw_out=True,
numpy_out=True,
predict_kw=dict(out_shape=(chip_sz, chip_sz)),
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))
)

labels = ObjectDetectionLabels.from_predictions(
ds.windows, predictions)
labels = ObjectDetectionLabels.prune_duplicates(
labels,
score_thresh=predict_options.score_thresh,
merge_thresh=predict_options.merge_thresh)

labels = predict_scene_od(self.learner, scene, predict_options)
return labels

def _make_chip_data_config(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING
from os.path import join
import uuid

Expand All @@ -10,12 +10,10 @@
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_ss
from rastervision.pytorch_learner.dataset import (
SemanticSegmentationSlidingWindowGeoDataset)
from rastervision.pytorch_learner.utils import predict_scene_ss

if TYPE_CHECKING:
from rastervision.core.data import (DatasetConfig, Scene,
SemanticSegmentationLabelStore)
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import (
ChipOptions, SemanticSegmentationPredictOptions)
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig
Expand Down Expand Up @@ -71,51 +69,9 @@ def chip_dataset(self,
def predict_scene(self, scene: 'Scene',
predict_options: 'SemanticSegmentationPredictOptions'
) -> 'SemanticSegmentationLabels':

if scene.label_store is None:
raise ValueError(
f'Scene.label_store is not set for scene {scene.id}')

if self.learner is None:
self.load_model()

chip_sz = predict_options.chip_sz
stride = predict_options.stride
crop_sz = predict_options.crop_sz
batch_sz = predict_options.batch_sz

label_store: 'SemanticSegmentationLabelStore' = scene.label_store
raw_out = label_store.smooth_output

# Important to use self.learner.cfg.data instead of
# self.learner_cfg.data because of the updates
# Learner.from_model_bundle() makes to the custom transforms.
base_tf, _ = self.learner.cfg.data.get_data_transforms()
pad_direction = 'end' if crop_sz is None else 'both'
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene,
size=chip_sz,
stride=stride,
pad_direction=pad_direction,
transform=base_tf)

predictions: Iterator[np.ndarray] = self.learner.predict_dataset(
ds,
raw_out=raw_out,
numpy_out=True,
predict_kw=dict(out_shape=(chip_sz, chip_sz)),
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))

labels = SemanticSegmentationLabels.from_predictions(
ds.windows,
predictions,
smooth=raw_out,
extent=scene.extent,
num_classes=len(label_store.class_config),
crop_sz=crop_sz)

labels = predict_scene_ss(self.learner, scene, predict_options)
return labels

def _make_chip_data_config(self, dataset: 'DatasetConfig',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def build_default_model(self, num_classes: int,
class ClassificationLearnerConfig(LearnerConfig):
"""Configure a :class:`.ClassificationLearner`."""

data: Union[ClassificationImageDataConfig, ClassificationGeoDataConfig]
model: Optional[ClassificationModelConfig]

def build(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,21 @@ def __init__(self,
training mode. Defaults to True.
"""
self.cfg = cfg

if model is None and cfg.model is None:
self.training = training
self._onnx_mode = (model_weights_path is not None
and model_weights_path.lower().endswith('.onnx'))
if self.onnx_mode and self.training:
raise ValueError('Training mode is not supported for ONNX models.')
if model is None and cfg.model is None and not self.onnx_mode:
raise ValueError(
'cfg.model can only be None if a custom model is specified.')
'cfg.model can only be None if a custom model is specified '
'or if model_weights_path is an .onnx file.')

if tmp_dir is None:
self._tmp_dir = get_tmp_dir()
tmp_dir = self._tmp_dir.name
self.tmp_dir = tmp_dir

self.training = training

self.train_ds = train_ds
self.valid_ds = valid_ds
self.test_ds = test_ds
Expand Down Expand Up @@ -198,38 +201,46 @@ def __init__(self,
# ---------------------------
# Set URIs
# ---------------------------
if output_dir is None and cfg.output_uri is None:
raise ValueError('output_dir or LearnerConfig.output_uri must '
'be specified.')
if output_dir is not None and cfg.output_uri is not None:
log.warning(
'Both output_dir and LearnerConfig.output_uri specified. '
'LearnerConfig.output_uri will be ignored.')
if output_dir is None:
assert cfg.output_uri is not None
self.output_dir = cfg.output_uri
self.model_bundle_uri = cfg.get_model_bundle_uri()
else:
self.output_dir = output_dir
self.model_bundle_uri = join(self.output_dir, 'model-bundle.zip')
if is_local(self.output_dir):
self.output_dir_local = self.output_dir
make_dir(self.output_dir_local)
else:
self.output_dir_local = get_local_path(self.output_dir, tmp_dir)
make_dir(self.output_dir_local, force_empty=True)
if self.training:
self.sync_from_cloud()
log.info(f'Local output dir: {self.output_dir_local}')
log.info(f'Remote output dir: {self.output_dir}')

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
self.checkpoints_dir_local = join(self.output_dir_local,
CHECKPOINTS_DIRNAME)
make_dir(self.checkpoints_dir_local)
self.output_dir = None
self.output_dir_local = None
self.model_bundle_uri = None
self.modules_dir = None
self.checkpoints_dir_local = None

if self.training:
if output_dir is None and cfg.output_uri is None:
raise ValueError('output_dir or LearnerConfig.output_uri must '
'be specified in training mode.')
if output_dir is not None and cfg.output_uri is not None:
log.warning(
'Both output_dir and LearnerConfig.output_uri specified. '
'LearnerConfig.output_uri will be ignored.')
if output_dir is None:
assert cfg.output_uri is not None
self.output_dir = cfg.output_uri
self.model_bundle_uri = cfg.get_model_bundle_uri()
else:
self.output_dir = output_dir
self.model_bundle_uri = join(self.output_dir,
'model-bundle.zip')
if is_local(self.output_dir):
self.output_dir_local = self.output_dir
make_dir(self.output_dir_local)
else:
self.output_dir_local = get_local_path(self.output_dir,
tmp_dir)
make_dir(self.output_dir_local, force_empty=True)
if self.training:
self.sync_from_cloud()
log.info(f'Local output dir: {self.output_dir_local}')
log.info(f'Remote output dir: {self.output_dir}')

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
self.checkpoints_dir_local = join(self.output_dir_local,
CHECKPOINTS_DIRNAME)
make_dir(self.checkpoints_dir_local)

# ---------------------------
self._onnx_mode = False
self.init_model_weights_path = model_weights_path
self.init_model_def_path = model_def_path
self.init_loss_def_path = loss_def_path
Expand Down Expand Up @@ -771,7 +782,7 @@ def predict_dataset(self,

dl_kw = dict(
collate_fn=self.get_collate_fn(),
batch_size=cfg.solver.batch_sz,
batch_size=cfg.solver.batch_sz if cfg.solver else 1,
num_workers=int(num_workers),
shuffle=False,
pin_memory=True)
Expand Down Expand Up @@ -1101,9 +1112,7 @@ def setup_model(self,
model_def_path (Optional[str], optional): Path to model definition.
Will be available when loading from a bundle. Defaults to None.
"""
self._onnx_mode = (model_weights_path is not None
and model_weights_path.lower().endswith('.onnx'))
if self._onnx_mode:
if self.onnx_mode:
self.model = self.load_onnx_model(model_weights_path)
return
if self.model is None:
Expand Down Expand Up @@ -1716,7 +1725,8 @@ def load_checkpoint(self):

def load_onnx_model(self, model_path: str) -> ONNXRuntimeAdapter:
log.info(f'Loading ONNX model from {model_path}')
onnx_model = ONNXRuntimeAdapter.from_file(model_path)
path = download_if_needed(model_path)
onnx_model = ONNXRuntimeAdapter.from_file(path)
return onnx_model

def log_data_stats(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1373,8 +1373,8 @@ def learner_config_upgrader(cfg_dict: dict, version: int) -> dict:
@register_config('learner', upgrader=learner_config_upgrader)
class LearnerConfig(Config):
"""Config for Learner."""
model: Optional[ModelConfig]
solver: SolverConfig
model: Optional[ModelConfig] = None
solver: Optional[SolverConfig] = None
data: DataConfig

eval_train: bool = Field(
Expand Down Expand Up @@ -1411,7 +1411,9 @@ def validate_run_tensorboard(cls, v: bool, values: dict) -> bool:

@root_validator(skip_on_failure=True)
def validate_class_loss_weights(cls, values: dict) -> dict:
solver: SolverConfig = values.get('solver')
solver: Optional[SolverConfig] = values.get('solver')
if solver is None:
return values
class_loss_weights = solver.class_loss_weights
if class_loss_weights is not None:
data: DataConfig = values.get('data')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def build_default_model(self, num_classes: int, in_channels: int,
class ObjectDetectionLearnerConfig(LearnerConfig):
"""Configure an :class:`.ObjectDetectionLearner`."""

data: Union[ObjectDetectionImageDataConfig, ObjectDetectionGeoDataConfig]
model: Optional[ObjectDetectionModelConfig]

def build(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ class RegressionLearnerConfig(LearnerConfig):
"""Configure a :class:`.RegressionLearner`."""

model: Optional[RegressionModelConfig]
data: Union[RegressionImageDataConfig, RegressionGeoDataConfig]

def build(self,
tmp_dir,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, Optional
from os.path import join
from enum import Enum
import logging
Expand Down Expand Up @@ -209,8 +209,6 @@ def build_default_model(self, num_classes: int,
class SemanticSegmentationLearnerConfig(LearnerConfig):
"""Configure a :class:`.SemanticSegmentationLearner`."""

data: Union[SemanticSegmentationImageDataConfig,
SemanticSegmentationGeoDataConfig]
model: Optional[SemanticSegmentationModelConfig]

def build(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from rastervision.pytorch_learner.utils.utils import *
from rastervision.pytorch_learner.utils.torch_hub import *
from rastervision.pytorch_learner.utils.distributed import *
from rastervision.pytorch_learner.utils.prediction import *

__all__ = [
SplitTensor.__name__,
Expand All @@ -24,4 +25,7 @@
torch_hub_load_local.__name__,
DDPContextManager.__name__,
'DDP_BACKEND',
predict_scene_cc.__name__,
predict_scene_od.__name__,
predict_scene_ss.__name__,
]
Loading
Loading