Skip to content

Commit

Permalink
Merge pull request #360 from flairNLP/refactor-filter
Browse files Browse the repository at this point in the history
Refactor `ExtractionFilter` and `Requires`
  • Loading branch information
MaxDall authored Feb 18, 2024
2 parents 7ece2f0 + 2659bee commit 69c9984
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 18 deletions.
76 changes: 69 additions & 7 deletions src/fundus/scraping/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,47 @@


def inverse(filter_func: Callable[P, bool]) -> Callable[P, bool]:
"""Logical not operator that can be used on filters
Args:
filter_func: The filter function to inverse.
Returns:
bool: boolean value of the evaluation
"""

def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
return not filter_func(*args, **kwargs)

return __call__


def lor(*filters: Callable[P, bool]) -> Callable[P, bool]:
"""Logical or operator that can be used on filters
Args:
*filters: The filter functions to or.
Returns:
bool: boolean value of the evaluation
"""

def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
return any(f(*args, **kwargs) for f in filters)

return __call__


def land(*filters: Callable[P, bool]) -> Callable[P, bool]:
"""Logical and operator that can be used on filters
Args:
*filters: The filter functions to and.
Returns:
bool: boolean value of the evaluation
"""

def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
return all(f(*args, **kwargs) for f in filters)

Expand All @@ -38,7 +65,7 @@ def __call__(self, url: str) -> bool:
"""Filters a website, represented by a given <url>, on the criterion if it represents an <article>
Args:
url (str): The url the evaluation should be based on.
url: The url the evaluation should be based on.
Returns:
bool: True if an <url> should be filtered out and not
Expand All @@ -55,21 +82,26 @@ def url_filter(url: str) -> bool:
return url_filter


class SupportsBool(Protocol):
def __bool__(self) -> bool:
...


class ExtractionFilter(Protocol):
"""Protocol to define filters used after article extraction.
Filters satisfying this protocol should work inverse to build in filter(),
so that True gets filtered and False don't.
"""

def __call__(self, extracted: Dict[str, Any]) -> bool:
def __call__(self, extraction: Dict[str, Any]) -> SupportsBool:
"""This should implement a selection based on <extracted>.
Extracted will be a dictionary returned by a parser mapping the attribute
names of the parser to the extracted values.
Args:
extracted (dict[str, Any]): The extracted values the evaluation
extraction: The extracted values the evaluation
should be based on.
Returns:
Expand All @@ -79,11 +111,41 @@ def __call__(self, extracted: Dict[str, Any]) -> bool:
...


class FilterResultWithMissingAttributes:
def __init__(self, *attributes: str) -> None:
self.missing_attributes = attributes

def __bool__(self) -> bool:
return bool(self.missing_attributes)


class Requires:
def __init__(self, *required_attributes: str) -> None:
"""Class to filter extractions based on attribute values
If a required_attribute is not present in the extracted data, this filter won't
be passed.
Args:
*required_attributes: Attributes required to evaluate to True in order to
pass the filter. If no attributes are given, all attributes will be evaluated:
"""
self.required_attributes = set(required_attributes)

def __call__(self, extracted: Dict[str, Any]) -> bool:
return not all(
bool(value := extracted.get(attr)) and not isinstance(value, Exception) for attr in self.required_attributes
)
def __call__(self, extraction: Dict[str, Any]) -> FilterResultWithMissingAttributes:
missing_attributes = [
attribute
for attribute in self.required_attributes or extraction.keys()
if not bool(value := extraction.get(attribute)) or isinstance(value, Exception)
]
return FilterResultWithMissingAttributes(*missing_attributes)


class RequiresAll(Requires):
def __init__(self):
"""Name wrap for Requires()
This is for readability only. It requires all attributes of the extraction to evaluate to True.
See class:Requires docstring for more information.
"""
super().__init__()
10 changes: 2 additions & 8 deletions src/fundus/scraping/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from fundus.logging import basic_logger
from fundus.publishers.base_objects import PublisherEnum
from fundus.scraping.article import Article
from fundus.scraping.filter import ExtractionFilter, Requires, URLFilter
from fundus.scraping.filter import ExtractionFilter, Requires, RequiresAll, URLFilter
from fundus.scraping.html import URLSource, session_handler
from fundus.scraping.scraper import Scraper
from fundus.utils.more_async import ManagedEventLoop, async_next
Expand Down Expand Up @@ -91,13 +91,7 @@ async def crawl_async(

def build_extraction_filter() -> Optional[ExtractionFilter]:
if isinstance(only_complete, bool):
return (
None
if only_complete is False
else lambda extracted: not all(
bool(v) if not isinstance(v, Exception) else False for _, v in extracted.items()
)
)
return None if only_complete is False else RequiresAll()
else:
return only_complete

Expand Down
17 changes: 14 additions & 3 deletions src/fundus/scraping/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from fundus.logging import basic_logger
from fundus.parser import ParserProxy
from fundus.scraping.article import Article
from fundus.scraping.filter import ExtractionFilter, Requires, URLFilter
from fundus.scraping.filter import (
ExtractionFilter,
FilterResultWithMissingAttributes,
Requires,
URLFilter,
)
from fundus.scraping.html import FundusSource


Expand Down Expand Up @@ -66,8 +71,14 @@ async def scrape(
else:
raise ValueError(f"Unknown value '{error_handling}' for parameter <error_handling>'")

if extraction_filter and extraction_filter(extraction):
basic_logger.debug(f"Skipped article at '{html.requested_url}' because of extraction filter")
if extraction_filter and (filter_result := extraction_filter(extraction)):
if isinstance(filter_result, FilterResultWithMissingAttributes):
basic_logger.debug(
f"Skipped article at '{html.requested_url}' because attribute(s) "
f"{', '.join(filter_result.missing_attributes)!r} is(are) missing"
)
else:
basic_logger.debug(f"Skipped article at '{html.requested_url}' because of extraction filter")
yield None
else:
article = Article.from_extracted(html=html, extracted=extraction)
Expand Down

0 comments on commit 69c9984

Please sign in to comment.