diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..5679bab7 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,33 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.9" + # You can also specify other tool versions: + # nodejs: "19" + # rust: "1.64" + # golang: "1.19" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: requirements/requirements.txt + - requirements: requirements/requirements-docs.txt diff --git a/examples/upload_and_predict_from_numpy.py b/examples/upload_and_predict_from_numpy.py index 4e73adf6..7d9a4811 100644 --- a/examples/upload_and_predict_from_numpy.py +++ b/examples/upload_and_predict_from_numpy.py @@ -62,7 +62,7 @@ def rotate_image(image: np.ndarray, angle: float) -> np.ndarray: rotated_image = rotate_image(image=numpy_image, angle=20) # Make sure that the project exists - ensure_trained_example_project(geti=geti, project_name=PROJECT_NAME) + project = ensure_trained_example_project(geti=geti, project_name=PROJECT_NAME) print( "Uploading and predicting example image now... The prediction results will be " @@ -71,7 +71,7 @@ def rotate_image(image: np.ndarray, angle: float) -> np.ndarray: # We can upload and predict the resulting array directly: sc_image, image_prediction = geti.upload_and_predict_image( - project_name=PROJECT_NAME, + project=project, image=rotated_image, visualise_output=False, delete_after_prediction=DELETE_AFTER_PREDICTION, @@ -100,7 +100,7 @@ def rotate_image(image: np.ndarray, angle: float) -> np.ndarray: print("Video generated, retrieving predictions...") # Create video, upload and predict from the list of frames sc_video, video_frames, frame_predictions = geti.upload_and_predict_video( - project_name=PROJECT_NAME, + project=project, video=rotation_video, frame_stride=1, visualise_output=False, diff --git a/examples/upload_and_predict_media_from_folder.py b/examples/upload_and_predict_media_from_folder.py index fb954d77..34407b64 100644 --- a/examples/upload_and_predict_media_from_folder.py +++ b/examples/upload_and_predict_media_from_folder.py @@ -38,11 +38,11 @@ # -------------------------------------------------- # Make sure that the specified project exists on the server - ensure_trained_example_project(geti=geti, project_name=PROJECT_NAME) + project = ensure_trained_example_project(geti=geti, project_name=PROJECT_NAME) # Upload the media in the folder and generate predictions geti.upload_and_predict_media_folder( - project_name=PROJECT_NAME, + project=project, media_folder=FOLDER_WITH_MEDIA, delete_after_prediction=DELETE_AFTER_PREDICTION, output_folder=OUTPUT_FOLDER, diff --git a/geti_sdk/annotation_readers/base_annotation_reader.py b/geti_sdk/annotation_readers/base_annotation_reader.py index 65c5fe49..55a245cb 100644 --- a/geti_sdk/annotation_readers/base_annotation_reader.py +++ b/geti_sdk/annotation_readers/base_annotation_reader.py @@ -33,12 +33,14 @@ def __init__( base_data_folder: str, annotation_format: str = ".json", task_type: Union[TaskType, str] = TaskType.DETECTION, + anomaly_reduction: bool = False, ): if task_type is not None and not isinstance(task_type, TaskType): task_type = TaskType(task_type) self.base_folder = base_data_folder self.annotation_format = annotation_format self.task_type = task_type + self.anomaly_reduction = anomaly_reduction self._filepaths: Optional[List[str]] = None diff --git a/geti_sdk/annotation_readers/datumaro_annotation_reader/datumaro_annotation_reader.py b/geti_sdk/annotation_readers/datumaro_annotation_reader/datumaro_annotation_reader.py index 1939ad8a..2bb4f216 100644 --- a/geti_sdk/annotation_readers/datumaro_annotation_reader/datumaro_annotation_reader.py +++ b/geti_sdk/annotation_readers/datumaro_annotation_reader/datumaro_annotation_reader.py @@ -46,6 +46,7 @@ class DatumAnnotationReader(AnnotationReader): TaskType.ANOMALY_CLASSIFICATION, TaskType.ANOMALY_DETECTION, TaskType.ANOMALY_SEGMENTATION, + TaskType.ANOMALY, ] def __init__( diff --git a/geti_sdk/annotation_readers/geti_annotation_reader.py b/geti_sdk/annotation_readers/geti_annotation_reader.py index e77d4f98..720d27e8 100644 --- a/geti_sdk/annotation_readers/geti_annotation_reader.py +++ b/geti_sdk/annotation_readers/geti_annotation_reader.py @@ -22,6 +22,7 @@ from geti_sdk.data_models import Annotation, TaskType from geti_sdk.data_models.media import MediaInformation +from geti_sdk.data_models.shapes import Rectangle from geti_sdk.rest_converters import AnnotationRESTConverter from geti_sdk.rest_converters.annotation_rest_converter import ( NormalizedAnnotationRESTConverter, @@ -41,6 +42,7 @@ def __init__( annotation_format: str = ".json", task_type: Optional[Union[TaskType, str]] = None, label_names_to_include: Optional[List[str]] = None, + anomaly_reduction: bool = False, ): """ :param base_data_folder: Path to the folder containing the annotations @@ -50,6 +52,10 @@ def __init__( :param label_names_to_include: Names of the labels that should be included when reading annotation data. This can be used to filter the annotations for certain labels. + :param anomaly_reduction: True to reduce all anomaly tasks to the single anomaly task. + This is done in accordance with the Intel Geti 2.5 Anomaly Reduction effort. + All pixel level annotations are converted to full rectangles. All anomaly tasks + are mapped to th new "Anomaly Detection" task wich corresponds to the old "Anomaly Classification". """ if annotation_format != ".json": raise ValueError( @@ -60,6 +66,7 @@ def __init__( base_data_folder=base_data_folder, annotation_format=annotation_format, task_type=task_type, + anomaly_reduction=anomaly_reduction, ) self._label_names_to_include = label_names_to_include self._normalized_annotations = self._has_normalized_annotations() @@ -160,6 +167,22 @@ def get_data( label_name=label_dict["name"] ) new_annotations.append(annotation_object) + if ( + self.anomaly_reduction + and annotation_object.labels[0].name.lower() == "anomalous" + ): + # Part of anomaly task reduction in Intel Geti 2.5 -> all anomaly tasks combined into one. + # Intel Geti now only accepts full rectangles for anomaly tasks. + new_annotations = [ + Annotation( + labels=[annotation_object.labels[0]], + shape=Rectangle.generate_full_box( + image_width=media_information.width, + image_height=media_information.height, + ), + ) + ] + break return new_annotations def get_all_label_names(self) -> List[str]: diff --git a/geti_sdk/benchmarking/benchmarker.py b/geti_sdk/benchmarking/benchmarker.py index bce2f332..3527b71a 100644 --- a/geti_sdk/benchmarking/benchmarker.py +++ b/geti_sdk/benchmarking/benchmarker.py @@ -53,7 +53,7 @@ class Benchmarker: def __init__( self, geti: Geti, - project: Union[str, Project], + project: Project, precision_levels: Optional[Sequence[str]] = None, models: Optional[Sequence[Model]] = None, algorithms: Optional[Sequence[str]] = None, @@ -83,7 +83,7 @@ def __init__( be called after initialization. :param geti: Geti instance on which the project to use for benchmarking lives - :param project: Project or project name to use for the benchmarking. The + :param project: Project to use for the benchmarking. The project must exist on the specified Geti instance :param precision_levels: List of model precision levels to run the benchmarking for. Throughput will be measured for each precision level @@ -111,11 +111,8 @@ def __init__( on. """ self.geti = geti - if isinstance(project, str): - project_name = project - else: - project_name = project.name - self.project = geti.get_project(project_name) + # Update project object to get the latest project details + self.project = self.geti.get_project(project_id=project.id) logging.info( f"Setting up Benchmarker for Intel® Geti™ project `{self.project.name}`." ) @@ -501,7 +498,7 @@ def prepare_benchmark(self, working_directory: os.PathLike = "."): output_folder = os.path.join(working_directory, f"deployment_{index}") with suppress_log_output(): self.geti.deploy_project( - project_name=self.project.name, + project=self.project, output_folder=output_folder, models=opt_models, ) diff --git a/geti_sdk/data_models/__init__.py b/geti_sdk/data_models/__init__.py index 210c7a28..baf5d805 100644 --- a/geti_sdk/data_models/__init__.py +++ b/geti_sdk/data_models/__init__.py @@ -177,6 +177,7 @@ TaskConfiguration, ) from .credit_system import CreditAccount, CreditBalance, Subscription +from .dataset import Dataset, Subset, TrainingDatasetStatistics from .enums import AnnotationKind, MediaType, TaskType from .job import Job from .label import Label, ScoredLabel @@ -185,7 +186,7 @@ from .model_group import ModelGroup, ModelSummary from .performance import Performance from .predictions import Prediction -from .project import Dataset, Pipeline, Project +from .project import Pipeline, Project from .status import ProjectStatus from .task import Task from .test_result import Score, TestResult @@ -198,6 +199,7 @@ "Label", "Task", "Pipeline", + "Dataset", "Image", "Video", "MediaItem", @@ -220,11 +222,12 @@ "ProjectStatus", "Job", "CodeDeploymentInformation", - "Dataset", "TestResult", "Score", "User", "CreditAccount", "CreditBalance", "Subscription", + "Subset", + "TrainingDatasetStatistics", ] diff --git a/geti_sdk/data_models/containers/algorithm_list.py b/geti_sdk/data_models/containers/algorithm_list.py index 04f568a2..e596100f 100644 --- a/geti_sdk/data_models/containers/algorithm_list.py +++ b/geti_sdk/data_models/containers/algorithm_list.py @@ -26,6 +26,7 @@ "anomaly_classification": "ote_anomaly_classification_padim", "anomaly_detection": "ote_anomaly_classification_padim", "anomaly_segmentation": "ote_anomaly_segmentation_padim", + "anomaly": "ote_anomaly_classification_padim", "rotated_detection": "Custom_Rotated_Detection_via_Instance_Segmentation_MaskRCNN_ResNet50", "instance_segmentation": "Custom_Counting_Instance_Segmentation_MaskRCNN_ResNet50", } diff --git a/geti_sdk/data_models/dataset.py b/geti_sdk/data_models/dataset.py new file mode 100644 index 00000000..59096071 --- /dev/null +++ b/geti_sdk/data_models/dataset.py @@ -0,0 +1,124 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +from typing import ClassVar, Dict, List, Optional + +import attr + +from geti_sdk.data_models.containers import MediaList +from geti_sdk.data_models.enums import SubsetPurpose +from geti_sdk.data_models.media import Image, VideoFrame +from geti_sdk.data_models.utils import ( + deidentify, + str_to_datetime, + str_to_enum_converter, +) + + +@attr.define +class Dataset: + """ + Representation of a dataset for a project in Intel® Geti™. + + :var id: Unique database ID of the dataset + :var name: name of the dataset + """ + + _identifier_fields: ClassVar[str] = ["id", "creation_time"] + _GET_only_fields: ClassVar[List[str]] = ["use_for_training", "creation_time"] + + name: str + id: Optional[str] = None + creation_time: Optional[str] = attr.field(default=None, converter=str_to_datetime) + use_for_training: Optional[bool] = None + + def deidentify(self) -> None: + """ + Remove unique database ID from the Dataset. + """ + deidentify(self) + + def prepare_for_post(self) -> None: + """ + Set all fields to None that are not valid for making a POST request to the + /projects endpoint. + + :return: + """ + for field_name in self._GET_only_fields: + setattr(self, field_name, None) + + +@attr.define +class TrainingDatasetStatistics: + """ + Statistics for a specific dataset that was used for training a model. Note that + a `dataset` includes both the training, validation and testing set. + """ + + id: str + creation_time: str = attr.field(converter=str_to_datetime) + subset_info: Dict[str, int] + dataset_info: Dict[str, int] + + @property + def training_size(self) -> int: + """Return the number of dataset items in the training set""" + return self.subset_info["training"] + + @property + def validation_size(self) -> int: + """Return the number of dataset items in the validation set""" + return self.subset_info["validation"] + + @property + def testing_size(self) -> int: + """Return the number of dataset items in the testing set""" + return self.subset_info["testing"] + + @property + def number_of_videos(self) -> int: + """Return the total number of videos in the dataset""" + return self.dataset_info["videos"] + + @property + def number_of_frames(self) -> int: + """Return the total number of video frames in the dataset""" + return self.dataset_info["frames"] + + @property + def number_of_images(self) -> int: + """Return the total number of images in the dataset""" + return self.dataset_info["images"] + + +@attr.define +class Subset: + """ + Return the media items for a specific subset (i.e. 'training', 'validation' or + 'testing') + + :var images: List of images in the subset + :var frames: List of video frames in the subset + :var purpose: string representing the purpose of the subset. Can be either + + """ + + images: MediaList[Image] + frames: MediaList[VideoFrame] + purpose: str = attr.field(converter=str_to_enum_converter(SubsetPurpose)) + + @property + def size(self) -> int: + """Return the total number of items in the subset""" + return len(self.images) + len(self.frames) diff --git a/geti_sdk/data_models/enums/__init__.py b/geti_sdk/data_models/enums/__init__.py index f2d8f04e..347e8b78 100644 --- a/geti_sdk/data_models/enums/__init__.py +++ b/geti_sdk/data_models/enums/__init__.py @@ -27,6 +27,7 @@ from .prediction_mode import PredictionMode from .shape_type import ShapeType from .subscription_status import SubscriptionStatus +from .subset_purpose import SubsetPurpose from .task_type import TaskType __all__ = [ @@ -44,4 +45,5 @@ "JobType", "JobState", "DeploymentState", + "SubsetPurpose", ] diff --git a/geti_sdk/data_models/enums/domain.py b/geti_sdk/data_models/enums/domain.py index af4e4648..1a0db09d 100644 --- a/geti_sdk/data_models/enums/domain.py +++ b/geti_sdk/data_models/enums/domain.py @@ -28,9 +28,9 @@ class Domain(Enum): ANOMALY_CLASSIFICATION = "ANOMALY_CLASSIFICATION" ANOMALY_DETECTION = "ANOMALY_DETECTION" ANOMALY_SEGMENTATION = "ANOMALY_SEGMENTATION" + ANOMALY = "ANOMALY" INSTANCE_SEGMENTATION = "INSTANCE_SEGMENTATION" ROTATED_DETECTION = "ROTATED_DETECTION" - ANOMALY = "ANOMALY" def __str__(self) -> str: """ diff --git a/geti_sdk/data_models/enums/subset_purpose.py b/geti_sdk/data_models/enums/subset_purpose.py new file mode 100644 index 00000000..4a1c80fb --- /dev/null +++ b/geti_sdk/data_models/enums/subset_purpose.py @@ -0,0 +1,25 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from enum import Enum + + +class SubsetPurpose(Enum): + """ + Enum representing the purpose of a subset of a dataset on the Intel® Geti™ platform. + """ + + TRAINING = "training" + VALIDATION = "validation" + TESTING = "testing" diff --git a/geti_sdk/data_models/enums/task_type.py b/geti_sdk/data_models/enums/task_type.py index e4197e0c..d0df087e 100644 --- a/geti_sdk/data_models/enums/task_type.py +++ b/geti_sdk/data_models/enums/task_type.py @@ -26,6 +26,7 @@ class TaskType(Enum): ANOMALY_CLASSIFICATION = "anomaly_classification" ANOMALY_DETECTION = "anomaly_detection" ANOMALY_SEGMENTATION = "anomaly_segmentation" + ANOMALY = "anomaly" INSTANCE_SEGMENTATION = "instance_segmentation" ROTATED_DETECTION = "rotated_detection" DATASET = "dataset" @@ -111,9 +112,14 @@ def from_domain(cls, domain): TaskType.ANOMALY_CLASSIFICATION, TaskType.ANOMALY_DETECTION, TaskType.ANOMALY_SEGMENTATION, + TaskType.ANOMALY, ] -GLOBAL_TASK_TYPES = [TaskType.CLASSIFICATION, TaskType.ANOMALY_CLASSIFICATION] +GLOBAL_TASK_TYPES = [ + TaskType.CLASSIFICATION, + TaskType.ANOMALY_CLASSIFICATION, + TaskType.ANOMALY, +] SEGMENTATION_TASK_TYPES = [ TaskType.SEGMENTATION, diff --git a/geti_sdk/data_models/job.py b/geti_sdk/data_models/job.py index 8912ec31..fc295089 100644 --- a/geti_sdk/data_models/job.py +++ b/geti_sdk/data_models/job.py @@ -17,8 +17,8 @@ import attr +from geti_sdk.data_models.dataset import Dataset from geti_sdk.data_models.enums import JobState, JobType -from geti_sdk.data_models.project import Dataset from geti_sdk.data_models.status import StatusSummary from geti_sdk.data_models.utils import ( attr_value_serializer, diff --git a/geti_sdk/data_models/media.py b/geti_sdk/data_models/media.py index 47530efe..47eeb65c 100644 --- a/geti_sdk/data_models/media.py +++ b/geti_sdk/data_models/media.py @@ -55,6 +55,7 @@ class MediaInformation: display_url: str height: int width: int + extension: Optional[str] = None # Added in Geti v2.5 size: Optional[int] = None # Added in Geti v1.2 diff --git a/geti_sdk/data_models/media_identifiers.py b/geti_sdk/data_models/media_identifiers.py index 51ca45a8..737026b6 100644 --- a/geti_sdk/data_models/media_identifiers.py +++ b/geti_sdk/data_models/media_identifiers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import ClassVar, Dict +from typing import ClassVar, Dict, Optional import attr @@ -62,12 +62,14 @@ class VideoFrameIdentifier(MediaIdentifier): :var frame_index: Index of the video frame in the full video :var video_id: unique database ID of the video to which the frame belongs + :var key_index: Index of the key frame in the video """ _identifier_fields: ClassVar[str] = ["video_id"] frame_index: int video_id: str + key_index: Optional[int] = None # Added in Geti 2.5 @attr.define diff --git a/geti_sdk/data_models/model.py b/geti_sdk/data_models/model.py index 02c2ca8b..feb55108 100644 --- a/geti_sdk/data_models/model.py +++ b/geti_sdk/data_models/model.py @@ -58,6 +58,16 @@ class ModelPurgeInfo: user_uid: Optional[str] = None +@attr.define +class TrainingFramework: + """ + Representation of the training framework used to train the model. + """ + + type: str + version: str + + @attr.define class OptimizationConfigurationParameter: """ @@ -97,6 +107,7 @@ class BaseModel: default=None ) # Added in Intel Geti 1.1 total_disk_size: Optional[int] = None # Added in Intel Geti 2.3 + training_framework: Optional[TrainingFramework] = None # Added in Intel Geti 2.5 def __attrs_post_init__(self): """ diff --git a/geti_sdk/data_models/project.py b/geti_sdk/data_models/project.py index 7d99bfe8..c09a924f 100644 --- a/geti_sdk/data_models/project.py +++ b/geti_sdk/data_models/project.py @@ -18,6 +18,7 @@ import attr +from .dataset import Dataset from .label import Label from .performance import Performance from .task import Task @@ -151,40 +152,6 @@ def prepare_for_post(self) -> None: task.prepare_for_post() -@attr.define -class Dataset: - """ - Representation of a dataset for a project in Intel® Geti™. - - :var id: Unique database ID of the dataset - :var name: name of the dataset - """ - - _identifier_fields: ClassVar[str] = ["id", "creation_time"] - _GET_only_fields: ClassVar[List[str]] = ["use_for_training", "creation_time"] - - name: str - id: Optional[str] = None - creation_time: Optional[str] = attr.field(default=None, converter=str_to_datetime) - use_for_training: Optional[bool] = None - - def deidentify(self) -> None: - """ - Remove unique database ID from the Dataset. - """ - deidentify(self) - - def prepare_for_post(self) -> None: - """ - Set all fields to None that are not valid for making a POST request to the - /projects endpoint. - - :return: - """ - for field_name in self._GET_only_fields: - setattr(self, field_name, None) - - @attr.define class Project: """ diff --git a/geti_sdk/demos/demo_projects/anomaly_demos.py b/geti_sdk/demos/demo_projects/anomaly_demos.py index a23021fb..304e6cee 100644 --- a/geti_sdk/demos/demo_projects/anomaly_demos.py +++ b/geti_sdk/demos/demo_projects/anomaly_demos.py @@ -28,14 +28,14 @@ def create_anomaly_classification_demo_project( geti: Geti, n_images: int, n_annotations: int = -1, - project_name: str = "Anomaly classification demo", + project_name: str = "Anomaly demo", dataset_path: Optional[str] = None, ) -> Project: """ - Create a demo project of type 'anomaly_classification', based off the MVTec + Create a demo project of type 'anomaly', based off the MVTec anomaly detection dataset. - This method creates a project with a single 'Anomaly classification' task. + This method creates a project with a single 'Anomaly' task. :param geti: Geti instance, representing the GETi server on which the project should be created. @@ -44,7 +44,7 @@ def create_anomaly_classification_demo_project( :param n_annotations: Number of images that should be annotated. Pass -1 to upload annotations for all images. :param project_name: Name of the project to create. - Defaults to 'Anomaly classification demo' + Defaults to 'Anomaly demo' :param dataset_path: Path to the dataset to use as data source. Defaults to the 'data' directory in the top level folder of the geti_sdk package. If the dataset is not found in the target folder, this method will attempt to @@ -54,7 +54,7 @@ def create_anomaly_classification_demo_project( """ project_client = ProjectClient(session=geti.session, workspace_id=geti.workspace_id) data_path = get_mvtec_dataset(dataset_path) - logging.info(" ------- Creating anomaly classification project --------------- ") + logging.info(" ------- Creating anomaly project --------------- ") # Create annotation reader annotation_reader = DirectoryTreeAnnotationReader( @@ -104,13 +104,13 @@ def create_anomaly_classification_demo_project( def ensure_trained_anomaly_project( - geti: Geti, project_name: str = "Transistor anomaly classification" + geti: Geti, project_name: str = "Transistor anomaly detection" ): """ Check whether the project named `project_name` exists on the server, and create it if it not. - If the project does not exist, this method will create an anomaly classification + If the project does not exist, this method will create an anomaly detection project based on the MVTec AD `transistor` dataset. :param geti: Geti instance pointing to the Intel® Geti™ server @@ -126,7 +126,7 @@ def ensure_trained_anomaly_project( geti=geti, n_images=-1, project_name=project_name ) logging.info( - f"Project `{project_name}` of type `anomaly_classification` was created on " + f"Project `{project_name}` of type `anomaly` was created on " f"host `{geti.session.config.host}`." ) diff --git a/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py b/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py index ab566023..e4bf90ea 100644 --- a/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py +++ b/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py @@ -30,7 +30,7 @@ from geti_sdk.data_models.annotations import Annotation from geti_sdk.data_models.enums.domain import Domain -from geti_sdk.data_models.label import ScoredLabel +from geti_sdk.data_models.label import Label, ScoredLabel from geti_sdk.data_models.label_schema import LabelSchema from geti_sdk.data_models.predictions import Prediction from geti_sdk.data_models.shapes import ( @@ -85,6 +85,9 @@ def __init__(self, label_schema: LabelSchema): # add empty labels if only one non-empty label exits non_empty_labels = [label for label in all_labels if not label.is_empty] self.labels = all_labels if len(non_empty_labels) == 1 else non_empty_labels + self.label_name_mapping: Dict[str, Label] = { + label.name: label for label in self.labels + } # get the first empty label self.empty_label = next((label for label in all_labels if label.is_empty), None) multilabel = len(label_schema.get_groups(False)) > 1 @@ -110,8 +113,13 @@ def convert_to_prediction( """ labels = [] for label in inference_results.top_labels: + label_idx, label_name, label_prob = label + # label_idx does not necessarily match the label index in the project + # labels. Therefore, we map the label by name instead. labels.append( - ScoredLabel.from_label(self.labels[label[0]], float(label[-1])) + ScoredLabel.from_label( + self.label_name_mapping[label_name], float(label_prob) + ) ) if not labels and self.empty_label: diff --git a/geti_sdk/geti.py b/geti_sdk/geti.py index f1d96b34..8c017c9c 100644 --- a/geti_sdk/geti.py +++ b/geti_sdk/geti.py @@ -252,19 +252,28 @@ def credit_balance(self) -> Optional[int]: return balance.available if balance is not None else None def get_project( - self, project_name: str, project_id: Optional[str] = None + self, + project_name: Optional[str] = None, + project_id: Optional[str] = None, + project: Optional[Project] = None, ) -> Project: """ - Return the Intel® Geti™ project named `project_name`, if any. If no project by - that name is found on the Intel® Geti™ server, this method will raise a - KeyError. - - :param project_name: Name of the project to retrieve - :raises: KeyError if project named `project_name` is not found on the server - :return: Project identified by `project_name` + Return the Intel® Geti™ project by name or ID, if any. + If a project object is passed, the method will return the updated object. + If no project by that name is found on the Intel® Geti™ server, + this method will raise a KeyError. + + :param project_name: Name of the project to retrieve. + :param project_id: ID of the project to retrieve. If not specified, the + project with name `project_name` will be retrieved. + :param project: Project object to update. If provided, the associated `project_id` + will be used to update the project object. + :raises: KeyError if the project identified by one of the arguments is not found on the server + :raises: ValueError if there are several projects on the server named `project_name` + :return: Project identified by one of the arguments. """ - project = self.project_client.get_project_by_name( - project_name=project_name, project_id=project_id + project = self.project_client.get_project( + project_name=project_name, project_id=project_id, project=project ) if project is None: raise KeyError( @@ -275,8 +284,7 @@ def get_project( def download_project_data( self, - project_name: str, - project_id: Optional[str] = None, + project: Project, target_folder: Optional[str] = None, include_predictions: bool = False, include_active_models: bool = False, @@ -332,7 +340,7 @@ def download_project_data( Downloading a project may take a substantial amount of time if the project dataset is large. - :param project_name: Name of the project to download + :param project: Project object to download :param target_folder: Path to the local folder in which the project data should be saved. If not specified, a new directory will be created inside the current working directory. The name of the resulting directory will be @@ -354,7 +362,7 @@ def download_project_data( regarding the downloaded project """ project = self.import_export_module.download_project_data( - project=self.get_project(project_name=project_name, project_id=project_id), + project=project, target_folder=target_folder, include_predictions=include_predictions, include_active_models=include_active_models, @@ -363,7 +371,7 @@ def download_project_data( # Download deployment if include_deployment: logging.info("Creating deployment for project...") - self.deploy_project(project.name, output_folder=target_folder) + self.deploy_project(project, output_folder=target_folder) logging.info(f"Project '{project.name}' was downloaded successfully.") return project @@ -459,8 +467,7 @@ def upload_all_projects(self, target_folder: str) -> List[Project]: def export_project( self, filepath: os.PathLike, - project_name: str, - project_id: Optional[str] = None, + project: Project, ) -> None: """ Export a project with name `project_name` to the file specified by `filepath`. @@ -468,19 +475,15 @@ def export_project( and metadata required for project import to another instance of the Intel® Geti™ platform. :param filepath: Path to the file to save the project to - :param project_name: Name of the project to export - :param project_id: Optional ID of the project to export. If not specified, the - project with name `project_name` will be exported. + :param project: Project object to export """ - if project_id is None: - project_id = self.get_project(project_name=project_name).id - if project_id is None: + if project.id is None: raise ValueError( - f"Could not retrieve project ID for project '{project_name}'." - "Please specify the project ID explicitly." + f"Could not retrieve project ID for project '{project.name}'." + "Please reinitialize the project object." ) self.import_export_module.export_project( - project_id=project_id, filepath=filepath + project_id=project.id, filepath=filepath ) def import_project( @@ -523,7 +526,7 @@ def export_dataset( in the dataset, False to only include media with annotations. Defaults to False. """ - if type(export_format) is str: + if isinstance(export_format, str): export_format = DatasetFormat[export_format] self.import_export_module.export_dataset( project=project, @@ -549,6 +552,7 @@ def import_dataset( * anomaly_classification * anomaly_detection * anomaly_segmentation + * anomaly (choose this working with SaaS) * detection_oriented * detection_to_classification * detection_to_segmentation @@ -592,6 +596,7 @@ def create_single_task_project_from_dataset( * anomaly_classification * anomaly_detection * anomaly_segmentation + * anomaly (new task - anomaly classification) * instance_segmentation * rotated_detection @@ -638,7 +643,7 @@ def create_single_task_project_from_dataset( if criterion == "XOR": multilabel = False labels = generate_classification_labels(labels, multilabel=multilabel) - elif project_type == "anomaly_classification": + elif project_type == "anomaly_classification" or project_type == "anomaly": labels = ["Normal", "Anomalous"] # Create project @@ -858,7 +863,7 @@ def create_task_chain_project_from_dataset( def upload_and_predict_media_folder( self, - project_name: str, + project: Project, media_folder: str, output_folder: Optional[str] = None, delete_after_prediction: bool = False, @@ -867,7 +872,7 @@ def upload_and_predict_media_folder( ) -> bool: """ Upload a folder with media (images, videos or both) from local disk at path - `target_folder` to the project with name `project_name` on the Intel® Geti™ + `target_folder` to the project provided with the `project` argument on the Intel® Geti™ server. After the media upload is complete, predictions will be downloaded for all media in the folder. This method will create a 'predictions' directory in @@ -877,7 +882,7 @@ def upload_and_predict_media_folder( removed from the project on the Intel® Geti™ server after the predictions have been downloaded. - :param project_name: Name of the project to upload media to + :param project: Project object to upload the media to :param media_folder: Path to the folder to upload media from :param output_folder: Path to save the predictions to. If not specified, this method will create a folder named '_predictions' on @@ -892,16 +897,6 @@ def upload_and_predict_media_folder( :return: True if all media was uploaded, and predictions for all media were successfully downloaded. False otherwise """ - # Obtain project details from cluster - try: - project = self.get_project(project_name=project_name) - except ValueError: - logging.info( - f"Project '{project_name}' was not found on the cluster. Aborting " - f"media upload." - ) - return False - # Upload images image_client = ImageClient( session=self.session, workspace_id=self.workspace_id, project=project @@ -927,7 +922,7 @@ def upload_and_predict_media_folder( ) if not prediction_client.ready_to_predict: logging.info( - f"Project '{project_name}' is not ready to make predictions, likely " + f"Project '{project.name}' is not ready to make predictions, likely " f"because one of the tasks in the task chain does not have a " f"trained model yet. Aborting prediction." ) @@ -965,17 +960,17 @@ def upload_and_predict_media_folder( def upload_and_predict_image( self, - project_name: str, + project: Project, image: Union[np.ndarray, Image, VideoFrame, str, os.PathLike], visualise_output: bool = True, delete_after_prediction: bool = False, dataset_name: Optional[str] = None, ) -> Tuple[Image, Prediction]: """ - Upload a single image to a project named `project_name` on the Intel® Geti™ + Upload a single image to a project on the Intel® Geti™ server, and return a prediction for it. - :param project_name: Name of the project to upload the image to + :param project: Project object to upload the image to :param image: Image, numpy array representing an image, or filepath to an image to upload and get a prediction for :param visualise_output: True to show the resulting prediction, overlayed on @@ -989,8 +984,6 @@ def upload_and_predict_image( - Image object representing the image that was uploaded - Prediction for the image """ - project = self.get_project(project_name=project_name) - # Get the dataset to upload to dataset: Optional[Dataset] = None if dataset_name is not None: @@ -1030,7 +1023,7 @@ def upload_and_predict_image( ) if not prediction_client.ready_to_predict: raise ValueError( - f"Project '{project_name}' is not ready to make predictions. At least " + f"Project '{project.name}' is not ready to make predictions. At least " f"one of the tasks in the task chain does not have any models trained." ) prediction = prediction_client.get_image_prediction(uploaded_image) @@ -1048,21 +1041,21 @@ def upload_and_predict_image( def upload_and_predict_video( self, - project_name: str, + project: Project, video: Union[Video, str, os.PathLike, Union[Sequence[np.ndarray], np.ndarray]], frame_stride: Optional[int] = None, visualise_output: bool = True, delete_after_prediction: bool = False, ) -> Tuple[Video, MediaList[VideoFrame], List[Prediction]]: """ - Upload a single video to a project named `project_name` on the Intel® Geti™ + Upload a single video to a project on the Intel® Geti™ server, and return a list of predictions for the frames in the video. The parameter 'frame_stride' is used to control the stride for frame extraction. Predictions are only generated for the extracted frames. So to get predictions for all frames, `frame_stride=1` can be passed. - :param project_name: Name of the project to upload the image to + :param project: Project to upload the video to :param video: Video or filepath to a video to upload and get predictions for. Can also be a 4D numpy array or a list of 3D numpy arrays, shaped such that the array dimensions represent `frames x width x height x channels`, @@ -1081,8 +1074,6 @@ def upload_and_predict_video( have been generated - List of Predictions for the Video """ - project = self.get_project(project_name=project_name) - # Upload the video video_client = VideoClient( session=self.session, workspace_id=self.workspace_id, project=project @@ -1105,7 +1096,7 @@ def upload_and_predict_video( else: video_data = video if needs_upload: - logging.info(f"Uploading video to project '{project_name}'...") + logging.info(f"Uploading video to project '{project.name}'...") uploaded_video = video_client.upload_video(video=video_data) else: uploaded_video = video @@ -1116,7 +1107,7 @@ def upload_and_predict_video( ) if not prediction_client.ready_to_predict: raise ValueError( - f"Project '{project_name}' is not ready to make predictions. At least " + f"Project '{project.name}' is not ready to make predictions. At least " f"one of the tasks in the task chain does not have any models trained." ) if frame_stride is None: @@ -1141,7 +1132,7 @@ def upload_and_predict_video( def deploy_project( self, - project_name: str, + project: Project, output_folder: Optional[Union[str, os.PathLike]] = None, models: Optional[Sequence[BaseModel]] = None, enable_explainable_ai: bool = False, @@ -1156,7 +1147,7 @@ def deploy_project( for each task in the project. However, it is possible to specify a particular model to use, by passing it in the list of `models` as input to this method. - :param project_name: Name of the project to deploy + :param project: Project object to deploy :param output_folder: Path to a folder on local disk to which the Deployment should be downloaded. If no path is specified, the deployment will not be saved. @@ -1174,8 +1165,6 @@ def deploy_project( launch an OVMS container serving the models. :return: Deployment for the project """ - project = self.get_project(project_name=project_name) - deployment_client = self._deployment_clients.get(project.id, None) if deployment_client is None: # Create deployment client and add to cache. diff --git a/geti_sdk/http_session/geti_session.py b/geti_sdk/http_session/geti_session.py index 2a4a68f1..6357c412 100644 --- a/geti_sdk/http_session/geti_session.py +++ b/geti_sdk/http_session/geti_session.py @@ -14,6 +14,7 @@ import logging import time import warnings +from datetime import datetime from functools import cache from json import JSONDecodeError from typing import Any, Dict, Optional, Union @@ -546,7 +547,10 @@ def base_url(self) -> str: def _get_organization_id(self) -> str: """ Return the organization ID associated with the user and host information configured - in this Session + in this Session. + + NOTE: When authenticating with username and password, this method returns the + ID of the default organization for the user! """ if not self.use_token: result = self.get_rest_response( @@ -560,7 +564,26 @@ def _get_organization_id(self) -> str: method="GET", include_organization_id=False, ) - org_id = result.get("organizationId", None) + if "organizationId" in result.keys(): + # Geti < 2.5, return the id directly + org_id = result.get("organizationId", None) + elif "organizations" in result.keys(): + # Geti 2.5 and up: list of organizations, return the default one (the one + # which was created at the earliest time) + org_list = result["organizations"] + creation_times = [ + datetime.fromisoformat( + x["organizationCreatedAt"].replace("Z", "+00:00") + ) + for x in org_list + ] + ids = [x["organizationId"] for x in org_list] + earliest_creation_time = min(creation_times) + earliest_idx = creation_times.index(earliest_creation_time) + org_id = ids[earliest_idx] + else: + org_id = None + if org_id is None: raise ValueError( f"Unable to retrieve organization ID from the Intel Geti server. " diff --git a/geti_sdk/import_export/import_export_module.py b/geti_sdk/import_export/import_export_module.py index 0f3b8ac6..f6d1ff2d 100644 --- a/geti_sdk/import_export/import_export_module.py +++ b/geti_sdk/import_export/import_export_module.py @@ -8,13 +8,15 @@ from tqdm.contrib.logging import logging_redirect_tqdm from geti_sdk.annotation_readers.geti_annotation_reader import GetiAnnotationReader +from geti_sdk.data_models import Dataset from geti_sdk.data_models.containers.media_list import MediaList from geti_sdk.data_models.enums.dataset_format import DatasetFormat from geti_sdk.data_models.media import Image, Video -from geti_sdk.data_models.project import Dataset, Project +from geti_sdk.data_models.project import Project from geti_sdk.http_session.exception import GetiRequestException from geti_sdk.http_session.geti_session import GetiSession from geti_sdk.import_export.tus_uploader import TUSUploader +from geti_sdk.platform_versions import GETI_25_VERSION from geti_sdk.rest_clients.annotation_clients.annotation_client import AnnotationClient from geti_sdk.rest_clients.configuration_client import ConfigurationClient from geti_sdk.rest_clients.dataset_client import DatasetClient @@ -75,7 +77,7 @@ def download_project_data( # Download project creation parameters: self.project_client.download_project_info( - project_name=project.name, path_to_folder=target_folder + project=project, path_to_folder=target_folder ) # Download images @@ -225,6 +227,7 @@ def upload_project_data( annotation_reader = GetiAnnotationReader( base_data_folder=os.path.join(target_folder, "annotations"), task_type=None, + anomaly_reduction=(self.session.version >= GETI_25_VERSION), ) annotation_client = AnnotationClient[GetiAnnotationReader]( session=self.session, @@ -280,7 +283,7 @@ def upload_project_data( return project def download_all_projects( - self, target_folder: str, include_predictions: bool = True + self, target_folder: str = "./projects", include_predictions: bool = True ) -> List[Project]: """ Download all projects from the Geti Platform. @@ -293,8 +296,6 @@ def download_all_projects( projects = self.project_client.get_all_projects() # Validate or create target_folder - if target_folder is None: - target_folder = os.path.join(".", "projects") os.makedirs(target_folder, exist_ok=True, mode=0o770) logging.info( f"Found {len(projects)} projects in the designated workspace on the " @@ -332,7 +333,7 @@ def upload_all_projects(self, target_folder: str) -> List[Project]: project_folders = [ folder for folder in candidate_project_folders - if ProjectClient.is_project_dir(folder) + if ProjectClient._is_project_dir(folder) ] logging.info( f"Found {len(project_folders)} project data folders in the target " @@ -362,8 +363,8 @@ def import_dataset_as_new_project( :param filepath: The path to the dataset archive. :param project_name: The name of the new project. :param project_type: The type of the new project. Provide one of - [classification, classification_hierarchical, detection, segmentation, - instance_segmentation, anomaly_classification, anomaly_detection, anomaly_segmentation, + [classification, classification_hierarchical, detection, segmentation, instance_segmentation, + anomaly_classification, anomaly_detection, anomaly_segmentation, anomaly, detection_oriented, detection_to_classification, detection_to_segmentation] :return: The imported project. :raises: RuntimeError if the project type is not supported for the imported dataset. @@ -435,8 +436,7 @@ def import_dataset_as_new_project( logging.info( f"Project '{project_name}' was successfully imported from the dataset." ) - imported_project = self.project_client.get_project_by_name( - project_name=project_name, + imported_project = self.project_client.get_project( project_id=job.metadata.project_id, ) if imported_project is None: @@ -481,8 +481,7 @@ def import_project( ) job = monitor_job(session=self.session, job=job, interval=5) - imported_project = self.project_client.get_project_by_name( - project_name=project_name, + imported_project = self.project_client.get_project( project_id=job.metadata.project_id, ) if imported_project is None: @@ -505,7 +504,7 @@ def _tus_upload_file(self, upload_endpoint: str, filepath: os.PathLike) -> str: ) tus_uploader.upload() file_id = tus_uploader.get_file_id() - if file_id is None or len(file_id) < 2: + if file_id is None: raise RuntimeError("Failed to get file id for project {project_name}.") return file_id diff --git a/geti_sdk/import_export/tus_uploader.py b/geti_sdk/import_export/tus_uploader.py index b68c45ef..a020b7ec 100644 --- a/geti_sdk/import_export/tus_uploader.py +++ b/geti_sdk/import_export/tus_uploader.py @@ -169,9 +169,14 @@ def get_file_id(self) -> Optional[str]: :return: File id. """ - if self.upload_url is None: + if ( + self.upload_url is None + or len(file_id := self.upload_url.split("/")[-1]) < 2 + ): + # We get the file_id from the upload url. If the url is not set or the file_id + # is not valid (may be an empty string if the url is not valid), we return None. return - return self.upload_url.split("/")[-1] + return file_id def upload_chunk(self): """ diff --git a/geti_sdk/platform_versions.py b/geti_sdk/platform_versions.py index a46a0280..905a7ead 100644 --- a/geti_sdk/platform_versions.py +++ b/geti_sdk/platform_versions.py @@ -163,3 +163,4 @@ def is_sc_1_1(self) -> bool: GETI_116_VERSION = GetiVersion("1.16.0-release-20240320101320") GETI_20_VERSION = GetiVersion("2.0.0-release-20240320101320") GETI_22_VERSION = GetiVersion("2.2.0-release-20240320101320") +GETI_25_VERSION = GetiVersion("2.5.0-release-20240320101320") diff --git a/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py b/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py index 9f42b9b6..9dcd623c 100644 --- a/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py +++ b/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py @@ -27,6 +27,7 @@ from geti_sdk.data_models import ( AnnotationKind, AnnotationScene, + Dataset, Image, Project, Video, @@ -35,7 +36,6 @@ from geti_sdk.data_models.containers.media_list import MediaList from geti_sdk.data_models.label import Label from geti_sdk.data_models.media import MediaInformation, MediaItem -from geti_sdk.data_models.project import Dataset from geti_sdk.http_session import GetiSession from geti_sdk.rest_clients.dataset_client import DatasetClient from geti_sdk.rest_converters import AnnotationRESTConverter diff --git a/geti_sdk/rest_clients/dataset_client.py b/geti_sdk/rest_clients/dataset_client.py index 2cf30397..d209855e 100644 --- a/geti_sdk/rest_clients/dataset_client.py +++ b/geti_sdk/rest_clients/dataset_client.py @@ -16,9 +16,19 @@ import warnings from typing import List, Optional -from geti_sdk.data_models import Dataset, Project +from geti_sdk.data_models import ( + Dataset, + Image, + Model, + Project, + Subset, + TrainingDatasetStatistics, + VideoFrame, +) +from geti_sdk.data_models.containers import MediaList from geti_sdk.http_session import GetiSession from geti_sdk.http_session.exception import GetiRequestException +from geti_sdk.rest_converters import MediaRESTConverter from geti_sdk.utils import deserialize_dictionary @@ -157,3 +167,93 @@ def _media_folder_has_dataset_subfolders(self, path_to_folder: str) -> bool: if not os.path.isdir(os.path.join(path_to_folder, dataset.name)): return False return True + + def get_training_dataset_summary(self, model: Model) -> TrainingDatasetStatistics: + """ + Return information concerning the training dataset for the `model`. + This includes the number of images and video frames, and the statistics for + the subset splitting (i.e. the number of training, test and validation + images/video frames) + + :param model: Model to get the training dataset for + :return: A `TrainingDatasetStatistics` object, containing the training dataset + statistics for the model + """ + ds_info = model.training_dataset_info + dataset_storage_id = ds_info.get("dataset_storage_id", None) + revision_id = ds_info.get("dataset_revision_id", None) + if dataset_storage_id is None or revision_id is None: + raise ValueError( + f"Unable to fetch the required dataset information from the model. " + f"Expected dataset and revision id's, got {ds_info} instead." + ) + training_dataset = self.session.get_rest_response( + url=f"{self.base_url}/{dataset_storage_id}/training_revisions/{revision_id}", + method="GET", + ) + return deserialize_dictionary(training_dataset, TrainingDatasetStatistics) + + def get_media_in_training_dataset( + self, model: Model, subset: str = "training" + ) -> Subset: + """ + Return the media in the training dataset for the `model`, for + the specified `subset`. Subset can be `training`, `validation` or `testing`. + + :param model: Model for which to get the media in the training dataset + :param subset: The subset for which to return the media items. Can be either + `training` (the default), `validation` or `testing` + return: A `Subset` object, containing lists of `images` and `video_frames` in + the requested `subset` + :raises: ValueError if the DatasetClient is unable to fetch the required + dataset information from the model + """ + ds_info = model.training_dataset_info + dataset_storage_id = ds_info.get("dataset_storage_id", None) + revision_id = ds_info.get("dataset_revision_id", None) + if dataset_storage_id is None or revision_id is None: + raise ValueError( + f"Unable to fetch the required dataset information from the model. " + f"Expected dataset and revision id's, got {ds_info} instead." + ) + post_data = { + "condition": "and", + "rules": [{"field": "subset", "operator": "equal", "value": subset}], + } + + images: MediaList[Image] = MediaList([]) + video_frames: MediaList[VideoFrame] = MediaList([]) + + next_page = f"{self.base_url}/{dataset_storage_id}/training_revisions/{revision_id}/media:query" + while next_page: + response = self.session.get_rest_response( + url=next_page, method="POST", data=post_data + ) + next_page = response.get("next_page", None) + for item in response["media"]: + if item["type"] == "image": + item.pop("annotation_scene_id", None) + item.pop("editor_name", None) + item.pop("roi_id", None) + image = MediaRESTConverter.from_dict(item, Image) + images.append(image) + + if item["type"] == "video": + video_id = item["id"] + next_frame_page = ( + f"{self.base_url}/{dataset_storage_id}/training_revisions/" + f"{revision_id}/media/videos/{video_id}:query" + ) + while next_frame_page: + frames_response = self.session.get_rest_response( + url=next_frame_page, method="POST", data=post_data + ) + for frame in frames_response["media"]: + frame["video_id"] = video_id + video_frame = MediaRESTConverter.from_dict( + frame, VideoFrame + ) + video_frames.append(video_frame) + next_frame_page = frames_response.get("next_page", None) + + return Subset(images=images, frames=video_frames, purpose=subset) diff --git a/geti_sdk/rest_clients/media_client/image_client.py b/geti_sdk/rest_clients/media_client/image_client.py index 29f6c72f..b65b4842 100644 --- a/geti_sdk/rest_clients/media_client/image_client.py +++ b/geti_sdk/rest_clients/media_client/image_client.py @@ -22,9 +22,8 @@ import cv2 import numpy as np -from geti_sdk.data_models import Image, MediaType +from geti_sdk.data_models import Dataset, Image, MediaType from geti_sdk.data_models.containers import MediaList -from geti_sdk.data_models.project import Dataset from geti_sdk.rest_converters import MediaRESTConverter from .media_client import MEDIA_SUPPORTED_FORMAT_MAPPING, BaseMediaClient @@ -186,22 +185,25 @@ def upload_from_list( else: logging.debug("Retrieving full filepaths for image upload...") + filenames_lookup = { + os.path.basename(path): path + for path in glob.glob( + os.path.join(path_to_folder, "**"), recursive=True + ) + } for image_name in image_names[0:n_to_upload]: + matches: List[str] = [] if not extension_included: - matches: List[str] = [] for media_extension in media_formats: - match_for_item = glob.glob( - os.path.join( - path_to_folder, "**", f"{image_name}{media_extension}" - ), - recursive=True, - ) - if len(match_for_item) > 0: - matches += match_for_item - break + if ( + filename := f"{image_name}{media_extension}" + ) in filenames_lookup: + matches.append(filenames_lookup[filename]) else: - matches = glob.glob( - os.path.join(path_to_folder, "**", image_name), recursive=True + matches = ( + [filenames_lookup[image_name]] + if image_name in filenames_lookup + else [] ) if not matches: raise ValueError( diff --git a/geti_sdk/rest_clients/media_client/media_client.py b/geti_sdk/rest_clients/media_client/media_client.py index d96f7fac..cd268fd6 100644 --- a/geti_sdk/rest_clients/media_client/media_client.py +++ b/geti_sdk/rest_clients/media_client/media_client.py @@ -32,13 +32,12 @@ from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -from geti_sdk.data_models import Image, MediaType, Project, Video, VideoFrame +from geti_sdk.data_models import Dataset, Image, MediaType, Project, Video, VideoFrame from geti_sdk.data_models.containers.media_list import MediaList, MediaTypeVar from geti_sdk.data_models.enums.media_type import ( SUPPORTED_IMAGE_FORMATS, SUPPORTED_VIDEO_FORMATS, ) -from geti_sdk.data_models.project import Dataset from geti_sdk.data_models.utils import numpy_from_buffer from geti_sdk.http_session import GetiRequestException, GetiSession from geti_sdk.rest_clients.dataset_client import DatasetClient diff --git a/geti_sdk/rest_clients/media_client/video_client.py b/geti_sdk/rest_clients/media_client/video_client.py index 90c7a875..c76cc6e0 100644 --- a/geti_sdk/rest_clients/media_client/video_client.py +++ b/geti_sdk/rest_clients/media_client/video_client.py @@ -20,9 +20,8 @@ import cv2 import numpy as np -from geti_sdk.data_models import MediaType, Video +from geti_sdk.data_models import Dataset, MediaType, Video from geti_sdk.data_models.containers import MediaList -from geti_sdk.data_models.project import Dataset from geti_sdk.http_session import GetiRequestException from geti_sdk.rest_converters import MediaRESTConverter diff --git a/geti_sdk/rest_clients/project_client/project_client.py b/geti_sdk/rest_clients/project_client/project_client.py index 14b60565..905333b7 100644 --- a/geti_sdk/rest_clients/project_client/project_client.py +++ b/geti_sdk/rest_clients/project_client/project_client.py @@ -17,11 +17,13 @@ import logging import os import time +import warnings from typing import Any, Dict, List, Optional, Tuple, Union from geti_sdk.data_models import Project, Task, TaskType from geti_sdk.data_models.utils import remove_null_fields from geti_sdk.http_session import GetiRequestException, GetiSession +from geti_sdk.platform_versions import GETI_25_VERSION from geti_sdk.rest_clients.dataset_client import DatasetClient from geti_sdk.rest_converters import ProjectRESTConverter from geti_sdk.utils.label_helpers import generate_unique_label_color @@ -31,6 +33,7 @@ ANOMALY_CLASSIFICATION_TASK, ANOMALY_DETECTION_TASK, ANOMALY_SEGMENTATION_TASK, + ANOMALY_TASK, BASE_TEMPLATE, CLASSIFICATION_TASK, CROP_TASK, @@ -48,6 +51,7 @@ TaskType.ANOMALY_CLASSIFICATION: ANOMALY_CLASSIFICATION_TASK, TaskType.ANOMALY_DETECTION: ANOMALY_DETECTION_TASK, TaskType.ANOMALY_SEGMENTATION: ANOMALY_SEGMENTATION_TASK, + TaskType.ANOMALY: ANOMALY_TASK, TaskType.INSTANCE_SEGMENTATION: INSTANCE_SEGMENTATION_TASK, TaskType.ROTATED_DETECTION: ROTATED_DETECTION_TASK, } @@ -97,7 +101,7 @@ def get_all_projects( project_detail_list: List[Project] = [] for project in project_list: try: - project_detail_list.append(self._get_project_detail(project)) + project_detail_list.append(self.get_project_by_id(project.id)) except GetiRequestException as e: if e.status_code == 403: logging.info( @@ -109,57 +113,47 @@ def get_all_projects( return project_list def get_project_by_name( - self, project_name: str, project_id: Optional[str] = None + self, + project_name: str, ) -> Optional[Project]: """ Get a project from the Intel® Geti™ server by project_name. + If multiple projects with the same name exist on the server, this method will + raise a ValueError. In that case, please use the `ProjectClient.get_project()` + method and provide a `project_id` to uniquely identify the project. + :param project_name: Name of the project to get - :param project_id: Optional ID of the project to get. Only used if more than - one project named `project_name` exists in the workspace. :raises: ValueError in case multiple projects with the specified name exist on the server, and no `project_id` is provided in order to allow unique identification of the project. :return: Project object containing the data of the project, if the project is - found on the server. Returns None if the project doesn't exist + found on the server. Returns None if the project doesn't exist. """ project_list = self.get_all_projects(get_project_details=False) matches = [project for project in project_list if project.name == project_name] if len(matches) == 1: - return self._get_project_detail(matches[0]) + return self.get_project_by_id(matches[0].id) elif len(matches) > 1: - if project_id is None: - detailed_matches = [ - self._get_project_detail(match) for match in matches - ] - projects_info = [ - ( - f"Name: {p.name}, Type: {p.project_type}, ID: {p.id}, " - f"creation_date: {p.creation_time}" - ) - for p in detailed_matches - ] - raise ValueError( - f"A total of {len(matches)} projects named `{project_name}` were " - f"found in the workspace. Unable to uniquely identify the " - f"desired project. Please provide a `project_id` to ensure the " - f"proper project is returned. The following projects were found:" - f"{projects_info}" - ) - else: - matched_project = next( - (project for project in matches if project.id == project_id), None + detailed_matches = [self.get_project_by_id(match.id) for match in matches] + projects_info = [ + ( + f"Name: {p.name},\t Type: {p.project_type},\t ID: {p.id},\t " + f"creation_date: {p.creation_time}\n" ) - if matched_project is None: - logging.info( - f"Projects with name `{project_name}` were found, but none of " - f"the project ID's `{[p.id for p in matches]}` matches the " - f"requested id `{project_id}`." - ) - return None - else: - return self._get_project_detail(matched_project) + for p in detailed_matches + ] + raise ValueError( + f"A total of {len(matches)} projects named `{project_name}` were " + f"found in the workspace. Unable to uniquely identify the " + f"desired project. Please provide a `project_id` to ensure the " + f"proper project is returned. The following projects were found:" + f"{projects_info}" + ) else: + warnings.warn( + f"Project with name {project_name} was not found on the server." + ) return None def get_or_create_project( @@ -212,40 +206,37 @@ def create_project( project_template = self._create_project_template( project_name=project_name, project_type=project_type, labels=labels ) - project = self.session.get_rest_response( + project_dict = self.session.get_rest_response( url=f"{self.base_url}projects", method="POST", data=project_template ) logging.info("Project created successfully.") - project = ProjectRESTConverter.from_dict(project) + project = ProjectRESTConverter.from_dict(project_dict) self._await_project_ready(project=project) return project - def download_project_info(self, project_name: str, path_to_folder: str) -> None: + def download_project_info(self, project: Project, path_to_folder: str) -> None: """ - Get the project data that can be used for project creation for a project on - the Intel® Geti™ server, named `project_name`. From the returned data, the + Get the project data that can be used for project creation on + the Intel® Geti™ server. From the returned data, the method `ProjectClient.get_or_create_project` can create a project on the Intel® Geti™ server. The data is retrieved from the cluster and saved in the target folder `path_to_folder`. - :param project_name: Name of the project to retrieve the data for + :param project: Project to download the data for :param path_to_folder: Target folder to save the project data to. Data will be saved as a .json file named "project.json" :raises ValueError: If the project with `project_name` is not found on the cluster """ - project = self.get_project_by_name(project_name) - if project is None: - raise ValueError( - f"Project with name {project_name} was not found on the server." - ) + # Update the project state + project = self.get_project_by_id(project.id) project_data = ProjectRESTConverter.to_dict(project) os.makedirs(path_to_folder, exist_ok=True, mode=0o770) project_config_path = os.path.join(path_to_folder, "project.json") with open(project_config_path, "w") as file: json.dump(project_data, file, indent=4) logging.info( - f"Project parameters for project '{project_name}' were saved to file " + f"Project parameters for project '{project.name}' were saved to file " f"{project_config_path}." ) @@ -368,7 +359,7 @@ def create_project_from_folder( return created_project @staticmethod - def is_project_dir(path_to_folder: str) -> bool: + def _is_project_dir(path_to_folder: str) -> bool: """ Check if the folder specified in `path_to_folder` is a directory containing valid Intel® Geti™ project data that can be used to upload to an @@ -435,6 +426,15 @@ def _create_project_template( for task_type, task_labels in zip( get_task_types_by_project_type(project_type), labels ): + # Anomaly task reduction introduced in Intel Geti 2.5 + # The last on-premises version of Intel Geti to support legacy anomaly projects is 2.0 + if ( + self.session.version >= GETI_25_VERSION + and task_type.is_anomaly + and task_type != TaskType.ANOMALY + ): + logging.info(f"The {task_type} task is mapped to {TaskType.ANOMALY}.") + task_type = TaskType.ANOMALY if not is_first_task and not previous_task_type.is_global: # Add crop task and connections, only for tasks that are not # first in the pipeline and are not preceded by a global task @@ -493,29 +493,22 @@ def _ensure_unique_task_name( return task_name def delete_project( - self, project: Union[str, Project], requires_confirmation: bool = True + self, project: Project, requires_confirmation: bool = True ) -> None: """ - Delete a project. The `project` to delete can either by a Project object or a - string containing the name of the project to delete. + Delete a project. By default, this method will ask for user confirmation before deleting the project. This can be overridden by passing `requires_confirmation = False`. - :param project: Project to delete, either a string containing the project - name or a Project instance + :param project: Project to delete :param requires_confirmation: True to ask for user confirmation before deleting the project, False to delete without confirmation. Defaults to True """ - if isinstance(project, str): - project = self.get_project_by_name(project_name=project) - if not isinstance(project, Project): - raise TypeError(f"{type(project)} is not a valid project type.") - if requires_confirmation: - # Update project details - project = self._get_project_detail(project) + # Update the project details + project = self.get_project_by_id(project.id) if project.datasets is None: project.datasets = [] image_count = 0 @@ -680,19 +673,49 @@ def _await_project_ready( f"seconds)." ) from error - def _get_project_detail(self, project: Union[Project, str]) -> Project: + def get_project_by_id(self, project_id: str) -> Optional[Project]: """ - Fetch the most recent project details from the Intel® Geti™ server + Get a project from the Intel® Geti™ server by project_id. - :param project: Name of the project or Project object representing the project - to get detailed information for. - :return: Updated Project object + :param project_id: ID of the project to get + :return: Project object containing the data of the project, if the project is + found on the server. Returns None if the project doesn't exist """ - if isinstance(project, str): - project = self.get_project_by_name(project_name=project) - return project + response = self.session.get_rest_response( + url=f"{self.base_url}projects/{project_id}", method="GET" + ) + return ProjectRESTConverter.from_dict(response) + + def get_project( + self, + project_name: Optional[str] = None, + project_id: Optional[str] = None, + project: Optional[Project] = None, + ) -> Optional[Project]: + """ + Get a project from the Intel® Geti™ server by project_name or project_id, or + update a provided Project object with the latest data from the server. + + :param project_name: Name of the project to get + :param project_id: ID of the project to get + :param project: Project object to update with the latest data from the server + :return: Project object containing the data of the project, if the project is + found on the server. Returns None if the project doesn't exist + """ + # The method prioritize the parameters in the following order: + if project_id is not None: + return self.get_project_by_id(project_id) + elif project is not None: + if project.id is not None: + return self.get_project_by_id(project.id) + else: + return self.get_project_by_name(project_name=project.name) + elif project_name is not None: + return self.get_project_by_name(project_name=project_name) else: - response = self.session.get_rest_response( - url=f"{self.base_url}projects/{project.id}", method="GET" + # No parameters provided + # Warn the user and return None + warnings.warn( + "At least one of the parameters `project_name`, `project_id`, or " + "`project` must be provided." ) - return ProjectRESTConverter.from_dict(response) diff --git a/geti_sdk/rest_clients/project_client/task_templates.py b/geti_sdk/rest_clients/project_client/task_templates.py index f336a3d9..7f8601c7 100644 --- a/geti_sdk/rest_clients/project_client/task_templates.py +++ b/geti_sdk/rest_clients/project_client/task_templates.py @@ -52,6 +52,15 @@ "labels": [], } +# This is the reduced anomaly task. +# It goes under `Anomaly` title, +# and it is `Anomally classification task` under the hood +ANOMALY_TASK = { + "title": "Anomaly", + "task_type": "anomaly", + "labels": [], +} + ANOMALY_DETECTION_TASK = { "title": "Anomaly detection task", "task_type": "anomaly_detection", diff --git a/geti_sdk/rest_clients/training_client.py b/geti_sdk/rest_clients/training_client.py index ef8f0f6c..db1095d0 100644 --- a/geti_sdk/rest_clients/training_client.py +++ b/geti_sdk/rest_clients/training_client.py @@ -16,6 +16,7 @@ from geti_sdk.data_models import ( Algorithm, + Dataset, Job, Project, ProjectStatus, @@ -24,7 +25,6 @@ ) from geti_sdk.data_models.containers import AlgorithmList from geti_sdk.data_models.enums import JobState, JobType -from geti_sdk.data_models.project import Dataset from geti_sdk.http_session import GetiSession from geti_sdk.rest_converters import ( ConfigurationRESTConverter, diff --git a/notebooks/001_create_project.ipynb b/notebooks/001_create_project.ipynb index 781fb572..4d2f5dd2 100644 --- a/notebooks/001_create_project.ipynb +++ b/notebooks/001_create_project.ipynb @@ -225,7 +225,7 @@ "id": "aa289ae0-36bb-40db-afb3-d1c89fb2a9e1", "metadata": {}, "source": [ - "The `project` object that was created by the `project_client.create_project()` method can also be retrieved by calling `project_client.get_project_by_name()`. This is useful if you do not want to create a new project, but would like to interact with an existing project instead" + "The `project` object that was created by the `project_client.create_project()` method can also be retrieved by calling `project_client.get_project()`. This is useful if you do not want to create a new project, but would like to interact with an existing project instead" ] }, { @@ -235,7 +235,7 @@ "metadata": {}, "outputs": [], "source": [ - "project = project_client.get_project_by_name(project_name=PROJECT_NAME)\n", + "project = project_client.get_project(project_name=PROJECT_NAME)\n", "print(project.summary)" ] }, @@ -395,7 +395,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/notebooks/003_upload_and_predict_image.ipynb b/notebooks/003_upload_and_predict_image.ipynb index eb064fee..630e384a 100644 --- a/notebooks/003_upload_and_predict_image.ipynb +++ b/notebooks/003_upload_and_predict_image.ipynb @@ -84,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "project = project_client.get_project_by_name(PROJECT_NAME)\n", + "project = project_client.get_project(project_name=PROJECT_NAME)\n", "image_client = ImageClient(\n", " session=geti.session, workspace_id=geti.workspace_id, project=project\n", ")\n", diff --git a/notebooks/005_modify_image.ipynb b/notebooks/005_modify_image.ipynb index 31b81001..c8a246d1 100644 --- a/notebooks/005_modify_image.ipynb +++ b/notebooks/005_modify_image.ipynb @@ -116,7 +116,7 @@ "metadata": {}, "outputs": [], "source": [ - "project = project_client.get_project_by_name(project_name=\"COCO horse detection demo\")" + "project = project_client.get_project(project_name=\"COCO horse detection demo\")" ] }, { diff --git a/notebooks/006_reconfigure_task.ipynb b/notebooks/006_reconfigure_task.ipynb index 2f61fbc3..b5dc4fc0 100644 --- a/notebooks/006_reconfigure_task.ipynb +++ b/notebooks/006_reconfigure_task.ipynb @@ -48,7 +48,7 @@ "PROJECT_NAME = \"COCO multitask animal demo\"\n", "projects = project_client.list_projects()\n", "\n", - "project = project_client.get_project_by_name(PROJECT_NAME)" + "project = project_client.get_project(PROJECT_NAME)" ] }, { @@ -233,7 +233,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/notebooks/007_train_project.ipynb b/notebooks/007_train_project.ipynb index a021b1fd..9ae790ba 100644 --- a/notebooks/007_train_project.ipynb +++ b/notebooks/007_train_project.ipynb @@ -65,7 +65,7 @@ "source": [ "PROJECT_NAME = \"COCO multitask animal demo\"\n", "\n", - "project = project_client.get_project_by_name(project_name=PROJECT_NAME)" + "project = project_client.get_project(project_name=PROJECT_NAME)" ] }, { @@ -297,7 +297,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/notebooks/011_benchmarking_models.ipynb b/notebooks/011_benchmarking_models.ipynb index 197170fa..a2151bdb 100644 --- a/notebooks/011_benchmarking_models.ipynb +++ b/notebooks/011_benchmarking_models.ipynb @@ -44,7 +44,7 @@ "outputs": [], "source": [ "PROJECT_NAME = \"COCO animal detection demo\"\n", - "project = geti.get_project(PROJECT_NAME)" + "project = geti.get_project(project_name=PROJECT_NAME)" ] }, { diff --git a/notebooks/014_asynchronous_inference.ipynb b/notebooks/014_asynchronous_inference.ipynb index 7cc09991..7a520717 100644 --- a/notebooks/014_asynchronous_inference.ipynb +++ b/notebooks/014_asynchronous_inference.ipynb @@ -55,7 +55,7 @@ "geti = Geti(server_config=geti_server_configuration)\n", "\n", "PROJECT_NAME = \"COCO multitask animal demo\"\n", - "project = geti.get_project(PROJECT_NAME)" + "project = geti.get_project(project_name=PROJECT_NAME)" ] }, { diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 5d51096a..07658ab9 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,6 +1,6 @@ # Requirements for running the tests vcrpy==6.0.* -pytest==8.2.* +pytest==8.3.* pytest-recording==0.13.* pytest-cov==5.0.* pytest-env==1.1.* diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95c4cf49..37b86ce3 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ # Base requirements -datumaro==1.8.* +datumaro==1.9.* requests==2.32.* numpy>=1.26.4 omegaconf==2.3.* diff --git a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_download_specific_dataset.cassette b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_download_specific_dataset.cassette index 67664107..71b8e9a0 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_download_specific_dataset.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_download_specific_dataset.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9b79182e779867bd17eb7ed154cee8b8c590f91f09ed8608218cbd310b3f7b39 -size 1910446 +oid sha256:bbf3edbf172648c35ed8d6eb811d3a88b9c1c25a9a14e79f3d30b1b016cd853a +size 1911130 diff --git a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_and_delete_image.cassette b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_and_delete_image.cassette index 71a6ec6e..65fed3a0 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_and_delete_image.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_and_delete_image.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0815b5be884f9baf14e2ad42887be55317a8a0c86a0ef5b4a9a0e9ba1ea199eb -size 103778 +oid sha256:419b81298d68c02b41081c5b67daf1c01ac1e616e4e225419052b5c1db41bcec +size 103816 diff --git a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_from_list.cassette b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_from_list.cassette new file mode 100644 index 00000000..7b8c89d0 --- /dev/null +++ b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_from_list.cassette @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6721d8ab96b75e7d951aa8fbd9195cfc3cedb2ffe6b04c745517fee632e9fda2 +size 271894 diff --git a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_image_folder_and_download.cassette b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_image_folder_and_download.cassette index fea3ed62..06b0e06b 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_image_folder_and_download.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestImageClient.test_upload_image_folder_and_download.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e27f3c50ca954006251294c6f6c967e7d8ef7d49f7d277dcb445d7570562cdd9 -size 1497279 +oid sha256:7d436ab1cbc32ce8cf300ae1e4c239181dd4718ff19311c9f24c729c2b9e9cb1 +size 1541784 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_download_active_model_for_task.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_download_active_model_for_task.cassette index 007fe20d..87a36eaa 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_download_active_model_for_task.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_download_active_model_for_task.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:41d63370cd083bbf30891e33b4686282bc1a6685d30e87eed5841319df8bb01f -size 133255530 +oid sha256:ccafd2713808ee157815d5eada21e67b262fa4f78176557c359af958a32fd875 +size 134828761 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_algorithm_task_and_version.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_algorithm_task_and_version.cassette index 600ce70d..513fb618 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_algorithm_task_and_version.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_algorithm_task_and_version.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:465cf8cf69ce0f80ffbf9d4d41747c5c162fdad376603b36285a013821691648 -size 35093 +oid sha256:a6e8c5b76b0e108d0a564bedb15928fc5d0a5e63d5723947218096bab481dce4 +size 35627 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_group_by_algo_name.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_group_by_algo_name.cassette index 50327fba..658b496c 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_group_by_algo_name.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_model_group_by_algo_name.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d604f52815506f49d0e1803907f84bca6006d8bf211d4c2259a9bcb484f889ab -size 2966 +oid sha256:abb8f04689ee428c330dceff77cd505de3efd4b18353f35a0e19319eb56e071a +size 3012 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_training_dataset.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_training_dataset.cassette new file mode 100644 index 00000000..5e0066a6 --- /dev/null +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_get_training_dataset.cassette @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56428fbc702ad18455e3e082e82db871a7729b793788f85f213af33eaa6c1528 +size 33201 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_predict_image.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_predict_image.cassette index 61a31b8e..04b857c8 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_predict_image.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_predict_image.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8708530801ef2384265918b44e42a2a95328535ad14ab5c43049cabfe0978db3 -size 1715277 +oid sha256:ca95796aab06dbb24bd56eb7ae208a4d0fd445ad4ffc129c306b0510cd70e382 +size 1735991 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_project_setup_and_get_model_by_job.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_project_setup_and_get_model_by_job.cassette index 2072c196..f4cf5799 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_project_setup_and_get_model_by_job.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_project_setup_and_get_model_by_job.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a8f4f822ee28609f69c90d5b645cf4c42a77a1b3e06c7ba4f0736d42e3a65b3a -size 278359 +oid sha256:3ff0690340926eab099702ea052ecfa4c4ca98b81b60e32bbfea226ef03e01d9 +size 165072 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_purge_model.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_purge_model.cassette index c0ed5e3d..fcbea508 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_purge_model.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_purge_model.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:46814c40ade667ca3a940d6ef12b7b57d69ddd48207a59a7da87f015a4d61202 -size 309737 +oid sha256:bfa65e40fcfa1e4c1d7b84f5ad1e7382decb115c88ded5ddedc38d127831aac7 +size 201313 diff --git a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_set_active_model.cassette b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_set_active_model.cassette index edb26681..f989cbef 100644 --- a/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_set_active_model.cassette +++ b/tests/fixtures/cassettes/DEVELOP/TestModelAndPredictionClient.test_set_active_model.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5be244b20ddb83081b81d8320e0713e34c519adb064c1e6bcdd721fffc6abc1d -size 347519 +oid sha256:afece9c1c5c19cf21713bfd03131e19d37859d23cdb19c9cc360d1f8083de907 +size 223880 diff --git a/tests/fixtures/cassettes/DEVELOP/geti.cassette b/tests/fixtures/cassettes/DEVELOP/geti.cassette index ca9ca59d..7402d84e 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5d9e259aad5f724e48e7867ac0134f9bcdceabb70b623dcb76b967bb2dd5a92b -size 74426 +oid sha256:83bea18e7dbc3658a0b2eca2eaaf0b718f26fd0771d4b349397eeb63312369be +size 75380 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client.cassette index 6570a07f..52c18e3f 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a075d7ac01e6f56e1c465eb6b7c916b9e3e26f19ce5dc90bb94c0a850c80e332 +oid sha256:9569b69d7706087562cb2031359f1e26594809c9f3db01b4e512b77c47764ecd size 7216 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client_deletion.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client_deletion.cassette index cb33a72d..3cc1f5a8 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client_deletion.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_image_client_deletion.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:663924d32224331284fbef532ce85677e3cbcfeebb6b59e630b07b26096e074f -size 10480 +oid sha256:9c157582df0dc86a3d45d3f5b90024b2b0ece5a4c6ead995d3ff5c0f28c9a0ac +size 2504 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client.cassette index 30157c39..fb7720d6 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:77d9c7ab5d5f12ad5cd7d8f24ae1c372437ccfc69cdf872ed0f38b1aa83aa4d0 -size 7306 +oid sha256:fe40534382dedb3fc42d7fc1d5d67ff8188af97b02188e80c70778bc95a9b699 +size 7410 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_add_annotated_media.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_add_annotated_media.cassette index 126ff854..bc634813 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_add_annotated_media.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_add_annotated_media.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b597b5a7afe51c6a9f8e0ad447df623d76a346f075cc2884a865ab6cf4683136 -size 467061 +oid sha256:9f40f315f516a369ac2c6f15a3e25b2c68bc80ef97a71159b21b53f699184db3 +size 473166 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_configuration_client.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_configuration_client.cassette index 9304c8e5..e158cc27 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_configuration_client.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_configuration_client.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1bdcba2cd566c6108edf8c5916a03e0702a433e0e80a13cdc8ccea0041ceb94f -size 6965 +oid sha256:5ecd835637ab97b534be4a58b5431239e72f636db374d05abfee9a642f440678 +size 7062 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_deletion.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_deletion.cassette index 522d8a53..a2360c16 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_deletion.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_deletion.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:11a3c27595f26017a19407162006440c341e286f189e2a7118dfcd5dbc2e01d1 -size 10070 +oid sha256:d3068f4c4a8791a5ce430b1d1d15ab9fdb60e2df5e647a2ae622e537c2fef998 +size 2552 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_model_client.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_model_client.cassette index 7f425847..9962ce24 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_model_client.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_model_client.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bb3604dfdd1adbc7298fd6dda9e1b39c27c0bf2b1ec7c4527ba9d53f24cd8dbf -size 3870 +oid sha256:07e006abcd717813008a8f7477fe79e35db41e7630314bedff221a1f7b458f5a +size 3920 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_prediction_client.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_prediction_client.cassette index fe74c2ec..d586f1df 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_prediction_client.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_prediction_client.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c06791f6397f23244583871284a178a12bd84b7ab59508073533c36c76b96f29 -size 2437 +oid sha256:5ee5c3c87c3658eb1eae0f26c9b76309a7548c61434fd2f673d7836692b74a60 +size 2481 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_auto_train.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_auto_train.cassette index 6eb42211..3794a22b 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_auto_train.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_auto_train.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0231d465fbafc60399cc3ad93b392d7131b977127619bc428e033a6b8e9cbdf7 -size 18997 +oid sha256:f28d52b4f4b1f2e507e9444e8a76eac57c4deded64a7518e44258acf1264e58a +size 19657 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_minimal_hypers.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_minimal_hypers.cassette index 97ae8a9b..821a7f9d 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_minimal_hypers.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_set_minimal_hypers.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:033ffe39b285727c6f081b5d45d46f838cec83770e1ff535dba0404a8f342a5a -size 80204 +oid sha256:0b940a128fbd9b1bffd7ca659dddd03caceee59f9dd84e7a2652cc49e6ab1d1a +size 82987 diff --git a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_training_client.cassette b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_training_client.cassette index 31434921..198fd3dc 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_training_client.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti_sdk_test_model_and_prediction_client_training_client.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5da5472a235a9e6c23f516b7d15b11d69abddb379c4433c4f6af5597f75b2aca -size 3870 +oid sha256:6bdbe93b5b15b03b18d540aa52578a7b5c52767a3c56e87a1ae2edeed78c2dd6 +size 3920 diff --git a/tests/fixtures/demos.py b/tests/fixtures/demos.py index b16e23e5..e7da452c 100644 --- a/tests/fixtures/demos.py +++ b/tests/fixtures/demos.py @@ -50,10 +50,10 @@ def fxt_anomaly_classification_demo_project( fxt_demo_images_and_annotations: Tuple[int, int], ) -> Project: """ - Create an annotated anomaly classification project on the Geti instance, and + Create an annotated anomaly detection project on the Geti instance, and return the Project object representing it. """ - project_name = f"{PROJECT_PREFIX}_anomaly_classification_demo" + project_name = f"{PROJECT_PREFIX}_anomaly_demo" project = create_anomaly_classification_demo_project( geti=fxt_geti_no_vcr, n_images=fxt_demo_images_and_annotations[0], @@ -62,9 +62,8 @@ def fxt_anomaly_classification_demo_project( ) yield project force_delete_project( - project_name=project_name, + project, project_client=fxt_project_client_no_vcr, - project_id=project.id, ) @@ -87,9 +86,8 @@ def fxt_segmentation_demo_project( ) yield project force_delete_project( - project_name=project.name, + project, project_client=fxt_project_client_no_vcr, - project_id=project.id, ) @@ -112,9 +110,8 @@ def fxt_detection_to_classification_demo_project( ) yield project force_delete_project( - project_name=project.name, + project, project_client=fxt_project_client_no_vcr, - project_id=project.id, ) @@ -137,9 +134,8 @@ def fxt_detection_to_segmentation_demo_project( ) yield project force_delete_project( - project_name=project_name, + project, project_client=fxt_project_client_no_vcr, - project_id=project.id, ) @@ -162,9 +158,8 @@ def fxt_classification_demo_project( ) yield project force_delete_project( - project_name=project_name, + project, project_client=fxt_project_client_no_vcr, - project_id=project.id, ) @@ -187,9 +182,8 @@ def fxt_detection_demo_project( ) yield project force_delete_project( - project_name=project_name, + project, project_client=fxt_project_client_no_vcr, - project_id=project.id, ) diff --git a/tests/fixtures/projects.py b/tests/fixtures/projects.py index 6b9df649..d23f3d4b 100644 --- a/tests/fixtures/projects.py +++ b/tests/fixtures/projects.py @@ -22,6 +22,7 @@ from geti_sdk.rest_clients import ProjectClient from tests.helpers import ProjectService, force_delete_project from tests.helpers.constants import CASSETTE_EXTENSION +from tests.helpers.enums import SdkTestMode @pytest.fixture(scope="class") @@ -45,6 +46,7 @@ def fxt_project_client_no_vcr(fxt_geti_no_vcr: Geti) -> ProjectClient: @pytest.fixture(scope="class") def fxt_project_service( fxt_vcr, + fxt_test_mode, fxt_geti: Geti, ) -> ProjectService: """ @@ -56,7 +58,9 @@ def fxt_project_service( The project is deleted once the test function finishes. """ - project_service = ProjectService(geti=fxt_geti, vcr=fxt_vcr) + project_service = ProjectService( + geti=fxt_geti, vcr=fxt_vcr, is_offline=(fxt_test_mode == SdkTestMode.OFFLINE) + ) yield project_service project_service.delete_project() @@ -64,6 +68,7 @@ def fxt_project_service( @pytest.fixture(scope="class") def fxt_project_service_2( fxt_vcr, + fxt_test_mode, fxt_geti: Geti, ) -> ProjectService: """ @@ -78,7 +83,9 @@ def fxt_project_service_2( NOTE: This fixture is the same as `fxt_project_service`, but was added to make it possible to persist two projects for the scope of one test class """ - project_service = ProjectService(geti=fxt_geti, vcr=fxt_vcr) + project_service = ProjectService( + geti=fxt_geti, vcr=fxt_vcr, is_offline=(fxt_test_mode == SdkTestMode.OFFLINE) + ) yield project_service project_service.delete_project() @@ -98,8 +105,8 @@ def fxt_project_finalizer(fxt_project_client: ProjectClient) -> Callable[[str], :var project_name: Name of the project for which to add the finalizer """ - def _project_finalizer(project_name: str, project_id: str) -> None: - force_delete_project(project_name, fxt_project_client, project_id) + def _project_finalizer(project: Project) -> None: + force_delete_project(project, fxt_project_client) return _project_finalizer diff --git a/tests/fixtures/unit_tests/benchmarker.py b/tests/fixtures/unit_tests/benchmarker.py index 86edc42b..69a7fc30 100644 --- a/tests/fixtures/unit_tests/benchmarker.py +++ b/tests/fixtures/unit_tests/benchmarker.py @@ -27,18 +27,17 @@ def fxt_benchmarker( fxt_mocked_geti: Geti, ) -> Benchmarker: _ = mocker.patch( - "geti_sdk.geti.ProjectClient.get_project_by_name", + "geti_sdk.geti.Geti.get_project", return_value=fxt_classification_project, ) _ = mocker.patch("geti_sdk.benchmarking.benchmarker.ModelClient") _ = mocker.patch("geti_sdk.benchmarking.benchmarker.TrainingClient") - project_name = "project name" algorithms_to_benchmark = ("ALGO_1", "ALGO_2") precision_levels = ("PRECISION_1", "PRECISION_2") images = ("path_1", "path_2") yield Benchmarker( geti=fxt_mocked_geti, - project=project_name, + project=mocker.MagicMock(), algorithms=algorithms_to_benchmark, precision_levels=precision_levels, benchmark_images=images, @@ -52,7 +51,7 @@ def fxt_benchmarker_task_chain( fxt_mocked_geti: Geti, ) -> Benchmarker: _ = mocker.patch( - "geti_sdk.geti.ProjectClient.get_project_by_name", + "geti_sdk.geti.Geti.get_project", return_value=fxt_detection_to_classification_project, ) model_client_object_mock = mocker.MagicMock() @@ -64,13 +63,12 @@ def fxt_benchmarker_task_chain( model_client_object_mock.get_all_active_models.return_value = active_models _ = mocker.patch("geti_sdk.benchmarking.benchmarker.TrainingClient") - project_name = "project name" precision_levels = ("PRECISION_1", "PRECISION_2") images = ("path_1", "path_2") yield Benchmarker( geti=fxt_mocked_geti, - project=project_name, + project=mocker.MagicMock(), precision_levels=precision_levels, benchmark_images=images, ) diff --git a/tests/helpers/finalizers.py b/tests/helpers/finalizers.py index 2094b4d6..67353acc 100644 --- a/tests/helpers/finalizers.py +++ b/tests/helpers/finalizers.py @@ -13,14 +13,12 @@ # and limitations under the License. import logging import time -from typing import Optional +from geti_sdk.data_models.project import Project from geti_sdk.rest_clients import ProjectClient, TrainingClient -def force_delete_project( - project_name: str, project_client: ProjectClient, project_id: Optional[str] = None -) -> None: +def force_delete_project(project: Project, project_client: ProjectClient) -> None: """ Deletes the project named 'project_name'. If any jobs are running for the project, this finalizer cancels them. @@ -30,17 +28,16 @@ def force_delete_project( :param project_id: Optional ID of the project to delete. This can be useful in case there are multiple projects with the same name in the workspace """ - project = project_client.get_project_by_name(project_name, project_id) try: project_client.delete_project(project=project, requires_confirmation=False) except TypeError: logging.warning( - f"Project {project_name} was not found on the server, it was most " + f"Project {project.name} was not found on the server, it was most " f"likely already deleted." ) except ValueError: logging.error( - f"Unable to delete project '{project_name}' from the server, it " + f"Unable to delete project '{project.name}' from the server, it " f"is most likely locked for deletion due to an operation/training " f"session that is in progress. " f"\n\n Attempting to cancel the job and re-try project deletion." diff --git a/tests/helpers/project_helpers.py b/tests/helpers/project_helpers.py index 50f5a3b4..a289a94e 100644 --- a/tests/helpers/project_helpers.py +++ b/tests/helpers/project_helpers.py @@ -94,7 +94,7 @@ def remove_all_test_projects(geti: Geti) -> List[str]: projects_removed: List[str] = [] for project in project_client.get_all_projects(get_project_details=False): if project.name.startswith(PROJECT_PREFIX): - force_delete_project(project.name, project_client, project_id=project.id) + force_delete_project(project, project_client) projects_removed.append(project.name) logging.info(f"{len(projects_removed)} test projects were removed from the server.") return projects_removed diff --git a/tests/helpers/project_service.py b/tests/helpers/project_service.py index 5969eb1f..0630aca2 100644 --- a/tests/helpers/project_service.py +++ b/tests/helpers/project_service.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions # and limitations under the License. import logging +import time from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Union @@ -23,6 +24,7 @@ from geti_sdk.rest_clients import ( AnnotationClient, ConfigurationClient, + DatasetClient, ImageClient, ModelClient, PredictionClient, @@ -47,7 +49,7 @@ class ProjectService: VCR cassettes """ - def __init__(self, geti: Geti, vcr: Optional[VCR] = None): + def __init__(self, geti: Geti, vcr: Optional[VCR] = None, is_offline: bool = False): if vcr is None: self.vcr_context = nullcontext else: @@ -60,6 +62,8 @@ def __init__(self, geti: Geti, vcr: Optional[VCR] = None): ) self._project: Optional[Project] = None + self._project_creation_timestamp: Optional[float] = None + self._is_offline: bool = is_offline self._configuration_client: Optional[ConfigurationClient] = None self._image_client: Optional[ImageClient] = None self._annotation_client: Optional[AnnotationClient] = None @@ -67,6 +71,7 @@ def __init__(self, geti: Geti, vcr: Optional[VCR] = None): self._video_client: Optional[VideoClient] = None self._model_client: Optional[ModelClient] = None self._prediction_client: Optional[PredictionClient] = None + self._dataset_client: Optional[DatasetClient] = None self._client_names = [ "_configuration_client", "_image_client", @@ -75,7 +80,9 @@ def __init__(self, geti: Geti, vcr: Optional[VCR] = None): "_video_client", "_model_client", "_prediction_client", + "_dataset_client", ] + self._project_removal_delay = 5 # seconds def create_project( self, @@ -100,7 +107,7 @@ def create_project( project = self.project_client.create_project( project_name=project_name, project_type=project_type, labels=labels ) - self._project = project + self.project = project return project else: raise ValueError( @@ -183,7 +190,7 @@ def create_project_from_dataset( "Please either delete the existing project first or use a new " "instance to create another project" ) - self._project = project + self.project = project return project @property @@ -191,7 +198,8 @@ def project(self) -> Project: """ Returns the project managed by the ProjectService. - :return: + :return: The project managed by the ProjectService + :raises: ValueError if the ProjectService does not contain a project yet """ if self._project is None: raise ValueError( @@ -200,6 +208,17 @@ def project(self) -> Project: ) return self._project + @project.setter + def project(self, value: Optional[Project]) -> None: + """ + Set the project for the ProjectService. + """ + self._project = value + if self._project is None: + self._project_creation_timestamp = None + else: + self._project_creation_timestamp = time.time() + @property def has_project(self) -> bool: """ @@ -260,6 +279,20 @@ def annotation_client(self) -> AnnotationClient: ) return self._annotation_client + @property + def dataset_client(self) -> DatasetClient: + """Returns the DatasetClient instance for the project""" + if self._dataset_client is None: + with self.vcr_context( + f"{self.project.name}_dataset_client.{CASSETTE_EXTENSION}" + ): + self._dataset_client = DatasetClient( + session=self.session, + workspace_id=self.workspace_id, + project=self.project, + ) + return self._dataset_client + @property def configuration_client(self) -> ConfigurationClient: """Returns the ConfigurationClient instance for the project""" @@ -332,11 +365,16 @@ def is_training(self) -> bool: def delete_project(self): """Deletes the project from the server""" - if self._project is not None: + if self.has_project: with self.vcr_context(f"{self.project.name}_deletion.{CASSETTE_EXTENSION}"): - force_delete_project( - self.project.name, self.project_client, self.project.id - ) + # Server needs a moment to process the project before deletion + if ( + not self._is_offline + and (lifetime := time.time() - self._project_creation_timestamp) + < self._project_removal_delay + ): + time.sleep(self._project_removal_delay - lifetime) + force_delete_project(self.project, self.project_client) self.reset_state() def reset_state(self) -> None: @@ -344,7 +382,7 @@ def reset_state(self) -> None: Resets the state of the ProjectService instance. This method should be called once the project belonging to the project service is deleted from the server """ - self._project = None + self.project = None for client_name in self._client_names: setattr(self, client_name, None) diff --git a/tests/nightly/demos/test_demo_projects.py b/tests/nightly/demos/test_demo_projects.py index 1cbdf7d7..6d58860f 100644 --- a/tests/nightly/demos/test_demo_projects.py +++ b/tests/nightly/demos/test_demo_projects.py @@ -69,7 +69,7 @@ def test_get_mvtec_dataset(self, fxt_anomaly_dataset: str): ids=[ "Detection", "Classification", - "Anomaly classification", + "Anomaly", "Segmentation", "Detection to classification", "Detection to segmentation", @@ -160,9 +160,8 @@ def test_ensure_trained_example_project( ) if any_project is not None: force_delete_project( - project_name=non_existing_project_name, + project=any_project, project_client=fxt_project_client_no_vcr, - project_id=any_project.id, ) assert non_existing_project_name not in [ project.name for project in fxt_project_client_no_vcr.get_all_projects() diff --git a/tests/nightly/test_anomaly_classification.py b/tests/nightly/test_anomaly.py similarity index 89% rename from tests/nightly/test_anomaly_classification.py rename to tests/nightly/test_anomaly.py index 492d75b2..3b94b2cc 100644 --- a/tests/nightly/test_anomaly_classification.py +++ b/tests/nightly/test_anomaly.py @@ -6,8 +6,8 @@ from tests.nightly.test_nightly_project import TestNightlyProject -class TestAnomalyClassification(TestNightlyProject): - PROJECT_TYPE = "anomaly_classification" +class TestAnomaly(TestNightlyProject): + PROJECT_TYPE = "anomaly" __test__ = True def test_project_setup( @@ -23,9 +23,8 @@ def test_project_setup( existing_project = fxt_project_client_no_vcr.get_project_by_name(project_name) if existing_project is not None: force_delete_project( - project_name=project_name, + project=existing_project, project_client=fxt_project_client_no_vcr, - project_id=existing_project.id, ) assert project_name not in [ project.name for project in fxt_project_client_no_vcr.get_all_projects() @@ -40,7 +39,7 @@ def test_project_setup( def test_monitor_jobs(self, fxt_project_service_no_vcr: ProjectService): """ - For anomaly classification projects, the training is run in the project_setup + For anomaly projects, the training is run in the project_setup phase. No need to monitor jobs. """ pass diff --git a/tests/nightly/test_classification.py b/tests/nightly/test_classification.py index b2688a1a..6c6765c6 100644 --- a/tests/nightly/test_classification.py +++ b/tests/nightly/test_classification.py @@ -89,7 +89,8 @@ def test_export_import_project( # Project is exported assert not os.path.exists(archive_path) fxt_geti_no_vcr.export_project( - project_name=project.name, project_id=project.id, filepath=archive_path + filepath=archive_path, + project=project, ) assert os.path.exists(archive_path) diff --git a/tests/nightly/test_nightly_project.py b/tests/nightly/test_nightly_project.py index ee8022c8..0c1a2fed 100644 --- a/tests/nightly/test_nightly_project.py +++ b/tests/nightly/test_nightly_project.py @@ -127,7 +127,7 @@ def test_upload_and_predict_image( for j in range(n_attempts): try: image, prediction = fxt_geti_no_vcr.upload_and_predict_image( - project_name=project.name, + project=project, image=fxt_image_path, visualise_output=False, delete_after_prediction=False, @@ -160,7 +160,7 @@ def test_deployment( deployment_folder = os.path.join(fxt_temp_directory, project.name) deployment = fxt_geti_no_vcr.deploy_project( - project.name, + project, output_folder=deployment_folder, enable_explainable_ai=True, ) @@ -177,7 +177,7 @@ def test_deployment( local_prediction = deployment.infer(image_np) assert isinstance(local_prediction, Prediction) image, online_prediction = fxt_geti_no_vcr.upload_and_predict_image( - project.name, + project, image=image_bgr, delete_after_prediction=True, visualise_output=False, diff --git a/tests/pre-merge/integration/rest_clients/test_image_client.py b/tests/pre-merge/integration/rest_clients/test_image_client.py index 696ab017..fba5965a 100644 --- a/tests/pre-merge/integration/rest_clients/test_image_client.py +++ b/tests/pre-merge/integration/rest_clients/test_image_client.py @@ -15,6 +15,7 @@ import os import shutil import tempfile +from pathlib import Path from typing import List import cv2 @@ -132,6 +133,39 @@ def test_upload_image_folder_and_download( for image in images + old_images: assert image.name + ".jpg" in downloaded_filenames + # remove images + image_client.delete_images(images) + assert len(image_client.get_all_images()) == n_images + + @pytest.mark.vcr() + def test_upload_from_list( + self, + fxt_project_service: ProjectService, + fxt_default_labels: List[str], + fxt_image_folder: str, + request: FixtureRequest, + ): + image_client = fxt_project_service.image_client + old_images = image_client.get_all_images() + n_old_images = len(old_images) + + # Upload images from list + image_base_names = [ + p.name.split(".")[0] for p in Path.glob(Path(fxt_image_folder), "*") + ] + n_to_upload = len(image_base_names) // 2 + assert n_to_upload > 0 + images = image_client.upload_from_list( + fxt_image_folder, + image_names=image_base_names, + max_threads=1, + extension_included=False, + n_images=n_to_upload, + image_names_as_full_paths=False, + ) + assert len(images) == n_to_upload + assert len(image_client.get_all_images()) == n_old_images + len(images) + @pytest.mark.vcr() def test_download_specific_dataset( self, diff --git a/tests/pre-merge/integration/rest_clients/test_model_and_prediction_client.py b/tests/pre-merge/integration/rest_clients/test_model_and_prediction_client.py index 6408bb8a..fdba57e8 100644 --- a/tests/pre-merge/integration/rest_clients/test_model_and_prediction_client.py +++ b/tests/pre-merge/integration/rest_clients/test_model_and_prediction_client.py @@ -21,7 +21,7 @@ from geti_sdk.annotation_readers import DatumAnnotationReader from geti_sdk.data_models import Image, Prediction, Project -from geti_sdk.data_models.enums import JobState, PredictionMode +from geti_sdk.data_models.enums import JobState, PredictionMode, SubsetPurpose from geti_sdk.demos import EXAMPLE_IMAGE_PATH from geti_sdk.http_session import GetiRequestException from geti_sdk.platform_versions import GETI_15_VERSION, GETI_22_VERSION @@ -358,3 +358,33 @@ def test_purge_model( # Assert purged_model = model_client.update_model_detail(previous_model) assert purged_model.purge_info.is_purged + + @pytest.mark.vcr() + def test_get_training_dataset( + self, + fxt_project_service: ProjectService, + fxt_test_mode: SdkTestMode, + ) -> None: + """ + Test that the media in the training dataset for a model con be recovered. + """ + model = fxt_project_service.model_client.get_all_active_models()[0] + dataset_client = fxt_project_service.dataset_client + + training_ds_summary = dataset_client.get_training_dataset_summary(model) + + train_set = dataset_client.get_media_in_training_dataset(model, "training") + val_set = dataset_client.get_media_in_training_dataset(model, "validation") + test_set = dataset_client.get_media_in_training_dataset(model, "testing") + + n_train_items = training_ds_summary.training_size + n_val_items = training_ds_summary.validation_size + n_test_items = training_ds_summary.testing_size + + assert train_set.size == n_train_items + assert val_set.size == n_val_items + assert test_set.size == n_test_items + + assert train_set.purpose == SubsetPurpose.TRAINING + assert val_set.purpose == SubsetPurpose.VALIDATION + assert test_set.purpose == SubsetPurpose.TESTING diff --git a/tests/pre-merge/integration/test_geti.py b/tests/pre-merge/integration/test_geti.py index fc7b509f..009716f3 100644 --- a/tests/pre-merge/integration/test_geti.py +++ b/tests/pre-merge/integration/test_geti.py @@ -209,7 +209,7 @@ def test_create_single_task_project_from_dataset( max_threads=1, ) - request.addfinalizer(lambda: fxt_project_finalizer(project_name, project.id)) + request.addfinalizer(lambda: fxt_project_finalizer(project)) @pytest.mark.vcr() @pytest.mark.parametrize( @@ -257,7 +257,7 @@ def test_create_task_chain_project_from_dataset( enable_auto_train=False, max_threads=1, ) - request.addfinalizer(lambda: fxt_project_finalizer(project_name, project.id)) + request.addfinalizer(lambda: fxt_project_finalizer(project)) all_labels = fxt_default_labels + ["block"] for label_name in all_labels: @@ -287,7 +287,7 @@ def test_download_and_upload_project( target_folder = os.path.join(fxt_temp_directory, project.name) fxt_geti.download_project_data( - project.name, + project, target_folder=target_folder, max_threads=1, ) @@ -304,9 +304,7 @@ def test_download_and_upload_project( enable_auto_train=False, max_threads=1, ) - request.addfinalizer( - lambda: fxt_project_finalizer(uploaded_project.name, uploaded_project.id) - ) + request.addfinalizer(lambda: fxt_project_finalizer(uploaded_project)) image_client = ImageClient( session=fxt_geti.session, workspace_id=fxt_geti.workspace_id, @@ -393,7 +391,7 @@ def test_upload_and_predict_image( for j in range(n_attempts): try: image, prediction = fxt_geti.upload_and_predict_image( - project_name=project.name, + project=project, image=fxt_image_path, visualise_output=False, delete_after_prediction=False, @@ -419,7 +417,7 @@ def test_upload_and_predict_video( Verify that the `Geti.upload_and_predict_video` method works as expected """ video, frames, predictions = fxt_geti.upload_and_predict_video( - project_name=fxt_project_service.project.name, + project=fxt_project_service.project, video=fxt_video_path_1_light_bulbs, visualise_output=False, ) @@ -432,15 +430,16 @@ def test_upload_and_predict_video( # Check that invalid project raises a KeyError with pytest.raises(KeyError): + project = fxt_geti.get_project(project_name="invalid_project_name") fxt_geti.upload_and_predict_video( - project_name="invalid_project_name", + project=project, video=fxt_video_path_1_light_bulbs, visualise_output=False, ) # Check that video is not uploaded if it's already in the project video, frames, predictions = fxt_geti.upload_and_predict_video( - project_name=fxt_project_service.project.name, + project=fxt_project_service.project, video=video, visualise_output=False, ) @@ -450,7 +449,7 @@ def test_upload_and_predict_video( new_frames = video.to_frames(frame_stride=50, include_data=True) np_frames = [frame.numpy for frame in new_frames] np_video, frames, predictions = fxt_geti.upload_and_predict_video( - project_name=fxt_project_service.project.name, + project=fxt_project_service.project, video=np_frames, visualise_output=False, delete_after_prediction=True, @@ -475,14 +474,14 @@ def test_upload_and_predict_media_folder( image_output_folder = os.path.join(fxt_temp_directory, "inferred_images") video_success = fxt_geti.upload_and_predict_media_folder( - project_name=fxt_project_service.project.name, + project=fxt_project_service.project, media_folder=fxt_video_folder_light_bulbs, output_folder=video_output_folder, delete_after_prediction=True, max_threads=1, ) image_success = fxt_geti.upload_and_predict_media_folder( - project_name=fxt_project_service.project.name, + project=fxt_project_service.project, media_folder=fxt_image_folder_light_bulbs, output_folder=image_output_folder, delete_after_prediction=True, @@ -519,7 +518,7 @@ def test_deployment( for _ in range(n_attempts): try: deployment = fxt_geti.deploy_project( - project.name, + project, output_folder=deployment_folder, enable_explainable_ai=True, ) @@ -538,7 +537,7 @@ def test_deployment( local_prediction = deployment.infer(image_np) assert isinstance(local_prediction, Prediction) image, online_prediction = fxt_geti.upload_and_predict_image( - project.name, + project, image=image_np, delete_after_prediction=True, visualise_output=False, @@ -577,7 +576,7 @@ def test_post_inference_hooks( project = fxt_project_service.project deployment_folder = os.path.join(fxt_temp_directory, project.name) - deployment = fxt_geti.deploy_project(project.name) + deployment = fxt_geti.deploy_project(project) dataset_name = "Test hooks" # Add a GetiDataCollectionHook @@ -670,7 +669,7 @@ def test_download_project_including_models_and_predictions( fxt_temp_directory, project.name + "_all_inclusive" ) fxt_geti.download_project_data( - project_name=project.name, + project=project, target_folder=target_folder, include_predictions=True, include_active_models=True, diff --git a/tests/pre-merge/integration/utils/test_utils.py b/tests/pre-merge/integration/utils/test_utils.py index 22b49656..5aa575d9 100644 --- a/tests/pre-merge/integration/utils/test_utils.py +++ b/tests/pre-merge/integration/utils/test_utils.py @@ -22,6 +22,7 @@ from tests.helpers.constants import ( DUMMY_HOST, DUMMY_PASSWORD, + DUMMY_TOKEN, DUMMY_USER, PROJECT_PREFIX, ) @@ -85,12 +86,13 @@ def test_get_server_details_from_env(self, fxt_env_filepath: str): assert not hasattr(server_config, "username") assert not hasattr(server_config, "password") - environ_keys = ["GETI_HOST", "GETI_USERNAME", "GETI_PASSWORD"] + environ_keys = ["GETI_HOST", "GETI_USERNAME", "GETI_PASSWORD", "GETI_TOKEN"] expected_results = {} dummy_results = { "GETI_HOST": DUMMY_HOST, "GETI_USERNAME": DUMMY_USER, "GETI_PASSWORD": DUMMY_PASSWORD, + "GETI_TOKEN": DUMMY_TOKEN, } for ekey in environ_keys: evalue = os.environ.get(ekey, None) @@ -105,7 +107,5 @@ def test_get_server_details_from_env(self, fxt_env_filepath: str): assert server_config.host.replace("https://", "") == expected_results[ "GETI_HOST" ].replace("https://", "") - assert server_config.username == expected_results["GETI_USERNAME"] - assert server_config.password == expected_results["GETI_PASSWORD"] + assert server_config.token == expected_results["GETI_TOKEN"] assert server_config.proxies is None - assert not hasattr(server_config, "token") diff --git a/tests/pre-merge/unit/benchmarking/test_benchmarker.py b/tests/pre-merge/unit/benchmarking/test_benchmarker.py index 14afa952..d150aa20 100644 --- a/tests/pre-merge/unit/benchmarking/test_benchmarker.py +++ b/tests/pre-merge/unit/benchmarking/test_benchmarker.py @@ -32,8 +32,8 @@ def test_initialize( mocker: MockerFixture, ): # Arrange - mock_get_project_by_name = mocker.patch( - "geti_sdk.geti.ProjectClient.get_project_by_name", + mock_get_project = mocker.patch( + "geti_sdk.geti.Geti.get_project", return_value=fxt_classification_project, ) mocked_model_client = mocker.patch( @@ -42,7 +42,7 @@ def test_initialize( mocked_training_client = mocker.patch( "geti_sdk.benchmarking.benchmarker.TrainingClient" ) - project_name = "project name" + project_mock = mocker.MagicMock() algorithms_to_benchmark = ("ALGO_1", "ALGO_2") precision_levels = ("PRECISION_1", "PRECISION_2") images = ("path_1", "path_2") @@ -52,16 +52,14 @@ def test_initialize( # Single task project, benchmarking on images benchmarker = Benchmarker( geti=fxt_mocked_geti, - project=project_name, + project=project_mock, algorithms=algorithms_to_benchmark, precision_levels=precision_levels, benchmark_images=images, ) # Assert - mock_get_project_by_name.assert_called_once_with( - project_name=project_name, project_id=None - ) + mock_get_project.assert_called_once_with(project_id=project_mock.id) mocked_model_client.assert_called_once() mocked_training_client.assert_called_once() assert benchmarker._is_single_task @@ -74,7 +72,7 @@ def test_initialize( with pytest.raises(ValueError): benchmarker = Benchmarker( geti=fxt_mocked_geti, - project=project_name, + project=project_mock, algorithms=algorithms_to_benchmark, precision_levels=precision_levels, benchmark_images=images, @@ -88,8 +86,8 @@ def test_initialize_task_chain( mocker: MockerFixture, ): # Arrange - mock_get_project_by_name = mocker.patch( - "geti_sdk.geti.ProjectClient.get_project_by_name", + mocker.patch( + "geti_sdk.geti.Geti.get_project", return_value=fxt_detection_to_classification_project, ) fetched_images = (mocker.MagicMock(),) @@ -108,21 +106,18 @@ def test_initialize_task_chain( mocked_training_client = mocker.patch( "geti_sdk.benchmarking.benchmarker.TrainingClient" ) - project_name = "project name" + project_mock = mocker.MagicMock() precision_levels = ["PRECISION_1", "PRECISION_2"] # Act # Multi task project, no media provided benchmarker = Benchmarker( geti=fxt_mocked_geti, - project=project_name, + project=project_mock, precision_levels=precision_levels, ) # Assert - mock_get_project_by_name.assert_called_once_with( - project_name=project_name, project_id=None - ) mock_image_client_get_all.assert_called_once() mocked_model_client.assert_called_once() model_client_object_mock.get_all_active_models.assert_called_once() diff --git a/tests/pre-merge/unit/test_geti_unit.py b/tests/pre-merge/unit/test_geti_unit.py index 45dec52f..faee28fa 100644 --- a/tests/pre-merge/unit/test_geti_unit.py +++ b/tests/pre-merge/unit/test_geti_unit.py @@ -173,7 +173,7 @@ def test_upload_all_projects( for project in fxt_nightly_projects: os.makedirs(os.path.join(target_dir, project.name)) mock_is_project_dir = mocker.patch( - "geti_sdk.geti.ProjectClient.is_project_dir", return_value=True + "geti_sdk.geti.ProjectClient._is_project_dir", return_value=True ) mock_upload_project_data = mocker.patch( "geti_sdk.import_export.import_export_module.GetiIE.upload_project_data"