From 41da3a9fe3fd41575b3be7ba47e2535f8b8742f3 Mon Sep 17 00:00:00 2001 From: suecharo Date: Tue, 8 Oct 2024 18:28:04 +0900 Subject: [PATCH] Fix issue where CORS headers were not being returned properly --- sapporo/app.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/sapporo/app.py b/sapporo/app.py index 7f1b6d1..4fe9141 100644 --- a/sapporo/app.py +++ b/sapporo/app.py @@ -8,7 +8,10 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from starlette.datastructures import Headers, MutableHeaders from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send from sapporo.auth import get_auth_config from sapporo.config import (LOGGER, PKG_DIR, add_openapi_info, get_config, @@ -59,6 +62,52 @@ async def generic_exception_handler(_request: Request, _exc: Exception) -> JSONR ) +class CustomCORSMiddleware(CORSMiddleware): + """\ + CORSMiddleware that returns CORS headers even if the Origin header is not present + """ + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + method = scope["method"] + headers = Headers(scope=scope) + + if method == "OPTIONS" and "access-control-request-method" in headers: + response = self.preflight_response(request_headers=headers) + await response(scope, receive, send) + return + + await self.simple_response(scope, receive, send, request_headers=headers) + + async def send( + self, message: Message, send: Send, request_headers: Headers + ) -> None: + if message["type"] != "http.response.start": + await send(message) + return + + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + headers.update(self.simple_headers) + origin = request_headers.get("Origin", "*") + has_cookie = "cookie" in request_headers + + # If request includes any cookie headers, then we must respond + # with the specific origin instead of '*'. + if self.allow_all_origins and has_cookie: + self.allow_explicit_origin(headers, origin) + + # If we only allow specific origins, then we have to mirror back + # the Origin header in the response. + elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): + self.allow_explicit_origin(headers, origin) + + await send(message) + + def init_app_state() -> None: """ Perform validation, initialize the cache, and log the configuration contents. @@ -146,9 +195,8 @@ def create_app() -> FastAPI: ) app.add_middleware( - CORSMiddleware, + CustomCORSMiddleware, allow_origins=[app_config.allow_origin], - allow_credentials=True, allow_methods=["*"], allow_headers=["*"], )