Skip to content

Commit

Permalink
fix(image_list): sorting on scores works
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 10, 2025
1 parent af14323 commit 4149ea6
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 83 deletions.
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DEFAULT_DATASETS = [
f"{DIR_NAME}/coco-od-2017/test_val2017.json",
]
NUM_IMAGES_DEFAULT = 500
NUM_IMAGES_DEFAULT = 200
NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds


Expand Down
30 changes: 18 additions & 12 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.library.scoring import partition
from nrtk_explorer.app.applet import Applet

from nrtk_explorer.app.images.image_ids import (
Expand Down Expand Up @@ -65,9 +66,7 @@ def on_server_ready(self, *args, **kwargs):
self.state.change("dataset_ids")(self.update_points)

self.server.controller.apply_transform.add(self.clear_points_transformations)
self.state.change("transform_enabled_switch")(
self.update_points_transformations_visibility
)
self.state.change("transform_enabled_switch")(self.update_points_transformations_state)

def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
Expand Down Expand Up @@ -118,16 +117,16 @@ def clear_points_transformations(self, **kwargs):
self.state.points_transformations = {} # ID to point
self._stashed_points_transformations = {}

def update_points_transformations_visibility(self, **kwargs):
def update_points_transformations_state(self, **kwargs):
if self.state.transform_enabled_switch:
self.state.points_transformations = self._stashed_points_transformations
else:
self._stashed_points_transformations = self.state.points_transformations
self.state.points_transformations = {}

async def compute_source_points(self):
with self.state:
self.state.is_loading = True
self.clear_points_transformations()

# Don't lock server before enabling the spinner on client
await self.server.network_completion
Expand All @@ -146,8 +145,6 @@ async def compute_source_points(self):
id: point for id, point in zip(self.state.dataset_ids, points)
}

self.clear_points_transformations()

self.state.camera_position = []

with self.state:
Expand All @@ -162,16 +159,25 @@ def on_run_clicked(self):
self.update_points()

def on_run_transformations(self, id_to_image):
hits, misses = partition(
lambda id: image_id_to_dataset_id(id) in self._stashed_points_transformations,
id_to_image.keys(),
)

to_plot = {id: id_to_image[id] for id in misses}
transformation_features = self.extractor.extract(
id_to_image.values(),
list(to_plot.values()),
batch_size=int(self.state.model_batch_size),
)

points = self.compute_points(self.features, transformation_features)
ids_to_points = zip(to_plot.keys(), points)

ids = id_to_image.keys()
updated_points = {image_id_to_dataset_id(id): point for id, point in zip(ids, points)}
self.state.points_transformations = {**self.state.points_transformations, **updated_points}
updated_points = {image_id_to_dataset_id(id): point for id, point in ids_to_points}
self._stashed_points_transformations = {
**self._stashed_points_transformations,
**updated_points,
}
self.update_points_transformations_state()

# called by category filter
def on_select(self, image_ids):
Expand Down
8 changes: 4 additions & 4 deletions src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def __init__(
self.delete_from_cache_callback = delete_from_cache_callback

def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
hits, misses = partition(self.cache.get_item, id_to_image.keys())
cached_predictions = {id: self.cache.get_item(id) for id in hits}
hits, misses = partition(
lambda id: self.cache.get_item(id) is not None, id_to_image.keys()
)

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
Expand All @@ -80,8 +81,7 @@ def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image
id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback
)

predictions.update(**cached_predictions)
return predictions
return {id: self.cache.get_item(id) for id in id_to_image.keys()}

def cache_clear(self):
self.cache.clear()
15 changes: 13 additions & 2 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def convert_to_base64(img: Image.Image) -> str:
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


IMAGE_CACHE_SIZE = 200
IMAGE_CACHE_SIZE = 500


@TrameApp()
Expand Down Expand Up @@ -74,11 +74,15 @@ def _load_transformed_image(self, transform: ImageTransform, dataset_id: str):
return transformed.resize(original.size)
return transformed

def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
def _get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
image_id = dataset_id_to_transformed_image_id(dataset_id)
image = self.transformed_images.get_item(image_id) or self._load_transformed_image(
transform, dataset_id
)
return image_id, image

