Skip to content

Commit

Permalink
- factor out API calls into utils.py
Browse files Browse the repository at this point in the history
- Changed from producer/consumer parallelism to (1) collect all packages
and then (2) collect all metadata with thread pool
- Changed from BeautifulSoup hardcoded html parsing of <pre> tag to more
flexible `trafilatura.extract` since some documents are more complex
html
  • Loading branch information
nkandpa2 committed Jun 20, 2024
1 parent 68e44b9 commit f14783b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 121 deletions.
38 changes: 5 additions & 33 deletions usgpo/download-files.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import argparse
import datetime
import html
import os
import queue
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

import jsonlines
import requests
from bs4 import BeautifulSoup
import trafilatura
from tqdm.auto import tqdm

from utils import api_query
from licensed_pile import logs
from licensed_pile.licenses import PermissiveLicenses
from licensed_pile.write import to_dolma
Expand Down Expand Up @@ -42,41 +38,18 @@ def parse_args():
return args


def api_query(endpoint, headers, params):
logger = logs.get_logger("usgpo")
response = requests.get(endpoint, headers=headers, params=params)
if response.status_code == 429:
# Sleep for an hour if we've hit the rate-limit
logger.info("Sleeping for one hour to avoid rate-limit")
time.sleep(60 * 60)
response = requests.get(endpoint, headers=headers, params=params)
return response


def download_file(api_key, file_url):
response = api_query(file_url, headers=None, params={"api_key": api_key})
text = response.text
return text


def parse_html(text):
# Most documents are primarily pre-formatted text inside of the a <pre> tag
# If so, just take the contents of that tag instead of the whole document
soup = BeautifulSoup(text, "html.parser")
pre_tag = soup.find("pre")
if pre_tag:
parsed_text = pre_tag.get_text()
else:
parsed_text = text
return html.unescape(parsed_text)


def construct_record(api_key, file):
file_url = file["links"].get("txtLink")
if file_url is None:
return None
raw_html = download_file(api_key, file_url)
parsed_text = parse_html(raw_html)
html = download_file(api_key, file_url)
text = trafilatura.extract(html)

