Skip to content

Commit

Permalink
Refactor introduce recording storage (#274)
Browse files Browse the repository at this point in the history
* refactor: separate injection and enable/disable logic
* refactor: add class that handles request records
  • Loading branch information
betaboon authored Nov 26, 2024
1 parent a5b5e34 commit e529319
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 130 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', ' 3.13', 'pypy3.10']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10']

steps:
- uses: actions/checkout@v4
Expand Down
18 changes: 1 addition & 17 deletions mocket/inject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import os
import socket
import ssl
from types import ModuleType
Expand All @@ -23,10 +22,7 @@ def _restore(module: ModuleType, name: str) -> None:
module.__dict__[name] = original_value


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
def enable() -> None:
from mocket.socket import (
MocketSocket,
mock_create_connection,
Expand Down Expand Up @@ -73,14 +69,6 @@ def enable(

extract_from_urllib3()

from mocket.mocket import Mocket

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError


def disable() -> None:
for module, name in list(_patches_restore.keys()):
Expand All @@ -90,7 +78,3 @@ def disable() -> None:
from urllib3.contrib.pyopenssl import inject_into_urllib3

inject_into_urllib3()

from mocket.mocket import Mocket

Mocket.reset()
46 changes: 39 additions & 7 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import collections
import itertools
import os
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar

import mocket.inject
from mocket.recording import MocketRecordStorage

# NOTE this is here for backwards-compat to keep old import-paths working
# from mocket.socket import MocketSocket as MocketSocket
Expand All @@ -20,11 +22,36 @@ class Mocket:
_address: ClassVar[Address] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
_requests: ClassVar[list] = []
_namespace: ClassVar[str] = str(id(_entries))
_truesocket_recording_dir: ClassVar[str | None] = None
_record_storage: ClassVar[MocketRecordStorage | None] = None

enable = mocket.inject.enable
disable = mocket.inject.disable
@classmethod
def enable(
cls,
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
if namespace is None:
namespace = str(id(cls._entries))

if truesocket_recording_dir is not None:
recording_dir = Path(truesocket_recording_dir)

if not recording_dir.is_dir():
# JSON dumps will be saved here
raise AssertionError

cls._record_storage = MocketRecordStorage(
directory=recording_dir,
namespace=namespace,
)

mocket.inject.enable()

@classmethod
def disable(cls) -> None:
cls.reset()

mocket.inject.disable()

@classmethod
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
Expand Down Expand Up @@ -69,6 +96,7 @@ def reset(cls) -> None:
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []
cls._record_storage = None

@classmethod
def last_request(cls):
Expand All @@ -89,12 +117,16 @@ def has_requests(cls) -> bool:
return bool(cls.request_list())

@classmethod
def get_namespace(cls) -> str:
return cls._namespace
def get_namespace(cls) -> str | None:
if not cls._record_storage:
return None
return cls._record_storage.namespace

@classmethod
def get_truesocket_recording_dir(cls) -> str | None:
return cls._truesocket_recording_dir
if not cls._record_storage:
return None
return str(cls._record_storage.directory)

@classmethod
def assert_fail_if_entries_not_served(cls) -> None:
Expand Down
147 changes: 147 additions & 0 deletions mocket/recording.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import contextlib
import hashlib
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path

from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.types import Address
from mocket.utils import hexdump, hexload

hash_function = hashlib.md5

with contextlib.suppress(ImportError):
from xxhash_cffi import xxh32 as xxhash_cffi_xxh32

hash_function = xxhash_cffi_xxh32

with contextlib.suppress(ImportError):
from xxhash import xxh32 as xxhash_xxh32

hash_function = xxhash_xxh32


def _hash_prepare_request(data: bytes) -> bytes:
_data = decode_from_bytes(data)
return encode_to_bytes("".join(sorted(_data.split("\r\n"))))


def _hash_request(data: bytes) -> str:
_data = _hash_prepare_request(data)
return hash_function(_data).hexdigest()


def _hash_request_fallback(data: bytes) -> str:
_data = _hash_prepare_request(data)
return hashlib.md5(_data).hexdigest()


@dataclass
class MocketRecord:
host: str
port: int
request: bytes
response: bytes


class MocketRecordStorage:
def __init__(self, directory: Path, namespace: str) -> None:
self._directory = directory
self._namespace = namespace
self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = (
defaultdict(defaultdict)
)

self._load()

@property
def directory(self) -> Path:
return self._directory

@property
def namespace(self) -> str:
return self._namespace

@property
def file(self) -> Path:
return self._directory / f"{self._namespace}.json"

def _load(self) -> None:
if not self.file.exists():
return

json_data = self.file.read_text()
records = json.loads(json_data)
for host, port_signature_record in records.items():
for port, signature_record in port_signature_record.items():
for signature, record in signature_record.items():
# NOTE backward-compat
try:
request_data = hexload(record["request"])
except ValueError:
request_data = record["request"]

self._records[(host, int(port))][signature] = MocketRecord(
host=host,
port=port,
request=request_data,
response=hexload(record["response"]),
)

def _save(self) -> None:
data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict(
lambda: defaultdict(defaultdict)
)
for address, signature_record in self._records.items():
host, port = address
for signature, record in signature_record.items():
data[host][str(port)][signature] = dict(
request=decode_from_bytes(record.request),
response=hexdump(record.response),
)

json_data = json.dumps(data, indent=4, sort_keys=True)
self.file.parent.mkdir(exist_ok=True)
self.file.write_text(json_data)

def get_records(self, address: Address) -> list[MocketRecord]:
return list(self._records[address].values())

def get_record(self, address: Address, request: bytes) -> MocketRecord | None:
# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
return self._records[address].get(request_signature_fallback)

request_signature = _hash_request(request)
if request_signature in self._records[address]:
return self._records[address][request_signature]

return None

def put_record(
self,
address: Address,
request: bytes,
response: bytes,
) -> None:
host, port = address
record = MocketRecord(
host=host,
port=port,
request=request,
response=response,
)

# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
self._records[address][request_signature_fallback] = record
return

request_signature = _hash_request(request)
self._records[address][request_signature] = record
self._save()
Loading

0 comments on commit e529319

Please sign in to comment.