Skip to content

Commit

Permalink
Merge pull request #17866 from kostrykin/verify_imagediff_iou/upstream
Browse files Browse the repository at this point in the history
Add `pin_labels` attribute for `image_diff` comparison method
  • Loading branch information
mvdbeek authored Apr 3, 2024
2 parents af53d03 + 8515eed commit 28c8b2f
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 16 deletions.
1 change: 1 addition & 0 deletions lib/galaxy/tool_util/parser/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

DEFAULT_METRIC = "mae"
DEFAULT_EPS = 0.01
DEFAULT_PIN_LABELS = None


def is_dict(item):
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/tool_util/parser/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DEFAULT_DELTA_FRAC,
DEFAULT_EPS,
DEFAULT_METRIC,
DEFAULT_PIN_LABELS,
)
from galaxy.util import (
Element,
Expand Down Expand Up @@ -793,6 +794,7 @@ def __parse_test_attributes(output_elem, attrib, parse_elements=False, parse_dis
# Parameters for "image_diff" comparison
attributes["metric"] = attrib.pop("metric", DEFAULT_METRIC)
attributes["eps"] = float(attrib.pop("eps", DEFAULT_EPS))
attributes["pin_labels"] = attrib.pop("pin_labels", DEFAULT_PIN_LABELS)
if location and file is None:
file = os.path.basename(location) # If no file specified, try to get filename from URL last component
attributes["location"] = location
Expand Down
69 changes: 57 additions & 12 deletions lib/galaxy/tool_util/verify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DEFAULT_DELTA_FRAC,
DEFAULT_EPS,
DEFAULT_METRIC,
DEFAULT_PIN_LABELS,
)
from galaxy.util import unicodify
from galaxy.util.compression_utils import get_fileobj
Expand Down Expand Up @@ -456,43 +457,87 @@ def files_contains(file1, file2, attributes=None):
raise AssertionError(f"Failed to find '{contains}' in history data. (lines_diff={lines_diff}).")


def _singleobject_intersection_over_union(
mask1: "numpy.typing.NDArray",
mask2: "numpy.typing.NDArray",
) -> "numpy.floating":
return numpy.logical_and(mask1, mask2).sum() / numpy.logical_or(mask1, mask2).sum()


def _multiobject_intersection_over_union(
mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", repeat_reverse: bool = True
mask1: "numpy.typing.NDArray",
mask2: "numpy.typing.NDArray",
pin_labels: Optional[List[int]] = None,
repeat_reverse: bool = True,
) -> List["numpy.floating"]:
iou_list = []
for label1 in numpy.unique(mask1):
cc1 = mask1 == label1
cc1_iou_list = []
for label2 in numpy.unique(mask2[cc1]):
cc2 = mask2 == label2
cc1_iou_list.append(intersection_over_union(cc1, cc2))
iou_list.append(max(cc1_iou_list))

# If the label is in `pin_labels`, then use the same label value to find the corresponding object in the second mask.
if pin_labels is not None and label1 in pin_labels:
cc2 = mask2 == label1
iou_list.append(_singleobject_intersection_over_union(cc1, cc2))

# Otherwise, use the object with the largest IoU value, excluding the pinned labels.
else:
cc1_iou_list = []
for label2 in numpy.unique(mask2[cc1]):
if pin_labels is not None and label2 in pin_labels:
continue
cc2 = mask2 == label2
cc1_iou_list.append(_singleobject_intersection_over_union(cc1, cc2))
iou_list.append(max(cc1_iou_list))

if repeat_reverse:
iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, repeat_reverse=False))
iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, pin_labels, repeat_reverse=False))

return iou_list


