Skip to content

Commit

Permalink
Merge pull request #34 from NREL/pp/wiki_scrape_example
Browse files Browse the repository at this point in the history
Add Wikipedia scraping example
  • Loading branch information
ppinchuk authored Nov 14, 2024
2 parents 1d02910 + 53c17f9 commit 2e4f268
Show file tree
Hide file tree
Showing 10 changed files with 1,083 additions and 162 deletions.
230 changes: 87 additions & 143 deletions elm/ords/download.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
# -*- coding: utf-8 -*-
"""ELM Ordinance county file downloading logic"""
import pprint
import asyncio
import logging
from itertools import zip_longest, chain
from contextlib import AsyncExitStack

from elm.ords.llm import StructuredLLMCaller
from elm.ords.extraction import check_for_ordinance_info
from elm.ords.services.threaded import TempFileCache
from elm.ords.validation.location import CountyValidator
from elm.web.document import PDFDocument
from elm.web.file_loader import AsyncFileLoader
from elm.web.google_search import PlaywrightGoogleLinkSearch
from elm.web.google_search import google_results_as_docs, filter_documents


logger = logging.getLogger(__name__)
Expand All @@ -26,118 +21,6 @@
]


async def _search_single(
location, question, browser_sem, num_results=10, **kwargs
):
"""Perform a single google search."""
if browser_sem is None:
browser_sem = AsyncExitStack()

search_engine = PlaywrightGoogleLinkSearch(**kwargs)
async with browser_sem:
return await search_engine.results(
question.format(location=location),
num_results=num_results,
)


async def _find_urls(location, num_results=10, browser_sem=None, **kwargs):
"""Parse google search output for URLs."""
searchers = [
asyncio.create_task(
_search_single(
location, q, browser_sem, num_results=num_results, **kwargs
),
name=location,
)
for q in QUESTION_TEMPLATES
]
return await asyncio.gather(*searchers)


def _down_select_urls(search_results, num_urls=5):
"""Select the top 5 URLs."""
all_urls = chain.from_iterable(
zip_longest(*[results[0] for results in search_results])
)
urls = set()
for url in all_urls:
if not url:
continue
urls.add(url)
if len(urls) == num_urls:
break
return urls


async def _load_docs(urls, text_splitter, browser_semaphore=None, **kwargs):
"""Load a document for each input URL."""
loader_kwargs = {
"html_read_kwargs": {"text_splitter": text_splitter},
"file_cache_coroutine": TempFileCache.call,
"browser_semaphore": browser_semaphore,
}
loader_kwargs.update(kwargs)
file_loader = AsyncFileLoader(**loader_kwargs)
docs = await file_loader.fetch_all(*urls)

logger.debug(
"Loaded the following number of pages for docs: %s",
pprint.PrettyPrinter().pformat(
{
doc.metadata.get("source", "Unknown"): len(doc.pages)
for doc in docs
}
),
)
return [doc for doc in docs if not doc.empty]


async def _down_select_docs_correct_location(
docs, location, county, state, **kwargs
):
"""Remove all documents not pertaining to the location."""
llm_caller = StructuredLLMCaller(**kwargs)
county_validator = CountyValidator(llm_caller)
searchers = [
asyncio.create_task(
county_validator.check(doc, county=county, state=state),
name=location,
)
for doc in docs
]
output = await asyncio.gather(*searchers)
correct_loc_docs = [doc for doc, check in zip(docs, output) if check]
return sorted(
correct_loc_docs,
key=lambda doc: (not isinstance(doc, PDFDocument), len(doc.text)),
)


async def _check_docs_for_ords(docs, text_splitter, **kwargs):
"""Check documents to see if they contain ordinance info."""
ord_docs = []
for doc in docs:
doc = await check_for_ordinance_info(doc, text_splitter, **kwargs)
if doc.metadata["contains_ord_info"]:
ord_docs.append(doc)
return ord_docs


def _parse_all_ord_docs(all_ord_docs):
"""Parse a list of documents and get the result for the best match."""
if not all_ord_docs:
return None

