Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix circular import problems #2059

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_cc
from rastervision.pytorch_learner import (
ClassificationGeoDataConfig, ClassificationSlidingWindowGeoDataset)
from rastervision.pytorch_learner.dataset import (
ClassificationSlidingWindowGeoDataset)
from rastervision.core.data import ChipClassificationLabels

if TYPE_CHECKING:
import numpy as np
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import ChipOptions, PredictOptions
from rastervision.pytorch_learner import ClassificationGeoDataConfig


class PyTorchChipClassificationSampleWriter(PyTorchLearnerSampleWriter):
Expand Down Expand Up @@ -89,7 +90,8 @@ def predict_scene(self, scene: 'Scene', predict_options: 'PredictOptions'

def _make_chip_data_config(
self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions') -> ClassificationGeoDataConfig:
chip_options: 'ChipOptions') -> 'ClassificationGeoDataConfig':
from rastervision.pytorch_learner import (ClassificationGeoDataConfig)
data_config = ClassificationGeoDataConfig(
scene_dataset=dataset, sampling=chip_options.sampling)
return data_config
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from rastervision.core.data.utils.misc import save_img
from rastervision.core.data_sample import DataSample
from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.learner_config import DataConfig

if TYPE_CHECKING:
from torch.utils.data import Dataset
from rastervision.core.data import ClassConfig, DatasetConfig, Scene
from rastervision.core.rv_pipeline import RVPipelineConfig, ChipOptions
from rastervision.pytorch_learner.learner_config import LearnerConfig
from rastervision.pytorch_learner import DataConfig, LearnerConfig

SPLITS = ['train', 'valid', 'test']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from rastervision.pytorch_backend.utils import chip_collate_fn_od
from rastervision.pytorch_learner.dataset import (
ObjectDetectionSlidingWindowGeoDataset)
from rastervision.pytorch_learner.object_detection_learner_config import (
ObjectDetectionGeoDataConfig)

if TYPE_CHECKING:
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import (ChipOptions,
ObjectDetectionPredictOptions)
from rastervision.pytorch_learner.object_detection_utils import BoxList
from rastervision.pytorch_learner.object_detection_learner_config import (
ObjectDetectionGeoDataConfig)


class PyTorchObjectDetectionSampleWriter(PyTorchLearnerSampleWriter):
Expand Down Expand Up @@ -154,7 +154,8 @@ def predict_scene(self, scene: 'Scene',

def _make_chip_data_config(
self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions') -> ObjectDetectionGeoDataConfig:
chip_options: 'ChipOptions') -> 'ObjectDetectionGeoDataConfig':
from rastervision.pytorch_learner import (ObjectDetectionGeoDataConfig)
data_config = ObjectDetectionGeoDataConfig(
scene_dataset=dataset, sampling=chip_options.sampling)
return data_config
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from rastervision.pytorch_backend.utils import chip_collate_fn_ss
from rastervision.pytorch_learner.dataset import (
SemanticSegmentationSlidingWindowGeoDataset)
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig

if TYPE_CHECKING:
from rastervision.core.data import (DatasetConfig, Scene,
SemanticSegmentationLabelStore)
from rastervision.core.rv_pipeline import (
ChipOptions, SemanticSegmentationPredictOptions)
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig


class PyTorchSemanticSegmentationSampleWriter(PyTorchLearnerSampleWriter):
Expand Down Expand Up @@ -118,9 +118,11 @@ def predict_scene(self, scene: 'Scene',

return labels

def _make_chip_data_config(
self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions') -> SemanticSegmentationGeoDataConfig:
def _make_chip_data_config(self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions'
) -> 'SemanticSegmentationGeoDataConfig':
from rastervision.pytorch_learner import (
SemanticSegmentationGeoDataConfig)
data_config = SemanticSegmentationGeoDataConfig(
scene_dataset=dataset, sampling=chip_options.sampling)
return data_config
Loading