diff --git a/inference/core/workflows/core_steps/transformations/perspective_correction/v1.py b/inference/core/workflows/core_steps/transformations/perspective_correction/v1.py index c82d222ea3..7d0df03e5f 100644 --- a/inference/core/workflows/core_steps/transformations/perspective_correction/v1.py +++ b/inference/core/workflows/core_steps/transformations/perspective_correction/v1.py @@ -406,11 +406,15 @@ def run( result_image = image if warp_image: # https://docs.opencv.org/4.9.0/da/d54/group__imgproc__transform.html#gaf73673a7e8e18ec6963e3774e6a94b87 - result_image = cv.warpPerspective( + warped_image = cv.warpPerspective( src=image.numpy_image, M=perspective_transformer, dsize=(transformed_rect_width, transformed_rect_height), ) + result_image = WorkflowImageData( + parent_metadata=image.parent_metadata, + numpy_image=warped_image, + ) if detections is None: result.append( diff --git a/tests/workflows/unit_tests/core_steps/transformations/test_perspective_correction.py b/tests/workflows/unit_tests/core_steps/transformations/test_perspective_correction.py index d7b381d0b7..ae7a8f626e 100644 --- a/tests/workflows/unit_tests/core_steps/transformations/test_perspective_correction.py +++ b/tests/workflows/unit_tests/core_steps/transformations/test_perspective_correction.py @@ -14,8 +14,15 @@ from inference.core.workflows.execution_engine.constants import ( KEYPOINTS_XY_KEY_IN_SV_DETECTIONS, ) -from inference.core.workflows.execution_engine.entities.base import Batch +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, + ImageParentMetadata +) +from inference.core.workflows.core_steps.transformations.perspective_correction.v1 import ( + PerspectiveCorrectionBlockV1, +) @pytest.mark.parametrize("broken_input", [1, "cat", np.array([])]) def test_pick_largest_perspective_polygons_raises_on_unexpected_type_of_input( @@ -286,3 +293,27 @@ def test_correct_detections_with_keypoints(): dtype="object", ) assert corrected_detections == expected_detections + + +def test_warp_image(): + # given + dummy_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + dummy_predictions = sv.Detections(xyxy=np.array([[10, 10, 20, 20]])) + perspective_correction_block = PerspectiveCorrectionBlockV1() + + workflow_image_data = WorkflowImageData(parent_metadata=ImageParentMetadata(parent_id="test"), numpy_image=dummy_image) + + # when + result = perspective_correction_block.run( + images=[workflow_image_data], + predictions=[dummy_predictions], + perspective_polygons=[[[1, 1], [99, 1], [99, 99], [1, 99]]], + transformed_rect_width=200, + transformed_rect_height=200, + extend_perspective_polygon_by_detections_anchor=None, + warp_image=True, + ) + + # then + assert "warped_image" in result[0], "warped_image key must be present in the result" + assert isinstance(result[0]["warped_image"], WorkflowImageData), f"warped_image must be of type WorkflowImageData"