From f14783bc35054a98c995c4742003bda73c31997f Mon Sep 17 00:00:00 2001 From: Nikhil Kandpal Date: Thu, 20 Jun 2024 14:05:09 -0400 Subject: [PATCH] - factor out API calls into `utils.py` - Changed from producer/consumer parallelism to (1) collect all packages and then (2) collect all metadata with thread pool - Changed from BeautifulSoup hardcoded html parsing of
 tag to more flexible `trafilatura.extract` since some
 documents are more complex html

---
 usgpo/download-files.py |  38 ++----------
 usgpo/get-links.py      | 127 ++++++++++++----------------------------
 usgpo/utils.py          |  15 +++++
 3 files changed, 59 insertions(+), 121 deletions(-)
 create mode 100644 usgpo/utils.py

diff --git a/usgpo/download-files.py b/usgpo/download-files.py
index f094893..69dad32 100644
--- a/usgpo/download-files.py
+++ b/usgpo/download-files.py
@@ -1,16 +1,12 @@
 import argparse
 import datetime
-import html
-import os
-import queue
-import time
 from concurrent.futures import ThreadPoolExecutor, as_completed
 
 import jsonlines
-import requests
-from bs4 import BeautifulSoup
+import trafilatura
 from tqdm.auto import tqdm
 
+from utils import api_query
 from licensed_pile import logs
 from licensed_pile.licenses import PermissiveLicenses
 from licensed_pile.write import to_dolma
@@ -42,41 +38,18 @@ def parse_args():
     return args
 
 
-def api_query(endpoint, headers, params):
-    logger = logs.get_logger("usgpo")
-    response = requests.get(endpoint, headers=headers, params=params)
-    if response.status_code == 429:
-        # Sleep for an hour if we've hit the rate-limit
-        logger.info("Sleeping for one hour to avoid rate-limit")
-        time.sleep(60 * 60)
-        response = requests.get(endpoint, headers=headers, params=params)
-    return response
-
-
 def download_file(api_key, file_url):
     response = api_query(file_url, headers=None, params={"api_key": api_key})
     text = response.text
     return text
 
 
-def parse_html(text):
-    # Most documents are primarily pre-formatted text inside of the a 
 tag
-    # If so, just take the contents of that tag instead of the whole document
-    soup = BeautifulSoup(text, "html.parser")
-    pre_tag = soup.find("pre")
-    if pre_tag:
-        parsed_text = pre_tag.get_text()
-    else:
-        parsed_text = text
-    return html.unescape(parsed_text)
-
-
 def construct_record(api_key, file):
     file_url = file["links"].get("txtLink")
     if file_url is None:
         return None
