Skip to content

Commit

Permalink
refactor?: add type hints, prepare for mypy
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Schimmelschulze <[email protected]>
  • Loading branch information
becktob and andreas-sipgate committed May 23, 2024
1 parent 22e7e7e commit 3d65aa6
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 38 deletions.
28 changes: 14 additions & 14 deletions http_request_recorder/http_request_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from aiohttp import web
from aiohttp.web_request import BaseRequest

ResponsesType = Union[str, bytes, web.Response, Iterable[str], Iterable[bytes], Iterable[web.Response]]
ResponsesType = str | bytes | web.Response


class RecordedRequest:
def __init__(self, ):
self.body = None
self.method = None
self.path = None
self.headers = None
def __init__(self) -> None:
self.body: bytes | None = None
self.method: str | None = None
self.path: str | None = None
self.headers: dict[str, str] | None = None

@staticmethod
async def from_base_request(request: BaseRequest):
async def from_base_request(request: BaseRequest) -> "RecordedRequest":
recorded_request = RecordedRequest()

recorded_request.body = await request.read()
Expand All @@ -34,12 +34,12 @@ async def from_base_request(request: BaseRequest):

class ExpectedInteraction:
class SingleRequest:
def __init__(self, response):
def __init__(self, response) -> None:
self.request = None
self.was_triggered = Event()
self.response = response

def __init__(self, matcher: Callable[[RecordedRequest], bool], responses: ResponsesType, name: str, timeout: int):
def __init__(self, matcher: Callable[[RecordedRequest], bool], responses: ResponsesType | Iterable[ResponsesType], name: str, timeout: int) -> None:
self.name: str = name
self._timeout: int = timeout

Expand All @@ -58,22 +58,22 @@ def __init__(self, matcher: Callable[[RecordedRequest], bool], responses: Respon
self._next_for_response, self._next_to_return = tee(self.responses)
self._matcher: Callable[[RecordedRequest], bool] = matcher

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} '{self.name}'>"

def record_once(self, request_body: bytes):
def record_once(self, request_body: bytes) -> ResponsesType:
for_response = next(self._next_for_response)
for_response.request = request_body
for_response.was_triggered.set()
self._recorded.append(for_response)
return for_response.response

def is_still_expecting_requests(self):
def is_still_expecting_requests(self) -> bool:
if self.expected_count is None:
return False
return len(self._recorded) < self.expected_count

def can_respond(self, request: RecordedRequest):
def can_respond(self, request: RecordedRequest) -> bool:
responds_infinitely = self.expected_count is None
if responds_infinitely:
will_respond = True
Expand All @@ -99,7 +99,7 @@ async def wait(self) -> str:


class HttpRequestRecorder:
def __init__(self, name: str, port: int):
def __init__(self, name: str, port: int) -> None:
self._logger = getLogger("recorder")

self._name = name
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
aiohttp~=3.8.4
aiohttp~=3.8.4
types-setuptools
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='http_request_recorder',
version='0.4.0',
version='0.4.1',
description='A package to record an respond to http requests, primarily for use in black box testing.',
long_description=readme,
author='',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from http_request_recorder.http_request_recorder import HttpRequestRecorder


