Skip to content

Commit

Permalink
Fix label mapping issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcornel committed Oct 17, 2024
1 parent fd85752 commit e98a140
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion geti_sdk/data_models/containers/label_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _generate_indices(self):
Map names and ID's to Label objects to enable quick label retrieval
"""
self._id_mapping = {x.id: x for x in self.data}
self._name_mapping = {x.name for x in self.data}
self._name_mapping = {x.name: x for x in self.data}

def get_by_id(self, id: str) -> Label:
"""
Expand Down
4 changes: 3 additions & 1 deletion geti_sdk/data_models/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ class ScoredLabel:

_identifier_fields: ClassVar[List[str]] = ["id"]

probability: float
probability: float = attr.field(converter=float) # float converter here to make
# sure we're storing probability
# as a float64 dtype
name: Optional[str] = None
color: Optional[str] = None
id: Optional[str] = None
Expand Down
6 changes: 5 additions & 1 deletion geti_sdk/deployment/deployed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def get_model_config(self) -> Dict[str, Any]:
config = {}
for child in model_info_node:
value = child.attrib["value"]
if " " in value:
if " " in value and "{" not in value:
value = value.split(" ")
value_list = []
for item in value:
Expand All @@ -769,6 +769,9 @@ def get_model_config(self) -> Dict[str, Any]:
except ValueError:
value_list.append(item)
config[child.tag] = value_list
elif "{" in value:
# Dictionaries are kept in string representation
config[child.tag] = value
else:
try:
value = int(value)
Expand Down Expand Up @@ -802,6 +805,7 @@ def _clean_model_config(self, configuration: Dict[str, Any]) -> Dict[str, Any]:
"model_type",
"optimization_config",
"task_type",
"labels",
]
for key in unused_keys:
configuration.pop(key, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class InferenceResultsToPredictionConverter(metaclass=abc.ABCMeta):
def __init__(
self, labels: LabelList, configuration: Optional[Dict[str, Any]] = None
):
self.labels = labels
self.labels = labels.get_non_empty_labels()
self.configuration = configuration

@abc.abstractmethod
Expand Down Expand Up @@ -175,8 +175,12 @@ def __init__(
self.confidence_threshold = configuration["confidence_threshold"]
if "label_ids" in configuration:
# Make sure the list of labels is sorted according to the order
# defined in the ModelAPI configuration
self.labels.sort_by_ids(configuration["label_ids"])
# defined in the ModelAPI configuration. If the 'label_ids' field
# only contains a single label, it will be typed as string. No need
# to sort in that case
ids = configuration["label_ids"]
if not isinstance(ids, str):
self.labels.sort_by_ids(configuration["label_ids"])

def _detection2array(self, detections: List[Detection]) -> np.ndarray:
"""
Expand Down

0 comments on commit e98a140

Please sign in to comment.