Skip to content

Commit

Permalink
Sync streaming interface on responses (#695)
Browse files Browse the repository at this point in the history
* Sync streaming interface on responses

* Fix test case

* Test coverage for sync response APIs

* Address review comments
  • Loading branch information
tomchristie authored Jan 2, 2020
1 parent b0bf2a7 commit 11e7604
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 81 deletions.
66 changes: 64 additions & 2 deletions httpx/content_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from .exceptions import StreamConsumed
from .utils import format_form_param

RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
RequestData = typing.Union[
dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]
]

RequestFiles = typing.Dict[
str,
Expand Down Expand Up @@ -47,6 +49,12 @@ def can_replay(self) -> bool:
"""
return True

def __iter__(self) -> typing.Iterator[bytes]:
yield b""

def close(self) -> None:
pass

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b""

Expand All @@ -68,10 +76,46 @@ def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
return {"Content-Length": content_length}

def __iter__(self) -> typing.Iterator[bytes]:
yield self.body

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body


class IteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an byte iterator.
"""

def __init__(
self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
) -> None:
self.iterator = iterator
self.close_func = close_func
self.is_stream_consumed = False

def can_replay(self) -> bool:
return False

def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}

def __iter__(self) -> typing.Iterator[bytes]:
if self.is_stream_consumed:
raise StreamConsumed()
self.is_stream_consumed = True
for part in self.iterator:
yield part

def __aiter__(self) -> typing.AsyncIterator[bytes]:
raise RuntimeError("Attempted to call a async iterator on an sync stream.")

def close(self) -> None:
if self.close_func is not None:
self.close_func()


class AsyncIteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an async byte iterator.
Expand All @@ -90,6 +134,9 @@ def can_replay(self) -> bool:
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}

def __iter__(self) -> typing.Iterator[bytes]:
raise RuntimeError("Attempted to call a sync iterator on an async stream.")

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
if self.is_stream_consumed:
raise StreamConsumed()
Expand All @@ -115,6 +162,9 @@ def get_headers(self) -> typing.Dict[str, str]:
content_type = "application/json"
return {"Content-Length": content_length, "Content-Type": content_type}

def __iter__(self) -> typing.Iterator[bytes]:
yield self.body

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body

Expand All @@ -132,6 +182,9 @@ def get_headers(self) -> typing.Dict[str, str]:
content_type = "application/x-www-form-urlencoded"
return {"Content-Length": content_length, "Content-Type": content_type}

def __iter__(self) -> typing.Iterator[bytes]:
yield self.body

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body

Expand Down Expand Up @@ -252,6 +305,9 @@ def get_headers(self) -> typing.Dict[str, str]:
content_type = self.content_type
return {"Content-Length": content_length, "Content-Type": content_type}

def __iter__(self) -> typing.Iterator[bytes]:
yield self.body

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body

Expand Down Expand Up @@ -280,5 +336,11 @@ def encode(
return URLEncodedStream(data=data)
elif isinstance(data, (str, bytes)):
return ByteStream(body=data)
else:
elif hasattr(data, "__aiter__"):
data = typing.cast(typing.AsyncIterator[bytes], data)
return AsyncIteratorStream(aiterator=data)
elif hasattr(data, "__iter__"):
data = typing.cast(typing.Iterator[bytes], data)
return IteratorStream(iterator=data)

raise TypeError(f"Unexpected type for 'data', {type(data)!r}")
131 changes: 95 additions & 36 deletions httpx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
import rfc3986

from .config import USER_AGENT
from .content_streams import ContentStream, RequestData, RequestFiles, encode
from .content_streams import (
ByteStream,
ContentStream,
RequestData,
RequestFiles,
encode,
)
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
Expand Down Expand Up @@ -665,15 +671,13 @@ def __init__(

self.history = [] if history is None else list(history)

if stream is None:
self.is_closed = True
self.is_stream_consumed = True
self._raw_content = content or b""
self._elapsed = request.timer.elapsed
else:
self.is_closed = False
self.is_stream_consumed = False
self.is_closed = False
self.is_stream_consumed = False
if stream is not None:
self._raw_stream = stream
else:
self._raw_stream = ByteStream(body=content or b"")
self.read()

@property
def elapsed(self) -> datetime.timedelta:
Expand Down Expand Up @@ -702,13 +706,7 @@ def url(self) -> typing.Optional[URL]:
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
if hasattr(self, "_raw_content"):
raw_content = self._raw_content # type: ignore
content = self.decoder.decode(raw_content)
content += self.decoder.flush()
self._content = content
else:
raise ResponseNotRead()
raise ResponseNotRead()
return self._content

@property
Expand Down Expand Up @@ -850,14 +848,6 @@ def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]:
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"

async def aread(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content

@property
def stream(self): # type: ignore
warnings.warn( # pragma: nocover
Expand All @@ -874,6 +864,78 @@ def raw(self): # type: ignore
)
return self.aiter_raw # pragma: nocover

def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part for part in self.iter_bytes()])
return self._content

def iter_bytes(self) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
else:
for chunk in self.iter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()

def iter_text(self) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.charset_encoding)
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()

def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
for text in self.iter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line

def iter_raw(self) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()

self.is_stream_consumed = True
for part in self._raw_stream:
yield part
self.close()

def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not self.is_closed:
self.is_closed = True
self._elapsed = self.request.timer.elapsed
if hasattr(self, "_raw_stream"):
self._raw_stream.close()

async def aread(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content

async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
Expand Down Expand Up @@ -909,18 +971,15 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if hasattr(self, "_raw_content"):
yield self._raw_content
else:
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()

self.is_stream_consumed = True
async for part in self._raw_stream:
yield part
await self.aclose()
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()

self.is_stream_consumed = True
async for part in self._raw_stream:
yield part
await self.aclose()

async def anext(self) -> "Response":
"""
Expand Down
Loading

0 comments on commit 11e7604

Please sign in to comment.