async def main():
async def main() -> None:
async with (
HttpRequestRecorder('any_recorder_name', 8080) as recorder,
ClientSession() as http_session
Expand Down
42 changes: 21 additions & 21 deletions tests/test_http_request_recorder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import re
import unittest
from typing import Generator

from aiohttp import web, ClientSession

Expand All @@ -28,18 +28,18 @@


class TestHttpRequestRecorder(unittest.IsolatedAsyncioTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, name: str) -> None:
super().__init__(name)
self.port = 18080

async def test_recorder_with_no_routes_yields_404(self):
async def test_recorder_with_no_routes_yields_404(self) -> None:
async with (HttpRequestRecorder(name="testrecorder", port=self.port),
ClientSession() as http_session):
response = await http_session.get(f"http://localhost:{self.port}/")

self.assertEqual(404, response.status)

async def test_recorder_with_a_route_returns_given_response(self):
async def test_recorder_with_a_route_returns_given_response(self) -> None:
async with (HttpRequestRecorder(name="testrecorder", port=self.port) as recorder,
ClientSession() as http_session):
recorder.expect_path(path="/path", responses="response")
Expand All @@ -49,7 +49,7 @@ async def test_recorder_with_a_route_returns_given_response(self):
self.assertEqual(200, response.status)
self.assertEqual(b"response", await response.content.read())

async def test_recorder_with_a_route_records_request(self):
async def test_recorder_with_a_route_records_request(self) -> None:
async with (HttpRequestRecorder(name="testrecorder", port=self.port) as recorder,
ClientSession() as http_session):
expectation = recorder.expect_path(path="/path", responses="response")
Expand All @@ -60,7 +60,7 @@ async def test_recorder_with_a_route_records_request(self):

self.assertEqual(b'testbody', recorded_request)

async def test_recorder_records_requests_to_different_paths(self):
async def test_recorder_records_requests_to_different_paths(self) -> None:
async with (HttpRequestRecorder(name="testrecorder", port=self.port) as recorder,
ClientSession() as http_session):
expectation1 = recorder.expect_path(path="/path1", responses="response1")
Expand All @@ -78,7 +78,7 @@ async def test_recorder_records_requests_to_different_paths(self):
self.assertEqual(b'response1', await response1.content.read())
self.assertEqual(b'response3', await response3.content.read())

async def test_multiple_replies_to_same_path(self):
async def test_multiple_replies_to_same_path(self) -> None:
async with (HttpRequestRecorder(name="multi-responding recorder", port=self.port) as recorder,
ClientSession() as http_session):
responses = ("response_0", "response_1", "response_2")
Expand All @@ -101,7 +101,7 @@ async def test_multiple_replies_to_same_path(self):

# TODO: error on more requests to route than prepared/expected responses

async def test_successful_response_logs_debug_message(self):
async def test_successful_response_logs_debug_message(self) -> None:
with self.assertLogs("recorder", level=logging.INFO) as log_recorder:
logging.getLogger("recorder").addHandler(logging.StreamHandler())

Expand All @@ -116,7 +116,7 @@ async def test_successful_response_logs_debug_message(self):
self.assertIn(request_path, logs[0])
self.assertIn("PUT", logs[0])

async def test_handle_unsatisfied_expectations(self):
async def test_handle_unsatisfied_expectations(self) -> None:
with self.assertLogs("recorder", level=logging.INFO) as log_recorder:
logging.getLogger("recorder").addHandler(logging.StreamHandler()) # also output logging

Expand All @@ -140,7 +140,7 @@ async def test_handle_unsatisfied_expectations(self):
unsatisfied = {e.name for e in recorder.unsatisfied_expectations()}
self.assertSetEqual({"/never_gets_called", "/neither"}, unsatisfied)

async def test_handle_unexpected_requests(self):
async def test_handle_unexpected_requests(self) -> None:
with self.assertLogs("recorder", level=logging.INFO) as log_recorder:
logging.getLogger("recorder").addHandler(logging.StreamHandler()) # also output logging

Expand All @@ -161,11 +161,11 @@ async def test_handle_unexpected_requests(self):
self.assertEqual("/called", unexpected_requests[0].path)
self.assertEqual("GET", unexpected_requests[0].method)

async def test_should_handle_late_request(self):
async def test_should_handle_late_request(self) -> None:
async with HttpRequestRecorder(name="patient recorder", port=self.port) as recorder, ClientSession() as http_session:
expectation = recorder.expect_path(path='/called-late', responses="response")

async def late_post_request():
async def late_post_request() -> None:
await asyncio.sleep(0.2)
await http_session.post(f"http://localhost:{self.port}/called-late", data='late_data')

Expand All @@ -176,15 +176,15 @@ async def late_post_request():
self.assertIn(b'late_data', recorded_request)

# TODO: re-enable and define assertion(s)
async def disabled_test_timeout_on_unrequested_expected_request(self):
async def disabled_test_timeout_on_unrequested_expected_request(self) -> None:
async with HttpRequestRecorder(name="disappointed recorder", port=self.port) as recorder:
expectation = recorder.expect_path(path='never called', responses="unused response")
# no request is sent.

await expectation.wait()


async def test_matching_on_body(self):
async def test_matching_on_body(self) -> None:
async with (HttpRequestRecorder(name="different body recorder", port=self.port) as recorder,
ClientSession() as http_session):
recorder.expect(lambda request: b"foo" in request.body, responses="foo called", name="foo-matcher")
Expand All @@ -199,7 +199,7 @@ async def test_matching_on_body(self):
self.assertIn(b"foo", foo_response_body)
self.assertIn(b"bar", bar_response_body)

async def test_exception_for_ambiguous_matching(self):
async def test_exception_for_ambiguous_matching(self) -> None:
with self.assertLogs("aiohttp", level=logging.ERROR) as aiohttp_recorder:
logging.getLogger("aiohttp").addHandler(logging.StreamHandler())

Expand All @@ -219,7 +219,7 @@ async def test_exception_for_ambiguous_matching(self):
record = logs[0]
self.assertIn("Error handling request", record.msg)

async def test_bytes_response(self):
async def test_bytes_response(self) -> None:
async with (HttpRequestRecorder(name="byte-returning recorder", port=self.port) as recorder,
ClientSession() as http_session):
recorder.expect_path("/", b'nom.')
Expand All @@ -228,7 +228,7 @@ async def test_bytes_response(self):

self.assertEqual(b'nom.', await response.read())

async def test_native_response(self):
async def test_native_response(self) -> None:
async with (HttpRequestRecorder(name="native-response-returning recorder", port=self.port) as recorder,
ClientSession() as http_session):
recorder.expect_path("/", web.Response(status=214, body='{}', content_type='application/json'))
Expand All @@ -239,14 +239,14 @@ async def test_native_response(self):
self.assertEqual('application/json', response.content_type)
self.assertEqual(b'{}', await response.read())

async def test_responds_infinitely(self):
async def test_responds_infinitely(self) -> None:
# neither "unexpected" nor "unsatisfied"
with self.assertNoLogs("recorder", level=logging.WARNING):
logging.getLogger("recorder").addHandler(logging.StreamHandler())

async with (HttpRequestRecorder(name="infinite responder", port=self.port) as recorder,
ClientSession() as http_session):
def inifinite_responses():
def inifinite_responses() -> Generator[bytes, None, None]:
while True:
yield b'on and on...'

Expand All @@ -255,7 +255,7 @@ def inifinite_responses():
for _ in range(10):
await http_session.post(f"http://localhost:{self.port}/")

async def test_matches_on_headers(self):
async def test_matches_on_headers(self) -> None:
async with (HttpRequestRecorder(name="header-sensitive recorder", port=self.port) as recorder,
ClientSession() as http_session):
foo_expect = recorder.expect(lambda req: "foo" in req.headers, "foo-response")
Expand Down

0 comments on commit 3d65aa6

Please sign in to comment.