-    raw_html = download_file(api_key, file_url)
-    parsed_text = parse_html(raw_html)
+    html = download_file(api_key, file_url)
+    text = trafilatura.extract(html)
 
     return {
         "id": file["package_id"],
@@ -85,8 +58,7 @@ def construct_record(api_key, file):
         "author": file["author"],
         "publisher": file["publisher"],
         "category": file["category"],
-        "html": raw_html,
-        "text": parsed_text,
+        "text": text,
         "source": SOURCE_NAME,
         "added": datetime.datetime.utcnow().isoformat(),
         "metadata": {"license": str(PermissiveLicenses.PD), "url": file_url},
diff --git a/usgpo/get-links.py b/usgpo/get-links.py
index 91d8ffa..b6ae62d 100644
--- a/usgpo/get-links.py
+++ b/usgpo/get-links.py
@@ -1,14 +1,12 @@
 import argparse
-import json
 import os
-import queue
 import time
 from concurrent.futures import ThreadPoolExecutor, as_completed
 
 import jsonlines
-import requests
 from tqdm.auto import tqdm
 
+from utils import api_query
 from licensed_pile import logs
 
 
@@ -34,7 +32,6 @@ def parse_args():
             "CRI",
             "CZIC",
             "GAOREPORTS",
-            "GOVPUB",
             "GPO",
             "HJOURNAL",
             "HOB",
@@ -47,37 +44,13 @@ def parse_args():
     return args
 
 
-def api_query(endpoint, headers, params):
+def get_packages(api_key, collections, start_date):
     logger = logs.get_logger("usgpo")
-    response = requests.get(endpoint, headers=headers, params=params)
-    if response.status_code == 429:
-        # Sleep for an hour if we've hit the rate-limit
-        logger.info("Sleeping for one hour to avoid rate-limit")
-        time.sleep(60 * 60)
-        response = requests.get(endpoint, headers=headers, params=params)
-    return response
 
-
-def get_collections(api_key):
-    logger = logs.get_logger("usgpo")
-    response = api_query(
-        "https://api.govinfo.gov/collections",
-        headers={"accept": "application/json"},
-        params={"api_key": args.api_key},
-    )
-    if response.status_code == 200:
-        output = response.json()
-        for record in output["collections"]:
-            yield record["collectionCode"]
-    else:
-        logger.error(f"get_collections received status code {response.status_code}")
-
-
-def get_packages(api_key, collections, start_date, package_queue):
-    logger = logs.get_logger("usgpo")
     url = f"https://api.govinfo.gov/published/{start_date}"
     offset_mark = "*"
-    pbar = tqdm(desc="Producer")
+    packages = []
+    pbar = tqdm()
     while url is not None:
         response = api_query(
             url,
@@ -93,20 +66,19 @@ def get_packages(api_key, collections, start_date, package_queue):
             output = response.json()
 
             for record in output["packages"]:
-                package_queue.put(record)
+                packages.append(record)
                 pbar.update(1)
 
             url = output["nextPage"]
             offset_mark = None
-            # Prevent too many API requests in a short period of time
+            # Sleep since a sudden burst of requests seems to result in erroneous rate-limiting
             time.sleep(5)
         else:
             logger.error(
                 f"get_packages received status code {response.status_code} for query {url}"
             )
             break
-
-    package_queue.put(None)
+    return packages
 
 
 def get_file_links(api_key, package):
@@ -122,63 +94,42 @@ def get_file_links(api_key, package):
     return None
 
 
-def get_package_metadata(api_key, package_queue, metadata_queue):
-    pbar = tqdm(desc="Consumer")
-    while True:
-        package = package_queue.get()
-        if package is None:
-            package_queue.put(None)
-            metadata_queue.put(None)
-            break
-
-        record = {
-            "title": package.get("title"),
-            "package_id": package.get("packageId"),
-            "date": package.get("dateIssued"),
-            "category": package.get("category"),
-            "author": package.get("governmentAuthor1"),
-            "publisher": package.get("publisher"),
-            "links": get_file_links(api_key, package),
-        }
-        metadata_queue.put(record)
-        pbar.update(1)
-
-
-def write_metadata(output_dir, metadata_queue):
-    with jsonlines.open(os.path.join(output_dir, "links.jsonl"), mode="w") as writer:
-        pbar = tqdm(desc="Writer")
-        while True:
-            metadata = metadata_queue.get()
-            if metadata is None:
-                metadata_queue.task_done()
-                break
-
-            writer.write(metadata)
-            pbar.update(1)
+def get_package_metadata(api_key, package):
+    record = {
+        "title": package.get("title"),
+        "package_id": package.get("packageId"),
+        "date": package.get("dateIssued"),
+        "category": package.get("category"),
+        "author": package.get("governmentAuthor1"),
+        "publisher": package.get("publisher"),
+        "links": get_file_links(api_key, package),
+    }
+    return record
 
 
 def main(args):
+    logger = logs.get_logger("usgpo")
     os.makedirs(args.output_dir, exist_ok=True)
-
-    package_queue = queue.Queue()
-    metadata_queue = queue.Queue()
-
-    with ThreadPoolExecutor(max_workers=args.workers + 2) as executor:
-        # One thread for getting each package (i.e. file) from the specified collections
-        executor.submit(
-            get_packages, args.api_key, args.collections, args.start_date, package_queue
-        )
-
-        # `args.workers` threads for getting package metadata
-        for _ in range(args.workers):
-            executor.submit(
-                get_package_metadata, args.api_key, package_queue, metadata_queue
-            )
-
-        # One thread for writing out the package metadata to disk
-        executor.submit(write_metadata, args.output_dir, metadata_queue)
-
-        metadata_queue.join()
+    
+    # Get packages from the specified USGPO collections from `args.start_date` to current day
+    logger.info(f"Getting packages from the following collections: {args.collections}")
+    packages = get_packages(args.api_key, args.collections, args.start_date)
+    
+    logger.info(f"Getting package metadata and writing out to {args.output_dir}")
+    with jsonlines.open(os.path.join(args.output_dir, "links.jsonl"), mode="w", flush=True) as writer:
+        # Spawn multiple worker threads to get the metadata associated with all packages
+        with ThreadPoolExecutor(max_workers=args.workers) as executor:
+            metadata_futures_to_package = {executor.submit(get_package_metadata, args.api_key, package): package for package in packages}
+
+            # Write out package metadata to file
+            for metadata_future in tqdm(as_completed(metadata_futures_to_package)):
+                package = metadata_futures_to_package[metadata_future]
+                try:
+                    record = metadata_future.result()
+                except Exception as e:
+                    logger.error(f"Package {package} raised exception {e}")
+                    continue
+                writer.write(record)
 
 
 if __name__ == "__main__":
diff --git a/usgpo/utils.py b/usgpo/utils.py
new file mode 100644
index 0000000..cfd35d3
--- /dev/null
+++ b/usgpo/utils.py
@@ -0,0 +1,15 @@
+import requests
+import time
+
+from licensed_pile import logs
+
+
+def api_query(endpoint, headers, params):
+    logger = logs.get_logger("usgpo")
+    response = requests.get(endpoint, headers=headers, params=params)
+    if response.status_code == 429:
+        # Sleep for an hour if we've hit the rate-limit
+        logger.info("Exceeded rate-limit, sleeping for one hour")
+        time.sleep(60 * 60)
+        response = requests.get(endpoint, headers=headers, params=params)
+    return response