Skip to content

Commit

Permalink
Move fetch_ordered_annotations into a new annotation service
Browse files Browse the repository at this point in the history
  • Loading branch information
Jon Betts committed Apr 5, 2023
1 parent 4696df6 commit 90ed652
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 249 deletions.
17 changes: 6 additions & 11 deletions h/activity/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import newrelic.agent
from pyramid.httpexceptions import HTTPFound
from sqlalchemy.orm import subqueryload

from h import links, presenters, storage
from h import links, presenters
from h.activity import bucketing
from h.models import Annotation, Group
from h.search import (
Expand All @@ -15,6 +14,7 @@
UsersAggregation,
parser,
)
from h.services import AnnotationService


class ActivityResults(
Expand Down Expand Up @@ -115,7 +115,7 @@ def execute(request, query, page_size):

# Load all referenced annotations from the database, bucket them, and add
# the buckets to result.timeframes.
anns = fetch_annotations(request.db, search_result.annotation_ids)
anns = _fetch_annotations(request, search_result.annotation_ids)
result.timeframes.extend(bucketing.bucket(anns))

# Fetch all groups
Expand Down Expand Up @@ -155,16 +155,11 @@ def aggregations_for(query):


@newrelic.agent.function_trace()
def fetch_annotations(session, ids):
def load_documents(query):
return query.options(subqueryload(Annotation.document))

annotations = storage.fetch_ordered_annotations(
session, ids, query_processor=load_documents
def _fetch_annotations(request, ids):
return request.find_service(AnnotationService).get_annotations_by_id(
ids=ids, eager_load=[Annotation.document]
)

return annotations


@newrelic.agent.function_trace()
def _execute_search(request, query, page_size):
Expand Down
14 changes: 14 additions & 0 deletions h/presenters/document_searchindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ def asdict(self):
document_dict["web_uri"] = self.document.web_uri

return document_dict


def format_document(document):
if not document:
return {}

document_dict = {}
if document.title:
document_dict["title"] = [document.title]

if document.web_uri:
document_dict["web_uri"] = document.web_uri

return document_dict
4 changes: 4 additions & 0 deletions h/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Service definitions that handle business logic."""
from h.services.annotation import AnnotationService
from h.services.auth_cookie import AuthCookieService
from h.services.bulk_annotation import BulkAnnotationService
from h.services.subscription import SubscriptionService


def includeme(config): # pragma: no cover
config.register_service_factory(
"h.services.annotation.service_factory", iface=AnnotationService
)
config.register_service_factory(".annotation_json.factory", name="annotation_json")
config.register_service_factory(
".annotation_moderation.annotation_moderation_service_factory",
Expand Down
35 changes: 35 additions & 0 deletions h/services/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Iterable, List, Optional

from sqlalchemy.orm import subqueryload

from h.models import Annotation


class AnnotationService:
def __init__(self, db_session):
self._db = db_session

def get_annotations_by_id(
self, ids: List[str], eager_load: Optional[List] = None
) -> Iterable[Annotation]:
"""
Get annotations in the same order as the provided ids.
:param ids: the list of annotation ids
:param eager_load: A list of annotatiopn relationships to eager load
like `Annotation.document`
"""

if not ids:
return []

query = self._db.query(Annotation).filter(Annotation.id.in_(ids))

if eager_load:
query = query.options(subqueryload(prop) for prop in eager_load)

return sorted(query, key=lambda annotation: ids.index(annotation.id))


def service_factory(_context, request):
return AnnotationService(db_session=request.db)
48 changes: 23 additions & 25 deletions h/services/annotation_json.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from copy import deepcopy

from sqlalchemy.orm import subqueryload

from h import storage
from h.models import Annotation, User
from h.security import Identity, identity_permits
from h.security.permissions import Permission
from h.services import AnnotationService
from h.session import user_info
from h.traversal import AnnotationContext
from h.util.datetime import utc_iso8601
Expand All @@ -14,7 +12,14 @@
class AnnotationJSONService:
"""A service for generating API compatible JSON for annotations."""

def __init__(self, session, links_service, flag_service, user_service):
def __init__(
self,
session,
annotation_service: AnnotationService,
links_service,
flag_service,
user_service,
):
"""
Instantiate the service.
Expand All @@ -24,6 +29,7 @@ def __init__(self, session, links_service, flag_service, user_service):
:param user_service: UserService instance
"""
self._session = session
self._annotation_service = annotation_service
self._links_service = links_service
self._flag_service = flag_service
self._user_service = user_service
Expand Down Expand Up @@ -136,34 +142,25 @@ def present_all_for_user(self, annotation_ids, user: User):
self._flag_service.all_flagged(user, annotation_ids)
self._flag_service.flag_counts(annotation_ids)

annotations = storage.fetch_ordered_annotations(
self._session,
annotation_ids,
query_processor=self._eager_load_related_items,
annotations = self._annotation_service.get_annotations_by_id(
ids=annotation_ids,
eager_load=[
# Optimise access to the document
Annotation.document,
# Optimise the check used for "hidden" above
Annotation.moderation,
# Optimise the permissions check for MODERATE permissions,
# which ultimately depends on group permissions, causing a
# group lookup for every annotation without this
Annotation.group,
],
)

# Optimise the user service `fetch()` call
self._user_service.fetch_all([annotation.userid for annotation in annotations])

return [self.present_for_user(annotation, user) for annotation in annotations]

@staticmethod
def _eager_load_related_items(query):
# Ensure that accessing `annotation.document` or `.moderation`
# doesn't trigger any more queries by pre-loading these

return query.options(
# Optimise access to the document which is called in
# `AnnotationJSONPresenter`
subqueryload(Annotation.document),
# Optimise the check used for "hidden" above
subqueryload(Annotation.moderation),
# Optimise the permissions check for MODERATE permissions,
# which ultimately depends on group permissions, causing a
# group lookup for every annotation without this
subqueryload(Annotation.group),
)

@classmethod
def _get_read_permission(cls, annotation):
if not annotation.shared:
Expand All @@ -187,6 +184,7 @@ def factory(_context, request):
return AnnotationJSONService(
session=request.db,
# Services
annotation_service=request.find_service(AnnotationService),
links_service=request.find_service(name="links"),
flag_service=request.find_service(name="flag"),
user_service=request.find_service(name="user"),
Expand Down
35 changes: 0 additions & 35 deletions h/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,6 @@ def fetch_annotation(session, id_):
return None


def fetch_ordered_annotations(session, ids, query_processor=None):
"""
Fetch all annotations with the given ids and order them based on the list of ids.
The optional `query_processor` parameter allows for passing in a function
that can change the query before it is run, especially useful for
eager-loading certain data. The function will get the query as an argument
and has to return a query object again.
:param session: the database session
:type session: sqlalchemy.orm.session.Session
:param ids: the list of annotation ids
:type ids: list
:param query_processor: an optional function that takes the query and
returns an updated query
:type query_processor: callable
:returns: the annotation, if found, or None.
:rtype: h.models.Annotation, NoneType
"""
if not ids:
return []

ordering = {x: i for i, x in enumerate(ids)}

query = session.query(models.Annotation).filter(models.Annotation.id.in_(ids))
if query_processor:
query = query_processor(query)

anns = sorted(query, key=lambda a: ordering.get(a.id))
return anns


def create_annotation(request, data):
"""
Create an annotation from already-validated data.
Expand Down
19 changes: 11 additions & 8 deletions h/views/feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@
from pyramid.view import view_config
from webob.multidict import MultiDict

from h import search
from h.feeds import render_atom, render_rss
from h.storage import fetch_ordered_annotations
from h.search import Search
from h.services import AnnotationService

_ = i18n.TranslationStringFactory(__package__)


def _annotations(request):
"""Return the annotations from the search API."""
result = search.Search(request).run(MultiDict(request.params))
return fetch_ordered_annotations(request.db, result.annotation_ids)


@view_config(route_name="stream_atom")
def stream_atom(request):
"""Get an Atom feed of the /stream page."""
Expand All @@ -40,3 +34,12 @@ def stream_rss(request):
description=request.registry.settings.get("h.feed.description")
or _("The Web. Annotated"),
)


def _annotations(request):
"""Return the annotations from the search API."""
result = Search(request).run(MultiDict(request.params))

return request.find_service(AnnotationService).get_annotations_by_id(
ids=result.annotation_ids
)
8 changes: 7 additions & 1 deletion tests/common/fixtures/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from h.services import BulkAnnotationService
from h.services import AnnotationService, BulkAnnotationService
from h.services.annotation_delete import AnnotationDeleteService
from h.services.annotation_json import AnnotationJSONService
from h.services.annotation_moderation import AnnotationModerationService
Expand Down Expand Up @@ -34,6 +34,7 @@
"mock_service",
"annotation_delete_service",
"annotation_json_service",
"annotation_service",
"auth_cookie_service",
"auth_token_service",
"bulk_annotation_service",
Expand Down Expand Up @@ -88,6 +89,11 @@ def annotation_json_service(mock_service):
return mock_service(AnnotationJSONService, name="annotation_json")


@pytest.fixture
def annotation_service(mock_service):
return mock_service(AnnotationService)


@pytest.fixture
def auth_cookie_service(mock_service):
return mock_service(AuthCookieService)
Expand Down
31 changes: 10 additions & 21 deletions tests/h/activity/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pyramid.httpexceptions import HTTPFound
from webob.multidict import MultiDict

from h.activity.query import check_url, execute, extract, fetch_annotations
from h.activity.query import check_url, execute, extract
from h.models import Annotation


class TestExtract:
Expand Down Expand Up @@ -190,7 +191,7 @@ def unparse(self):


@pytest.mark.usefixtures(
"fetch_annotations",
"annotation_service",
"_fetch_groups",
"bucketing",
"presenters",
Expand Down Expand Up @@ -353,20 +354,22 @@ def test_it_returns_the_search_result_if_there_are_no_matches(
assert result.timeframes == []

def test_it_fetches_the_annotations_from_the_database(
self, fetch_annotations, pyramid_request, search
self, annotation_service, pyramid_request, search
):
execute(pyramid_request, MultiDict(), self.PAGE_SIZE)

fetch_annotations.assert_called_once_with(
pyramid_request.db, search.run.return_value.annotation_ids
annotation_service.get_annotations_by_id.assert_called_once_with(
ids=search.run.return_value.annotation_ids, eager_load=[Annotation.document]
)

def test_it_buckets_the_annotations(
self, fetch_annotations, bucketing, pyramid_request
self, annotation_service, bucketing, pyramid_request
):
result = execute(pyramid_request, MultiDict(), self.PAGE_SIZE)

bucketing.bucket.assert_called_once_with(fetch_annotations.return_value)
bucketing.bucket.assert_called_once_with(
annotation_service.get_annotations_by_id.return_value
)
assert result.timeframes == bucketing.bucket.return_value

def test_it_fetches_the_groups_from_the_database(
Expand Down Expand Up @@ -460,10 +463,6 @@ def test_it_returns_the_aggregations(self, pyramid_request):

assert result.aggregations == mock.sentinel.aggregations

@pytest.fixture
def fetch_annotations(self, patch):
return patch("h.activity.query.fetch_annotations")

@pytest.fixture
def _fetch_groups(self, group_pubids, patch):
_fetch_groups = patch("h.activity.query._fetch_groups")
Expand Down Expand Up @@ -607,16 +606,6 @@ def pyramid_request(self, pyramid_request):
return pyramid_request


class TestFetchAnnotations:
def test_it_returns_annotations_by_ids(self, db_session, factories):
annotations = factories.Annotation.create_batch(3)
ids = [a.id for a in annotations]

result = fetch_annotations(db_session, ids)

assert annotations == result


@pytest.fixture
def pyramid_request(pyramid_request):
class DummyRoute:
Expand Down
Loading

0 comments on commit 90ed652

Please sign in to comment.