-
Notifications
You must be signed in to change notification settings - Fork 16
/
action_matching.py
executable file
·230 lines (189 loc) · 8.16 KB
/
action_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
'''
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
'''
import jax
import jax.numpy as jnp
import numpy as np
import action_type as action_type_lib
_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
# Interval determining if an action is a tap or a swipe.
_SWIPE_DISTANCE_THRESHOLD = 0.04
def _yx_in_bounding_boxes(
yx, bounding_boxes
):
"""Check if the (y,x) point is contained in each bounding box.
Args:
yx: The (y, x) coordinate in pixels of the point.
bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
represents a bounding box: (y_top_left, x_top_left, box_height,
box_width). Note: containment is inclusive of the bounding box edges.
Returns:
is_inside: A 1D bool array where each element specifies if the point is
contained within the respective box.
"""
y, x = yx
# `bounding_boxes` has shape (n_elements, 4); we extract each array along the
# last axis into shape (n_elements, 1), then squeeze unneeded dimension.
top, left, height, width = [
jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
]
# The y-axis is inverted for AndroidEnv, so bottom = top + height.
bottom, right = top + height, left + width
return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
x >= left, x <= right)
def _resize_annotation_bounding_boxes(
annotation_positions, annotation_width_augment_fraction,
annotation_height_augment_fraction):
"""Resize the bounding boxes by the given fractions.
Args:
annotation_positions: Array of shape (N, 4), where each row represents the
(y, x, height, width) of the bounding boxes.
annotation_width_augment_fraction: The fraction to augment the box widths,
E.g., 1.4 == 240% total increase.
annotation_height_augment_fraction: Same as described for width, but for box
height.
Returns:
Resized bounding box.
"""
height_change = (
annotation_height_augment_fraction * annotation_positions[:, 2])
width_change = (
annotation_width_augment_fraction * annotation_positions[:, 3])
# Limit bounding box positions to the screen.
resized_annotations = jnp.stack([
jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
jnp.minimum(1, annotation_positions[:, 2] + height_change),
jnp.minimum(1, annotation_positions[:, 3] + width_change),
],
axis=1)
return resized_annotations
def is_tap_action(normalized_start_yx,
normalized_end_yx):
distance = jnp.linalg.norm(
jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
return distance <= _SWIPE_DISTANCE_THRESHOLD
def _is_non_dual_point_action(action_type):
return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT)
def _check_tap_actions_match(
tap_1_yx,
tap_2_yx,
annotation_positions,
matching_tap_distance_threshold_screen_percentage,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
):
"""Determines if two tap actions are the same."""
resized_annotation_positions = _resize_annotation_bounding_boxes(
annotation_positions,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
# Check if the ground truth tap action falls in an annotation's bounding box.
tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
both_in_box = jnp.max(tap1_in_box & tap2_in_box)
# If the ground-truth tap action falls outside any of the annotation
# bounding boxes or one of the actions is inside a bounding box and the other
# is outside bounding box or vice versa, compare the points using Euclidean
# distance.
within_threshold = (
jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
<= matching_tap_distance_threshold_screen_percentage
)
return jnp.logical_or(both_in_box, within_threshold)
def _check_drag_actions_match(
drag_1_touch_yx,
drag_1_lift_yx,
drag_2_touch_yx,
drag_2_lift_yx,
):
"""Determines if two drag actions are the same."""
# Store drag deltas (the change in the y and x coordinates from touch to
# lift), magnitudes, and the index of the main axis, which is the axis with
# the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
# ending at (0.3, 0.5) has a main axis index of 1).
drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
drag_1_magnitudes = jnp.abs(drag_1_deltas)
drag_1_main_axis = np.argmax(drag_1_magnitudes)
drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
drag_2_magnitudes = jnp.abs(drag_2_deltas)
drag_2_main_axis = np.argmax(drag_2_magnitudes)
return jnp.equal(drag_1_main_axis, drag_2_main_axis)
def check_actions_match(
action_1_touch_yx,
action_1_lift_yx,
action_1_action_type,
action_2_touch_yx,
action_2_lift_yx,
action_2_action_type,
annotation_positions,
tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
):
"""Determines if two actions are considered to be the same.
Two actions being "the same" is defined here as two actions that would result
in a similar screen state.
Args:
action_1_touch_yx: The (y, x) coordinates of the first action's touch.
action_1_lift_yx: The (y, x) coordinates of the first action's lift.
action_1_action_type: The action type of the first action.
action_2_touch_yx: The (y, x) coordinates of the second action's touch.
action_2_lift_yx: The (y, x) coordinates of the second action's lift.
action_2_action_type: The action type of the second action.
annotation_positions: The positions of the UI annotations for the screen. It
is A 2D int array of shape (num_bboxes, 4), where each row represents a
bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
containment is inclusive of the bounding box edges.
tap_distance_threshold: The threshold that determines if two taps result in
a matching screen state if they don't fall the same bounding boxes.
annotation_width_augment_fraction: The fraction to increase the width of the
bounding box by.
annotation_height_augment_fraction: The fraction to increase the height of
of the bounding box by.
Returns:
A boolean representing whether the two given actions are the same or not.
"""
action_1_touch_yx = jnp.asarray(action_1_touch_yx)
action_1_lift_yx = jnp.asarray(action_1_lift_yx)
action_2_touch_yx = jnp.asarray(action_2_touch_yx)
action_2_lift_yx = jnp.asarray(action_2_lift_yx)
# Checks if at least one of the actions is global (i.e. not DUAL_POINT),
# because if that is the case, only the actions' types need to be compared.
has_non_dual_point_action = jnp.logical_or(
_is_non_dual_point_action(action_1_action_type),
_is_non_dual_point_action(action_2_action_type),
)
different_dual_point_types = jnp.logical_xor(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
is_tap = jnp.logical_and(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
taps_match = _check_tap_actions_match(
action_1_touch_yx,
action_2_touch_yx,
annotation_positions,
tap_distance_threshold,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
taps_match = jnp.logical_and(is_tap, taps_match)
drags_match = _check_drag_actions_match(
action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
)
drags_match = jnp.where(is_tap, False, drags_match)
return jnp.where(
has_non_dual_point_action,
jnp.equal(action_1_action_type, action_2_action_type),
jnp.where(
different_dual_point_types,
False,
jnp.logical_or(taps_match, drags_match),
),
)