From e89e0c1519a825d493c886b84748ecfeb04d1dca Mon Sep 17 00:00:00 2001 From: Dan Yazovsky Date: Sun, 14 May 2023 10:04:45 +0100 Subject: [PATCH] v0.1.7 (#9) * Adding support for setting scraper state shared between jobs * Added support for downloading multiple urls at once * Added scraper context tests * Added helper functions to download and process files in parallel * Update version --- .gitignore | 6 +- Makefile | 2 +- poetry.lock | 17 +- pyproject.toml | 3 +- sneakpeek/lib/models.py | 2 + sneakpeek/runner.py | 50 ++++- sneakpeek/scraper_context.py | 392 +++++++++++++++++++++++++++------- tests/test_scraper_context.py | 322 ++++++++++++++++++++++++++++ 8 files changed, 706 insertions(+), 88 deletions(-) create mode 100644 tests/test_scraper_context.py diff --git a/.gitignore b/.gitignore index 4352bd6..ef3dc20 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,9 @@ **/__pycache__/* .venv - *.install.stamp - dist .dist **/.pytest_cache/* .pytest_cache/ - -.coverage \ No newline at end of file +.coverage +htmlcov \ No newline at end of file diff --git a/Makefile b/Makefile index b6688fb..f9f402f 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ test: $(PY_INSTALL_STAMP) ##Run tests .PHONE: coverage coverage: $(PY_INSTALL_STAMP) ##Run tests - $(POETRY) run pytest --cov=sneakpeek tests --cov-fail-under=70 + $(POETRY) run pytest --cov=sneakpeek tests --cov-fail-under=70 --cov-report term-missing --cov-report html build-ui: ##Build frontend $(YARN) --cwd $(ROOT_DIR)/front/ quasar build diff --git a/poetry.lock b/poetry.lock index 7a9a384..0f84f38 100644 --- a/poetry.lock +++ b/poetry.lock @@ -127,6 +127,21 @@ async-timeout = ">=4.0.0" [package.extras] aiohttp = ["aiohttp (>=3.8.0)"] +[[package]] +name = "aioresponses" +version = "0.7.4" +description = "Mock out requests made by ClientSession from aiohttp package" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "aioresponses-0.7.4-py2.py3-none-any.whl", hash = "sha256:1160486b5ea96fcae6170cf2bdef029b9d3a283b7dbeabb3d7f1182769bfb6b7"}, + {file = "aioresponses-0.7.4.tar.gz", hash = "sha256:9b8c108b36354c04633bad0ea752b55d956a7602fe3e3234b939fc44af96f1d8"}, +] + +[package.dependencies] +aiohttp = ">=2.0.0,<4.0.0" + [[package]] name = "aiosignal" version = "1.3.1" @@ -1637,4 +1652,4 @@ docs = ["Sphinx", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d800062c13d978b1e03542b1a3b00a35f9b38b81471d12cd73d70fd68705de1c" +content-hash = "655b2ffeb244dee0b37c0d757072db9afd8f35f0299a4a8cb469477e930d0c41" diff --git a/pyproject.toml b/pyproject.toml index 1c4f452..a4110d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "sneakpeek-py" packages = [{ include = "sneakpeek" }] -version = "0.1.6" +version = "0.1.7" description = "Sneakpeek is a framework that helps to quickly and conviniently develop scrapers. It's the best choice for scrapers that have some specific complex scraping logic that needs to be run on a constant basis." authors = ["Dan Yazovsky "] maintainers = ["Dan Yazovsky "] @@ -49,6 +49,7 @@ black = "^23.3.0" pytest-lazy-fixture = "^0.6.3" pytest-asyncio = "^0.21.0" pytest-cov = "^4.0.0" +aioresponses = "^0.7.4" [build-system] requires = ["poetry-core"] diff --git a/sneakpeek/lib/models.py b/sneakpeek/lib/models.py index 0b07993..62c8a9a 100644 --- a/sneakpeek/lib/models.py +++ b/sneakpeek/lib/models.py @@ -55,6 +55,8 @@ class Scraper(BaseModel): config: ScraperConfig #: Scraper configuration that is passed to the handler #: Default priority to enqueue scraper jobs with schedule_priority: ScraperJobPriority = ScraperJobPriority.NORMAL + #: Scraper state (might be useful to optimise scraping, e.g. only process pages that weren't processed in the last jobs) + state: str | None = None class ScraperJob(BaseModel): diff --git a/sneakpeek/runner.py b/sneakpeek/runner.py index 0af9ad4..32f842b 100644 --- a/sneakpeek/runner.py +++ b/sneakpeek/runner.py @@ -8,7 +8,7 @@ from prometheus_client import Counter from sneakpeek.lib.errors import ScraperJobPingFinishedError, UnknownScraperHandlerError -from sneakpeek.lib.models import ScraperJob, ScraperJobStatus +from sneakpeek.lib.models import Scraper, ScraperJob, ScraperJobStatus from sneakpeek.lib.queue import QueueABC from sneakpeek.lib.storage.base import ScraperJobsStorage from sneakpeek.logging import configure_logging, scraper_job_context @@ -124,26 +124,43 @@ async def ping_session(): class LocalRunner: """Scraper runner that is meant to be used for local debugging""" + @staticmethod + async def _ping_session(): + logging.debug("Pinging session") + + @staticmethod + async def _update_scraper_state(state: str) -> Scraper | None: + logging.debug(f"Updating scraper state with: {state}") + return None + @staticmethod async def run_async( handler: ScraperHandler, config: ScraperConfig, plugins: list[Plugin] | None = None, + scraper_state: str | None = None, logging_level: int = logging.DEBUG, ) -> None: """ Execute scraper locally. Args: - config (ScraperConfig): Scraper config + handler (ScraperHandler): Scraper handler to execute + config (ScraperConfig): Scraper config to pass to the handler + plugins (list[Plugin] | None, optional): List of plugins that will be used by scraper runner. Defaults to None. + scraper_state (str | None, optional): Scraper state to pass to the handler. Defaults to None. + logging_level (int, optional): Minimum logging level. Defaults to logging.DEBUG. """ configure_logging(logging_level) logging.info("Starting scraper") - async def ping_session(): - pass - - context = ScraperContext(config, plugins, ping_session) + context = ScraperContext( + config, + plugins, + scraper_state=scraper_state, + ping_session_func=LocalRunner._ping_session, + update_scraper_state_func=LocalRunner._update_scraper_state, + ) try: await context.start_session() result = await handler.run(context) @@ -159,6 +176,25 @@ def run( handler: ScraperHandler, config: ScraperConfig, plugins: list[Plugin] | None = None, + scraper_state: str | None = None, logging_level: int = logging.DEBUG, ) -> None: - asyncio.run(LocalRunner.run_async(handler, config, plugins, logging_level)) + """ + Execute scraper locally. + + Args: + handler (ScraperHandler): Scraper handler to execute + config (ScraperConfig): Scraper config to pass to the handler + plugins (list[Plugin] | None, optional): List of plugins that will be used by scraper runner. Defaults to None. + scraper_state (str | None, optional): Scraper state to pass to the handler. Defaults to None. + logging_level (int, optional): Minimum logging level. Defaults to logging.DEBUG. + """ + asyncio.run( + LocalRunner.run_async( + handler, + config, + plugins, + scraper_state, + logging_level, + ) + ) diff --git a/sneakpeek/scraper_context.py b/sneakpeek/scraper_context.py index d03adde..14f4e58 100644 --- a/sneakpeek/scraper_context.py +++ b/sneakpeek/scraper_context.py @@ -1,9 +1,14 @@ +import asyncio import logging +import os import re +import sys +import tempfile from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, Callable +from typing import Any, Awaitable, Callable +from uuid import uuid4 import aiohttp @@ -11,6 +16,7 @@ ScraperJobPingFinishedError, ScraperJobPingNotStartedError, ) +from sneakpeek.lib.models import Scraper from sneakpeek.scraper_config import ScraperConfig HttpHeaders = dict[str, str] @@ -23,7 +29,7 @@ class HttpMethod(str, Enum): GET = "get" POST = "post" HEAD = "head" - PUT = "PUT" + PUT = "put" DELETE = "delete" OPTIONS = "options" @@ -38,6 +44,27 @@ class Request: kwargs: dict[str, Any] | None = None +@dataclass +class _BatchRequest: + """HTTP Batch request metadata""" + + method: HttpMethod + urls: list[str] + headers: HttpHeaders | None = None + kwargs: dict[str, Any] | None = None + + def to_single_requests(self) -> list[Request]: + return [ + Request( + method=self.method, + url=url, + headers=self.headers, + kwargs=self.kwargs, + ) + for url in self.urls + ] + + @dataclass class RegexMatch: """Regex match""" @@ -105,6 +132,7 @@ async def after_response( Plugin = BeforeRequestPlugin | AfterResponsePlugin +Response = aiohttp.ClientResponse | list[aiohttp.ClientResponse | Exception] class ScraperContext: @@ -117,16 +145,22 @@ def __init__( self, config: ScraperConfig, plugins: list[Plugin] | None = None, + scraper_state: str | None = None, ping_session_func: Callable | None = None, + update_scraper_state_func: Callable | None = None, ) -> None: """ Args: config (ScraperConfig): Scraper configuration plugins (list[BeforeRequestPlugin | AfterResponsePlugin] | None, optional): List of available plugins. Defaults to None. + scraper_state (str | None, optional): Scraper state. Defaults to None. ping_session_func (Callable | None, optional): Function that pings scraper job. Defaults to None. + update_scraper_state_func (Callable | None, optional): Function that update scraper state. Defaults to None. """ self.params = config.params - self.ping_session_func = ping_session_func + self.state = scraper_state + self._ping_session_func = ping_session_func + self._update_scraper_state_func = update_scraper_state_func self._logger = logging.getLogger(__name__) self._plugins_configs = config.plugins or {} self._session: aiohttp.ClientSession | None = None @@ -136,7 +170,7 @@ def __init__( def _init_plugins(self, plugins: list[Plugin] | None = None) -> None: for plugin in plugins or []: - if not plugin.name.isidentifier: + if not plugin.name.isidentifier(): raise ValueError( "Plugin name must be a Python identifier. " f"Plugin {plugin.__class__} has invalid name: {plugin.name}" @@ -179,8 +213,7 @@ async def _after_response( ) return response - async def _request(self, request: Request) -> aiohttp.ClientResponse: - await self.ping_session() + async def _single_request(self, request: Request) -> aiohttp.ClientResponse: request = await self._before_request(request) response = await getattr(self._session, request.method)( request.url, @@ -188,18 +221,41 @@ async def _request(self, request: Request) -> aiohttp.ClientResponse: **(request.kwargs or {}), ) response = await self._after_response(request, response) - await self.ping_session() return response + async def _request( + self, + request: _BatchRequest, + max_concurrency: int = 0, + return_exceptions: bool = False, + ) -> Response: + await self.ping_session() + single_requests = request.to_single_requests() + if len(single_requests) == 1: + return await self._single_request(single_requests[0]) + + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency > 0 else None + + async def process_request(request: Request): + if semaphore: + async with semaphore: + return await self._single_request(request) + return await self._single_request(request) + + return await asyncio.gather( + *[process_request(request) for request in single_requests], + return_exceptions=return_exceptions, + ) + async def ping_session(self) -> None: """Ping scraper job, so it's not considered dead""" - if not self.ping_session_func: + if not self._ping_session_func: self._logger.warning( "Tried to ping scraper job, but the function to ping session is None" ) return try: - await self.ping_session_func() + await self._ping_session_func() except ScraperJobPingNotStartedError as e: self._logger.error( f"Failed to ping PENDING scraper job because due to some infra error: {e}" @@ -213,142 +269,314 @@ async def ping_session(self) -> None: except Exception as e: self._logger.error(f"Failed to ping scraper job: {e}") - async def get( + async def request( self, - url: str, + method: HttpMethod, + url: str | list[str], *, headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, **kwargs, - ) -> aiohttp.ClientResponse: - """Make GET request to the given URL + ) -> Response: + """Perform HTTP request to the given URL(s) Args: - url (str): URL to send GET request to + method (HttpMethod): HTTP request method to perform + url (str | list[str]): URL(s) to send HTTP request to headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. - **kwargs: See aiohttp.get() for the full list of arguments + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, + **kwargs: See aiohttp.request() for the full list of arguments + + Returns: + Response: HTTP response(s) """ return await self._request( - Request( - method=HttpMethod.GET, - url=url, + _BatchRequest( + method=method, + urls=url if isinstance(url, list) else [url], headers=headers, kwargs=kwargs, - ) + ), + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, ) - async def post( + async def download_file( self, + method: HttpMethod, url: str, *, + file_path: str | None = None, + file_process_fn: Callable[[str], Awaitable[Any]] | None = None, headers: HttpHeaders | None = None, **kwargs, - ) -> aiohttp.ClientResponse: - """Make POST request to the given URL + ) -> str | Any: + """Perform HTTP request and save it to the specified file Args: - url (str): URL to send POST request to + method (HttpMethod): HTTP request method to perform + url (str): URL to send HTTP request to + file_path (str, optional): Path of the file to save request to. If not specified temporary file name will be generated. Defaults to None. + file_process_fn (Callable[[str], Any], optional): Function to process the file. If specified then function will be applied to the file and its result will be returned, the file will be removed after the function call. Defaults to None. headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + **kwargs: See aiohttp.request() for the full list of arguments + + Returns: + str | Any: File path if file process function is not defined or file process function result otherwise + """ + if not file_path: + file_path = os.path.join(tempfile.mkdtemp(), str(uuid4())) + response = await self.request( + method=method, + url=url, + headers=headers, + **kwargs, + ) + contents = await response.read() + with open(file_path, "wb") as f: + f.write(contents) + if not file_process_fn: + return file_path + result = await file_process_fn(file_path) + os.remove(file_path) + return result + + async def download_files( + self, + method: HttpMethod, + urls: list[str], + *, + file_paths: list[str] | None = None, + file_process_fn: Callable[[str], Awaitable[Any]] | None = None, + headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, + **kwargs, + ) -> list[str | Any | Exception]: + """Perform HTTP requests and save them to the specified files + + Args: + method (HttpMethod): HTTP request method to perform + urls (list[str]): URLs to send HTTP request to + file_paths (list[str], optional): Path of the files to save requests to. If not specified temporary file names will be generated. Defaults to None. + file_process_fn (Callable[[str], Any], optional): Function to process the file. If specified then function will be applied to the file and its result will be returned, the file will be removed after the function call. Defaults to None. + headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, + **kwargs: See aiohttp.request() for the full list of arguments + + Returns: + list[str | Any | Exception]: For each URL: file path if file process function is not defined or file process function result otherwise + """ + if file_paths: + if len(file_paths) != len(urls): + raise ValueError( + f"Expected to have 1 file path per 1 URL, only have {len(file_paths)} for {len(urls)} URLs" + ) + + semaphore = asyncio.Semaphore( + max_concurrency if max_concurrency > 0 else sys.maxsize + ) + + async def process_request(url: str, file_path: str): + async with semaphore: + return await self.download_file( + method, + url, + file_path=file_path, + file_process_fn=file_process_fn, + headers=headers, + **kwargs, + ) + + return await asyncio.gather( + *[ + process_request(url, file_path) + for url, file_path in zip(urls, file_paths) + ], + return_exceptions=return_exceptions, + ) + + async def get( + self, + url: str | list[str], + *, + headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, + **kwargs, + ) -> Response: + """Make GET request to the given URL(s) + + Args: + url (str | list[str]): URL(s) to send GET request to + headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, **kwargs: See aiohttp.get() for the full list of arguments + + Returns: + Response: HTTP response(s) """ - return await self._request( - Request( - method=HttpMethod.POST, - url=url, - headers=headers, - kwargs=kwargs, - ) + return await self.request( + HttpMethod.GET, + url, + headers=headers, + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, + **kwargs, + ) + + async def post( + self, + url: str | list[str], + *, + headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, + **kwargs, + ) -> Response: + """Make POST request to the given URL(s) + + Args: + url (str | list[str]): URL(s) to send POST request to + headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, + **kwargs: See aiohttp.post() for the full list of arguments + + Returns: + Response: HTTP response(s) + """ + return await self.request( + HttpMethod.POST, + url, + headers=headers, + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, + **kwargs, ) async def head( self, - url: str, + url: str | list[str], *, headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, **kwargs, - ) -> aiohttp.ClientResponse: - """Make HEAD request to the given URL + ) -> Response: + """Make HEAD request to the given URL(s) Args: - url (str): URL to send HEAD request to + url (str | list[str]): URL(s) to send HEAD request to headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, **kwargs: See aiohttp.head() for the full list of arguments + + Returns: + Response: HTTP response(s) """ - return await self._request( - Request( - method=HttpMethod.HEAD, - url=url, - headers=headers, - kwargs=kwargs, - ) + return await self.request( + HttpMethod.HEAD, + url, + headers=headers, + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, + **kwargs, ) async def delete( self, - url: str, + url: str | list[str], *, headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, **kwargs, - ) -> aiohttp.ClientResponse: - """Make DELETE request to the given URL + ) -> Response: + """Make DELETE request to the given URL(s) Args: - url (str): URL to send DELETE request to + url (str | list[str]): URL(s) to send DELETE request to headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, **kwargs: See aiohttp.delete() for the full list of arguments + + Returns: + Response: HTTP response(s) """ - return await self._request( - Request( - method=HttpMethod.DELETE, - url=url, - headers=headers, - kwargs=kwargs, - ) + return await self.request( + HttpMethod.DELETE, + url, + headers=headers, + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, + **kwargs, ) async def put( self, - url: str, + url: str | list[str], *, headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, **kwargs, - ) -> aiohttp.ClientResponse: - """Make PUT request to the given URL + ) -> Response: + """Make PUT request to the given URL(s) Args: - url (str): URL to send PUT request to + url (str | list[str]): URL(s) to send PUT request to headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, **kwargs: See aiohttp.put() for the full list of arguments + + Returns: + Response: HTTP response(s) """ - return await self._request( - Request( - method=HttpMethod.PUT, - url=url, - headers=headers, - kwargs=kwargs, - ) + return await self.request( + HttpMethod.PUT, + url, + headers=headers, + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, + **kwargs, ) async def options( self, - url: str, + url: str | list[str], *, headers: HttpHeaders | None = None, + max_concurrency: int = 0, + return_exceptions: bool = False, **kwargs, - ) -> aiohttp.ClientResponse: - """Make OPTIONS request to the given URL + ) -> Response: + """Make OPTIONS request to the given URL(s) Args: - url (str): URL to send OPTIONS request to + url (str | list[str]): URL(s) to send OPTIONS request to headers (HttpHeaders | None, optional): HTTP headers. Defaults to None. + max_concurrency (int, optional): Maximum number of concurrent requests. If set to 0 no limit is applied. Defaults to 0. + return_exceptions (bool, optional): Whether to return exceptions instead of raising if there are multiple URLs provided. Defaults to False, **kwargs: See aiohttp.options() for the full list of arguments + + Returns: + Response: HTTP response(s) """ - return await self._request( - Request( - method=HttpMethod.OPTIONS, - url=url, - headers=headers, - kwargs=kwargs, - ) + return await self.request( + HttpMethod.OPTIONS, + url, + headers=headers, + max_concurrency=max_concurrency, + return_exceptions=return_exceptions, + **kwargs, ) def regex( @@ -371,3 +599,19 @@ def regex( RegexMatch(full_match=match.group(0), groups=match.groupdict()) for match in re.finditer(pattern, text, flags) ] + + async def update_scraper_state(self, state: str) -> Scraper: + """Update scraper state + + Args: + state (str): State to persist + + Returns: + Scraper: Updated scraper metadata + """ + if not self._update_scraper_state_func: + self._logger.warning( + "Tried to update scraper state, but the function to do it is not set" + ) + return + return await self._update_scraper_state_func(state) diff --git a/tests/test_scraper_context.py b/tests/test_scraper_context.py new file mode 100644 index 0000000..23e89e7 --- /dev/null +++ b/tests/test_scraper_context.py @@ -0,0 +1,322 @@ +import os +from typing import Any +from unittest.mock import AsyncMock, call, patch + +import aiohttp +import pytest +from aioresponses import aioresponses + +from sneakpeek.scraper_config import ScraperConfig +from sneakpeek.scraper_context import ( + AfterResponsePlugin, + BeforeRequestPlugin, + HttpMethod, + Plugin, + Request, + ScraperContext, +) + + +class MockPlugin(BeforeRequestPlugin, AfterResponsePlugin): + def __init__(self) -> None: + self.before_request_mock = AsyncMock() + self.after_response_mock = AsyncMock() + + @property + def name(self) -> str: + return "test" + + async def before_request( + self, + request: Request, + config: Any | None = None, + ) -> Request: + await self.before_request_mock(request.url, config) + return request + + async def after_response( + self, + request: Request, + response: aiohttp.ClientResponse, + config: Any | None = None, + ) -> aiohttp.ClientResponse: + await self.after_response_mock(request.url, config) + return response + + +def context( + plugins: list[Plugin] | None = None, + plugins_configs: dict[str, Any] | None = None, +) -> ScraperContext: + async def ping(): + pass + + return ScraperContext( + ScraperConfig(plugins=plugins_configs), + plugins=plugins, + ping_session_func=ping, + ) + + +@pytest.mark.parametrize("method", ["get", "post", "put", "delete", "options", "head"]) +@pytest.mark.asyncio +async def test_http_methods(method: str): + url = "test_url" + headers = {"header1": "value1"} + ctx = context() + await ctx.start_session() + with patch( + f"aiohttp.ClientSession.{method}", + new_callable=AsyncMock, + ) as mocked_request: + await getattr(ctx, method)(url, headers=headers) + mocked_request.assert_called_once_with(url, headers=headers) + + +@pytest.mark.parametrize("max_concurrency", [-1, 0, 1]) +@pytest.mark.parametrize("method", ["get", "post", "put", "delete", "options", "head"]) +@pytest.mark.asyncio +async def test_http_methods_multiple(method: str, max_concurrency: int): + urls = [f"url{i}" for i in range(10)] + headers = {"header1": "value1"} + ctx = context() + await ctx.start_session() + with patch( + f"aiohttp.ClientSession.{method}", + new_callable=AsyncMock, + ) as mocked_request: + responses = await getattr(ctx, method)( + urls, + headers=headers, + max_concurrency=max_concurrency, + ) + assert len(responses) == len( + urls + ), f"Expected {len(urls)} responses but received {len(responses)}" + mocked_request.assert_has_awaits( + [call(url, headers=headers) for url in urls], + any_order=True, + ) + + +@pytest.mark.parametrize("method", ["get", "post", "put", "delete", "options", "head"]) +@pytest.mark.asyncio +async def test_plugin_is_called(method: str): + urls = [f"url{i}" for i in range(10)] + headers = {"header1": "value1"} + plugin = MockPlugin() + plugin_config = {"config1": "value1"} + ctx = context( + plugins=[plugin], + plugins_configs={plugin.name: plugin_config}, + ) + await ctx.start_session() + with patch( + f"aiohttp.ClientSession.{method}", + new_callable=AsyncMock, + ) as mocked_request: + responses = await getattr(ctx, method)(urls, headers=headers) + assert len(responses) == len( + urls + ), f"Expected {len(urls)} responses but received {len(responses)}" + mocked_request.assert_has_awaits( + [call(url, headers=headers) for url in urls], + any_order=True, + ) + plugin.before_request_mock.assert_has_awaits( + [call(url, plugin_config) for url in urls], + any_order=True, + ) + plugin.after_response_mock.assert_has_awaits( + [call(url, plugin_config) for url in urls], + any_order=True, + ) + + +@pytest.mark.asyncio +async def test_invalid_plugin(): + class InvalidPlugin(BeforeRequestPlugin, AfterResponsePlugin): + @property + def name(self) -> str: + return "not a python identifier" + + async def before_request( + self, request: Request, config: Any | None = None + ) -> Request: + return request + + async def after_response( + self, + request: Request, + response: aiohttp.ClientResponse, + config: Any | None = None, + ) -> aiohttp.ClientResponse: + return response + + with pytest.raises(ValueError): + ScraperContext(ScraperConfig(), plugins=[InvalidPlugin()]) + + +def test_regex(): + text = 'some content' + pattern = r']*href="(?P[^"]+)' + matches = context().regex(text, pattern) + assert len(matches) == 1, "Expected to find a single match" + match = matches[0] + assert match.full_match == '