From 1d56b5d912287e0dd708ad8dec052152ea95a95d Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 7 Feb 2019 14:34:31 +0000 Subject: [PATCH 1/2] Implement graceful handling of WDQS rate-limiting --- commons_api/settings.py | 1 + commons_api/wikidata/tasks/country.py | 18 ++-- commons_api/wikidata/tasks/legislature.py | 35 +++---- commons_api/wikidata/tasks/wikidata_item.py | 9 +- commons_api/wikidata/tests/__init__.py | 1 + commons_api/wikidata/tests/updating.py | 2 +- .../wikidata/tests/wdqs_rate_limiting.py | 56 +++++++++++ commons_api/wikidata/utils.py | 99 ++++++++++++++++++- 8 files changed, 184 insertions(+), 37 deletions(-) create mode 100644 commons_api/wikidata/tests/wdqs_rate_limiting.py diff --git a/commons_api/settings.py b/commons_api/settings.py index b05e568..e066ce2 100644 --- a/commons_api/settings.py +++ b/commons_api/settings.py @@ -84,6 +84,7 @@ MEDIA_ROOT = os.environ.get('MEDIA_ROOT') or os.path.expanduser('~/media') +WDQS_RETRIES = int(os.environ.get('WDQS_RETRIES', 5)) WDQS_URL = 'https://query.wikidata.org/sparql' ENABLE_MODERATION = bool(os.environ.get('ENABLE_MODERATION')) diff --git a/commons_api/wikidata/tasks/country.py b/commons_api/wikidata/tasks/country.py index f7666f7..c8a4109 100644 --- a/commons_api/wikidata/tasks/country.py +++ b/commons_api/wikidata/tasks/country.py @@ -1,24 +1,18 @@ import celery -from SPARQLWrapper import SPARQLWrapper, JSON -from django.conf import settings -from django.template.loader import get_template - -from commons_api.wikidata.utils import item_uri_to_id +from commons_api.wikidata import utils from commons_api.wikidata import models __all__ = ['refresh_country_list'] -@celery.shared_task -def refresh_country_list(): - sparql = SPARQLWrapper(settings.WDQS_URL) - sparql.setQuery(get_template('wikidata/query/country_list.rq').render()) - sparql.setReturnFormat(JSON) - results = sparql.query().convert() +@celery.shared_task(bind=True, queue='wdqs') +@utils.queries_wikidata +def refresh_country_list(self, rate_limiting_handler): + results = utils.templated_wikidata_query('wikidata/query/country_list.rq', {}, rate_limiting_handler) seen_ids = set() for result in results['results']['bindings']: - id = item_uri_to_id(result['item']) + id = utils.item_uri_to_id(result['item']) country = models.Country.objects.for_id_and_label(id, str(result['itemLabel']['value'])) country.iso_3166_1_code = result['itemCode']['value'].upper() if result.get('itemCode') else None country.save() diff --git a/commons_api/wikidata/tasks/legislature.py b/commons_api/wikidata/tasks/legislature.py index ad602d7..129c09f 100644 --- a/commons_api/wikidata/tasks/legislature.py +++ b/commons_api/wikidata/tasks/legislature.py @@ -9,22 +9,18 @@ import celery import collections import itertools -from SPARQLWrapper import SPARQLWrapper, JSON -from django.conf import settings -from django.template.loader import get_template -from commons_api.wikidata.namespaces import WD from commons_api.wikidata.utils import item_uri_to_id, statement_uri_to_id, get_date, templated_wikidata_query -from .. import models +from .. import models, utils @with_periodic_queuing_task(superclass=models.Country) -@celery.shared_task -def refresh_legislatures(id, queued_at): +@celery.shared_task(bind=True, queue='wdqs') +@utils.queries_wikidata +def refresh_legislatures(self, id, queued_at, rate_limiting_handler): country = models.Country.objects.get(id=id, refresh_legislatures_last_queued=queued_at) - results = templated_wikidata_query('wikidata/query/legislature_list.rq', {'country': country}) - # print(get_template('wikidata/query/legislature_list.rq').render()) - # print(len(results['results']['bindings'])) + results = templated_wikidata_query('wikidata/query/legislature_list.rq', {'country': country}, + rate_limiting_handler) legislature_positions = collections.defaultdict(list) legislative_terms = collections.defaultdict(list) for result in results['results']['bindings']: @@ -52,7 +48,8 @@ def refresh_legislatures(id, queued_at): for position in positions] results = templated_wikidata_query('wikidata/query/legislature_terms_list.rq', - {'house_positions': house_positions}) + {'house_positions': house_positions}, + rate_limiting_handler) for result in results['results']['bindings']: if 'termSpecificPositionLabel' in result: term_specific_position = models.Position.objects.for_id_and_label( @@ -83,12 +80,14 @@ def refresh_legislatures(id, queued_at): @with_periodic_queuing_task(superclass=models.LegislativeHouse) -@celery.shared_task -def refresh_members(id, queued_at): +@celery.shared_task(bind=True, queue='wdqs') +@utils.queries_wikidata +def refresh_members(self, id, queued_at, rate_limiting_handler): house = models.LegislativeHouse.objects.get(id=id, refresh_members_last_queued=queued_at) results = templated_wikidata_query('wikidata/query/legislature_memberships.rq', - {'positions': house.positions.all()}) + {'positions': house.positions.all()}, + rate_limiting_handler) seen_statement_ids = set() for i, (statement, rows) in enumerate(itertools.groupby(results['results']['bindings'], key=lambda row: row['statement']['value'])): @@ -182,12 +181,14 @@ def refresh_members(id, queued_at): @with_periodic_queuing_task(superclass=models.LegislativeHouse) -@celery.shared_task -def refresh_districts(id, queued_at): +@celery.shared_task(bind=True, queue='wdqs') +@utils.queries_wikidata +def refresh_districts(self, id, queued_at, rate_limiting_handler): house = models.LegislativeHouse.objects.get(id=id, refresh_districts_last_queued=queued_at) results = templated_wikidata_query('wikidata/query/legislature_constituencies.rq', - {'house': house}) + {'house': house}, + rate_limiting_handler) for result in results['results']['bindings']: electoral_district = models.ElectoralDistrict.objects.for_id_and_label(item_uri_to_id(result['constituency']), diff --git a/commons_api/wikidata/tasks/wikidata_item.py b/commons_api/wikidata/tasks/wikidata_item.py index 2e91406..261d0fa 100644 --- a/commons_api/wikidata/tasks/wikidata_item.py +++ b/commons_api/wikidata/tasks/wikidata_item.py @@ -10,8 +10,9 @@ @with_periodic_queuing_task -@celery.shared_task -def refresh_labels(app_label, model, ids=None, queued_at=None): +@celery.shared_task(bind=True, queue='wdqs') +@utils.queries_wikidata +def refresh_labels(self, app_label, model, ids=None, queued_at=None, rate_limiting_handler=None): """Refreshes all labels for the given model""" queryset = get_wikidata_model_by_name(app_label, model).objects.all() if queued_at is not None: @@ -20,7 +21,9 @@ def refresh_labels(app_label, model, ids=None, queued_at=None): queryset = queryset.objects.filter(id__in=ids) for items in utils.split_every(queryset, 250): items = {item.id: item for item in items} - results = utils.templated_wikidata_query('wikidata/query/labels.rq', {'ids': sorted(items)}) + results = utils.templated_wikidata_query('wikidata/query/labels.rq', + {'ids': sorted(items)}, + rate_limiting_handler) for id, rows in itertools.groupby(results['results']['bindings'], key=lambda row: row['id']['value']): id = utils.item_uri_to_id(id) diff --git a/commons_api/wikidata/tests/__init__.py b/commons_api/wikidata/tests/__init__.py index a948c7f..d32bfa3 100644 --- a/commons_api/wikidata/tests/__init__.py +++ b/commons_api/wikidata/tests/__init__.py @@ -3,6 +3,7 @@ from .moderation import * from .popolo import * from .serializers import * +from .wdqs_rate_limiting import * from .updating import * from .utils import * from .views import * \ No newline at end of file diff --git a/commons_api/wikidata/tests/updating.py b/commons_api/wikidata/tests/updating.py index 95c53d6..5bbd3c7 100644 --- a/commons_api/wikidata/tests/updating.py +++ b/commons_api/wikidata/tests/updating.py @@ -51,6 +51,6 @@ def testRefreshForModelRefreshesMatchingLastQueued(self, templated_wikidata_quer }]} } wikidata_item.refresh_labels('wikidata', 'country', queued_at=self.refresh_labels_last_queued) - templated_wikidata_query.assert_called_once_with('wikidata/query/labels.rq', {'ids': [self.country.id]}) + templated_wikidata_query.assert_called_once_with('wikidata/query/labels.rq', {'ids': [self.country.id]}, None) self.country.refresh_from_db() self.assertEqual({'en': 'France', 'de': 'Frankreich'}, self.country.labels) diff --git a/commons_api/wikidata/tests/wdqs_rate_limiting.py b/commons_api/wikidata/tests/wdqs_rate_limiting.py new file mode 100644 index 0000000..043276d --- /dev/null +++ b/commons_api/wikidata/tests/wdqs_rate_limiting.py @@ -0,0 +1,56 @@ +import http.client +import uuid +from unittest import mock +from urllib.error import HTTPError + +from celery.app.task import Context +from django.conf import settings +from django.test import TestCase + +from commons_api import celery_app +from commons_api.wikidata.tasks import refresh_country_list + + +class WDQSRateLimitingTestCase(TestCase): + @mock.patch('SPARQLWrapper.Wrapper.urlopener') + @mock.patch('time.sleep') + def testRetriesIfTooManyRequests(self, time_sleep, urlopen): + retry_after = 10 + urlopen.side_effect = HTTPError(url=None, + code=http.client.TOO_MANY_REQUESTS, + hdrs={'Retry-After': str(retry_after)}, + msg='', fp=None) + with self.assertRaises(HTTPError): + refresh_country_list() + + self.assertEqual(settings.WDQS_RETRIES, urlopen.call_count) + self.assertEqual(settings.WDQS_RETRIES - 1, time_sleep.call_count) + time_sleep.assert_has_calls([mock.call(retry_after)] * (settings.WDQS_RETRIES - 1)) + + @mock.patch('SPARQLWrapper.Wrapper.urlopener') + @mock.patch('time.sleep') + def testSuspendsConsuming(self, time_sleep, urlopen): + retry_after = 10 + urlopen.side_effect = HTTPError(url=None, + code=http.client.TOO_MANY_REQUESTS, + hdrs={'Retry-After': str(retry_after)}, + msg='', fp=None) + refresh_country_list._default_request = Context(id=str(uuid.uuid4()), + called_directly=False, + delivery_info={'routing_key': 'wdqs'}, + hostname='nodename') + + with mock.patch.object(celery_app.control, 'cancel_consumer') as cancel_consumer, \ + mock.patch.object(celery_app.control, 'add_consumer') as add_consumer: + manager = mock.Mock() + manager.attach_mock(cancel_consumer, 'cancel_consumer') + manager.attach_mock(add_consumer, 'add_consumer') + manager.attach_mock(time_sleep, 'sleep') + with self.assertRaises(HTTPError): + refresh_country_list.run() + manager.assert_has_calls([mock.call.cancel_consumer('wdqs', connection=mock.ANY, destination=['nodename']), + mock.call.sleep(retry_after), + mock.call.sleep(retry_after), + mock.call.sleep(retry_after), + mock.call.sleep(retry_after), + mock.call.add_consumer('wdqs', connection=mock.ANY, destination=['nodename'])]) diff --git a/commons_api/wikidata/utils.py b/commons_api/wikidata/utils.py index c605502..af8641a 100644 --- a/commons_api/wikidata/utils.py +++ b/commons_api/wikidata/utils.py @@ -1,16 +1,27 @@ +import logging +import time + +import celery.app +import celery.task +import functools +import http.client import itertools -from typing import Mapping +import urllib.error +from typing import Callable, Mapping import re import SPARQLWrapper from django.conf import settings +from django.dispatch import Signal from django.template.loader import get_template from django.utils import translation -from . import models from .namespaces import WD, WDS +logger = logging.getLogger(__name__) + + def lang_dict(terms): return {term.language: str(term) for term in terms} @@ -83,9 +94,89 @@ def group(first): yield group(first) -def templated_wikidata_query(query_name, context): +wdqs_rate_limiting = Signal(['retry_after']) + + +def templated_wikidata_query(query_name: str, context: dict, + rate_limiting_handler: Callable[[bool], None]=None) -> dict: + """Constructs a query for Wikidata using django.template and returns the parsed results + + If the query elicits a `429 Too Many Requests` response, it retries up to `settings.WDQS_RETRIES` times, and + calls the rate_limiting_handler callback if provided to signal the start and end of a "stop querying" period. + + :param query_name: A template name that can be loaded by Django's templating system + :param context: A template context dict to use when rendering the query + :param rate_limiting_handler: A function to handle rate-limiting requests. Should suspend all querying if called + with `True`, and resume it if called with `False`. + :returns: The parsed SRJ results as a basic Python data structure + """ + + rate_limiting_handler = rate_limiting_handler or (lambda suspend: None) sparql = SPARQLWrapper.SPARQLWrapper(settings.WDQS_URL) sparql.setMethod(SPARQLWrapper.POST) sparql.setQuery(get_template(query_name).render(context)) sparql.setReturnFormat(SPARQLWrapper.JSON) - return sparql.query().convert() + has_suspended = False + try: + for i in range(1, settings.WDQS_RETRIES + 1): + try: + logger.info("Performing query %r (attempt %d/%d)", query_name, i, settings.WDQS_RETRIES) + response = sparql.query() + except urllib.error.HTTPError as e: + if e.code == http.client.TOO_MANY_REQUESTS and i < settings.WDQS_RETRIES: + if not has_suspended: + has_suspended = True + rate_limiting_handler(True) + retry_after = int(e.headers.get('Retry-After', 60)) + time.sleep(retry_after) + else: + raise + else: + return response.convert() + finally: + if has_suspended: + rate_limiting_handler(False) + + +def queries_wikidata(task_func): + """Decorator for task functions that query Wikidata + + This decorator passes a handle_ratelimiting argument to the wrapped task that should be passed to + `templated_wikidata_query` to handle rate-limiting requests from WDQS by suspending the execution of tasks that + query Wikidata. This is achieved by having celery cancel consumption of the given queue by the worker if `suspend` + is True, and resume it otherwise. + + This behaviour doesn't occur if the task was called directly — i.e. not in a worker. + + Tasks that query Wikidata should be separated from other tasks by being sent to a different queue, by e.g. + + @celery.shared_task(bind=True, queue='wdqs') + @utils.queries_wikidata + def task_function(self, …, templated_wikidata_query=None): + … + """ + @functools.wraps(task_func) + def new_task_func(self: celery.task.Task, *args, **kwargs): + def handle_ratelimiting(suspend): + app = self.app + # Celery routes to the right queue using the default exchange and a routing key, so the routing key tells + # us our queue name. See . + queue = self.request.delivery_info['routing_key'] + # This identifies the current celery worker + nodename = self.request.hostname + with app.connection_or_acquire() as conn: + if suspend: + logger.info("WDQS rate-limiting started; WDQS task consumption suspended") + app.control.cancel_consumer(queue, connection=conn, destination=[nodename]) + self.update_state(state='RATE_LIMITED') + else: + logger.info("WDQS rate-limiting finished; WDQS task consumption resumed") + app.control.add_consumer(queue, connection=conn, destination=[nodename]) + self.update_state(state='ACTIVE') + + # Only use a handler if executing in a celery worker. + rate_limiting_handler = handle_ratelimiting if not self.request.called_directly else None + + return task_func(self, *args, rate_limiting_handler=rate_limiting_handler, **kwargs) + + return new_task_func From dfedd3e8c3be9e30b6843d2e1d80a687cfc02819 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 7 Feb 2019 15:30:14 +0000 Subject: [PATCH 2/2] Make the default worker handle tasks from both the default and wdqs queues --- Procfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Procfile b/Procfile index 0040402..314e2e9 100644 --- a/Procfile +++ b/Procfile @@ -1,6 +1,6 @@ release: python manage.py migrate web: gunicorn commons_api.wsgi:application -worker: celery -A commons_api worker --beat --without-heartbeat -X shapefiles -n default-worker@%h +worker: celery -A commons_api worker --beat --without-heartbeat -Q celery,wdqs -n default-worker@%h # celery_beat: celery -A commons_api beat --without-heartbeat # shapefiles_worker: celery -A commons_api worker --without-heartbeat -c 1 -Q shapefiles -n shapefiles-worker@%h --max-tasks-per-child=1