From 4149ea6442930fc117bbec5c068604ace111a06b Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Tue, 7 Jan 2025 13:38:53 -0500 Subject: [PATCH] fix(image_list): sorting on scores works --- src/nrtk_explorer/app/core.py | 2 +- src/nrtk_explorer/app/embeddings.py | 30 ++-- src/nrtk_explorer/app/images/annotations.py | 8 +- src/nrtk_explorer/app/images/images.py | 15 +- src/nrtk_explorer/app/transforms.py | 143 ++++++++++++------- src/nrtk_explorer/app/ui/image_list.css | 8 ++ src/nrtk_explorer/app/ui/image_list.py | 27 ++-- src/nrtk_explorer/library/object_detector.py | 2 + 8 files changed, 152 insertions(+), 83 deletions(-) diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index 00c8c6f6..fc16642c 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -32,7 +32,7 @@ DEFAULT_DATASETS = [ f"{DIR_NAME}/coco-od-2017/test_val2017.json", ] -NUM_IMAGES_DEFAULT = 500 +NUM_IMAGES_DEFAULT = 200 NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds diff --git a/src/nrtk_explorer/app/embeddings.py b/src/nrtk_explorer/app/embeddings.py index 9fd73b8a..06f48535 100644 --- a/src/nrtk_explorer/app/embeddings.py +++ b/src/nrtk_explorer/app/embeddings.py @@ -2,6 +2,7 @@ from nrtk_explorer.library import embeddings_extractor from nrtk_explorer.library import dimension_reducers from nrtk_explorer.library.dataset import get_dataset +from nrtk_explorer.library.scoring import partition from nrtk_explorer.app.applet import Applet from nrtk_explorer.app.images.image_ids import ( @@ -65,9 +66,7 @@ def on_server_ready(self, *args, **kwargs): self.state.change("dataset_ids")(self.update_points) self.server.controller.apply_transform.add(self.clear_points_transformations) - self.state.change("transform_enabled_switch")( - self.update_points_transformations_visibility - ) + self.state.change("transform_enabled_switch")(self.update_points_transformations_state) def on_feature_extraction_model_change(self, **kwargs): feature_extraction_model = self.state.feature_extraction_model @@ -118,16 +117,16 @@ def clear_points_transformations(self, **kwargs): self.state.points_transformations = {} # ID to point self._stashed_points_transformations = {} - def update_points_transformations_visibility(self, **kwargs): + def update_points_transformations_state(self, **kwargs): if self.state.transform_enabled_switch: self.state.points_transformations = self._stashed_points_transformations else: - self._stashed_points_transformations = self.state.points_transformations self.state.points_transformations = {} async def compute_source_points(self): with self.state: self.state.is_loading = True + self.clear_points_transformations() # Don't lock server before enabling the spinner on client await self.server.network_completion @@ -146,8 +145,6 @@ async def compute_source_points(self): id: point for id, point in zip(self.state.dataset_ids, points) } - self.clear_points_transformations() - self.state.camera_position = [] with self.state: @@ -162,16 +159,25 @@ def on_run_clicked(self): self.update_points() def on_run_transformations(self, id_to_image): + hits, misses = partition( + lambda id: image_id_to_dataset_id(id) in self._stashed_points_transformations, + id_to_image.keys(), + ) + + to_plot = {id: id_to_image[id] for id in misses} transformation_features = self.extractor.extract( - id_to_image.values(), + list(to_plot.values()), batch_size=int(self.state.model_batch_size), ) - points = self.compute_points(self.features, transformation_features) + ids_to_points = zip(to_plot.keys(), points) - ids = id_to_image.keys() - updated_points = {image_id_to_dataset_id(id): point for id, point in zip(ids, points)} - self.state.points_transformations = {**self.state.points_transformations, **updated_points} + updated_points = {image_id_to_dataset_id(id): point for id, point in ids_to_points} + self._stashed_points_transformations = { + **self._stashed_points_transformations, + **updated_points, + } + self.update_points_transformations_state() # called by category filter def on_select(self, image_ids): diff --git a/src/nrtk_explorer/app/images/annotations.py b/src/nrtk_explorer/app/images/annotations.py index 67498b60..e3acef7b 100644 --- a/src/nrtk_explorer/app/images/annotations.py +++ b/src/nrtk_explorer/app/images/annotations.py @@ -68,8 +68,9 @@ def __init__( self.delete_from_cache_callback = delete_from_cache_callback def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]): - hits, misses = partition(self.cache.get_item, id_to_image.keys()) - cached_predictions = {id: self.cache.get_item(id) for id in hits} + hits, misses = partition( + lambda id: self.cache.get_item(id) is not None, id_to_image.keys() + ) to_detect = {id: id_to_image[id] for id in misses} predictions = detector.eval( @@ -80,8 +81,7 @@ def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback ) - predictions.update(**cached_predictions) - return predictions + return {id: self.cache.get_item(id) for id in id_to_image.keys()} def cache_clear(self): self.cache.clear() diff --git a/src/nrtk_explorer/app/images/images.py b/src/nrtk_explorer/app/images/images.py index ebcfc35d..c0f776b1 100644 --- a/src/nrtk_explorer/app/images/images.py +++ b/src/nrtk_explorer/app/images/images.py @@ -18,7 +18,7 @@ def convert_to_base64(img: Image.Image) -> str: return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() -IMAGE_CACHE_SIZE = 200 +IMAGE_CACHE_SIZE = 500 @TrameApp() @@ -74,11 +74,15 @@ def _load_transformed_image(self, transform: ImageTransform, dataset_id: str): return transformed.resize(original.size) return transformed - def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs): + def _get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs): image_id = dataset_id_to_transformed_image_id(dataset_id) image = self.transformed_images.get_item(image_id) or self._load_transformed_image( transform, dataset_id ) + return image_id, image + + def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs): + image_id, image = self._get_transformed_image(transform, dataset_id, **kwargs) self.transformed_images.add_item(image_id, image, **kwargs) return image @@ -90,6 +94,13 @@ def get_stateful_transformed_image(self, transform: ImageTransform, dataset_id: on_clear_item=self._delete_from_state, ) + def get_transformed_image_without_cache_eviction( + self, transform: ImageTransform, dataset_id: str + ): + image_id, image = self._get_transformed_image(transform, dataset_id) + self.transformed_images.add_if_room(image_id, image) + return image + @change("current_dataset") def clear_all(self, **kwargs): self.original_images.clear() diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index e0aabd5b..ee21330b 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -1,5 +1,6 @@ import logging from typing import Dict, Callable +from collections.abc import Mapping from trame.ui.quasar import QLayout from trame.widgets import quasar @@ -44,10 +45,36 @@ ] +UPDATE_IMAGES_CHUNK_SIZE = 32 + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +class LazyDict(Mapping): + def __init__(self, *args, **kw): + self._raw_dict = dict(*args, **kw) + + def __getitem__(self, key): + val = self._raw_dict[key] + return val() if callable(val) else val + + def __setitem__(self, key, value): + self._raw_dict[key] = value + + def __iter__(self): + return iter(self._raw_dict) + + def __len__(self): + return len(self._raw_dict) + + def values(self): + return (self[k] for k in self._raw_dict) + + def items(self): + return ((k, self[k]) for k in self._raw_dict) + + class ProcessingStep: def __init__( self, @@ -157,6 +184,8 @@ def delete_meta_state(old_ids, new_ids): delete_state(self.state, dataset_id_to_meta(id)) change_checker(self.state, "dataset_ids")(delete_meta_state) + # clear score when changing model + # clear score when changing transform self._parameters_app = ParametersApp( server=server, @@ -201,7 +230,7 @@ def delete_meta_state(old_ids, new_ids): feature_enabled_state_key="transform_enabled", gui_switch_key="transform_enabled_switch", column_name=TRANSFORM_COLUMNS[0], - enabled_callback=self._start_transformed_images, + enabled_callback=self._start_update_images, ) self.server.controller.on_server_ready.add(self.on_server_ready) @@ -236,44 +265,32 @@ def on_transform(self, *args, **kwargs): def on_apply_transform(self, **kwargs): # Turn on switch if user clicked lower apply button self.state.transform_enabled_switch = True - self._start_transformed_images() - - def _start_transformed_images(self, *args, **kwargs): - logger.debug("_start_transformed_images") - if self._updating_images(): - if self._updating_transformed_images: - # computing stale transformed images, restart task - self._cancel_update_images() - else: - return # update_images will call update_transformed_images() at the end - self._update_task = asynchronous.create_task( - self.update_transformed_images(self.visible_dataset_ids) - ) - - async def update_transformed_images(self, dataset_ids): - self._updating_transformed_images = True - try: - await self._update_transformed_images(dataset_ids) - finally: - self._updating_transformed_images = False + self._start_update_images() - async def _update_transformed_images(self, dataset_ids): + async def update_transformed_images(self, dataset_ids, visible=False): if not self.state.transform_enabled: return transforms = list(map(lambda t: t["instance"], self.context.transforms)) transform = trans.ChainedImageTransform(transforms) - id_to_matching_size_img = {} + id_to_image = LazyDict() for id in dataset_ids: - with self.state: - transformed = self.images.get_stateful_transformed_image(transform, id) - id_to_matching_size_img[dataset_id_to_transformed_image_id(id)] = transformed - await self.server.network_completion + if visible: + with self.state: + transformed = self.images.get_stateful_transformed_image(transform, id) + id_to_image[dataset_id_to_transformed_image_id(id)] = transformed + await self.server.network_completion + else: + id_to_image[dataset_id_to_transformed_image_id(id)] = ( + lambda id=id: self.images.get_transformed_image_without_cache_eviction( + transform, id + ) + ) with self.state: annotations = self.transformed_detection_annotations.get_annotations( - self.detector, id_to_matching_size_img + self.detector, id_to_image ) await self.server.network_completion @@ -304,25 +321,26 @@ async def _update_transformed_images(self, dataset_ids): {"original_detection_to_transformed_detection_score": score}, ) - id_to_image = { - dataset_id_to_transformed_image_id(id): self.images.get_transformed_image( - transform, id - ) - for id in dataset_ids - } - - self.on_transform(id_to_image) - self.state.flush() # needed cuz in async func and modifying state or else UI does not update + # sortable score value may have changed which may have changed images that are in view + self.server.controller.check_images_in_view() + + self.on_transform(id_to_image) # inform embeddings app + self.state.flush() def compute_predictions_original_images(self, dataset_ids): if not self.state.predictions_original_images_enabled: return - image_id_to_image = { - dataset_id_to_image_id(id): self.images.get_image_without_cache_eviction(id) - for id in dataset_ids - } + image_id_to_image = LazyDict( + { + dataset_id_to_image_id( + id + ): lambda id=id: self.images.get_image_without_cache_eviction(id) + for id in dataset_ids + } + ) + self.predictions_original_images = self.original_detection_annotations.get_annotations( self.detector, image_id_to_image ) @@ -340,24 +358,42 @@ def compute_predictions_original_images(self, dataset_ids): self.state, dataset_id, {"original_ground_to_original_detection_score": score} ) - async def _update_images(self, dataset_ids): - # load images on state for ImageList - for id in dataset_ids: + async def _update_images(self, dataset_ids, visible=False): + if visible: + # load images on state for ImageList with self.state: - self.images.get_stateful_image(id) + for id in dataset_ids: + self.images.get_stateful_image(id) + self.ground_truth_annotations.get_annotations(dataset_ids) await self.server.network_completion + # always push to state because compute_predictions_original_images updates score metadata with self.state: - self.ground_truth_annotations.get_annotations(dataset_ids) + self.compute_predictions_original_images(dataset_ids) await self.server.network_completion + # sortable score value may have changed which may have changed images that are in view + self.server.controller.check_images_in_view() + + await self.update_transformed_images(dataset_ids, visible=visible) + async def _chunk_update_images(self, dataset_ids, visible=False): + ids = list(dataset_ids) + + for i in range(0, len(ids), UPDATE_IMAGES_CHUNK_SIZE): + chunk = ids[i : i + UPDATE_IMAGES_CHUNK_SIZE] + await self._update_images(chunk, visible=visible) + + async def _update_all_images(self, visible_images): with self.state: - self.compute_predictions_original_images(dataset_ids) - await self.server.network_completion + self.state.updating_images = True + + await self._chunk_update_images(visible_images, visible=True) + + other_images = set(self.state.user_selected_ids) - set(visible_images) + await self._chunk_update_images(other_images, visible=False) with self.state: - await self.update_transformed_images(dataset_ids) - await self.server.network_completion + self.state.updating_images = False def _cancel_update_images(self, **kwargs): if hasattr(self, "_update_task"): @@ -365,10 +401,9 @@ def _cancel_update_images(self, **kwargs): def _start_update_images(self, **kwargs): self._cancel_update_images() - self._update_task = asynchronous.create_task(self._update_images(self.visible_dataset_ids)) - - def _updating_images(self): - return hasattr(self, "_update_task") and not self._update_task.done() + self._update_task = asynchronous.create_task( + self._update_all_images(self.visible_dataset_ids) + ) def on_scroll(self, visible_ids): self.visible_dataset_ids = visible_ids diff --git a/src/nrtk_explorer/app/ui/image_list.css b/src/nrtk_explorer/app/ui/image_list.css index 4d660a76..fab77872 100644 --- a/src/nrtk_explorer/app/ui/image_list.css +++ b/src/nrtk_explorer/app/ui/image_list.css @@ -8,4 +8,12 @@ thead tr:first-child th { top: 0; } + + /* this is when the loading indicator appears */ + tr:last-child.q-table__progress th { + /* height of all previous header rows */ + top: 48px; + position: sticky; + z-index: 1; + } } diff --git a/src/nrtk_explorer/app/ui/image_list.py b/src/nrtk_explorer/app/ui/image_list.py index b9b9f0cf..4337f044 100644 --- a/src/nrtk_explorer/app/ui/image_list.py +++ b/src/nrtk_explorer/app/ui/image_list.py @@ -117,6 +117,11 @@ def __init__( @TrameApp() class ImageList(html.Div): + # keep identical ID across datasets from stopping update + @change("current_dataset") + def clear_old_visible_ids(self, **kwargs): + self.visible_ids = set() + def set_in_view_ids(self, ids): visible = set(ids) if self.visible_ids != visible: @@ -137,18 +142,18 @@ def update_image_list_ids(self, **kwargs): self._set_image_list_ids(self.state.user_selected_ids) @change("image_list_ids") - def reset_view_range(self, **kwargs): - self.visible_ids = set() - self.server.js_call(ref="image-list", method="resetVirtualScroll") + def check_images_in_view(self, **kwargs): if self.state.image_list_view_mode == "grid": - self.server.controller.get_visible_ids() + self.server.controller.get_visible_ids_for_grid() + return + self.server.js_call(ref="image-list", method="resetVirtualScroll") @change("image_list_view_mode") def update_pagination(self, **kwargs): old_pagination = self.state.pagination or {} if self.state.image_list_view_mode == "grid": self.state.pagination = {**old_pagination, "rowsPerPage": 12} - self.server.controller.get_visible_ids() + self.server.controller.get_visible_ids_for_grid() else: self.state.pagination = {**old_pagination, "rowsPerPage": 0} # show all rows @@ -164,7 +169,7 @@ def __init__(self, on_scroll, on_hover, **kwargs): with self: client.Style(CSS_FILE.read_text()) - get_visible_ids = client.JSEval( + get_visible_ids_for_grid = client.JSEval( exec=f''' ;const list = trame.refs['image-list'] if (!list) return @@ -175,13 +180,15 @@ def __init__(self, on_scroll, on_hover, **kwargs): }}, 0) "''', ) - self.ctrl.get_visible_ids = get_visible_ids.exec + self.ctrl.get_visible_ids_for_grid = get_visible_ids_for_grid.exec + self.ctrl.check_images_in_view = self.check_images_in_view with quasar.QTable( ref=("image-list"), classes="full-height sticky-header", flat=True, hide_bottom=("image_list_view_mode !== 'grid'", True), title="Sampled Images", + loading=("updating_images", False), grid=("image_list_view_mode === 'grid'", False), filter=("image_list_search", ""), id="image-list", # set id so that the ImageDetection component can select the container for tooltip positioning @@ -226,9 +233,9 @@ def __init__(self, on_scroll, on_hover, **kwargs): }}"''', "virtual-scroll-sticky-size-start='48'", r"v-model:pagination='pagination'", - f'''@update:pagination="() => {{ - if(get('image_list_view_mode').value !== 'grid') return; - trigger('{ self.server.controller.trigger_name(self.ctrl.get_visible_ids) }') + f'''@update:pagination="(e) => {{ + console.log('pagination updated') + trigger('{ self.server.controller.trigger_name(self.ctrl.check_images_in_view) }') }}"''', ], ): diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/object_detector.py index 281ffa0d..56d91301 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/object_detector.py @@ -68,6 +68,8 @@ def eval( batch_size: int = 0, # 0 means use last successful batch size ) -> ImageIdToAnnotations: """Compute object recognition. Returns Annotations grouped by input image paths.""" + if len(images) == 0: + return {} # optimization images_with_ids = [ImageWithId(id, img) for id, img in images.items()]