def intersection_over_union(mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray") -> "numpy.floating":
def intersection_over_union(
mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", pin_labels: Optional[List[int]] = None
) -> "numpy.floating":
"""Compute the intersection over union (IoU) for the objects in two masks containing lables.
The IoU is computed for each uniquely labeled image region (object), and the overall minimum value is returned (i.e. the worst value).
To compute the IoU for each object, the corresponding object in the other mask needs to be determined.
The object correspondences are not necessarily symmetric.
By default, the corresponding object in the other mask is determined as the one with the largest IoU value.
If the label of an object is listed in `pin_labels`, then the corresponding object in the other mask is determined as the object with the same label value.
Objects with labels listed in `pin_labels` also cannot correspond to objects with different labels.
This is particularly useful when specific image regions must always be labeled with a designated label value (e.g., the image background is often labeled with 0 or -1).
"""
assert mask1.dtype == mask2.dtype
assert mask1.ndim == mask2.ndim == 2
assert mask1.shape == mask2.shape
if mask1.dtype == bool:
return numpy.logical_and(mask1, mask2).sum() / numpy.logical_or(mask1, mask2).sum()
for label in pin_labels or []:
count = sum(label in mask for mask in (mask1, mask2))
count_str = {1: "one", 2: "both"}
assert count == 2, f"Label {label} is pinned but missing in {count_str[2 - count]} of the images."
return min(_multiobject_intersection_over_union(mask1, mask2, pin_labels))


def _parse_label_list(label_list_str: Optional[str]) -> List[int]:
if label_list_str is None:
return []
else:
return min(_multiobject_intersection_over_union(mask1, mask2))
return [int(label.strip()) for label in label_list_str.split(",") if len(label_list_str) > 0]


def get_image_metric(
attributes: Dict[str, Any]
) -> Callable[["numpy.typing.NDArray", "numpy.typing.NDArray"], "numpy.floating"]:
metric_name = attributes.get("metric", DEFAULT_METRIC)
pin_labels = _parse_label_list(attributes.get("pin_labels", DEFAULT_PIN_LABELS))
metrics = {
"mae": lambda arr1, arr2: numpy.abs(arr1 - arr2).mean(),
# Convert to float before squaring to prevent overflows
"mse": lambda arr1, arr2: numpy.square((arr1 - arr2).astype(float)).mean(),
"rms": lambda arr1, arr2: math.sqrt(numpy.square((arr1 - arr2).astype(float)).mean()),
"fro": lambda arr1, arr2: numpy.linalg.norm((arr1 - arr2).reshape(1, -1), "fro"),
"iou": lambda arr1, arr2: 1 - intersection_over_union(arr1, arr2),
"iou": lambda arr1, arr2: 1 - intersection_over_union(arr1, arr2, pin_labels),
}
try:
return metrics[metric_name]
Expand Down
7 changes: 6 additions & 1 deletion lib/galaxy/tool_util/xsd/galaxy.xsd
Original file line number Diff line number Diff line change
Expand Up @@ -1825,6 +1825,11 @@ If you specify a `checksum`, it will be also used to check the integrity of the
<xs:documentation xml:lang="en">If ``compare`` is set to ``image_diff``, this is the maximum allowed distance between the data set that is generated in the test and the file in ``test-data/`` that is referenced by the ``file`` attribute, with distances computed with respect to the specified ``metric``. Default value is 0.01.</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="pin_labels" type="xs:string" use="optional">
<xs:annotation>
<xs:documentation xml:lang="en">If ``compare`` is set to ``image_diff`` and ``metric`` is set to ``iou``, by default, object correspondances are established by maximizing the pairwise intersection over the union. If, however, the label of an object is listed in ``pin_labels``, then the corresponding object is determined according to the same label value (and that object cannot be the corresponding object of any other object with a different label).</xs:documentation>
</xs:annotation>
</xs:attribute>
</xs:complexType>
<xs:group name="TestOutputElement">
<xs:choice>
Expand Down Expand Up @@ -7788,7 +7793,7 @@ favour of a ``has_size`` assertion.</xs:documentation>
</xs:simpleType>
<xs:simpleType name="TestOutputMetricType">
<xs:annotation>
<xs:documentation xml:lang="en">If ``compare`` is set to ``image_diff``, this is the metric used to compute the distance between images for quantification of their difference. For intensity images, possible metrics are *mean absolute error* (``mae``, the default), *mean squared error* (``mse``), *root mean squared* error (``rms``), and the *Frobenius norm* (``fro``). In addition, for binary images and label maps (with multiple objects), ``iou`` can be used to compute *one minus* the *intersection over the union* (IoU). Object correspondances are established by taking the pair of objects, for which the IoU is highest, and the distance of the images is the worst value determined for any pair of corresponding objects.</xs:documentation>
<xs:documentation xml:lang="en">If ``compare`` is set to ``image_diff``, this is the metric used to compute the distance between images for quantification of their difference. For intensity images, possible metrics are *mean absolute error* (``mae``, the default), *mean squared error* (``mse``), *root mean squared* error (``rms``), and the *Frobenius norm* (``fro``). In addition, for binary images and label maps (with multiple objects), ``iou`` can be used to compute *one minus* the *intersection over the union* (IoU). Object correspondances are established by taking the pair of objects, for which the IoU is highest (also see the ``pin_labels`` attribute), and the distance of the images is the worst value determined for any pair of corresponding objects.</xs:documentation>
</xs:annotation>
<xs:restriction base="xs:string">
<xs:enumeration value="mae"/>
Expand Down
8 changes: 8 additions & 0 deletions test/functional/tools/image_diff.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
<param name="in" value="im2_a.png" />
<output name="out" value="im2_b.png" compare="image_diff" metric="mae" eps="0.25" />
</test>
<test>
<param name="in" value="im2_a.png" />
<output name="out" value="im2_b.png" compare="image_diff" metric="iou" eps="0.75" />
</test>
<test expect_test_failure="true">
<param name="in" value="im2_a.png" />
<output name="out" value="im2_b.png" compare="image_diff" metric="iou" eps="0.75" pin_labels="2" />
</test>
<!-- test RGB data -->
<test>
<param name="in" value="im3_a.png" />
Expand Down
14 changes: 11 additions & 3 deletions test/unit/tool_util/test_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def _encode_image(im, **kwargs):
F9 = _encode_image(
numpy.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 1, 2],
[200, 200, 200],
[200, 1, 200],
[200, 1, 2],
],
dtype=numpy.uint8,
),
Expand Down Expand Up @@ -179,6 +179,14 @@ def generate_tests_image_diff():
(f6, f7, {"metric": "fro", "eps": 100 - 1e-4}, AssertionError),
(f6, f9, {"metric": "iou", "eps": (1 - 1 / 8) + 1e-4}, None),
(f6, f9, {"metric": "iou", "eps": (1 - 1 / 8) - 1e-4}, AssertionError),
# tests `pin_labels` with a label not present in any image
(f6, f9, {"metric": "iou", "eps": 0.999999, "pin_labels": "5"}, AssertionError),
# tests `pin_labels` with a label present in both images
(f6, f9, {"metric": "iou", "eps": 0.999999, "pin_labels": "200"}, AssertionError),
(f6, f9, {"metric": "iou", "eps": 1.0, "pin_labels": "200"}, None),
# tests `pin_labels` with a label only present in one image
(f6, f9, {"metric": "iou", "eps": 1.0, "pin_labels": "200, 1"}, AssertionError),
(f6, f9, {"metric": "iou", "eps": 1.0, "pin_labels": "200, 255"}, AssertionError),
]
return tests

Expand Down

0 comments on commit 28c8b2f

Please sign in to comment.