Skip to content

Commit

Permalink
fix everything reported by mypy --strict
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Schimmelschulze <[email protected]>
  • Loading branch information
becktob and andreas-sipgate committed May 29, 2024
1 parent 8808e5f commit addaab3
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions http_request_recorder/http_request_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from asyncio import Event
from itertools import tee
from logging import getLogger
from typing import Iterable
from typing import Iterable, Any
from collections.abc import Callable

from aiohttp import web
Expand All @@ -15,10 +15,10 @@

class RecordedRequest:
def __init__(self) -> None:
self.body: bytes | None = None
self.method: str | None = None
self.path: str | None = None
self.headers: dict[str, str] | None = None
self.body: bytes = b''
self.method: str = ""
self.path: str = ""
self.headers: dict[str, str] = dict()

@staticmethod
async def from_base_request(request: BaseRequest) -> "RecordedRequest":
Expand All @@ -34,10 +34,10 @@ async def from_base_request(request: BaseRequest) -> "RecordedRequest":

class ExpectedInteraction:
class SingleRequest:
def __init__(self, response) -> None:
def __init__(self, response: ResponsesType) -> None:
self.request: bytes | None = None
self.was_triggered = Event()
self.response = response
self.response: ResponsesType = response

def __init__(self, matcher: Callable[[RecordedRequest], bool], responses: ResponsesType | Iterable[ResponsesType], name: str | None, timeout: int) -> None:
self.name: str | None = name
Expand All @@ -49,7 +49,9 @@ def __init__(self, matcher: Callable[[RecordedRequest], bool], responses: Respon
self.responses = (ExpectedInteraction.SingleRequest(responses),)
self.expected_count = 1
elif isinstance(responses, Iterable):
self.responses = map(ExpectedInteraction.SingleRequest, responses)
# Mypy thinks `responses` can be an int here - maybe because bytes is almost Iterable[int]
self.responses = (ExpectedInteraction.SingleRequest(chr(r)) if isinstance(
r, int) else ExpectedInteraction.SingleRequest(r) for r in responses)
if hasattr(responses, "__len__"):
self.expected_count = sum(1 for _ in responses)
else:
Expand Down Expand Up @@ -119,24 +121,24 @@ def __init__(self, name: str, port: int) -> None:

self.runner = web.AppRunner(app)

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} '{self._name}' on :{self._port}>"

async def __aenter__(self):
async def __aenter__(self) -> "HttpRequestRecorder":
await self.runner.setup()
site = web.TCPSite(self.runner, '0.0.0.0', self._port)
await site.start()

return self

async def __aexit__(self, *args, **kwargs):
async def __aexit__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
if len(self.unsatisfied_expectations()) > 0:
self._logger.warning(
f"{self} is exiting but there are unsatisfied Expectations: {self.unsatisfied_expectations()}")

await self.runner.cleanup()

async def handle_request(self, request: BaseRequest):
async def handle_request(self, request: BaseRequest) -> web.Response:
request_body = await request.read()
self._logger.info(f"{self} got {await self._request_string_for_log(request)}")

Expand All @@ -161,17 +163,17 @@ async def handle_request(self, request: BaseRequest):

return web.Response(status=200, body=response)

def expect(self, matcher: Callable[[RecordedRequest], bool], responses: ResponsesType = "", name: str | None = None, timeout: int = 3) -> ExpectedInteraction:
def expect(self, matcher: Callable[[RecordedRequest], bool], responses: ResponsesType | Iterable[ResponsesType] = "", name: str | None = None, timeout: int = 3) -> ExpectedInteraction:
expectation = ExpectedInteraction(matcher, responses, name, timeout)
self._expectations.append(expectation)
return expectation

def expect_path(self, path: str, responses: ResponsesType = "", timeout: int = 3) -> ExpectedInteraction:
def expect_path(self, path: str, responses: ResponsesType | Iterable[ResponsesType] = "", timeout: int = 3) -> ExpectedInteraction:
return self.expect(lambda request: path == request.path, responses, name=path, timeout=timeout)

def expect_xml_rpc(self, in_body: bytes, responses: ResponsesType = "", timeout: int = 3):
def expect_xml_rpc(self, in_body: bytes, responses: ResponsesType | Iterable[ResponsesType] = "", timeout: int = 3) -> ExpectedInteraction:
# TODO: test
def matcher(request):
def matcher(request: RecordedRequest) -> bool:
return "/RPC2" == request.path and in_body in request.body
return self.expect(matcher,
responses=responses,
Expand All @@ -187,7 +189,7 @@ def unexpected_requests(self) -> list[RecordedRequest]:
return self._unexpected_requests

@staticmethod
async def _request_string_for_log(request):
async def _request_string_for_log(request: BaseRequest) -> str:
request_body = await request.read()

xml_rpc_method = re.search(
Expand All @@ -199,4 +201,4 @@ async def _request_string_for_log(request):
if json_rpc_method is not None:
return f"{request.method} - jsonRpc - {json_rpc_method.group(0).decode('UTF-8')}"

return f"{request.method} to '{request.path}' with body '{request_body[:10]}'"
return f"{request.method} to '{request.path}' with body {request_body[:10]!r}"

0 comments on commit addaab3

Please sign in to comment.