Skip to content

Commit

Permalink
replace class_names and class_colors in DataConfig w/ class_config
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Apr 5, 2024
1 parent b621ac7 commit f902484
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,8 @@ def update(self, pipeline: Optional[RVPipelineConfig] = None):
if self.data.uri is None and self.data.group_uris is None:
self.data.uri = pipeline.chip_uri

if not self.data.class_names:
# We want to defer validating class_names against class_colors
# until we have updated both. Hence, we use Config.copy(update=)
# here because it does not trigger pydantic validators.
self.data = self.data.copy(
update={'class_names': pipeline.dataset.class_config.names})
if not self.data.class_colors:
self.data.class_colors = pipeline.dataset.class_config.colors
if self.data.class_config is None:
self.data.class_config = pipeline.dataset.class_config

if not self.data.img_channels:
self.data.img_channels = self.get_img_channels(pipeline)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def register_plugin(registry: 'Registry'):
registry.set_plugin_version('rastervision.pytorch_learner', 6)
registry.set_plugin_version('rastervision.pytorch_learner', 7)
registry.register_renamed_type_hints('geo_data_window', 'window_sampling')


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
SceneDatasetConfig)
from rastervision.core.rv_pipeline import (WindowSamplingConfig)
from rastervision.pytorch_learner.utils import (
color_to_triple, validate_albumentation_transform, MinMaxNormalize,
validate_albumentation_transform, MinMaxNormalize,
deserialize_albumentation_transform, get_hubconf_dir_from_cfg,
torch_hub_load_local, torch_hub_load_github, torch_hub_load_uri)

