Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize honeypot selection algorithm #8857

Merged
merged 7 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def _to_abs_frame(rel_frame: int) -> int:
)

if bulk_context:
active_validation_frame_counts = bulk_context.active_validation_frame_counts
frame_selector = bulk_context.honeypot_frame_selector
else:
active_validation_frame_counts = {
validation_frame: 0 for validation_frame in task_active_validation_frames
Expand All @@ -1111,7 +1111,8 @@ def _to_abs_frame(rel_frame: int) -> int:
if real_frame in task_active_validation_frames:
active_validation_frame_counts[real_frame] += 1

frame_selector = HoneypotFrameSelector(active_validation_frame_counts)
frame_selector = HoneypotFrameSelector(active_validation_frame_counts)

requested_frames = frame_selector.select_next_frames(segment_honeypots_count)
requested_frames = list(map(_to_abs_frame, requested_frames))
else:
Expand Down Expand Up @@ -1358,7 +1359,7 @@ def __init__(
honeypot_frames: list[int],
all_validation_frames: list[int],
active_validation_frames: list[int],
validation_frame_counts: dict[int, int] | None = None
honeypot_frame_selector: HoneypotFrameSelector | None = None
):
self.updated_honeypots: dict[int, models.Image] = {}
self.updated_segments: list[int] = []
Expand All @@ -1370,7 +1371,7 @@ def __init__(
self.honeypot_frames = honeypot_frames
self.all_validation_frames = all_validation_frames
self.active_validation_frames = active_validation_frames
self.active_validation_frame_counts = validation_frame_counts
self.honeypot_frame_selector = honeypot_frame_selector

class TaskValidationLayoutWriteSerializer(serializers.Serializer):
disabled_frames = serializers.ListField(
Expand Down Expand Up @@ -1485,7 +1486,9 @@ def update(self, instance: models.Task, validated_data: dict[str, Any]) -> model
)
elif frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM:
# Reset distribution for active validation frames
bulk_context.active_validation_frame_counts = { f: 0 for f in active_validation_frames }
active_validation_frame_counts = { f: 0 for f in active_validation_frames }
frame_selector = HoneypotFrameSelector(active_validation_frame_counts)
bulk_context.honeypot_frame_selector = frame_selector

# Could be done using Django ORM, but using order_by() and filter()
# would result in an extra DB request
Expand Down
125 changes: 100 additions & 25 deletions cvat/apps/engine/task_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,122 @@
#
# SPDX-License-Identifier: MIT

from collections.abc import Mapping, Sequence
from typing import Generic, TypeVar
from __future__ import annotations

from typing import Callable, Counter, Generic, Iterable, Mapping, Sequence, TypeVar
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved

import attrs
import numpy as np

_T = TypeVar("_T")
_K = TypeVar("_K")


@attrs.define
class _BaggedCounter(Generic[_K]):
# Stores items with count = k in a single "bag". Bags are stored in the ascending order
bags: dict[
int,
dict[_K, None],
# dict is used instead of a set to preserve item order. It's also more performant
]

@staticmethod
def from_iterable(items: Sequence[_K]) -> _BaggedCounter:
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
return _BaggedCounter.from_dict(Counter(items))

@staticmethod
def from_dict(item_counts: Mapping[_K, int]) -> _BaggedCounter:
return _BaggedCounter.from_counts(item_counts, item_count=item_counts.__getitem__)

@staticmethod
def from_counts(items: Sequence[_K], item_count: Callable[[_K], int]) -> _BaggedCounter:
bags = {}
for item in items:
count = item_count(item)
bags.setdefault(count, dict())[item] = None

return _BaggedCounter(bags=bags)

def __attrs_post_init__(self):
self._sort_bags()

def _sort_bags(self):
self.bags = dict(sorted(self.bags.items(), key=lambda e: e[0]))

def shuffle(self, *, rng: np.random.Generator | None):
if not rng:
rng = np.random.default_rng()

Check notice

Code scanning / SonarCloud

Results that depend on random number generation should be reproducible Low

Provide a seed for this random generator. See more on SonarQube Cloud

for count, bag in self.bags.items():
items = list(bag.items())
rng.shuffle(items)
self.bags[count] = dict(items)

def use_item(self, item: _K, *, count: int | None = None, bag: dict | None = None):
if count is not None:
if bag is None:
bag = self.bags[count]
elif count is None and bag is None:
count, bag = next((c, b) for c, b in self.bags.items() if item in b)
else:
raise AssertionError("'bag' can only be used together with 'count'")

bag.pop(item)

class HoneypotFrameSelector(Generic[_T]):
if not bag:
self.bags.pop(count)

next_bag = self.bags.get(count + 1)
if next_bag is None:
next_bag = {}
self.bags[count + 1] = next_bag
self._sort_bags() # the new bag can be added in the wrong position if there were gaps

next_bag[item] = None

def __iter__(self) -> Iterable[tuple[int, _K, dict]]:
for count, bag in self.bags.items(): # bags must be ordered
for item in bag:
yield (count, item, bag)

def select_next_least_used(self, count: int) -> Sequence[_K]:
pick = [None] * count
pick_original_use_counts = [(None, None)] * count
for i, (use_count, item, bag) in zip(range(count), self):
pick[i] = item
pick_original_use_counts[i] = (use_count, bag)

for item, (use_count, bag) in zip(pick, pick_original_use_counts):
self.use_item(item, count=use_count, bag=bag)

return pick


class HoneypotFrameSelector(Generic[_K]):
def __init__(
self, validation_frame_counts: Mapping[_T, int], *, rng: np.random.Generator | None = None
self,
validation_frame_counts: Mapping[_K, int],
*,
rng: np.random.Generator | None = None,
):
self.validation_frame_counts = validation_frame_counts

if not rng:
rng = np.random.default_rng()

self.rng = rng

def select_next_frames(self, count: int) -> Sequence[_T]:
self._counter = _BaggedCounter.from_dict(validation_frame_counts)
self._counter.shuffle(rng=rng)

def select_next_frames(self, count: int) -> Sequence[_K]:
# This approach guarantees that:
# - every GT frame is used
# - GT frames are used uniformly (at most min count + 1)
# - GT frames are not repeated in jobs
# - honeypot sets are different in jobs
# - honeypot sets are random
# if possible (if the job and GT counts allow this).
pick = []

for random_number in self.rng.random(count):
least_count = min(c for f, c in self.validation_frame_counts.items() if f not in pick)
least_used_frames = tuple(
f
for f, c in self.validation_frame_counts.items()
if f not in pick
if c == least_count
)

selected_item = int(random_number * len(least_used_frames))
selected_frame = least_used_frames[selected_item]
pick.append(selected_frame)
self.validation_frame_counts[selected_frame] += 1

return pick
# Picks must be reproducible for a given rng state.
"""
Selects 'count' least used items randomly, without repetition
"""
return self._counter.select_next_least_used(count)
Loading