Skip to content

Commit

Permalink
fix error for automl object detection models when initializing error …
Browse files Browse the repository at this point in the history
…analysis and optimize explanation execution (microsoft#2245)
  • Loading branch information
imatiach-msft authored Aug 16, 2023
1 parent 3cc2e06 commit 5c7a0fa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,23 @@ def __init__(self, model, dataset, image_mode, transformations,
test = np.array(
self.dataset.iloc[:, 0].tolist()
)
test = pd.DataFrame(
data=[
get_base64_string_from_path(img_path) for img_path in test
],
columns=[MLFlowSchemaLiterals.INPUT_COLUMN_IMAGE],
)
if self.task_type == ModelTask.OBJECT_DETECTION:
test = pd.DataFrame(
data=[[x for x in get_base64_string_from_path(
img_path, return_image_size=True)] for
img_path in test],
columns=[
MLFlowSchemaLiterals.INPUT_COLUMN_IMAGE,
MLFlowSchemaLiterals.INPUT_IMAGE_SIZE],
)
else:
test = pd.DataFrame(
data=[
get_base64_string_from_path(
img_path) for img_path in test
],
columns=[MLFlowSchemaLiterals.INPUT_COLUMN_IMAGE],
)
else:
test = get_images(self.dataset, self.image_mode,
self.transformations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,19 @@ def __init__(self, model: Any,
self._image_mode = image_mode
if task_type == ModelTask.OBJECT_DETECTION:
if is_automl_image_model(model):
self._model = MLflowDRiseWrapper(model._model, classes)
try:
python_model = model._model._model_impl.python_model
automl_wrapper = python_model._model
inner_model = automl_wrapper._model
number_of_classes = automl_wrapper._number_of_classes
self._model = PytorchDRiseWrapper(
inner_model, number_of_classes, device=device)
except Exception as e:
warnings.warn(("Could not extract inner automl model." +
"Explanation may take longer to compute." +
"Inner exception: {}").format(str(e)),
UserWarning)
self._model = MLflowDRiseWrapper(model._model, classes)
else:
self._model = PytorchDRiseWrapper(
model._model, len(classes), device=device)
Expand Down

0 comments on commit 5c7a0fa

Please sign in to comment.