From e98a1401bc7db6f4058f73f786cf4de2defe0a85 Mon Sep 17 00:00:00 2001 From: ljcornel Date: Thu, 17 Oct 2024 14:12:47 +0200 Subject: [PATCH] Fix label mapping issues --- geti_sdk/data_models/containers/label_list.py | 2 +- geti_sdk/data_models/label.py | 4 +++- geti_sdk/deployment/deployed_model.py | 6 +++++- .../results_to_prediction_converter.py | 10 +++++++--- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/geti_sdk/data_models/containers/label_list.py b/geti_sdk/data_models/containers/label_list.py index 84a1cc26..c661dc80 100644 --- a/geti_sdk/data_models/containers/label_list.py +++ b/geti_sdk/data_models/containers/label_list.py @@ -47,7 +47,7 @@ def _generate_indices(self): Map names and ID's to Label objects to enable quick label retrieval """ self._id_mapping = {x.id: x for x in self.data} - self._name_mapping = {x.name for x in self.data} + self._name_mapping = {x.name: x for x in self.data} def get_by_id(self, id: str) -> Label: """ diff --git a/geti_sdk/data_models/label.py b/geti_sdk/data_models/label.py index c800eb9e..ff474a89 100644 --- a/geti_sdk/data_models/label.py +++ b/geti_sdk/data_models/label.py @@ -114,7 +114,9 @@ class ScoredLabel: _identifier_fields: ClassVar[List[str]] = ["id"] - probability: float + probability: float = attr.field(converter=float) # float converter here to make + # sure we're storing probability + # as a float64 dtype name: Optional[str] = None color: Optional[str] = None id: Optional[str] = None diff --git a/geti_sdk/deployment/deployed_model.py b/geti_sdk/deployment/deployed_model.py index d910ba40..4e9520f3 100644 --- a/geti_sdk/deployment/deployed_model.py +++ b/geti_sdk/deployment/deployed_model.py @@ -760,7 +760,7 @@ def get_model_config(self) -> Dict[str, Any]: config = {} for child in model_info_node: value = child.attrib["value"] - if " " in value: + if " " in value and "{" not in value: value = value.split(" ") value_list = [] for item in value: @@ -769,6 +769,9 @@ def get_model_config(self) -> Dict[str, Any]: except ValueError: value_list.append(item) config[child.tag] = value_list + elif "{" in value: + # Dictionaries are kept in string representation + config[child.tag] = value else: try: value = int(value) @@ -802,6 +805,7 @@ def _clean_model_config(self, configuration: Dict[str, Any]) -> Dict[str, Any]: "model_type", "optimization_config", "task_type", + "labels", ] for key in unused_keys: configuration.pop(key, None) diff --git a/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py b/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py index 13b3ab29..e22d53bb 100644 --- a/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py +++ b/geti_sdk/deployment/predictions_postprocessing/results_converter/results_to_prediction_converter.py @@ -51,7 +51,7 @@ class InferenceResultsToPredictionConverter(metaclass=abc.ABCMeta): def __init__( self, labels: LabelList, configuration: Optional[Dict[str, Any]] = None ): - self.labels = labels + self.labels = labels.get_non_empty_labels() self.configuration = configuration @abc.abstractmethod @@ -175,8 +175,12 @@ def __init__( self.confidence_threshold = configuration["confidence_threshold"] if "label_ids" in configuration: # Make sure the list of labels is sorted according to the order - # defined in the ModelAPI configuration - self.labels.sort_by_ids(configuration["label_ids"]) + # defined in the ModelAPI configuration. If the 'label_ids' field + # only contains a single label, it will be typed as string. No need + # to sort in that case + ids = configuration["label_ids"] + if not isinstance(ids, str): + self.labels.sort_by_ids(configuration["label_ids"]) def _detection2array(self, detections: List[Detection]) -> np.ndarray: """