return {
"id": file["package_id"],
Expand All @@ -85,8 +58,7 @@ def construct_record(api_key, file):
"author": file["author"],
"publisher": file["publisher"],
"category": file["category"],
"html": raw_html,
"text": parsed_text,
"text": text,
"source": SOURCE_NAME,
"added": datetime.datetime.utcnow().isoformat(),
"metadata": {"license": str(PermissiveLicenses.PD), "url": file_url},
Expand Down
127 changes: 39 additions & 88 deletions usgpo/get-links.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import argparse
import json
import os
import queue
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

import jsonlines
import requests
from tqdm.auto import tqdm

from utils import api_query
from licensed_pile import logs


Expand All @@ -34,7 +32,6 @@ def parse_args():
"CRI",
"CZIC",
"GAOREPORTS",
"GOVPUB",
"GPO",
"HJOURNAL",
"HOB",
Expand All @@ -47,37 +44,13 @@ def parse_args():
return args


def api_query(endpoint, headers, params):
def get_packages(api_key, collections, start_date):
logger = logs.get_logger("usgpo")
response = requests.get(endpoint, headers=headers, params=params)
if response.status_code == 429:
# Sleep for an hour if we've hit the rate-limit
logger.info("Sleeping for one hour to avoid rate-limit")
time.sleep(60 * 60)
response = requests.get(endpoint, headers=headers, params=params)
return response


def get_collections(api_key):
logger = logs.get_logger("usgpo")
response = api_query(
"https://api.govinfo.gov/collections",
headers={"accept": "application/json"},
params={"api_key": args.api_key},
)
if response.status_code == 200:
output = response.json()
for record in output["collections"]:
yield record["collectionCode"]
else:
logger.error(f"get_collections received status code {response.status_code}")


def get_packages(api_key, collections, start_date, package_queue):
logger = logs.get_logger("usgpo")
url = f"https://api.govinfo.gov/published/{start_date}"
offset_mark = "*"
pbar = tqdm(desc="Producer")
packages = []
pbar = tqdm()
while url is not None:
response = api_query(
url,
Expand All @@ -93,20 +66,19 @@ def get_packages(api_key, collections, start_date, package_queue):
output = response.json()

for record in output["packages"]:
package_queue.put(record)
packages.append(record)
pbar.update(1)

url = output["nextPage"]
offset_mark = None
# Prevent too many API requests in a short period of time
# Sleep since a sudden burst of requests seems to result in erroneous rate-limiting
time.sleep(5)
else:
logger.error(
f"get_packages received status code {response.status_code} for query {url}"
)
break

package_queue.put(None)
return packages


def get_file_links(api_key, package):
Expand All @@ -122,63 +94,42 @@ def get_file_links(api_key, package):
return None


def get_package_metadata(api_key, package_queue, metadata_queue):
pbar = tqdm(desc="Consumer")
while True:
package = package_queue.get()
if package is None:
package_queue.put(None)
metadata_queue.put(None)
break

record = {
"title": package.get("title"),
"package_id": package.get("packageId"),
"date": package.get("dateIssued"),
"category": package.get("category"),
"author": package.get("governmentAuthor1"),
"publisher": package.get("publisher"),
"links": get_file_links(api_key, package),
}
metadata_queue.put(record)
pbar.update(1)


def write_metadata(output_dir, metadata_queue):
with jsonlines.open(os.path.join(output_dir, "links.jsonl"), mode="w") as writer:
pbar = tqdm(desc="Writer")
while True:
metadata = metadata_queue.get()
if metadata is None:
metadata_queue.task_done()
break

writer.write(metadata)
pbar.update(1)
def get_package_metadata(api_key, package):
record = {
"title": package.get("title"),
"package_id": package.get("packageId"),
"date": package.get("dateIssued"),
"category": package.get("category"),
"author": package.get("governmentAuthor1"),
"publisher": package.get("publisher"),
"links": get_file_links(api_key, package),
}
return record


def main(args):
logger = logs.get_logger("usgpo")
os.makedirs(args.output_dir, exist_ok=True)

package_queue = queue.Queue()
metadata_queue = queue.Queue()

with ThreadPoolExecutor(max_workers=args.workers + 2) as executor:
# One thread for getting each package (i.e. file) from the specified collections
executor.submit(
get_packages, args.api_key, args.collections, args.start_date, package_queue
)

# `args.workers` threads for getting package metadata
for _ in range(args.workers):
executor.submit(
get_package_metadata, args.api_key, package_queue, metadata_queue
)

# One thread for writing out the package metadata to disk
executor.submit(write_metadata, args.output_dir, metadata_queue)

metadata_queue.join()
# Get packages from the specified USGPO collections from `args.start_date` to current day
logger.info(f"Getting packages from the following collections: {args.collections}")
packages = get_packages(args.api_key, args.collections, args.start_date)

logger.info(f"Getting package metadata and writing out to {args.output_dir}")
with jsonlines.open(os.path.join(args.output_dir, "links.jsonl"), mode="w", flush=True) as writer:
# Spawn multiple worker threads to get the metadata associated with all packages
with ThreadPoolExecutor(max_workers=args.workers) as executor:
metadata_futures_to_package = {executor.submit(get_package_metadata, args.api_key, package): package for package in packages}

# Write out package metadata to file
for metadata_future in tqdm(as_completed(metadata_futures_to_package)):
package = metadata_futures_to_package[metadata_future]
try:
record = metadata_future.result()
except Exception as e:
logger.error(f"Package {package} raised exception {e}")
continue
writer.write(record)


if __name__ == "__main__":
Expand Down
15 changes: 15 additions & 0 deletions usgpo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import requests
import time

from licensed_pile import logs


def api_query(endpoint, headers, params):
logger = logs.get_logger("usgpo")
response = requests.get(endpoint, headers=headers, params=params)
if response.status_code == 429:
# Sleep for an hour if we've hit the rate-limit
logger.info("Exceeded rate-limit, sleeping for one hour")
time.sleep(60 * 60)
response = requests.get(endpoint, headers=headers, params=params)
return response

0 comments on commit f14783b

Please sign in to comment.