Skip to content

Commit

Permalink
Move BackgroundTask execution outside of request/response cycle
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jul 16, 2024
1 parent 6f863b0 commit 8eab239
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 36 deletions.
2 changes: 2 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -96,6 +97,7 @@ def build_middleware_stack(self) -> ASGIApp:

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ [Middleware(BackgroundTaskMiddleware)]
+ self.user_middleware
+ [
Middleware(
Expand Down
37 changes: 37 additions & 0 deletions starlette/middleware/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, cast

from starlette.background import BackgroundTask
from starlette.types import ASGIApp, Receive, Scope, Send

# consider this a private implementation detail subject to change
# do not rely on this key
_SCOPE_KEY = "starlette._background"


_BackgroundTaskList = List[BackgroundTask]


def is_background_task_middleware_installed(scope: Scope) -> bool:
return _SCOPE_KEY in scope


def add_tasks(scope: Scope, task: BackgroundTask, /) -> None:
if _SCOPE_KEY not in scope: # pragma: no cover
raise RuntimeError(
"`add_tasks` can only be used if `BackgroundTaskMIddleware is installed"
)
cast(_BackgroundTaskList, scope[_SCOPE_KEY]).append(task)


class BackgroundTaskMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
tasks: _BackgroundTaskList
scope[_SCOPE_KEY] = tasks = []
try:
await self._app(scope, receive, send)
finally:
for task in tasks:
await task()
19 changes: 19 additions & 0 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import anyio.to_thread

from starlette._compat import md5_hexdigest
from starlette.middleware import background
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
Expand Down Expand Up @@ -148,6 +149,12 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
Expand Down Expand Up @@ -255,6 +262,12 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
Expand Down Expand Up @@ -322,6 +335,12 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
self.headers.setdefault("etag", etag)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
Expand Down
142 changes: 141 additions & 1 deletion tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from anyio.abc import TaskStatus

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.background import BackgroundTask, BackgroundTasks
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -1035,3 +1036,142 @@ async def endpoint(request: Request) -> Response:
resp.raise_for_status()

assert bodies == [b"Hello, World!-foo"]

@pytest.mark.anyio
async def test_background_tasks_client_disconnect() -> None:
# test for https://github.com/encode/starlette/issues/1438
container: list[str] = []

disconnected = anyio.Event()

async def slow_background() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
container.append("called")

app: ASGIApp
app = PlainTextResponse("hi!", background=BackgroundTask(slow_background))

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(app, dispatch=dispatch)

app = BackgroundTaskMiddleware(app)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
await disconnected.wait()
while True:
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
while True:
msg = yield
if msg["type"] == "http.response.body" and not msg.get("more_body", False):
disconnected.set()

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)

assert container == ["called"]

@pytest.mark.anyio
async def test_background_tasks_client_disconnect() -> None:
# test for https://github.com/encode/starlette/issues/1438
container: list[str] = []

disconnected = anyio.Event()

async def slow_background() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
container.append("called")

app: ASGIApp
app = PlainTextResponse("hi!", background=BackgroundTask(slow_background))

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(app, dispatch=dispatch)

app = BackgroundTaskMiddleware(app)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
await disconnected.wait()
while True:
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
while True:
msg = yield
if msg["type"] == "http.response.body" and not msg.get("more_body", False):
disconnected.set()

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)

assert container == ["called"]


@pytest.mark.anyio
async def test_background_tasks_failure(
test_client_factory: TestClientFactory,
) -> None:
# test for https://github.com/encode/starlette/discussions/2640
container: list[str] = []

def task1() -> None:
container.append("task1 called")
raise ValueError("task1 failed")

def task2() -> None:
container.append("task2 called")

async def endpoint(request: Request) -> Response:
background = BackgroundTasks()
background.add_task(task1)
background.add_task(task2)
return PlainTextResponse("hi!", background=background)

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = Starlette(
routes=[Route("/", endpoint)],
middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)],
)

client = test_client_factory(app)

response = client.get("/")
assert response.status_code == 200
assert response.text == "hi!"

assert container == ["task1 called"]
Loading

0 comments on commit 8eab239

Please sign in to comment.