def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
image_id, image = self._get_transformed_image(transform, dataset_id, **kwargs)
self.transformed_images.add_item(image_id, image, **kwargs)
return image

Expand All @@ -90,6 +94,13 @@ def get_stateful_transformed_image(self, transform: ImageTransform, dataset_id:
on_clear_item=self._delete_from_state,
)

def get_transformed_image_without_cache_eviction(
self, transform: ImageTransform, dataset_id: str
):
image_id, image = self._get_transformed_image(transform, dataset_id)
self.transformed_images.add_if_room(image_id, image)
return image

@change("current_dataset")
def clear_all(self, **kwargs):
self.original_images.clear()
Expand Down
143 changes: 89 additions & 54 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Dict, Callable
from collections.abc import Mapping

from trame.ui.quasar import QLayout
from trame.widgets import quasar
Expand Down Expand Up @@ -44,10 +45,36 @@
]


UPDATE_IMAGES_CHUNK_SIZE = 32

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class LazyDict(Mapping):
def __init__(self, *args, **kw):
self._raw_dict = dict(*args, **kw)

def __getitem__(self, key):
val = self._raw_dict[key]
return val() if callable(val) else val

def __setitem__(self, key, value):
self._raw_dict[key] = value

def __iter__(self):
return iter(self._raw_dict)

def __len__(self):
return len(self._raw_dict)

def values(self):
return (self[k] for k in self._raw_dict)

def items(self):
return ((k, self[k]) for k in self._raw_dict)


class ProcessingStep:
def __init__(
self,
Expand Down Expand Up @@ -157,6 +184,8 @@ def delete_meta_state(old_ids, new_ids):
delete_state(self.state, dataset_id_to_meta(id))

change_checker(self.state, "dataset_ids")(delete_meta_state)
# clear score when changing model
# clear score when changing transform

self._parameters_app = ParametersApp(
server=server,
Expand Down Expand Up @@ -201,7 +230,7 @@ def delete_meta_state(old_ids, new_ids):
feature_enabled_state_key="transform_enabled",
gui_switch_key="transform_enabled_switch",
column_name=TRANSFORM_COLUMNS[0],
enabled_callback=self._start_transformed_images,
enabled_callback=self._start_update_images,
)

self.server.controller.on_server_ready.add(self.on_server_ready)
Expand Down Expand Up @@ -236,44 +265,32 @@ def on_transform(self, *args, **kwargs):
def on_apply_transform(self, **kwargs):
# Turn on switch if user clicked lower apply button
self.state.transform_enabled_switch = True
self._start_transformed_images()

def _start_transformed_images(self, *args, **kwargs):
logger.debug("_start_transformed_images")
if self._updating_images():
if self._updating_transformed_images:
# computing stale transformed images, restart task
self._cancel_update_images()
else:
return # update_images will call update_transformed_images() at the end
self._update_task = asynchronous.create_task(
self.update_transformed_images(self.visible_dataset_ids)
)

async def update_transformed_images(self, dataset_ids):
self._updating_transformed_images = True
try:
await self._update_transformed_images(dataset_ids)
finally:
self._updating_transformed_images = False
self._start_update_images()

async def _update_transformed_images(self, dataset_ids):
async def update_transformed_images(self, dataset_ids, visible=False):
if not self.state.transform_enabled:
return

transforms = list(map(lambda t: t["instance"], self.context.transforms))
transform = trans.ChainedImageTransform(transforms)

id_to_matching_size_img = {}
id_to_image = LazyDict()
for id in dataset_ids:
with self.state:
transformed = self.images.get_stateful_transformed_image(transform, id)
id_to_matching_size_img[dataset_id_to_transformed_image_id(id)] = transformed
await self.server.network_completion
if visible:
with self.state:
transformed = self.images.get_stateful_transformed_image(transform, id)
id_to_image[dataset_id_to_transformed_image_id(id)] = transformed
await self.server.network_completion
else:
id_to_image[dataset_id_to_transformed_image_id(id)] = (
lambda id=id: self.images.get_transformed_image_without_cache_eviction(
transform, id
)
)

with self.state:
annotations = self.transformed_detection_annotations.get_annotations(
self.detector, id_to_matching_size_img
self.detector, id_to_image
)
await self.server.network_completion

Expand Down Expand Up @@ -304,25 +321,26 @@ async def _update_transformed_images(self, dataset_ids):
{"original_detection_to_transformed_detection_score": score},
)

