From 90ed652c25e03abbcbf04071aa2e24cba5a06b34 Mon Sep 17 00:00:00 2001 From: Jon Betts Date: Tue, 4 Apr 2023 18:52:27 +0100 Subject: [PATCH] Move `fetch_ordered_annotations` into a new annotation service --- h/activity/query.py | 17 ++-- h/presenters/document_searchindex.py | 14 +++ h/services/__init__.py | 4 + h/services/annotation.py | 35 ++++++++ h/services/annotation_json.py | 48 +++++------ h/storage.py | 35 -------- h/views/feeds.py | 19 +++-- tests/common/fixtures/services.py | 8 +- tests/h/activity/query_test.py | 31 +++---- tests/h/services/annotation_json_test.py | 62 ++++---------- tests/h/services/annotation_test.py | 60 +++++++++++++ tests/h/storage_test.py | 32 ------- tests/h/views/feeds_test.py | 103 ++++++++--------------- 13 files changed, 219 insertions(+), 249 deletions(-) create mode 100644 h/services/annotation.py create mode 100644 tests/h/services/annotation_test.py diff --git a/h/activity/query.py b/h/activity/query.py index ef63e1e1fd1..9c40ca19caf 100644 --- a/h/activity/query.py +++ b/h/activity/query.py @@ -2,9 +2,8 @@ import newrelic.agent from pyramid.httpexceptions import HTTPFound -from sqlalchemy.orm import subqueryload -from h import links, presenters, storage +from h import links, presenters from h.activity import bucketing from h.models import Annotation, Group from h.search import ( @@ -15,6 +14,7 @@ UsersAggregation, parser, ) +from h.services import AnnotationService class ActivityResults( @@ -115,7 +115,7 @@ def execute(request, query, page_size): # Load all referenced annotations from the database, bucket them, and add # the buckets to result.timeframes. - anns = fetch_annotations(request.db, search_result.annotation_ids) + anns = _fetch_annotations(request, search_result.annotation_ids) result.timeframes.extend(bucketing.bucket(anns)) # Fetch all groups @@ -155,16 +155,11 @@ def aggregations_for(query): @newrelic.agent.function_trace() -def fetch_annotations(session, ids): - def load_documents(query): - return query.options(subqueryload(Annotation.document)) - - annotations = storage.fetch_ordered_annotations( - session, ids, query_processor=load_documents +def _fetch_annotations(request, ids): + return request.find_service(AnnotationService).get_annotations_by_id( + ids=ids, eager_load=[Annotation.document] ) - return annotations - @newrelic.agent.function_trace() def _execute_search(request, query, page_size): diff --git a/h/presenters/document_searchindex.py b/h/presenters/document_searchindex.py index 36a91c75be0..5000e614d13 100644 --- a/h/presenters/document_searchindex.py +++ b/h/presenters/document_searchindex.py @@ -14,3 +14,17 @@ def asdict(self): document_dict["web_uri"] = self.document.web_uri return document_dict + + +def format_document(document): + if not document: + return {} + + document_dict = {} + if document.title: + document_dict["title"] = [document.title] + + if document.web_uri: + document_dict["web_uri"] = document.web_uri + + return document_dict diff --git a/h/services/__init__.py b/h/services/__init__.py index 5ea6037c11e..11febe07f98 100644 --- a/h/services/__init__.py +++ b/h/services/__init__.py @@ -1,10 +1,14 @@ """Service definitions that handle business logic.""" +from h.services.annotation import AnnotationService from h.services.auth_cookie import AuthCookieService from h.services.bulk_annotation import BulkAnnotationService from h.services.subscription import SubscriptionService def includeme(config): # pragma: no cover + config.register_service_factory( + "h.services.annotation.service_factory", iface=AnnotationService + ) config.register_service_factory(".annotation_json.factory", name="annotation_json") config.register_service_factory( ".annotation_moderation.annotation_moderation_service_factory", diff --git a/h/services/annotation.py b/h/services/annotation.py new file mode 100644 index 00000000000..7b546f695bc --- /dev/null +++ b/h/services/annotation.py @@ -0,0 +1,35 @@ +from typing import Iterable, List, Optional + +from sqlalchemy.orm import subqueryload + +from h.models import Annotation + + +class AnnotationService: + def __init__(self, db_session): + self._db = db_session + + def get_annotations_by_id( + self, ids: List[str], eager_load: Optional[List] = None + ) -> Iterable[Annotation]: + """ + Get annotations in the same order as the provided ids. + + :param ids: the list of annotation ids + :param eager_load: A list of annotatiopn relationships to eager load + like `Annotation.document` + """ + + if not ids: + return [] + + query = self._db.query(Annotation).filter(Annotation.id.in_(ids)) + + if eager_load: + query = query.options(subqueryload(prop) for prop in eager_load) + + return sorted(query, key=lambda annotation: ids.index(annotation.id)) + + +def service_factory(_context, request): + return AnnotationService(db_session=request.db) diff --git a/h/services/annotation_json.py b/h/services/annotation_json.py index bfe6c6e8d12..4cf181a4675 100644 --- a/h/services/annotation_json.py +++ b/h/services/annotation_json.py @@ -1,11 +1,9 @@ from copy import deepcopy -from sqlalchemy.orm import subqueryload - -from h import storage from h.models import Annotation, User from h.security import Identity, identity_permits from h.security.permissions import Permission +from h.services import AnnotationService from h.session import user_info from h.traversal import AnnotationContext from h.util.datetime import utc_iso8601 @@ -14,7 +12,14 @@ class AnnotationJSONService: """A service for generating API compatible JSON for annotations.""" - def __init__(self, session, links_service, flag_service, user_service): + def __init__( + self, + session, + annotation_service: AnnotationService, + links_service, + flag_service, + user_service, + ): """ Instantiate the service. @@ -24,6 +29,7 @@ def __init__(self, session, links_service, flag_service, user_service): :param user_service: UserService instance """ self._session = session + self._annotation_service = annotation_service self._links_service = links_service self._flag_service = flag_service self._user_service = user_service @@ -136,10 +142,18 @@ def present_all_for_user(self, annotation_ids, user: User): self._flag_service.all_flagged(user, annotation_ids) self._flag_service.flag_counts(annotation_ids) - annotations = storage.fetch_ordered_annotations( - self._session, - annotation_ids, - query_processor=self._eager_load_related_items, + annotations = self._annotation_service.get_annotations_by_id( + ids=annotation_ids, + eager_load=[ + # Optimise access to the document + Annotation.document, + # Optimise the check used for "hidden" above + Annotation.moderation, + # Optimise the permissions check for MODERATE permissions, + # which ultimately depends on group permissions, causing a + # group lookup for every annotation without this + Annotation.group, + ], ) # Optimise the user service `fetch()` call @@ -147,23 +161,6 @@ def present_all_for_user(self, annotation_ids, user: User): return [self.present_for_user(annotation, user) for annotation in annotations] - @staticmethod - def _eager_load_related_items(query): - # Ensure that accessing `annotation.document` or `.moderation` - # doesn't trigger any more queries by pre-loading these - - return query.options( - # Optimise access to the document which is called in - # `AnnotationJSONPresenter` - subqueryload(Annotation.document), - # Optimise the check used for "hidden" above - subqueryload(Annotation.moderation), - # Optimise the permissions check for MODERATE permissions, - # which ultimately depends on group permissions, causing a - # group lookup for every annotation without this - subqueryload(Annotation.group), - ) - @classmethod def _get_read_permission(cls, annotation): if not annotation.shared: @@ -187,6 +184,7 @@ def factory(_context, request): return AnnotationJSONService( session=request.db, # Services + annotation_service=request.find_service(AnnotationService), links_service=request.find_service(name="links"), flag_service=request.find_service(name="flag"), user_service=request.find_service(name="user"), diff --git a/h/storage.py b/h/storage.py index db1ca4ccbb3..243b2b5ffe4 100644 --- a/h/storage.py +++ b/h/storage.py @@ -49,41 +49,6 @@ def fetch_annotation(session, id_): return None -def fetch_ordered_annotations(session, ids, query_processor=None): - """ - Fetch all annotations with the given ids and order them based on the list of ids. - - The optional `query_processor` parameter allows for passing in a function - that can change the query before it is run, especially useful for - eager-loading certain data. The function will get the query as an argument - and has to return a query object again. - - :param session: the database session - :type session: sqlalchemy.orm.session.Session - - :param ids: the list of annotation ids - :type ids: list - - :param query_processor: an optional function that takes the query and - returns an updated query - :type query_processor: callable - - :returns: the annotation, if found, or None. - :rtype: h.models.Annotation, NoneType - """ - if not ids: - return [] - - ordering = {x: i for i, x in enumerate(ids)} - - query = session.query(models.Annotation).filter(models.Annotation.id.in_(ids)) - if query_processor: - query = query_processor(query) - - anns = sorted(query, key=lambda a: ordering.get(a.id)) - return anns - - def create_annotation(request, data): """ Create an annotation from already-validated data. diff --git a/h/views/feeds.py b/h/views/feeds.py index d5a54a38a6c..655485e69e4 100644 --- a/h/views/feeds.py +++ b/h/views/feeds.py @@ -2,19 +2,13 @@ from pyramid.view import view_config from webob.multidict import MultiDict -from h import search from h.feeds import render_atom, render_rss -from h.storage import fetch_ordered_annotations +from h.search import Search +from h.services import AnnotationService _ = i18n.TranslationStringFactory(__package__) -def _annotations(request): - """Return the annotations from the search API.""" - result = search.Search(request).run(MultiDict(request.params)) - return fetch_ordered_annotations(request.db, result.annotation_ids) - - @view_config(route_name="stream_atom") def stream_atom(request): """Get an Atom feed of the /stream page.""" @@ -40,3 +34,12 @@ def stream_rss(request): description=request.registry.settings.get("h.feed.description") or _("The Web. Annotated"), ) + + +def _annotations(request): + """Return the annotations from the search API.""" + result = Search(request).run(MultiDict(request.params)) + + return request.find_service(AnnotationService).get_annotations_by_id( + ids=result.annotation_ids + ) diff --git a/tests/common/fixtures/services.py b/tests/common/fixtures/services.py index 2300e3109dc..8767ab5bcf2 100644 --- a/tests/common/fixtures/services.py +++ b/tests/common/fixtures/services.py @@ -2,7 +2,7 @@ import pytest -from h.services import BulkAnnotationService +from h.services import AnnotationService, BulkAnnotationService from h.services.annotation_delete import AnnotationDeleteService from h.services.annotation_json import AnnotationJSONService from h.services.annotation_moderation import AnnotationModerationService @@ -34,6 +34,7 @@ "mock_service", "annotation_delete_service", "annotation_json_service", + "annotation_service", "auth_cookie_service", "auth_token_service", "bulk_annotation_service", @@ -88,6 +89,11 @@ def annotation_json_service(mock_service): return mock_service(AnnotationJSONService, name="annotation_json") +@pytest.fixture +def annotation_service(mock_service): + return mock_service(AnnotationService) + + @pytest.fixture def auth_cookie_service(mock_service): return mock_service(AuthCookieService) diff --git a/tests/h/activity/query_test.py b/tests/h/activity/query_test.py index c13f21b2f0a..d5f2666c50a 100644 --- a/tests/h/activity/query_test.py +++ b/tests/h/activity/query_test.py @@ -5,7 +5,8 @@ from pyramid.httpexceptions import HTTPFound from webob.multidict import MultiDict -from h.activity.query import check_url, execute, extract, fetch_annotations +from h.activity.query import check_url, execute, extract +from h.models import Annotation class TestExtract: @@ -190,7 +191,7 @@ def unparse(self): @pytest.mark.usefixtures( - "fetch_annotations", + "annotation_service", "_fetch_groups", "bucketing", "presenters", @@ -353,20 +354,22 @@ def test_it_returns_the_search_result_if_there_are_no_matches( assert result.timeframes == [] def test_it_fetches_the_annotations_from_the_database( - self, fetch_annotations, pyramid_request, search + self, annotation_service, pyramid_request, search ): execute(pyramid_request, MultiDict(), self.PAGE_SIZE) - fetch_annotations.assert_called_once_with( - pyramid_request.db, search.run.return_value.annotation_ids + annotation_service.get_annotations_by_id.assert_called_once_with( + ids=search.run.return_value.annotation_ids, eager_load=[Annotation.document] ) def test_it_buckets_the_annotations( - self, fetch_annotations, bucketing, pyramid_request + self, annotation_service, bucketing, pyramid_request ): result = execute(pyramid_request, MultiDict(), self.PAGE_SIZE) - bucketing.bucket.assert_called_once_with(fetch_annotations.return_value) + bucketing.bucket.assert_called_once_with( + annotation_service.get_annotations_by_id.return_value + ) assert result.timeframes == bucketing.bucket.return_value def test_it_fetches_the_groups_from_the_database( @@ -460,10 +463,6 @@ def test_it_returns_the_aggregations(self, pyramid_request): assert result.aggregations == mock.sentinel.aggregations - @pytest.fixture - def fetch_annotations(self, patch): - return patch("h.activity.query.fetch_annotations") - @pytest.fixture def _fetch_groups(self, group_pubids, patch): _fetch_groups = patch("h.activity.query._fetch_groups") @@ -607,16 +606,6 @@ def pyramid_request(self, pyramid_request): return pyramid_request -class TestFetchAnnotations: - def test_it_returns_annotations_by_ids(self, db_session, factories): - annotations = factories.Annotation.create_batch(3) - ids = [a.id for a in annotations] - - result = fetch_annotations(db_session, ids) - - assert annotations == result - - @pytest.fixture def pyramid_request(pyramid_request): class DummyRoute: diff --git a/tests/h/services/annotation_json_test.py b/tests/h/services/annotation_json_test.py index 023abe2aec4..16c2bf2e572 100644 --- a/tests/h/services/annotation_json_test.py +++ b/tests/h/services/annotation_json_test.py @@ -4,8 +4,8 @@ import pytest from h_matchers import Any from pyramid.authorization import Everyone -from sqlalchemy import event +from h.models import Annotation from h.security.permissions import Permission from h.services.annotation_json import AnnotationJSONService, factory from h.traversal import AnnotationContext @@ -183,14 +183,18 @@ def test_present_for_user_hidden_shows_everything_to_moderators( assert result["tags"] def test_present_all_for_user( - self, service, annotation, user, flag_service, user_service + self, service, annotation, user, annotation_service, flag_service, user_service ): - annotation_ids = [annotation.id] + annotation_service.get_annotations_by_id.return_value = [annotation] - result = service.present_all_for_user(annotation_ids, user) + result = service.present_all_for_user(sentinel.annotation_ids, user) - flag_service.all_flagged.assert_called_once_with(user, annotation_ids) - flag_service.flag_counts.assert_called_once_with(annotation_ids) + annotation_service.get_annotations_by_id.assert_called_once_with( + ids=sentinel.annotation_ids, + eager_load=[Annotation.document, Annotation.moderation, Annotation.group], + ) + flag_service.all_flagged.assert_called_once_with(user, sentinel.annotation_ids) + flag_service.flag_counts.assert_called_once_with(sentinel.annotation_ids) user_service.fetch_all.assert_called_once_with([annotation.userid]) assert result == [ @@ -198,49 +202,13 @@ def test_present_all_for_user( Any.dict.containing({"id": Any(), "hidden": False}) ] - @pytest.mark.parametrize("attribute", ("document", "moderation", "group")) - @pytest.mark.parametrize("with_preload", (True, False)) - def test_present_all_for_userpreloading_is_effective( - self, - service, - annotation, - user, - db_session, - query_counter, - attribute, - with_preload, - ): - # Ensure SQLAlchemy forgets all about our annotation - db_session.flush() - db_session.expire(annotation) - if with_preload: - service.present_all_for_user([annotation.id], user) - - query_counter.reset() - getattr(annotation, attribute) - - # If we preloaded, we shouldn't execute any queries (and vice versa) - assert bool(query_counter.count) != with_preload - @pytest.fixture - def query_counter(self, db_engine): - class QueryCounter: - count = 0 - - def __call__(self, *args, **kwargs): - self.count += 1 - - def reset(self): - self.count = 0 - - query_counter = QueryCounter() - event.listen(db_engine, "before_cursor_execute", query_counter) - return query_counter - - @pytest.fixture - def service(self, db_session, links_service, flag_service, user_service): + def service( + self, db_session, annotation_service, links_service, flag_service, user_service + ): return AnnotationJSONService( session=db_session, + annotation_service=annotation_service, links_service=links_service, flag_service=flag_service, user_service=user_service, @@ -274,6 +242,7 @@ def test_it( self, pyramid_request, AnnotationJSONService, + annotation_service, flag_service, links_service, user_service, @@ -284,6 +253,7 @@ def test_it( AnnotationJSONService.assert_called_once_with( session=pyramid_request.db, + annotation_service=annotation_service, links_service=links_service, flag_service=flag_service, user_service=user_service, diff --git a/tests/h/services/annotation_test.py b/tests/h/services/annotation_test.py new file mode 100644 index 00000000000..ddaee09c0d6 --- /dev/null +++ b/tests/h/services/annotation_test.py @@ -0,0 +1,60 @@ +import pytest +from sqlalchemy import event + +from h.models import Annotation +from h.services import AnnotationService + + +class TestAnnotationService: + @pytest.mark.parametrize("reverse", (True, False)) + def test_get_annotations_by_id(self, svc, factories, reverse): + annotations = factories.Annotation.create_batch(3) + if reverse: + annotations = list(reversed(annotations)) + + results = svc.get_annotations_by_id( + [annotation.id for annotation in annotations] + ) + + assert results == annotations + + def test_get_annotations_by_id_with_no_input(self, svc): + assert not svc.get_annotations_by_id(ids=[]) + + @pytest.mark.parametrize("attribute", ("document", "moderation", "group")) + def test_get_annotations_by_id_preloading( + self, svc, factories, db_session, query_counter, attribute + ): + annotation = factories.Annotation() + + # Ensure SQLAlchemy forgets all about our annotation + db_session.flush() + db_session.expire(annotation) + svc.get_annotations_by_id( + [annotation.id], eager_load=[getattr(Annotation, attribute)] + ) + query_counter.reset() + + getattr(annotation, attribute) + + # If we preloaded, we shouldn't execute any queries + assert not query_counter.count + + @pytest.fixture + def query_counter(self, db_engine): + class QueryCounter: + count = 0 + + def __call__(self, *args, **kwargs): + self.count += 1 + + def reset(self): + self.count = 0 + + query_counter = QueryCounter() + event.listen(db_engine, "before_cursor_execute", query_counter) + return query_counter + + @pytest.fixture + def svc(self, db_session): + return AnnotationService(db_session) diff --git a/tests/h/storage_test.py b/tests/h/storage_test.py index 115b2e2b478..35a5ffbd01b 100644 --- a/tests/h/storage_test.py +++ b/tests/h/storage_test.py @@ -7,7 +7,6 @@ from h_matchers import Any from h import storage -from h.models.annotation import Annotation from h.models.document import Document, DocumentURI from h.schemas import ValidationError from h.security import Permission @@ -27,37 +26,6 @@ def test_it_does_not_crash_if_id_is_invalid(self, db_session): assert storage.fetch_annotation(db_session, "foo") is None -class TestFetchOrderedAnnotations: - def test_it_returns_annotations_for_ids_in_the_same_order( - self, db_session, factories - ): - ann_1 = factories.Annotation(userid="luke") - ann_2 = factories.Annotation(userid="luke") - - assert [ann_2, ann_1] == storage.fetch_ordered_annotations( - db_session, [ann_2.id, ann_1.id] - ) - assert [ann_1, ann_2] == storage.fetch_ordered_annotations( - db_session, [ann_1.id, ann_2.id] - ) - - def test_it_allows_to_change_the_query(self, db_session, factories): - ann_1 = factories.Annotation(userid="luke") - ann_2 = factories.Annotation(userid="maria") - - def only_maria(query): - return query.filter(Annotation.userid == "maria") - - assert [ann_2] == storage.fetch_ordered_annotations( - db_session, [ann_2.id, ann_1.id], query_processor=only_maria - ) - - def test_it_handles_empty_ids(self): - results = storage.fetch_ordered_annotations(sentinel.db_session, ids=[]) - - assert results == [] - - class TestExpandURI: @pytest.mark.parametrize( "normalized,expected_uris", diff --git a/tests/h/views/feeds_test.py b/tests/h/views/feeds_test.py index 7bb02a20396..e810a7024ef 100644 --- a/tests/h/views/feeds_test.py +++ b/tests/h/views/feeds_test.py @@ -1,103 +1,66 @@ -from unittest import mock +from unittest.mock import sentinel import pytest -from h.search.core import SearchResult from h.views.feeds import stream_atom, stream_rss -@pytest.mark.usefixtures( - "fetch_ordered_annotations", "render_atom", "search_run", "routes" -) class TestStreamAtom: - def test_renders_atom(self, pyramid_request, render_atom): - stream_atom(pyramid_request) + def test_it(self, render_atom, pyramid_request, annotation_service): + result = stream_atom(pyramid_request) render_atom.assert_called_once_with( request=pyramid_request, - annotations=mock.sentinel.fetched_annotations, - atom_url="http://example.com/thestream.atom", - html_url="http://example.com/thestream", - title="Some feed", - subtitle="It contains stuff", + annotations=annotation_service.get_annotations_by_id.return_value, + atom_url="http://example.com/stream_atom", + html_url="http://example.com/stream", + title=sentinel.feed_title, + subtitle=sentinel.feed_subtitle, ) - def test_returns_rendered_atom(self, pyramid_request, render_atom): - result = stream_atom(pyramid_request) - assert result == render_atom.return_value + @pytest.fixture + def render_atom(self, patch): + return patch("h.views.feeds.render_atom") + -@pytest.mark.usefixtures( - "fetch_ordered_annotations", "render_rss", "search_run", "routes" -) class TestStreamRSS: - def test_renders_rss(self, pyramid_request, render_rss): - stream_rss(pyramid_request) + def test_it(self, render_rss, pyramid_request, annotation_service): + result = stream_rss(pyramid_request) render_rss.assert_called_once_with( request=pyramid_request, - annotations=mock.sentinel.fetched_annotations, - rss_url="http://example.com/thestream.rss", - html_url="http://example.com/thestream", - title="Some feed", - description="Stuff and things", + annotations=annotation_service.get_annotations_by_id.return_value, + rss_url="http://example.com/stream_rss", + html_url="http://example.com/stream", + title=sentinel.feed_title, + description=sentinel.feed_description, ) - def test_returns_rendered_rss(self, pyramid_request, render_rss): - result = stream_rss(pyramid_request) - assert result == render_rss.return_value - -@pytest.fixture -def fetch_ordered_annotations(patch): - fetch_ordered_annotations = patch("h.views.feeds.fetch_ordered_annotations") - fetch_ordered_annotations.return_value = mock.sentinel.fetched_annotations - return fetch_ordered_annotations + @pytest.fixture + def render_rss(self, patch): + return patch("h.views.feeds.render_rss") @pytest.fixture def pyramid_settings(pyramid_settings): - settings = {} - settings.update(pyramid_settings) - settings.update( - { - "h.feed.title": "Some feed", - "h.feed.subtitle": "It contains stuff", - "h.feed.description": "Stuff and things", - } - ) - return settings + pyramid_settings["h.feed.title"] = sentinel.feed_title + pyramid_settings["h.feed.subtitle"] = sentinel.feed_subtitle + pyramid_settings["h.feed.description"] = sentinel.feed_description + return pyramid_settings -@pytest.fixture -def render_atom(patch): - return patch("h.views.feeds.render_atom") - -@pytest.fixture -def render_rss(patch): - return patch("h.views.feeds.render_rss") +@pytest.fixture(autouse=True) +def Search(patch): + return patch("h.views.feeds.Search") -@pytest.fixture +@pytest.fixture(autouse=True) def routes(pyramid_config): - pyramid_config.add_route("stream_atom", "/thestream.atom") - pyramid_config.add_route("stream_rss", "/thestream.rss") - pyramid_config.add_route("stream", "/thestream") - - -@pytest.fixture -def search(patch): - return patch("h.views.feeds.search") - - -@pytest.fixture -def search_run(search): - result = SearchResult( - total=123, annotation_ids=["foo", "bar"], reply_ids=[], aggregations={} - ) - search_run = search.Search.return_value.run - search_run.return_value = result - return search_run + pyramid_config.add_route("stream_atom", "/stream_atom") + pyramid_config.add_route("stream_rss", "/stream_rss") + pyramid_config.add_route("stream", "/stream")