diff --git a/mocket/entry.py b/mocket/entry.py index 8fa28bc7..9dbbf442 100644 --- a/mocket/entry.py +++ b/mocket/entry.py @@ -1,6 +1,7 @@ import collections.abc from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket class MocketEntry: @@ -41,8 +42,6 @@ def can_handle(data): return True def collect(self, data): - from mocket import Mocket - req = self.request_cls(data) Mocket.collect(req) diff --git a/mocket/inject.py b/mocket/inject.py new file mode 100644 index 00000000..cba0b40b --- /dev/null +++ b/mocket/inject.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import os +import socket +import ssl + +import urllib3 +from urllib3.connection import match_hostname as urllib3_match_hostname +from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket + +try: + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket +except ImportError: + urllib3_wrap_socket = None + + +try: # pragma: no cover + from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 + + pyopenssl_override = True +except ImportError: + pyopenssl_override = False + +true_socket = socket.socket +true_create_connection = socket.create_connection +true_gethostbyname = socket.gethostbyname +true_gethostname = socket.gethostname +true_getaddrinfo = socket.getaddrinfo +true_socketpair = socket.socketpair +true_ssl_wrap_socket = getattr( + ssl, "wrap_socket", None +) # from Py3.12 it's only under SSLContext +true_ssl_socket = ssl.SSLSocket +true_ssl_context = ssl.SSLContext +true_inet_pton = socket.inet_pton +true_urllib3_wrap_socket = urllib3_wrap_socket +true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket +true_urllib3_match_hostname = urllib3_match_hostname + + +def enable( + namespace: str | None = None, + truesocket_recording_dir: str | None = None, +) -> None: + from mocket.mocket import Mocket + from mocket.socket import MocketSocket, create_connection, socketpair + from mocket.ssl import FakeSSLContext + + Mocket._namespace = namespace + Mocket._truesocket_recording_dir = truesocket_recording_dir + + if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): + # JSON dumps will be saved here + raise AssertionError + + socket.socket = socket.__dict__["socket"] = MocketSocket + socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket + socket.SocketType = socket.__dict__["SocketType"] = MocketSocket + socket.create_connection = socket.__dict__["create_connection"] = create_connection + socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" + socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1" + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( + lambda host, port, family=None, socktype=None, proto=None, flags=None: [ + (2, 1, 6, "", (host, port)) + ] + ) + socket.socketpair = socket.__dict__["socketpair"] = socketpair + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket + ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext + socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( + "\x7f\x00\x00\x01", "utf-8" + ) + urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( + FakeSSLContext.wrap_socket + ) + urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ + "ssl_wrap_socket" + ] = FakeSSLContext.wrap_socket + urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( + FakeSSLContext.wrap_socket + ) + urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ + "ssl_wrap_socket" + ] = FakeSSLContext.wrap_socket + urllib3.connection.match_hostname = urllib3.connection.__dict__[ + "match_hostname" + ] = lambda *args: None + if pyopenssl_override: # pragma: no cover + # Take out the pyopenssl version - use the default implementation + extract_from_urllib3() + + +def disable() -> None: + from mocket.mocket import Mocket + + socket.socket = socket.__dict__["socket"] = true_socket + socket._socketobject = socket.__dict__["_socketobject"] = true_socket + socket.SocketType = socket.__dict__["SocketType"] = true_socket + socket.create_connection = socket.__dict__["create_connection"] = ( + true_create_connection + ) + socket.gethostname = socket.__dict__["gethostname"] = true_gethostname + socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo + socket.socketpair = socket.__dict__["socketpair"] = true_socketpair + if true_ssl_wrap_socket: + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket + ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context + socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton + urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( + true_urllib3_wrap_socket + ) + urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ + "ssl_wrap_socket" + ] = true_urllib3_ssl_wrap_socket + urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( + true_urllib3_ssl_wrap_socket + ) + urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ + "ssl_wrap_socket" + ] = true_urllib3_ssl_wrap_socket + urllib3.connection.match_hostname = urllib3.connection.__dict__[ + "match_hostname" + ] = true_urllib3_match_hostname + Mocket.reset() + if pyopenssl_override: # pragma: no cover + # Put the pyopenssl version back in place + inject_into_urllib3() diff --git a/mocket/io.py b/mocket/io.py index 45bb8272..648b16dd 100644 --- a/mocket/io.py +++ b/mocket/io.py @@ -1,6 +1,8 @@ import io import os +from mocket.mocket import Mocket + class MocketSocketCore(io.BytesIO): def __init__(self, address) -> None: @@ -8,8 +10,6 @@ def __init__(self, address) -> None: super().__init__() def write(self, content): - from mocket import Mocket - super().write(content) _, w_fd = Mocket.get_pair(self._address) diff --git a/mocket/mocket.py b/mocket/mocket.py index 6bb0e566..3476902d 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,57 +1,33 @@ +from __future__ import annotations + import collections import itertools import os -import socket -import ssl -from typing import Optional, Tuple - -import urllib3 -from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket - -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - - -from mocket.socket import MocketSocket, create_connection, socketpair -from mocket.ssl import FakeSSLContext - -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 - - pyopenssl_override = True -except ImportError: - pyopenssl_override = False - -true_socket = socket.socket -true_create_connection = socket.create_connection -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_getaddrinfo = socket.getaddrinfo -true_socketpair = socket.socketpair -true_ssl_wrap_socket = getattr( - ssl, "wrap_socket", None -) # from Py3.12 it's only under SSLContext -true_ssl_socket = ssl.SSLSocket -true_ssl_context = ssl.SSLContext -true_inet_pton = socket.inet_pton -true_urllib3_wrap_socket = urllib3_wrap_socket -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_match_hostname = urllib3_match_hostname +from typing import TYPE_CHECKING, ClassVar + +import mocket.inject + +# NOTE this is here for backwards-compat to keep old import-paths working +# from mocket.socket import MocketSocket as MocketSocket + +if TYPE_CHECKING: + from mocket.entry import MocketEntry + from mocket.types import Address class Mocket: - _socket_pairs = {} - _address = (None, None) - _entries = collections.defaultdict(list) - _requests = [] - _namespace = str(id(_entries)) - _truesocket_recording_dir = None + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} + _address: ClassVar[Address] = (None, None) + _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) + _requests: ClassVar[list] = [] + _namespace: ClassVar[str] = str(id(_entries)) + _truesocket_recording_dir: ClassVar[str | None] = None + + enable = mocket.inject.enable + disable = mocket.inject.disable @classmethod - def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: """ Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) @@ -59,7 +35,7 @@ def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: return cls._socket_pairs.get(address, (None, None)) @classmethod - def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: """ Store a pair of file descriptors under the key `id_` as a tuple of two integers: (, ) @@ -67,25 +43,26 @@ def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: cls._socket_pairs[address] = pair @classmethod - def register(cls, *entries): + def register(cls, *entries: MocketEntry) -> None: for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host, port, data): - host = host or Mocket._address[0] - port = port or Mocket._address[1] + def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: + host = host or cls._address[0] + port = port or cls._address[1] entries = cls._entries.get((host, port), []) for entry in entries: if entry.can_handle(data): return entry + return None @classmethod - def collect(cls, data): - cls.request_list().append(data) + def collect(cls, data) -> None: + cls._requests.append(data) @classmethod - def reset(cls): + def reset(cls) -> None: for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) @@ -96,116 +73,31 @@ def reset(cls): @classmethod def last_request(cls): if cls.has_requests(): - return cls.request_list()[-1] + return cls._requests[-1] @classmethod def request_list(cls): return cls._requests @classmethod - def remove_last_request(cls): + def remove_last_request(cls) -> None: if cls.has_requests(): del cls._requests[-1] @classmethod - def has_requests(cls): + def has_requests(cls) -> bool: return bool(cls.request_list()) - @staticmethod - def enable(namespace=None, truesocket_recording_dir=None): - Mocket._namespace = namespace - Mocket._truesocket_recording_dir = truesocket_recording_dir - - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): - # JSON dumps will be saved here - raise AssertionError - - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" - socket.gethostbyname = socket.__dict__["gethostbyname"] = ( - lambda host: "127.0.0.1" - ) - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( - lambda host, port, family=None, socktype=None, proto=None, flags=None: [ - (2, 1, 6, "", (host, port)) - ] - ) - socket.socketpair = socket.__dict__["socketpair"] = socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( - "\x7f\x00\x00\x01", "utf-8" - ) - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = lambda *args: None - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - - @staticmethod - def disable(): - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname - Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() - @classmethod - def get_namespace(cls): + def get_namespace(cls) -> str: return cls._namespace @classmethod - def get_truesocket_recording_dir(cls): + def get_truesocket_recording_dir(cls) -> str | None: return cls._truesocket_recording_dir @classmethod - def assert_fail_if_entries_not_served(cls): + def assert_fail_if_entries_not_served(cls) -> None: """Mocket checks that all entries have been served at least once.""" if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/mocketizer.py b/mocket/mocketizer.py index 5a988c77..2bf2b9cd 100644 --- a/mocket/mocketizer.py +++ b/mocket/mocketizer.py @@ -1,3 +1,4 @@ +from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import get_mocketize @@ -23,8 +24,6 @@ def __init__( ) def enter(self): - from mocket import Mocket - Mocket.enable( namespace=self.namespace, truesocket_recording_dir=self.truesocket_recording_dir, @@ -39,7 +38,6 @@ def __enter__(self): def exit(self): if self.instance: self.check_and_call("mocketize_teardown") - from mocket import Mocket Mocket.disable() diff --git a/mocket/mode.py b/mocket/mode.py index 3c0638e5..e1da7955 100644 --- a/mocket/mode.py +++ b/mocket/mode.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from mocket.exceptions import StrictMocketException +from mocket.mocket import Mocket if TYPE_CHECKING: # pragma: no cover from typing import NoReturn @@ -31,8 +32,6 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool: @staticmethod def raise_not_allowed() -> NoReturn: - from mocket.mocket import Mocket - current_entries = [ (location, "\n ".join(map(str, entries))) for location, entries in Mocket._entries.items() diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index d5e41e30..fac61840 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -1,6 +1,7 @@ -from mocket import Mocket, mocketize +from mocket import mocketize from mocket.async_mocket import async_mocketize from mocket.compat import ENCODING +from mocket.mocket import Mocket from mocket.mockhttp import Entry as MocketHttpEntry from mocket.mockhttp import Request as MocketHttpRequest from mocket.mockhttp import Response as MocketHttpResponse diff --git a/mocket/socket.py b/mocket/socket.py index 3a971af5..e4be00b6 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -10,7 +10,13 @@ from json.decoder import JSONDecodeError from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.inject import ( + true_gethostbyname, + true_socket, + true_urllib3_ssl_wrap_socket, +) from mocket.io import MocketSocketCore +from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import hexdump, hexload @@ -63,8 +69,6 @@ class MocketSocket: def __init__( self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs ): - from mocket.mocket import true_socket - self.true_socket = true_socket(family, type, proto) self._buflen = 65536 self._entry = None @@ -90,8 +94,6 @@ def io(self): return self._io def fileno(self): - from mocket.mocket import Mocket - address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -133,8 +135,6 @@ def getsockname(self): return socket.gethostbyname(self._address[0]), self._address[1] def getpeercert(self, *args, **kwargs): - from mocket.mocket import Mocket - if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -161,8 +161,6 @@ def write(self, data): return self.send(encode_to_bytes(data)) def connect(self, address): - from mocket.mocket import Mocket - self._address = self._host, self._port = address Mocket._address = address @@ -172,8 +170,6 @@ def makefile(self, mode="r", bufsize=-1): return self.io def get_entry(self, data): - from mocket.mocket import Mocket - return Mocket.get_entry(self._host, self._port, data) def sendall(self, data, entry=None, *args, **kwargs): @@ -210,8 +206,6 @@ def recv_into(self, buffer, buffersize=None, flags=None): return len(data) def recv(self, buffersize, flags=None): - from mocket.mocket import Mocket - r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -225,13 +219,6 @@ def recv(self, buffersize, flags=None): raise exc def true_sendall(self, data, *args, **kwargs): - from mocket.mocket import ( - Mocket, - true_gethostbyname, - true_socket, - true_urllib3_ssl_wrap_socket, - ) - if not MocketMode().is_allowed((self._host, self._port)): MocketMode.raise_not_allowed() @@ -246,7 +233,8 @@ def true_sendall(self, data, *args, **kwargs): if Mocket.get_truesocket_recording_dir(): path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" + Mocket.get_truesocket_recording_dir(), + Mocket.get_namespace() + ".json", ) # check if there's already a recorded session dumped to a JSON file try: @@ -319,8 +307,6 @@ def true_sendall(self, data, *args, **kwargs): return encoded_response def send(self, data, *args, **kwargs): # pragma: no cover - from mocket.mocket import Mocket - entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry diff --git a/mocket/types.py b/mocket/types.py new file mode 100644 index 00000000..61b7a4d5 --- /dev/null +++ b/mocket/types.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from typing import Tuple + +Address = Tuple[str, int] diff --git a/tests/test_socket.py b/tests/test_socket.py index 8a6e65ad..112a9089 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -2,7 +2,7 @@ import pytest -from mocket.mocket import MocketSocket +from mocket.socket import MocketSocket @pytest.mark.parametrize("blocking", (False, True))