Expand Down Expand Up @@ -594,44 +594,23 @@ def validate_channel_display_groups(
return validate_channel_display_groups(v)


def ensure_class_colors(
class_names: List[str],
class_colors: Optional[List[Union[str, RGBTuple]]] = None):
"""Ensure that class_colors is valid.
If class_names is empty, fill with random colors.
Args:
class_names: see DataConfig.class_names
class_colors: see DataConfig.class_colors
"""
if class_colors is not None:
if len(class_names) != len(class_colors):
raise ConfigError(f'len(class_names) ({len(class_names)}) != '
f'len(class_colors) ({len(class_colors)})\n'
f'class_names: {class_names}\n'
f'class_colors: {class_colors}')
elif len(class_names) > 0:
class_colors = [color_to_triple() for _ in class_names]
return class_colors


def data_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version < 2:
if version == 1:
cfg_dict['type_hint'] = 'image_data'
elif version < 3:
elif version == 2:
cfg_dict['img_channels'] = cfg_dict.get('img_channels')
elif version == 6:
class_names = cfg_dict.pop('class_names', [])
class_colors = cfg_dict.pop('class_colors', [])
cfg_dict['class_config'] = ClassConfig(
names=class_names, colors=class_colors)
return cfg_dict


@register_config('data', upgrader=data_config_upgrader)
class DataConfig(Config):
"""Config related to dataset for training and testing."""
class_names: List[str] = Field([], description='Names of classes.')
class_colors: Optional[List[Union[str, RGBTuple]]] = Field(
None,
description=('Colors used to display classes. '
'Can be color 3-tuples in list form.'))
class_config: ClassConfig | None = Field(None, description='Class config.')
img_channels: Optional[PosInt] = Field(
None, description='The number of channels of the training images.')
img_sz: PosInt = Field(
Expand Down Expand Up @@ -675,23 +654,28 @@ class DataConfig(Config):
('Optional limit on the number of items in the preview plots produced '
'during training.'))

@property
def class_names(self):
if self.class_config is None:
return None
return self.class_config.names

@property
def class_colors(self):
if self.class_config is None:
return None
return self.class_config.colors

@property
def num_classes(self):
return len(self.class_names)
return len(self.class_config)

# validators
_base_tf = validator(
'base_transform', allow_reuse=True)(validate_albumentation_transform)
_aug_tf = validator(
'aug_transform', allow_reuse=True)(validate_albumentation_transform)

@root_validator(skip_on_failure=True)
def ensure_class_colors(cls, values: dict) -> dict:
class_names = values.get('class_names')
class_colors = values.get('class_colors')
values['class_colors'] = ensure_class_colors(class_names, class_colors)
return values

@validator('augmentors', each_item=True)
def validate_augmentors(cls, v: str) -> str:
if v not in augmentors:
Expand Down Expand Up @@ -1224,14 +1208,13 @@ def validate_sampling(
return v

@root_validator(skip_on_failure=True)
def get_class_info_from_class_config_if_needed(cls, values: dict) -> dict:
no_classes = len(values['class_names']) == 0
def get_class_config_from_dataset_if_needed(cls, values: dict) -> dict:
has_class_config = values.get('class_config') is not None
if has_class_config:
return values
has_scene_dataset = values.get('scene_dataset') is not None
if no_classes and has_scene_dataset:
class_config: ClassConfig = values['scene_dataset'].class_config
class_config.update()
values['class_names'] = class_config.names
values['class_colors'] = class_config.colors
if has_scene_dataset:
values['class_config'] = values['scene_dataset'].class_config
return values

def build_scenes(self,
Expand Down
3 changes: 1 addition & 2 deletions tests/pytorch_learner/test_classification_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def _test_learner(self,
data_cfg = ClassificationGeoDataConfig(
scene_dataset=dataset_cfg,
sampling=sampling_cfg,
class_names=class_config.names,
class_colors=class_config.colors,
class_config=class_config,
plot_options=PlotOptions(
channel_display_groups=channel_display_groups),
num_workers=0)
Expand Down
30 changes: 22 additions & 8 deletions tests/pytorch_learner/test_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,25 @@ def assertNoError(self, fn: Callable, msg: str = ''):
except Exception:
self.fail(msg)

def test_upgrader(self):
def test_upgrader_v2(self):
old_cfg_dict = DataConfig().dict()
del old_cfg_dict['img_channels']
new_cfg_dict = data_config_upgrader(old_cfg_dict, version=2)
self.assertNoError(lambda: build_config(new_cfg_dict))

def test_upgrader_v6(self):
class_names = ['bg', 'fg']
class_colors = ['black', 'white']
old_cfg_dict = DataConfig().dict()
old_cfg_dict['class_names'] = class_names
old_cfg_dict['class_colors'] = class_colors
del old_cfg_dict['class_config']
new_cfg_dict = data_config_upgrader(old_cfg_dict, version=6)
self.assertIn('class_config', new_cfg_dict)
self.assertListEqual(new_cfg_dict['class_config'].names, class_names)
self.assertListEqual(new_cfg_dict['class_config'].colors, class_colors)
self.assertNoError(lambda: build_config(new_cfg_dict))


class TestSemanticSegmentationDataConfig(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
Expand Down Expand Up @@ -182,6 +195,7 @@ def test_build_cc(self):

nclasses = 2
class_names = [f'class_{i}' for i in range(nclasses)]
class_config = ClassConfig(names=class_names)
chip_sz = 100
img_sz = 200
nchannels = 3
Expand All @@ -191,7 +205,7 @@ def test_build_cc(self):
data_dir = join(tmp_dir, 'data')
for split in ['train', 'valid']:
os.makedirs(join(data_dir, split))
for c in class_names:
for c in class_config.names:
class_dir = join(data_dir, split, c)
os.makedirs(class_dir)
for i in range(nchips):
Expand All @@ -207,7 +221,7 @@ def test_build_cc(self):
# data config -- unzipped
data_cfg = ClassificationImageDataConfig(
uri=data_dir,
class_names=class_names,
class_config=class_config,
img_channels=nchannels,
img_sz=img_sz)
train_ds, val_ds, test_ds = data_cfg.build(tmp_dir)
Expand All @@ -229,7 +243,7 @@ def test_build_cc(self):
zipdir(data_dir, zip_path)
data_cfg = ClassificationImageDataConfig(
uri=zip_path,
class_names=class_names,
class_config=class_config,
img_channels=nchannels,
img_sz=img_sz)
train_ds, val_ds, test_ds = data_cfg.build(tmp_dir)
Expand All @@ -255,6 +269,7 @@ def test_build_ss(self):

nclasses = 2
class_names = [f'class_{i}' for i in range(nclasses)]
class_config = ClassConfig(names=class_names)
chip_sz = 100
img_sz = 200
nchannels = 3
Expand Down Expand Up @@ -283,7 +298,7 @@ def test_build_ss(self):
# data config -- unzipped
data_cfg = SemanticSegmentationImageDataConfig(
uri=data_dir,
class_names=class_names,
class_config=class_config,
img_channels=nchannels,
img_sz=img_sz)
train_ds, val_ds, test_ds = data_cfg.build(tmp_dir)
Expand Down Expand Up @@ -356,7 +371,7 @@ def test_window_config(self):
self.assertRaises(ValidationError,
lambda: WindowSamplingConfig(**args))

def test_get_class_info_from_class_config_if_needed(self):
def test_get_class_config_from_dataset_if_needed(self):
class_config = ClassConfig(names=['bg', 'fg'])
scene_dataset = DatasetConfig(
class_config=class_config, train_scenes=[], validation_scenes=[])
Expand Down Expand Up @@ -423,8 +438,7 @@ def make_scene(num_channels: int, num_classes: int) -> SceneConfig:
scene_dataset=dataset_cfg,
sampling=WindowSamplingConfig(
size=chip_sz, stride=chip_sz, padding=0),
class_names=class_config.names,
class_colors=class_config.colors,
class_config=class_config,
img_sz=img_sz,
num_workers=0)
with get_tmp_dir() as tmp_dir:
Expand Down
3 changes: 1 addition & 2 deletions tests/pytorch_learner/test_object_detection_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def _test_learner(self,
size=200,
max_windows=8,
neg_ratio=0.5),
class_names=class_config.names,
class_colors=class_config.colors,
class_config=class_config,
plot_options=PlotOptions(
channel_display_groups=channel_display_groups),
num_workers=0)
Expand Down
3 changes: 1 addition & 2 deletions tests/pytorch_learner/test_regression_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def _test_learner(self,
sampling=WindowSamplingConfig(
method=WindowSamplingMethod.random, size=20,
max_windows=8),
class_names=class_config.names,
class_colors=class_config.colors,
class_config=class_config,
plot_options=RegressionPlotOptions(
channel_display_groups=channel_display_groups),
num_workers=0)
Expand Down
3 changes: 1 addition & 2 deletions tests/pytorch_learner/test_semantic_segmentation_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def _test_learner(self,
sampling=WindowSamplingConfig(
method=WindowSamplingMethod.random, size=20,
max_windows=8),
class_names=class_config.names,
class_colors=class_config.colors,
class_config=class_config,
aug_transform=aug_tf,
plot_options=PlotOptions(
channel_display_groups=channel_display_groups),
Expand Down

0 comments on commit f902484

Please sign in to comment.