Skip to content

Commit

Permalink
remove LearnerPipeline and LearnerPipelineConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Feb 8, 2024
1 parent f032fe1 commit 9ad0072
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ def register_plugin(registry: 'Registry'):
import rastervision.pipeline
from rastervision.pytorch_learner.learner_config import *
from rastervision.pytorch_learner.learner import *
from rastervision.pytorch_learner.learner_pipeline_config import *
from rastervision.pytorch_learner.learner_pipeline import *
from rastervision.pytorch_learner.classification_learner_config import *
from rastervision.pytorch_learner.classification_learner import *
from rastervision.pytorch_learner.regression_learner_config import *
Expand All @@ -26,9 +24,6 @@ def register_plugin(registry: 'Registry'):
from rastervision.pytorch_learner.dataset import *

__all__ = [
# LearnerPipeline
LearnerPipeline.__name__,
LearnerPipelineConfig.__name__,
# Learner
Learner.__name__,
SemanticSegmentationLearner.__name__,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,24 @@
from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.utils import get_env_var
from rastervision.pipeline.file_system import (
sync_to_dir, json_to_file, file_to_json, make_dir, zipdir,
download_if_needed, download_or_copy, sync_from_dir, get_local_path, unzip,
str_to_file, is_local, get_tmp_dir)
sync_to_dir, json_to_file, make_dir, zipdir, download_if_needed,
download_or_copy, sync_from_dir, get_local_path, unzip, is_local,
get_tmp_dir)
from rastervision.pipeline.file_system.utils import file_exists
from rastervision.pipeline.utils import terminate_at_exit
from rastervision.pipeline.config import (build_config, upgrade_config,
save_pipeline_config)
from rastervision.pipeline.config import build_config
from rastervision.pytorch_learner.utils import (
get_hubconf_dir_from_cfg, aggregate_metrics, log_metrics_to_csv,
log_system_details, ONNXRuntimeAdapter, DDPContextManager)
aggregate_metrics, DDPContextManager, get_hubconf_dir_from_cfg,
get_learner_config_from_bundle_dir, log_metrics_to_csv, log_system_details,
ONNXRuntimeAdapter)
from rastervision.pytorch_learner.dataset.visualizer import Visualizer

if TYPE_CHECKING:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset, Sampler

from rastervision.pytorch_learner import (LearnerConfig,
LearnerPipelineConfig)
from rastervision.pytorch_learner import LearnerConfig

warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -305,14 +304,7 @@ def from_model_bundle(cls: Type,
unzip(model_bundle_path, model_bundle_dir)

if cfg is None:
config_path = join(model_bundle_dir, 'pipeline-config.json')

config_dict = file_to_json(config_path)
config_dict = upgrade_config(config_dict)

learner_pipeline_cfg: 'LearnerPipelineConfig' = build_config(
config_dict)
cfg = learner_pipeline_cfg.learner
cfg = get_learner_config_from_bundle_dir(model_bundle_dir)

hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
model_def_path = None
Expand Down Expand Up @@ -1024,8 +1016,8 @@ def setup_training(self, loss_def_path: Optional[str] = None) -> None:
"""
cfg = self.cfg

self.config_path = join(self.output_dir, 'learner-config.json')
str_to_file(cfg.json(), self.config_path)
self.config_path = join(self.output_dir_local, 'learner-config.json')
cfg.to_file(self.config_path)
self.log_path = join(self.output_dir_local, 'log.csv')
self.last_model_weights_path = join(self.output_dir_local,
'last-model.pth')
Expand Down Expand Up @@ -1399,9 +1391,6 @@ def save_model_bundle(self, export_onnx: bool = True):
This is a zip file with the model weights in .pth format and a serialized
copy of the LearningConfig, which allows for making predictions in the future.
"""
from rastervision.pytorch_learner.learner_pipeline_config import (
LearnerPipelineConfig)

if self.cfg.model is None:
log.warning(
'Model was not configured via ModelConfig, and therefore, '
Expand All @@ -1417,9 +1406,8 @@ def save_model_bundle(self, export_onnx: bool = True):
self._bundle_modules(model_bundle_dir)
self._bundle_transforms(model_bundle_dir)

pipeline_cfg = LearnerPipelineConfig(learner=self.cfg)
save_pipeline_config(pipeline_cfg,
join(model_bundle_dir, 'pipeline-config.json'))
cfg_uri = join(model_bundle_dir, 'learner-config.json')
shutil.copy(self.config_path, cfg_uri)

zip_path = join(self.output_dir_local, basename(self.model_bundle_uri))
log.info(f'Saving bundle to {zip_path}.')
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import (Any, Dict, Sequence, Tuple, Optional, Union, List,
Iterable, Container)
from typing import (TYPE_CHECKING, Any, Dict, Sequence, Tuple, Optional, Union,
List, Iterable, Container)
from os.path import basename, join, isfile
import logging

Expand All @@ -14,8 +14,13 @@
import pandas as pd
import onnxruntime as ort

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.pipeline.config import ConfigError
from rastervision.pipeline.file_system.utils import (file_exists, file_to_json,
get_tmp_dir)
from rastervision.pipeline.config import (build_config, Config, ConfigError,
upgrade_config)

if TYPE_CHECKING:
from rastervision.pytorch_learner import LearnerConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -487,3 +492,23 @@ def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
if isinstance(out, np.ndarray):
out = torch.from_numpy(out)
return out


def get_learner_config_from_bundle_dir(
model_bundle_dir: str) -> 'LearnerConfig':
config_path = join(model_bundle_dir, 'learner-config.json')
if file_exists(config_path):
cfg = Config.from_file(config_path)
else:
# backward compatibility
config_path = join(model_bundle_dir, 'pipeline-config.json')
if not file_exists(config_path):
raise FileNotFoundError(
'Could not find a valid config file in the bundle.')
pipeline_cfg_dict = file_to_json(config_path)
cfg_dict = pipeline_cfg_dict['learner']
cfg_dict['plugin_versions'] = pipeline_cfg_dict['plugin_versions']
cfg_dict = upgrade_config(cfg_dict)
cfg_dict.pop('plugin_versions', None)
cfg = build_config(cfg_dict)
return cfg

0 comments on commit 9ad0072

Please sign in to comment.