Skip to content

Commit

Permalink
refactor: move enable- and disable-functions from mocket.mocket to mo…
Browse files Browse the repository at this point in the history
…cket.inject
  • Loading branch information
betaboon committed Nov 17, 2024
1 parent 2beb8f1 commit d6b3bb1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 127 deletions.
128 changes: 128 additions & 0 deletions mocket/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import os
import socket
import ssl

import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
urllib3_wrap_socket = None


try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3

pyopenssl_override = True
except ImportError:
pyopenssl_override = False

true_socket = socket.socket
true_create_connection = socket.create_connection
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_getaddrinfo = socket.getaddrinfo
true_socketpair = socket.socketpair
true_ssl_wrap_socket = getattr(
ssl, "wrap_socket", None
) # from Py3.12 it's only under SSLContext
true_ssl_socket = ssl.SSLSocket
true_ssl_context = ssl.SSLContext
true_inet_pton = socket.inet_pton
true_urllib3_wrap_socket = urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
true_urllib3_match_hostname = urllib3_match_hostname


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import MocketSocket, create_connection, socketpair
from mocket.ssl import FakeSSLContext

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError

socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = create_connection
socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1"
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
lambda host, port, family=None, socktype=None, proto=None, flags=None: [
(2, 1, 6, "", (host, port))
]
)
socket.socketpair = socket.__dict__["socketpair"] = socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
"\x7f\x00\x00\x01", "utf-8"
)
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
FakeSSLContext.wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = FakeSSLContext.wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
FakeSSLContext.wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = FakeSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = lambda *args: None
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()


def disable() -> None:
from mocket.mocket import Mocket

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
socket.create_connection = socket.__dict__["create_connection"] = (
true_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
if true_ssl_wrap_socket:
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
true_urllib3_wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
true_urllib3_ssl_wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
128 changes: 9 additions & 119 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,12 @@
import collections
import itertools
import os
import socket
import ssl
from typing import Optional, Tuple

import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
import mocket.inject

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
urllib3_wrap_socket = None


from mocket.socket import MocketSocket, create_connection, socketpair
from mocket.ssl import FakeSSLContext

try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3

pyopenssl_override = True
except ImportError:
pyopenssl_override = False

true_socket = socket.socket
true_create_connection = socket.create_connection
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_getaddrinfo = socket.getaddrinfo
true_socketpair = socket.socketpair
true_ssl_wrap_socket = getattr(
ssl, "wrap_socket", None
) # from Py3.12 it's only under SSLContext
true_ssl_socket = ssl.SSLSocket
true_ssl_context = ssl.SSLContext
true_inet_pton = socket.inet_pton
true_urllib3_wrap_socket = urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
true_urllib3_match_hostname = urllib3_match_hostname
# NOTE this is here for backwards-compat to keep old import-paths working
from mocket.socket import MocketSocket as MocketSocket


class Mocket:
Expand Down Expand Up @@ -111,94 +78,17 @@ def remove_last_request(cls):
def has_requests(cls):
return bool(cls.request_list())

@classmethod
def get_namespace(cls):
return cls._namespace

@staticmethod
def enable(namespace=None, truesocket_recording_dir=None):
Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError

socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = (
create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
socket.gethostbyname = socket.__dict__["gethostbyname"] = (
lambda host: "127.0.0.1"
)
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
lambda host, port, family=None, socktype=None, proto=None, flags=None: [
(2, 1, 6, "", (host, port))
]
)
socket.socketpair = socket.__dict__["socketpair"] = socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
"\x7f\x00\x00\x01", "utf-8"
)
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
FakeSSLContext.wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = FakeSSLContext.wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
FakeSSLContext.wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = FakeSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = lambda *args: None
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()
mocket.inject.enable(namespace, truesocket_recording_dir)

@staticmethod
def disable():
socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
socket.create_connection = socket.__dict__["create_connection"] = (
true_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
if true_ssl_wrap_socket:
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
true_urllib3_wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
true_urllib3_ssl_wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()

@classmethod
def get_namespace(cls):
return cls._namespace
mocket.inject.disable()

@classmethod
def get_truesocket_recording_dir(cls):
Expand Down
14 changes: 6 additions & 8 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from json.decoder import JSONDecodeError

from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.inject import (
true_gethostbyname,
true_socket,
true_urllib3_ssl_wrap_socket,
)
from mocket.io import MocketSocketCore
from mocket.mode import MocketMode
from mocket.utils import hexdump, hexload
Expand Down Expand Up @@ -63,8 +68,6 @@ class MocketSocket:
def __init__(
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
):
from mocket.mocket import true_socket

self.true_socket = true_socket(family, type, proto)
self._buflen = 65536
self._entry = None
Expand Down Expand Up @@ -225,12 +228,7 @@ def recv(self, buffersize, flags=None):
raise exc

def true_sendall(self, data, *args, **kwargs):
from mocket.mocket import (
Mocket,
true_gethostbyname,
true_socket,
true_urllib3_ssl_wrap_socket,
)
from mocket.mocket import Mocket

if not MocketMode().is_allowed((self._host, self._port)):
MocketMode.raise_not_allowed()
Expand Down

0 comments on commit d6b3bb1

Please sign in to comment.