Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring and Typing #261

Merged
merged 2 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .async_mocket import async_mocketize
from .mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize
from mocket.async_mocket import async_mocketize
from mocket.mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize

__all__ = (
"async_mocketize",
Expand Down
4 changes: 2 additions & 2 deletions mocket/async_mocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .mocket import Mocketizer
from .utils import get_mocketize
from mocket.mocket import Mocketizer
from mocket.utils import get_mocketize


async def wrapper(
Expand Down
12 changes: 4 additions & 8 deletions mocket/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,17 @@

ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8")

text_type = str
byte_type = bytes
basestring = (str,)


def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes:
if isinstance(s, text_type):
if isinstance(s, str):
s = s.encode(encoding)
return byte_type(s)
return bytes(s)


def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str:
if isinstance(s, byte_type):
if isinstance(s, bytes):
s = codecs.decode(s, encoding, "ignore")
return text_type(s)
return str(s)


def shsplit(s: str | bytes) -> list[str]:
Expand Down
22 changes: 10 additions & 12 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
urllib3_wrap_socket = None


from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
from .utils import (
from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.utils import (
MocketMode,
MocketSocketCore,
get_mocketize,
Expand Down Expand Up @@ -317,7 +317,7 @@ def true_sendall(self, data, *args, **kwargs):
# make request unique again
req_signature = _hash_request(hasher, req)
# port should be always a string
port = text_type(self._port)
port = str(self._port)

# prepare responses dictionary
responses = {}
Expand Down Expand Up @@ -427,7 +427,7 @@ class Mocket:
_address = (None, None)
_entries = collections.defaultdict(list)
_requests = []
_namespace = text_type(id(_entries))
_namespace = str(id(_entries))
_truesocket_recording_dir = None

@classmethod
Expand Down Expand Up @@ -518,7 +518,7 @@ def enable(namespace=None, truesocket_recording_dir=None):
socket.socketpair = socket.__dict__["socketpair"] = socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
"\x7f\x00\x00\x01", "utf-8"
)
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
Expand Down Expand Up @@ -592,13 +592,13 @@ def assert_fail_if_entries_not_served(cls):


class MocketEntry:
class Response(byte_type):
class Response(bytes):
@property
def data(self):
return self

response_index = 0
request_cls = byte_type
request_cls = bytes
response_cls = Response
responses = None
_served = None
Expand All @@ -607,9 +607,7 @@ def __init__(self, location, responses):
self._served = False
self.location = location

if not isinstance(responses, collections_abc.Iterable) or isinstance(
responses, basestring
):
if not isinstance(responses, collections_abc.Iterable):
responses = [responses]

if not responses:
Expand All @@ -618,7 +616,7 @@ def __init__(self, location, responses):
self.responses = []
for r in responses:
if not isinstance(r, BaseException) and not getattr(r, "data", False):
if isinstance(r, text_type):
if isinstance(r, str):
r = encode_to_bytes(r)
r = self.response_cls(r)
self.responses.append(r)
Expand Down Expand Up @@ -658,7 +656,7 @@ def __init__(
):
self.instance = instance
self.truesocket_recording_dir = truesocket_recording_dir
self.namespace = namespace or text_type(id(self))
self.namespace = namespace or str(id(self))
MocketMode().STRICT = strict_mode
if strict_mode:
MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
Expand Down
4 changes: 2 additions & 2 deletions mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from h11 import SERVER, Connection, Data
from h11 import Request as H11Request

from .compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes
from .mocket import Mocket, MocketEntry
from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes
from mocket.mocket import Mocket, MocketEntry

STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()}
CRLF = "\r\n"
Expand Down
18 changes: 11 additions & 7 deletions mocket/mockredis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from itertools import chain

from .compat import byte_type, decode_from_bytes, encode_to_bytes, shsplit, text_type
from .mocket import Mocket, MocketEntry
from mocket.compat import (
decode_from_bytes,
encode_to_bytes,
shsplit,
)
from mocket.mocket import Mocket, MocketEntry


class Request:
Expand All @@ -14,7 +18,7 @@ def __init__(self, data=None):
self.data = Redisizer.redisize(data or OK)


class Redisizer(byte_type):
class Redisizer(bytes):
@staticmethod
def tokens(iterable):
iterable = [encode_to_bytes(x) for x in iterable]
Expand All @@ -30,15 +34,15 @@ def get_conversion(t):
Redisizer.tokens(list(chain(*tuple(x.items()))))
),
int: lambda x: f":{x}".encode(),
text_type: lambda x: "${}\r\n{}".format(
len(x.encode("utf-8")), x
).encode("utf-8"),
str: lambda x: "${}\r\n{}".format(len(x.encode("utf-8")), x).encode(
"utf-8"
),
list: lambda x: b"\r\n".join(Redisizer.tokens(x)),
}[t]

if isinstance(data, Redisizer):
return data
if isinstance(data, byte_type):
if isinstance(data, bytes):
data = decode_from_bytes(data)
return Redisizer(get_conversion(data.__class__)(data) + b"\r\n")

Expand Down
6 changes: 3 additions & 3 deletions mocket/plugins/httpretty/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mocket import Mocket, mocketize
from mocket.async_mocket import async_mocketize
from mocket.compat import ENCODING, byte_type, text_type
from mocket.compat import ENCODING
from mocket.mockhttp import Entry as MocketHttpEntry
from mocket.mockhttp import Request as MocketHttpRequest
from mocket.mockhttp import Response as MocketHttpResponse
Expand Down Expand Up @@ -129,6 +129,6 @@ def __getattr__(self, name):
"HEAD",
"PATCH",
"register_uri",
"text_type",
"byte_type",
"str",
"bytes",
)
6 changes: 3 additions & 3 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import ssl
from typing import TYPE_CHECKING, Any, Callable, ClassVar

from .compat import decode_from_bytes, encode_to_bytes
from .exceptions import StrictMocketException
from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.exceptions import StrictMocketException

if TYPE_CHECKING: # pragma: no cover
from typing import NoReturn
Expand Down Expand Up @@ -83,7 +83,7 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool:

@staticmethod
def raise_not_allowed() -> NoReturn:
from .mocket import Mocket
from mocket.mocket import Mocket

current_entries = [
(location, "\n ".join(map(str, entries)))
Expand Down