diff --git a/cl/recap/tasks.py b/cl/recap/tasks.py index ee674a9f25..8d89c09d71 100644 --- a/cl/recap/tasks.py +++ b/cl/recap/tasks.py @@ -144,6 +144,20 @@ async def process_recap_upload(pq: ProcessingQueue) -> None: docket = await process_recap_acms_docket(pq.pk) +def build_pdf_retrieval_task_chain( + fq: PacerFetchQueue, rate_limit: str = None +): + rd_pk = fq.recap_document_id + pacer_fetch_task = fetch_pacer_doc_by_rd.si(rd_pk, fq.pk) + if rate_limit: + pacer_fetch_task = pacer_fetch_task.set(rate_limit=rate_limit) + return chain( + pacer_fetch_task, + extract_recap_pdf.si(rd_pk), + mark_fq_successful.si(fq.pk), + ) + + def do_pacer_fetch(fq: PacerFetchQueue): """Process a request made by a user to get an item from PACER. @@ -160,12 +174,7 @@ def do_pacer_fetch(fq: PacerFetchQueue): result = c.apply_async() elif fq.request_type == REQUEST_TYPE.PDF: # Request by recap_document_id - rd_pk = fq.recap_document_id - result = chain( - fetch_pacer_doc_by_rd.si(rd_pk, fq.pk), - extract_recap_pdf.si(rd_pk), - mark_fq_successful.si(fq.pk), - ).apply_async() + result = build_pdf_retrieval_task_chain(fq).apply_async() elif fq.request_type == REQUEST_TYPE.ATTACHMENT_PAGE: result = fetch_attachment_page.apply_async(args=(fq.pk,)) return result diff --git a/cl/search/management/commands/pacer_bulk_fetch.py b/cl/search/management/commands/pacer_bulk_fetch.py new file mode 100644 index 0000000000..140ec5395c --- /dev/null +++ b/cl/search/management/commands/pacer_bulk_fetch.py @@ -0,0 +1,233 @@ +import logging +from datetime import datetime + +from django.contrib.auth.models import User +from django.core.management.base import CommandError +from django.db.models import Q + +from cl import settings +from cl.lib.celery_utils import CeleryThrottle +from cl.lib.command_utils import VerboseCommand +from cl.lib.pacer_session import get_or_cache_pacer_cookies +from cl.recap.models import REQUEST_TYPE, PacerFetchQueue +from cl.recap.tasks import build_pdf_retrieval_task_chain +from cl.search.models import Court, RECAPDocument + +logger = logging.getLogger(__name__) + + +class Command(VerboseCommand): + help = "Download multiple documents from PACER with rate limiting" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.user = None + self.recap_documents = None + self.courts_with_docs = {} + self.total_launched = 0 + self.total_errors = 0 + self.pacer_username = None + self.pacer_password = None + self.throttle = None + self.queue_name = None + self.rate_limit = None + + def add_arguments(self, parser) -> None: + parser.add_argument( + "--rate-limit", + type=float, + help="The maximum rate for requests, e.g. '1/m', or '10/2h' or similar. Defaults to 1/2s", + ) + parser.add_argument( + "--min-page-count", + type=int, + help="Get docs with this number of pages or more", + ) + parser.add_argument( + "--max-page-count", + type=int, + help="Get docs with this number of pages or less", + ) + parser.add_argument( + "--username", + type=str, + help="Username to associate with the processing queues (defaults to 'recap')", + ) + parser.add_argument( + "--queue-name", + type=str, + help="Celery queue name used for processing tasks", + ) + parser.add_argument( + "--testing", + type=str, + help="Prevents creation of log file", + ) + + @staticmethod + def setup_logging(testing: bool = False) -> None: + if not testing: + logging.basicConfig( + filename=f'pacer_bulk_fetch_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log', + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + ) + + def setup_celery(self, options) -> None: + """Setup Celery by setting the queue_name, rate_limit and throttle.""" + self.queue_name = options.get("queue_name", "pacer_bulk_fetch") + self.rate_limit = options.get("rate_limit", "1/2s") + self.throttle = CeleryThrottle(queue_name=self.queue_name) + + def handle_pacer_session(self, options) -> None: + """Make sure we have an active PACER session for the user.""" + self.pacer_username = options.get( + "pacer_username", settings.PACER_USERNAME + ) + self.pacer_password = options.get( + "pacer_password", settings.PACER_PASSWORD + ) + get_or_cache_pacer_cookies( + self.user.pk, + username=self.pacer_username, + password=self.pacer_password, + ) + + def set_user(self, username: str) -> None: + """Get user or raise CommandError""" + if not username: + raise CommandError( + "No username provided, cannot create PacerFetchQueues." + ) + try: + self.user = User.objects.get(username=username) + except User.DoesNotExist: + raise CommandError(f"User {username} does not exist") + + def identify_documents(self, options: dict) -> None: + """Get eligible documents grouped by court""" + filters = [ + Q(pacer_doc_id__isnull=False), + Q(is_available=False), + ] + if options.get("min_page_count"): + filters.append(Q(page_count__gte=options["min_page_count"])) + if options.get("max_page_count"): + filters.append(Q(page_count__lte=options["max_page_count"])) + + self.recap_documents = ( + RECAPDocument.objects.filter(*filters) + .values( + "id", + "page_count", + "docket_entry__docket__court_id", + "pacer_doc_id", + ) + .order_by("-page_count") + ) + + courts = ( + Court.objects.filter( + dockets__docket_entries__recap_documents__in=[ + recap_doc_id["id"] for recap_doc_id in self.recap_documents + ] + ) + .order_by("pk") + .distinct() + ) + + for court in courts: + self.courts_with_docs[court.pk] = [ + doc + for doc in self.recap_documents + if doc["docket_entry__docket__court_id"] == court.pk + ] + + def enqueue_pacer_fetch(self, doc: dict) -> None: + self.throttle.maybe_wait() + + fq = PacerFetchQueue.objects.create( + request_type=REQUEST_TYPE.PDF, + recap_document_id=doc.get("id"), + user_id=self.user.pk, + ) + build_pdf_retrieval_task_chain( + fq, + rate_limit=self.rate_limit, + ).apply_async(queue=self.queue_name) + self.total_launched += 1 + logger.info( + f"Launched download for doc {doc.get('id')} from court {doc.get('docket_entry__docket__court_id')}" + f"\nProgress: {self.total_launched}/{len(self.recap_documents)}" + ) + + def execute_round( + self, remaining_courts: dict, options: dict, is_last_round: bool + ) -> dict: + remaining_courts_copy = ( + remaining_courts.copy() + ) # don't remove elements from list we're iterating over + court_keys = remaining_courts.keys() + for court_index, court_id in enumerate(court_keys): + doc = remaining_courts[court_id].pop(0) + + try: + self.enqueue_pacer_fetch(doc) + except Exception as e: + self.total_errors += 1 + logger.error( + f"Error queuing document {doc.get("id")}: {str(e)}", + exc_info=True, + ) + finally: + # If this court doesn't have any more docs, remove from dict: + if len(remaining_courts[court_id]) == 0: + remaining_courts_copy.pop(court_id) + + return remaining_courts_copy + + def process_documents(self, options: dict) -> None: + """Process documents in round-robin fashion by court""" + remaining_courts = self.courts_with_docs + court_doc_counts = [ + len(self.courts_with_docs[court_id]) + for court_id in self.courts_with_docs.keys() + ] + rounds = max(court_doc_counts) + + for i in range(rounds): + is_last_round = i == rounds - 1 + remaining_courts = self.execute_round( + remaining_courts, options, is_last_round + ) + + if self.total_errors: + logger.error( + f"Finished processing with {self.total_errors} error{"s" if self.total_errors > 1 else ""}." + ) + + def handle(self, *args, **options) -> None: + self.setup_logging(options.get("testing", False)) + self.setup_celery(options) + + logger.info("Starting pacer_bulk_fetch command") + + try: + self.set_user(options.get("username", "recap")) + self.handle_pacer_session(options) + + self.identify_documents(options) + + logger.info( + f"{self.user} found {len(self.recap_documents)} documents across {len(self.courts_with_docs)} courts." + ) + + self.process_documents(options) + + logger.info( + f"Created {self.total_launched} processing queues for a total of {len(self.recap_documents)} docs found." + ) + + except Exception as e: + logger.error(f"Fatal error in command: {str(e)}", exc_info=True) + raise diff --git a/cl/search/tests/test_pacer_bulk_fetch.py b/cl/search/tests/test_pacer_bulk_fetch.py new file mode 100644 index 0000000000..ec5219a65a --- /dev/null +++ b/cl/search/tests/test_pacer_bulk_fetch.py @@ -0,0 +1,251 @@ +import random +from unittest.mock import MagicMock, patch + +from cl.recap.models import PacerFetchQueue +from cl.search.factories import ( + CourtFactory, + DocketEntryFactory, + DocketFactory, + RECAPDocumentFactory, +) +from cl.search.management.commands.pacer_bulk_fetch import Command +from cl.search.models import Docket, RECAPDocument +from cl.tests.cases import TestCase +from cl.users.factories import UserFactory + + +class BulkFetchPacerDocsTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.user = UserFactory() + + cls.courts = [CourtFactory() for _ in range(6)] + + dockets_per_court = 15 + entries_per_docket = 8 + + page_count_ranges = [ + (1000, 2000), + (500, 999), + (100, 499), + (1, 99), + ] + cls.big_page_count = 1000 + cls.big_docs_count = 0 + + for court in cls.courts: + [DocketFactory(court=court) for _ in range(dockets_per_court)] + + for docket in Docket.objects.all(): + docket_entries = [ + DocketEntryFactory(docket=docket) + for _ in range(entries_per_docket) + ] + + for de in docket_entries: + min_pages, max_pages = random.choice(page_count_ranges) + page_count = random.randint(min_pages, max_pages) + cls.big_docs_count += 1 if page_count >= 1000 else 0 + RECAPDocumentFactory( + docket_entry=de, + page_count=page_count, + is_available=False, + ) + + def setUp(self): + self.command = Command() + self.big_docs_created = RECAPDocument.objects.filter( + page_count__gte=self.big_page_count, + is_available=False, + pacer_doc_id__isnull=False, + ) + self.assertEqual(self.big_docs_count, self.big_docs_created.count()) + + @patch( + "cl.search.management.commands.pacer_bulk_fetch.CeleryThrottle.maybe_wait" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.build_pdf_retrieval_task_chain" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.get_or_cache_pacer_cookies" + ) + def test_document_filtering( + self, + mock_pacer_cookies, + mock_chain_builder, + mock_throttle, + ): + """Test document filtering according to command arguments passed.""" + # Setup mock chain + mock_chain = MagicMock() + mock_chain_builder.return_value = mock_chain + + self.command.handle( + min_page_count=self.big_page_count, + request_interval=1.0, + username=self.user.username, + testing=True, + ) + + self.assertEqual( + mock_chain.apply_async.call_count, + self.big_docs_count, + f"Expected {self.big_docs_count} documents to be processed", + ) + + fetch_queues = PacerFetchQueue.objects.all() + self.assertEqual( + fetch_queues.count(), + self.big_docs_count, + f"Expected {self.big_docs_count} fetch queues", + ) + + enqueued_doc_ids = [fq.recap_document_id for fq in fetch_queues] + big_doc_ids = self.big_docs_created.values_list("id", flat=True) + self.assertSetEqual(set(enqueued_doc_ids), set(big_doc_ids)) + + @patch( + "cl.search.management.commands.pacer_bulk_fetch.CeleryThrottle.maybe_wait" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.build_pdf_retrieval_task_chain" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.get_or_cache_pacer_cookies" + ) + def test_rate_limiting( + self, + mock_pacer_cookies, + mock_chain_builder, + mock_throttle, + ): + """Test rate limiting.""" + # Setup mock chain + mock_chain = MagicMock() + mock_chain_builder.return_value = mock_chain + + rate_limit = "10/m" + self.command.handle( + min_page_count=1000, + rate_limit=rate_limit, + username=self.user.username, + testing=True, + ) + + # Verify the rate limit was passed correctly + for call in mock_chain_builder.call_args_list: + with self.subTest(call=call): + _, kwargs = call + self.assertEqual( + kwargs.get("rate_limit"), + rate_limit, + "Rate limit should be passed to chain builder", + ) + + self.assertEqual( + mock_throttle.call_count, + self.big_docs_count, + "CeleryThrottle.maybe_wait should be called for each document", + ) + + @patch( + "cl.search.management.commands.pacer_bulk_fetch.CeleryThrottle.maybe_wait" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.build_pdf_retrieval_task_chain" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.get_or_cache_pacer_cookies" + ) + def test_error_handling( + self, + mock_pacer_cookies, + mock_chain_builder, + mock_throttle, + ): + """Test that errors are handled gracefully""" + mock_chain_builder.side_effect = Exception("Chain building error") + + self.command.handle( + min_page_count=1000, + username=self.user.username, + testing=True, + ) + + self.assertEqual( + PacerFetchQueue.objects.count(), + self.big_docs_count, + "PacerFetchQueue objects should still be created even if chain building fails", + ) + + @patch( + "cl.search.management.commands.pacer_bulk_fetch.CeleryThrottle.maybe_wait" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.build_pdf_retrieval_task_chain" + ) + @patch( + "cl.search.management.commands.pacer_bulk_fetch.get_or_cache_pacer_cookies" + ) + def test_round_robin( + self, + mock_pacer_cookies, + mock_chain_builder, + mock_throttle, + ): + """ + Verify that each call to 'execute_round' never processes the same court + more than once. + """ + mock_chain = MagicMock() + mock_chain_builder.return_value = mock_chain + + calls_per_round = [] + original_execute_round = self.command.execute_round + + def track_rounds_side_effect(remaining_courts, options, is_last_round): + """ + Tracks PacerFetchQueue creation before and after calling execute_round + to identify which courts were processed in each round. + """ + start_count = PacerFetchQueue.objects.count() + updated_remaining = original_execute_round( + remaining_courts, options, is_last_round + ) + end_count = PacerFetchQueue.objects.count() + + # Get the fetch queues created in this round + current_round_queues = PacerFetchQueue.objects.order_by("pk")[ + start_count:end_count + ] + calls_per_round.append(current_round_queues) + + return updated_remaining + + with patch.object( + Command, "execute_round", side_effect=track_rounds_side_effect + ): + self.command.handle( + min_page_count=1000, + request_interval=1.0, + username=self.user.username, + testing=True, + ) + + for round_index, round_queues in enumerate(calls_per_round, start=1): + court_ids_this_round = [] + + for queue in round_queues: + court_id = queue.recap_document.docket_entry.docket.court_id + court_ids_this_round.append(court_id) + + with self.subTest( + court_ids_this_round=court_ids_this_round, + round_index=round_index, + ): + self.assertEqual( + len(court_ids_this_round), + len(set(court_ids_this_round)), + f"Round {round_index} had duplicate courts: {court_ids_this_round}", + )