return sorted(all_ord_docs, key=_ord_doc_sorting_key)[-1]


def _ord_doc_sorting_key(doc):
"""All text sorting key"""
year, month, day = doc.metadata.get("date", (-1, -1, -1))
return year, isinstance(doc, PDFDocument), -1 * len(doc.text), month, day


async def download_county_ordinance(
location,
text_splitter,
Expand All @@ -146,27 +29,29 @@ async def download_county_ordinance(
browser_semaphore=None,
**kwargs
):
"""Download the ordinance document for a single county.
"""Download the ordinance document(s) for a single county.
Parameters
----------
location : elm.ords.utilities.location.Location
location : :class:`elm.ords.utilities.location.Location`
Location objects representing the county.
text_splitter : obj, optional
Instance of an object that implements a `split_text` method.
The method should take text as input (str) and return a list
of text chunks. Langchain's text splitters should work for this
input.
of text chunks. Raw text from HTML pages will be passed through
this splitter to split the single wep page into multiple pages
for the output document. Langchain's text splitters should work
for this input.
num_urls : int, optional
Number of unique Google search result URL's to check for
ordinance document. By default, ``5``.
file_loader_kwargs : dict, optional
Dictionary of keyword-argument pairs to initialize
:class:`elm.web.file_loader.AsyncFileLoader` with. The
:class:`elm.web.file_loader.AsyncFileLoader` with. If found, the
"pw_launch_kwargs" key in these will also be used to initialize
the :class:`elm.web.google_search.PlaywrightGoogleLinkSearch`
used for the google URL search. By default, ``None``.
browser_semaphore : asyncio.Semaphore, optional
browser_semaphore : :class:`asyncio.Semaphore`, optional
Semaphore instance that can be used to limit the number of
playwright browsers open concurrently. If ``None``, no limits
are applied. By default, ``None``.
Expand All @@ -180,30 +65,89 @@ async def download_county_ordinance(
Document instance for the downloaded document, or ``None`` if no
document was found.
"""
file_loader_kwargs = file_loader_kwargs or {}
pw_launch_kwargs = file_loader_kwargs.get("pw_launch_kwargs", {})
urls = await _find_urls(
location.full_name,
num_results=10,
browser_sem=browser_semaphore,
**pw_launch_kwargs
)
urls = _down_select_urls(urls, num_urls=num_urls)
logger.debug("Downloading documents for URLS: \n\t-%s", "\n\t-".join(urls))
docs = await _load_docs(
urls, text_splitter, browser_semaphore, **file_loader_kwargs
docs = await _docs_from_google_search(
location,
text_splitter,
num_urls,
browser_semaphore,
**(file_loader_kwargs or {})
)
docs = await _down_select_docs_correct_location(
docs,
location=location.full_name,
county=location.name,
state=location.state,
**kwargs
docs, location=location, **kwargs
)
docs = await _down_select_docs_correct_content(
docs, location=location, text_splitter=text_splitter, **kwargs
)
docs = await _check_docs_for_ords(docs, text_splitter, **kwargs)
logger.info(
"Found %d potential ordinance documents for %s",
len(docs),
location.full_name,
)
return _parse_all_ord_docs(docs)
return _sort_final_ord_docs(docs)


async def _docs_from_google_search(
location, text_splitter, num_urls, browser_semaphore, **file_loader_kwargs
):
"""Download docs from google location queries. """
queries = [
question.format(location=location.full_name)
for question in QUESTION_TEMPLATES
]
file_loader_kwargs.update(
{
"html_read_kwargs": {"text_splitter": text_splitter},
"file_cache_coroutine": TempFileCache.call,
}
)
return await google_results_as_docs(
queries,
num_urls=num_urls,
text_splitter=text_splitter,
browser_semaphore=browser_semaphore,
task_name=location.full_name,
**file_loader_kwargs,
)


async def _down_select_docs_correct_location(docs, location, **kwargs):
"""Remove all documents not pertaining to the location."""
llm_caller = StructuredLLMCaller(**kwargs)
county_validator = CountyValidator(llm_caller)
return await filter_documents(
docs,
validation_coroutine=county_validator.check,
task_name=location.full_name,
county=location.name,
state=location.state,
)


async def _down_select_docs_correct_content(docs, location, **kwargs):
"""Remove all documents that don't contain ordinance info."""
return await filter_documents(
docs,
validation_coroutine=_contains_ords,
task_name=location.full_name,
**kwargs,
)


async def _contains_ords(doc, **kwargs):
"""Helper coroutine that checks for ordinance info. """
doc = check_for_ordinance_info(doc, **kwargs)
return doc.metadata.get("contains_ord_info", False)


def _sort_final_ord_docs(all_ord_docs):
"""Sort the final list of documents by year, type, and text length."""
if not all_ord_docs:
return None

return sorted(all_ord_docs, key=_ord_doc_sorting_key)[-1]


def _ord_doc_sorting_key(doc):
"""All text sorting key"""
year, month, day = doc.metadata.get("date", (-1, -1, -1))
return year, isinstance(doc, PDFDocument), -1 * len(doc.text), month, day
16 changes: 2 additions & 14 deletions elm/ords/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
"""ELM Ordinance full processing logic"""
import os
import time
import json
import asyncio
Expand All @@ -14,6 +13,7 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter

from elm import ApiBase
from elm.utilities import validate_azure_api_params
from elm.ords.download import download_county_ordinance
from elm.ords.extraction import (
extract_ordinance_text_with_ngram_validation,
Expand Down Expand Up @@ -269,7 +269,7 @@ async def _process_with_logs(
):
"""Process counties with logging enabled."""
counties = _load_counties_to_process(county_fp)
azure_api_key, azure_version, azure_endpoint = validate_api_params(
azure_api_key, azure_version, azure_endpoint = validate_azure_api_params(
azure_api_key, azure_version, azure_endpoint
)

Expand Down Expand Up @@ -379,18 +379,6 @@ def _load_counties_to_process(county_fp):
return load_counties_from_fp(county_fp)


def validate_api_params(azure_api_key=None, azure_version=None,
azure_endpoint=None):
"""Validate OpenAI API parameters."""
azure_api_key = azure_api_key or os.environ.get("AZURE_OPENAI_API_KEY")
azure_version = azure_version or os.environ.get("AZURE_OPENAI_VERSION")
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
assert azure_api_key is not None, "Must set AZURE_OPENAI_API_KEY!"
assert azure_version is not None, "Must set AZURE_OPENAI_VERSION!"
assert azure_endpoint is not None, "Must set AZURE_OPENAI_ENDPOINT!"
return azure_api_key, azure_version, azure_endpoint


def _configure_thread_pool_kwargs(tpe_kwargs):
"""Set thread pool workers to 5 if user didn't specify."""
tpe_kwargs = tpe_kwargs or {}
Expand Down
20 changes: 18 additions & 2 deletions elm/ords/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@


logger = logging.getLogger(__name__)
MISSING_SERVICE_MESSAGE = """Must initialize the queue for {service_name!r}.
You can likely use the following code structure to fix this:
from elm.ords.services.provider import RunningAsyncServices
services = [
...
{service_name}(...),
...
]
async with RunningAsyncServices(services):
# function call here
"""


class Service(ABC):
Expand All @@ -20,9 +34,11 @@ class Service(ABC):
@classmethod
def _queue(cls):
"""Get queue for class."""
queue = get_service_queue(cls.__name__)
service_name = cls.__name__
queue = get_service_queue(service_name)
if queue is None:
raise ELMOrdsNotInitializedError("Must initialize the queue!")
msg = MISSING_SERVICE_MESSAGE.format(service_name=service_name)
raise ELMOrdsNotInitializedError(msg)
return queue

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions elm/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""ELM utility classes and functions. """

from .validation import validate_azure_api_params
Loading

0 comments on commit 2e4f268

Please sign in to comment.