id_to_image = {
dataset_id_to_transformed_image_id(id): self.images.get_transformed_image(
transform, id
)
for id in dataset_ids
}

self.on_transform(id_to_image)

self.state.flush() # needed cuz in async func and modifying state or else UI does not update
# sortable score value may have changed which may have changed images that are in view
self.server.controller.check_images_in_view()

self.on_transform(id_to_image) # inform embeddings app
self.state.flush()

def compute_predictions_original_images(self, dataset_ids):
if not self.state.predictions_original_images_enabled:
return

image_id_to_image = {
dataset_id_to_image_id(id): self.images.get_image_without_cache_eviction(id)
for id in dataset_ids
}
image_id_to_image = LazyDict(
{
dataset_id_to_image_id(
id
): lambda id=id: self.images.get_image_without_cache_eviction(id)
for id in dataset_ids
}
)

self.predictions_original_images = self.original_detection_annotations.get_annotations(
self.detector, image_id_to_image
)
Expand All @@ -340,35 +358,52 @@ def compute_predictions_original_images(self, dataset_ids):
self.state, dataset_id, {"original_ground_to_original_detection_score": score}
)

async def _update_images(self, dataset_ids):
# load images on state for ImageList
for id in dataset_ids:
async def _update_images(self, dataset_ids, visible=False):
if visible:
# load images on state for ImageList
with self.state:
self.images.get_stateful_image(id)
for id in dataset_ids:
self.images.get_stateful_image(id)
self.ground_truth_annotations.get_annotations(dataset_ids)
await self.server.network_completion

# always push to state because compute_predictions_original_images updates score metadata
with self.state:
self.ground_truth_annotations.get_annotations(dataset_ids)
self.compute_predictions_original_images(dataset_ids)
await self.server.network_completion
# sortable score value may have changed which may have changed images that are in view
self.server.controller.check_images_in_view()

await self.update_transformed_images(dataset_ids, visible=visible)

async def _chunk_update_images(self, dataset_ids, visible=False):
ids = list(dataset_ids)

for i in range(0, len(ids), UPDATE_IMAGES_CHUNK_SIZE):
chunk = ids[i : i + UPDATE_IMAGES_CHUNK_SIZE]
await self._update_images(chunk, visible=visible)

async def _update_all_images(self, visible_images):
with self.state:
self.compute_predictions_original_images(dataset_ids)
await self.server.network_completion
self.state.updating_images = True

await self._chunk_update_images(visible_images, visible=True)

other_images = set(self.state.user_selected_ids) - set(visible_images)
await self._chunk_update_images(other_images, visible=False)

with self.state:
await self.update_transformed_images(dataset_ids)
await self.server.network_completion
self.state.updating_images = False

def _cancel_update_images(self, **kwargs):
if hasattr(self, "_update_task"):
self._update_task.cancel()

def _start_update_images(self, **kwargs):
self._cancel_update_images()
self._update_task = asynchronous.create_task(self._update_images(self.visible_dataset_ids))

def _updating_images(self):
return hasattr(self, "_update_task") and not self._update_task.done()
self._update_task = asynchronous.create_task(
self._update_all_images(self.visible_dataset_ids)
)

def on_scroll(self, visible_ids):
self.visible_dataset_ids = visible_ids
Expand Down
8 changes: 8 additions & 0 deletions src/nrtk_explorer/app/ui/image_list.css
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,12 @@
thead tr:first-child th {
top: 0;
}

/* this is when the loading indicator appears */
tr:last-child.q-table__progress th {
/* height of all previous header rows */
top: 48px;
position: sticky;
z-index: 1;
}
}
Loading

0 comments on commit 4149ea6

Please sign in to comment.