From fd664f08a3e098876b01d80bbd2a032bbacf5d53 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 4 Feb 2025 17:52:31 +0000 Subject: [PATCH] Invert the creation of API routes - Such that decode_access_token can be overriden when serving behind proxied OIDC - Removes injection of password into security obj - Removes use of dependency_override which is intended for use in tests --- tiled/client/context.py | 3 +- tiled/server/app.py | 197 +- tiled/server/authentication.py | 575 +++--- tiled/server/dependencies.py | 259 ++- tiled/server/metrics.py | 4 +- tiled/server/router.py | 3071 ++++++++++++++++---------------- tiled/server/utils.py | 105 +- 7 files changed, 2068 insertions(+), 2146 deletions(-) diff --git a/tiled/client/context.py b/tiled/client/context.py index 502e2c295..42a4437c4 100644 --- a/tiled/client/context.py +++ b/tiled/client/context.py @@ -449,8 +449,7 @@ def from_app( # Extract the API key from the app and set it. from ..server.settings import get_settings - settings = app.dependency_overrides[get_settings]() - api_key = settings.single_user_api_key or None + api_key = get_settings().single_user_api_key or None else: # This is a multi-user server but no API key was passed, # so we will leave it as None on the Context. diff --git a/tiled/server/app.py b/tiled/server/app.py index 851c0eae4..7f1462271 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -8,9 +8,9 @@ import urllib.parse import warnings from contextlib import asynccontextmanager -from functools import cache, partial +from functools import partial from pathlib import Path -from typing import List +from typing import Any import anyio import packaging.version @@ -34,30 +34,32 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) +from tiled.server.authentication import ( + get_current_principal_from_api_key, + session_state_getter, +) from tiled.server.protocols import Authenticator from ..config import construct_build_app_kwargs +from ..media_type_registration import CompressionRegistry, SerializationRegistry from ..media_type_registration import ( compression_registry as default_compression_registry, ) +from ..media_type_registration import ( + deserialization_registry as default_deserialization_registry, +) +from ..query_registration import QueryRegistry +from ..query_registration import query_registry as default_query_registry from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType from ..validation_registration import validation_registry as default_validation_registry -from . import schemas -from .authentication import get_current_principal from .compression import CompressionMiddleware -from .dependencies import ( - get_query_registry, - get_root_tree, - get_serialization_registry, - get_validation_registry, -) -from .router import distinct, patch_route_signature, router, search +from .router import get_router from .settings import get_settings from .utils import ( API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, - get_authenticators, get_root_url, + move_api_key, record_timing, ) @@ -113,9 +115,10 @@ def build_app( tree, authentication=None, server_settings=None, - query_registry=None, - serialization_registry=None, - compression_registry=None, + query_registry: QueryRegistry | None = None, + serialization_registry: SerializationRegistry | None = None, + compression_registry: CompressionRegistry | None = None, + deserialization_registry: SerializationRegistry | None = None, validation_registry=None, tasks=None, scalable=False, @@ -138,10 +141,13 @@ def build_app( spec["provider"]: spec["authenticator"] for spec in authentication.get("providers", []) } - server_settings = server_settings or {} - query_registry = query_registry or get_query_registry() + server_settings = server_settings or get_settings() + query_registry = query_registry or default_query_registry compression_registry = compression_registry or default_compression_registry validation_registry = validation_registry or default_validation_registry + deserialization_registry = ( + deserialization_registry or default_deserialization_registry + ) tasks = tasks or {} tasks.setdefault("startup", []) tasks.setdefault("background", []) @@ -265,9 +271,7 @@ async def lookup_file(path, try_app=True): @app.get("/", response_class=HTMLResponse) async def index( request: Request, - # This dependency is here because it runs the code that moves - # API key from the query parameter to a cookie (if it is valid). - principal=Security(get_current_principal, scopes=[]), + _: str | None = Security(move_api_key), ): return templates.TemplateResponse( request, @@ -348,99 +352,54 @@ async def unhandled_exception_handler( ), ) - app.include_router(router, prefix="/api/v1") - # The Tree and Authenticator have the opportunity to add custom routes to # the server here. (Just for example, a Tree of BlueskyRuns uses this - # hook to add a /documents route.) This has to be done before dependency_overrides - # are processed, so we cannot just inject this configuration via Depends. + # hook to add a /documents route.) for custom_router in getattr(tree, "include_routers", []): app.include_router(custom_router, prefix="/api/v1") if authenticators: # Delay this imports to avoid delaying startup with the SQL and cryptography # imports if they are not needed. - from .authentication import build_authentication_router + from .authentication import ( + build_authentication_router, + current_principal_getter, + ) # For the OpenAPI schema, inject a OAuth2PasswordBearer URL. first_provider = authentication["providers"][0]["provider"] authentication_router = build_authentication_router( - authenticators, first_provider + authenticators, first_provider, server_settings ) # And add this authentication_router itself to the app. app.include_router(authentication_router, prefix="/api/v1/auth") - - # The /search route is defined after import time so that the user has the - # opporunity to register custom query types before startup. - app.get( - "/api/v1/search/{path:path}", - response_model=schemas.Response[ - List[schemas.Resource[schemas.NodeAttributes, dict, dict]], - schemas.PaginationLinks, - dict, - ], - )(patch_route_signature(search, query_registry)) - app.get( - "/api/v1/distinct/{path:path}", - response_model=schemas.GetDistinctResponse, - )(patch_route_signature(distinct, query_registry)) - - @cache - def override_get_authenticators(): - return authenticators - - @cache - def override_get_root_tree(): - return tree - - @cache - def override_get_settings(): - settings = get_settings() - for item in [ - "allow_anonymous_access", - "secret_keys", - "single_user_api_key", - "access_token_max_age", - "refresh_token_max_age", - "session_max_age", - ]: - if authentication.get(item) is not None: - setattr(settings, item, authentication[item]) - if authentication.get("single_user_api_key") is not None: - settings.single_user_api_key_generated = False - for item in [ - "allow_origins", - "response_bytesize_limit", - "reject_undeclared_specs", - "expose_raw_assets", - ]: - if server_settings.get(item) is not None: - setattr(settings, item, server_settings[item]) - database = server_settings.get("database", {}) - if database.get("uri"): - settings.database_uri = database["uri"] - if database.get("pool_size"): - settings.database_pool_size = database["pool_size"] - if database.get("pool_pre_ping"): - settings.database_pool_pre_ping = database["pool_pre_ping"] - if database.get("max_overflow"): - settings.database_max_overflow = database["max_overflow"] - if database.get("init_if_not_exists"): - settings.database_init_if_not_exists = database["init_if_not_exists"] - if authentication.get("providers"): - # If we support authentication providers, we need a database, so if one is - # not set, use a SQLite database in memory. Horizontally scaled deployments - # must specify a persistent database. - settings.database_uri = settings.database_uri or "sqlite://" - return settings + principal_getter = current_principal_getter(authenticators, server_settings) + + else: + principal_getter = get_current_principal_from_api_key() + + get_session_state = session_state_getter(authenticators, server_settings) + + app.include_router( + get_router( + query_registry, + authenticators, + principal_getter, + server_settings.tree, + get_session_state, + serialization_registry, + deserialization_registry, + validation_registry, + ), + prefix="/api/v1", + ) async def startup_event(): from .. import __version__ logger.info(f"Tiled version {__version__}") # Validate the single-user API key. - settings = app.dependency_overrides[get_settings]() - single_user_api_key = settings.single_user_api_key + single_user_api_key = server_settings.single_user_api_key API_KEY_MSG = """ Here are two ways to generate a good API key: @@ -485,13 +444,12 @@ async def startup_event(): asyncio_task = asyncio.create_task(task()) app.state.tasks.append(asyncio_task) - app.state.allow_origins.extend(settings.allow_origins) + app.state.allow_origins.extend(server_settings.allow_origins) # Expose the root_tree here to make it easier to access it from tests, # in usages like: # client.context.app.state.root_tree - app.state.root_tree = app.dependency_overrides[get_root_tree]() - if settings.database_uri is not None: + if server_settings.database_uri is not None: from sqlalchemy.ext.asyncio import AsyncSession from ..alembic_utils import ( @@ -512,7 +470,7 @@ async def startup_event(): # This creates a connection pool and stashes it in a module-global # registry, keyed on database_settings, where can be retrieved by # the Dependency get_database_session. - engine = open_database_connection_pool(settings.database_settings) + engine = open_database_connection_pool(server_settings.database_settings) if not engine.url.database: # Special-case for in-memory SQLite: Because it is transient we can # skip over anything related to migrations. @@ -523,7 +481,7 @@ async def startup_event(): try: await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS) except UninitializedDatabase: - if settings.database_init_if_not_exists: + if server_settings.database_init_if_not_exists: # The alembic stamping can only be does synchronously. # The cleanest option available is to start a subprocess # because SQLite is allergic to threads. @@ -616,13 +574,12 @@ async def shutdown_event(): for task in tasks.get("shutdown", []): await task() - settings = app.dependency_overrides[get_settings]() - if settings.database_uri is not None: + if server_settings.database_uri is not None: from ..authn_database.connection_pool import close_database_connection_pool for task in app.state.tasks: task.cancel() - await close_database_connection_pool(settings.database_settings) + await close_database_connection_pool(server_settings.database_settings) app.add_middleware( CompressionMiddleware, @@ -714,35 +671,6 @@ async def set_cookies(request: Request, call_next): return response app.openapi = partial(custom_openapi, app) - app.dependency_overrides[get_authenticators] = override_get_authenticators - app.dependency_overrides[get_root_tree] = override_get_root_tree - app.dependency_overrides[get_settings] = override_get_settings - if query_registry is not None: - - @cache - def override_get_query_registry(): - return query_registry - - app.dependency_overrides[get_query_registry] = override_get_query_registry - if serialization_registry is not None: - - @cache - def override_get_serialization_registry(): - return serialization_registry - - app.dependency_overrides[ - get_serialization_registry - ] = override_get_serialization_registry - - if validation_registry is not None: - - @cache - def override_get_validation_registry(): - return validation_registry - - app.dependency_overrides[ - get_validation_registry - ] = override_get_validation_registry @app.middleware("http") async def capture_metrics(request: Request, call_next): @@ -883,15 +811,16 @@ def __getattr__(name): def print_admin_api_key_if_generated( - web_app: FastAPI, host: str, port: int, force: bool = False + web_app: FastAPI, + host: str, + port: int, + authenticators: dict[str, Any] | None, + force: bool = False, ): "Print message to stderr with API key if server-generated (or force=True)." host = host or "127.0.0.1" port = port or 8000 - settings = web_app.dependency_overrides.get(get_settings, get_settings)() - authenticators = web_app.dependency_overrides.get( - get_authenticators, get_authenticators - )() + settings = get_settings() if settings.allow_anonymous_access: print( """ diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 0d014368b..db5c6081a 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -3,8 +3,8 @@ import uuid as uuid_module import warnings from collections.abc import Callable -from datetime import datetime, timedelta, timezone -from functools import partial +from datetime import timedelta, timezone +from functools import cache from pathlib import Path from typing import Any, Optional @@ -19,7 +19,6 @@ Response, Security, ) -from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security import ( OAuth2, OAuth2AuthorizationCodeBearer, @@ -27,8 +26,6 @@ OAuth2PasswordRequestForm, SecurityScopes, ) -from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery -from fastapi.security.utils import get_authorization_scheme_param from fastapi.templating import Jinja2Templates from sqlalchemy.future import select from sqlalchemy.orm import selectinload @@ -73,7 +70,14 @@ UserSessionState, ) from .settings import Settings, get_settings -from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url +from .utils import ( + API_KEY_COOKIE_NAME, + get_api_key, + get_base_url, + headers_for_401, + move_api_key, + utcnow, +) ALGORITHM = "HS256" UNIT_SECOND = timedelta(seconds=1) @@ -90,11 +94,6 @@ DEVICE_CODE_POLLING_INTERVAL = 5 # seconds -def utcnow(): - "UTC now with second resolution" - return datetime.now(timezone.utc).replace(microsecond=0) - - class Token(BaseModel): access_token: str token_type: str @@ -104,52 +103,6 @@ class TokenData(BaseModel): username: Optional[str] = None -class APIKeyAuthorizationHeader(APIKeyBase): - """ - Expect a header like - - Authorization: Apikey SECRET - - where Apikey is case-insensitive. - """ - - def __init__( - self, - *, - name: str, - scheme_name: Optional[str] = None, - description: Optional[str] = None, - ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.header}, name=name, description=description - ) - self.scheme_name = scheme_name or self.__class__.__name__ - - async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") - scheme, param = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() == "bearer": - return None - if scheme.lower() != "apikey": - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - "Authorization header must include the authorization type " - "followed by a space and then the secret, as in " - "'Bearer SECRET' or 'Apikey SECRET'. " - ), - ) - return param - - -api_key_query = APIKeyQuery(name="api_key", auto_error=False) -api_key_header = APIKeyAuthorizationHeader( - name="Authorization", - description="Prefix value with 'Apikey ' as in, 'Apikey SECRET'", -) -api_key_cookie = APIKeyCookie(name=API_KEY_COOKIE_NAME, auto_error=False) - - def create_access_token(data, secret_key, expires_delta): to_encode = data.copy() expire = utcnow() + expires_delta @@ -169,53 +122,43 @@ def create_refresh_token(session_id, secret_key, expires_delta): return encoded_jwt -def decode_token(token: str, secret_keys: list[str]): - credentials_exception = HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - # The first key in settings.secret_keys is used for *encoding*. - # All keys are tried for *decoding* until one works or they all - # fail. They supports key rotation. - for secret_key in secret_keys: - try: - payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) - break - except ExpiredSignatureError: - # Do not let this be caught below with the other JWTError types. - raise - except JWTError: - # Try the next key in the key rotation. - continue - else: - raise credentials_exception - return payload - - -async def get_api_key( - api_key_query: str = Security(api_key_query), - api_key_header: str = Security(api_key_header), - api_key_cookie: str = Security(api_key_cookie), +def decode_token_for_authenticators( + authenticators: dict[str, Any] | None, settings: Settings ): - for api_key in [api_key_query, api_key_header, api_key_cookie]: - if api_key is not None: - return api_key - return None + if ( + authenticators is not None + and len(authenticators) == 1 + and isinstance( + auth := authenticators.get(next(iter(authenticators))), + ProxiedOIDCAuthenticator, + ) + ): + return auth.decode_access_token + def decode_token(token: str): + credentials_exception = HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + # The first key in settings.secret_keys is used for *encoding*. + # All keys are tried for *decoding* until one works or they all + # fail. They supports key rotation. + for secret_key in settings.secret_keys: + try: + payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) + break + except ExpiredSignatureError: + # Do not let this be caught below with the other JWTError types. + raise + except JWTError: + # Try the next key in the key rotation. + continue + else: + raise credentials_exception + return payload -def headers_for_401(request: Request, security_scopes: SecurityScopes): - # call directly from methods, rather than as a dependency, to avoid calling - # when not needed. - if security_scopes.scopes: - authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' - else: - authenticate_value = "Bearer" - headers_for_401 = { - "WWW-Authenticate": authenticate_value, - "X-Tiled-Root": get_base_url(request), - } - return headers_for_401 + return decode_token async def create_pending_session(db): @@ -243,6 +186,241 @@ async def create_pending_session(db): } +async def get_current_principal_from_api_key( + request: Request, + security_scopes: SecurityScopes, + api_key: str | None = Depends(get_api_key), + settings: Settings = Depends(get_settings), + db=Depends(get_database_session), +): + """ + Get current Principal from: + - API key in 'api_key' query parameter + - API key in header 'Authorization: Apikey ...' + - API key in cookie 'tiled_api_key' + - OAuth2 JWT access token in header 'Authorization: Bearer ...' + + Fall back to SpecialUsers.public, if anonymous access is allowed + If this server is configured with a "single-user API key", then + the Principal will be SpecialUsers.admin always. + """ + + if api_key is not None: + # Tiled is in a "single user" mode with only one API key. + if secrets.compare_digest(api_key, settings.single_user_api_key): + principal = SpecialUsers.admin + scopes = { + "read:metadata", + "read:data", + "write:metadata", + "write:data", + "create", + "register", + "metrics", + } + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers=headers_for_401(request, security_scopes), + ) + else: + # No form of authentication is present. + principal = SpecialUsers.public + # Is anonymous public access permitted? + if settings.allow_anonymous_access: + # Any user who can see the server can make unauthenticated requests. + # This is a sentinel that has special meaning to the authorization + # code (the access control policies). + scopes = {"read:metadata", "read:data"} + else: + # In this mode, there may still be entries that are visible to all, + # but users have to authenticate as *someone* to see anything. + # They can still access the / and /docs routes. + scopes = {} + # Scope enforcement happens here. + # https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ + if not set(security_scopes.scopes).issubset(scopes): + # Include a link to the root page which provides a list of + # authenticators. The use case here is: + # 1. User is emailed a link like https://example.com/subpath//metadata/a/b/c + # 2. Tiled Client tries to connect to that and gets 401. + # 3. Client can use this header to find its way to + # https://examples.com/subpath/ and obtain a list of + # authentication providers and endpoints. + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=( + "Not enough permissions. " + f"Requires scopes {security_scopes.scopes}. " + f"Request had scopes {list(scopes)}" + ), + headers=headers_for_401(request, security_scopes), + ) + # This is used to pass the currently-authenticated principal into the logger. + request.state.principal = principal + return principal + + +@cache +def session_state_getter(authenticators: dict[str, Authenticator], settings: Settings): + decode_token = decode_token_for_authenticators(authenticators, settings) + + async def get_session_state( + decoded_access_token: dict[str, Any] | None = Depends(decode_token) + ): + if decoded_access_token: + return decoded_access_token.get("state") + + return get_session_state + + +@cache +def current_principal_getter( + authenticators: dict[str, Authenticator], + settings: Settings, +): + decode_token = decode_token_for_authenticators(authenticators, settings) + + async def get_current_principal( + request: Request, + security_scopes: SecurityScopes, + decoded_access_token: dict[str, Any] | None = Depends(decode_token), + api_key: str | None = Depends(get_api_key), + settings: Settings = Depends(get_settings), + db=Depends(get_database_session), + ): + """ + Get current Principal from: + - API key in 'api_key' query parameter + - API key in header 'Authorization: Apikey ...' + - API key in cookie 'tiled_api_key' + - OAuth2 JWT access token in header 'Authorization: Bearer ...' + + Fall back to SpecialUsers.public, if anonymous access is allowed + If this server is configured with a "single-user API key", then + the Principal will be SpecialUsers.admin always. + """ + + if api_key is not None: + if authenticators: + # Tiled is in a multi-user configuration with authentication providers. + # We store the hashed value of the API key secret. + # By comparing hashes we protect against timing attacks. + # By storing only the hash of the (high-entropy) secret + # we reduce the value of that an attacker can extracted from a + # stolen database backup. + try: + secret = bytes.fromhex(api_key) + except Exception: + # Not valid hex, therefore not a valid API key + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers=headers_for_401(request, security_scopes), + ) + api_key_orm = await lookup_valid_api_key(db, secret) + if api_key_orm is not None: + principal = api_key_orm.principal + principal_scopes = set().union( + *[role.scopes for role in principal.roles] + ) + # This intersection addresses the case where the Principal has + # lost a scope that they had when this key was created. + scopes = set(api_key_orm.scopes).intersection( + principal_scopes | {"inherit"} + ) + if "inherit" in scopes: + # The scope "inherit" is a metascope that confers all the + # scopes for the Principal associated with this API, + # resolved at access time. + scopes.update(principal_scopes) + api_key_orm.latest_activity = utcnow() + await db.commit() + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers=headers_for_401(request, security_scopes), + ) + else: + # Tiled is in a "single user" mode with only one API key. + if secrets.compare_digest(api_key, settings.single_user_api_key): + principal = SpecialUsers.admin + scopes = { + "read:metadata", + "read:data", + "write:metadata", + "write:data", + "create", + "register", + "metrics", + } + else: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers=headers_for_401(request, security_scopes), + ) + # If we made it to this point, we have a valid API key. + # If the API key was given in query param, move to cookie. + # This is convenient for browser-based access. + if ("api_key" in request.query_params) and ( + request.cookies.get(API_KEY_COOKIE_NAME) != api_key + ): + request.state.cookies_to_set.append( + {"key": API_KEY_COOKIE_NAME, "value": api_key} + ) + elif decoded_access_token is not None: + principal = schemas.Principal( + uuid=uuid_module.UUID(hex=decoded_access_token["sub"]), + type=decoded_access_token["sub_typ"], + identities=[ + schemas.Identity(id=identity["id"], provider=identity["idp"]) + for identity in decoded_access_token["ids"] + ], + ) + scopes = decoded_access_token["scp"] + else: + # No form of authentication is present. + principal = SpecialUsers.public + # Is anonymous public access permitted? + if settings.allow_anonymous_access: + # Any user who can see the server can make unauthenticated requests. + # This is a sentinel that has special meaning to the authorization + # code (the access control policies). + scopes = {"read:metadata", "read:data"} + else: + # In this mode, there may still be entries that are visible to all, + # but users have to authenticate as *someone* to see anything. + # They can still access the / and /docs routes. + scopes = {} + # Scope enforcement happens here. + # https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ + if not set(security_scopes.scopes).issubset(scopes): + # Include a link to the root page which provides a list of + # authenticators. The use case here is: + # 1. User is emailed a link like https://example.com/subpath//metadata/a/b/c + # 2. Tiled Client tries to connect to that and gets 401. + # 3. Client can use this header to find its way to + # https://examples.com/subpath/ and obtain a list of + # authentication providers and endpoints. + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=( + "Not enough permissions. " + f"Requires scopes {security_scopes.scopes}. " + f"Request had scopes {list(scopes)}" + ), + headers=headers_for_401(request, security_scopes), + ) + # This is used to pass the currently-authenticated principal into the logger. + request.state.principal = principal + return principal + + return get_current_principal + + async def create_session( settings: Settings, db, @@ -644,171 +822,12 @@ async def generate_apikey(db, principal, apikey_params, request): def build_base_authentication_router( - oauth2_scheme: OAuth2, decode_token: Callable[[str], dict[str, Any]] + oauth2_schema: OAuth2, + decode_token: Callable[[str], dict[str, Any]], + authenticators: dict[str, Authenticator], ) -> APIRouter: authentication_router = APIRouter() - - async def get_decoded_access_token( - request: Request, - security_scopes: SecurityScopes, - access_token: str | None = Depends(oauth2_scheme), - ) -> Optional[dict[str, Any]]: - if not access_token: - return None - try: - payload = decode_token(access_token) - except ExpiredSignatureError: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Access token has expired. Refresh token.", - headers=headers_for_401(request, security_scopes), - ) - return payload - - async def get_session_state( - decoded_access_token: Optional[dict[str, Any]] = Depends( - get_decoded_access_token - ) - ): - if decoded_access_token: - return decoded_access_token.get("state") - - async def get_current_principal( - request: Request, - security_scopes: SecurityScopes, - decoded_access_token: str | None = Depends(get_decoded_access_token), - api_key: str | None = Depends(get_api_key), - settings: Settings = Depends(get_settings), - authenticators=Depends(get_authenticators), - db=Depends(get_database_session), - ): - """ - Get current Principal from: - - API key in 'api_key' query parameter - - API key in header 'Authorization: Apikey ...' - - API key in cookie 'tiled_api_key' - - OAuth2 JWT access token in header 'Authorization: Bearer ...' - - Fall back to SpecialUsers.public, if anonymous access is allowed - If this server is configured with a "single-user API key", then - the Principal will be SpecialUsers.admin always. - """ - - if api_key is not None: - if authenticators: - # Tiled is in a multi-user configuration with authentication providers. - # We store the hashed value of the API key secret. - # By comparing hashes we protect against timing attacks. - # By storing only the hash of the (high-entropy) secret - # we reduce the value of that an attacker can extracted from a - # stolen database backup. - try: - secret = bytes.fromhex(api_key) - except Exception: - # Not valid hex, therefore not a valid API key - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers=headers_for_401(request, security_scopes), - ) - api_key_orm = await lookup_valid_api_key(db, secret) - if api_key_orm is not None: - principal = api_key_orm.principal - principal_scopes = set().union( - *[role.scopes for role in principal.roles] - ) - # This intersection addresses the case where the Principal has - # lost a scope that they had when this key was created. - scopes = set(api_key_orm.scopes).intersection( - principal_scopes | {"inherit"} - ) - if "inherit" in scopes: - # The scope "inherit" is a metascope that confers all the - # scopes for the Principal associated with this API, - # resolved at access time. - scopes.update(principal_scopes) - api_key_orm.latest_activity = utcnow() - await db.commit() - else: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers=headers_for_401(request, security_scopes), - ) - else: - # Tiled is in a "single user" mode with only one API key. - if secrets.compare_digest(api_key, settings.single_user_api_key): - principal = SpecialUsers.admin - scopes = { - "read:metadata", - "read:data", - "write:metadata", - "write:data", - "create", - "register", - "metrics", - } - else: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers=headers_for_401(request, security_scopes), - ) - # If we made it to this point, we have a valid API key. - # If the API key was given in query param, move to cookie. - # This is convenient for browser-based access. - if ("api_key" in request.query_params) and ( - request.cookies.get(API_KEY_COOKIE_NAME) != api_key - ): - request.state.cookies_to_set.append( - {"key": API_KEY_COOKIE_NAME, "value": api_key} - ) - elif decoded_access_token is not None: - principal = schemas.Principal( - uuid=uuid_module.UUID(hex=decoded_access_token["sub"]), - type=decoded_access_token["sub_typ"], - identities=[ - schemas.Identity(id=identity["id"], provider=identity["idp"]) - for identity in decoded_access_token["ids"] - ], - ) - scopes = decoded_access_token["scp"] - else: - # No form of authentication is present. - principal = SpecialUsers.public - # Is anonymous public access permitted? - if settings.allow_anonymous_access: - # Any user who can see the server can make unauthenticated requests. - # This is a sentinel that has special meaning to the authorization - # code (the access control policies). - scopes = {"read:metadata", "read:data"} - else: - # In this mode, there may still be entries that are visible to all, - # but users have to authenticate as *someone* to see anything. - # They can still access the / and /docs routes. - scopes = {} - # Scope enforcement happens here. - # https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ - if not set(security_scopes.scopes).issubset(scopes): - # Include a link to the root page which provides a list of - # authenticators. The use case here is: - # 1. User is emailed a link like https://example.com/subpath//metadata/a/b/c - # 2. Tiled Client tries to connect to that and gets 401. - # 3. Client can use this header to find its way to - # https://examples.com/subpath/ and obtain a list of - # authentication providers and endpoints. - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail=( - "Not enough permissions. " - f"Requires scopes {security_scopes.scopes}. " - f"Request had scopes {list(scopes)}" - ), - headers=headers_for_401(request, security_scopes), - ) - # This is used to pass the currently-authenticated principal into the logger. - request.state.principal = principal - return principal + get_current_principal = current_principal_getter(oauth2_schema, authenticators) @authentication_router.get( "/principal", @@ -891,7 +910,7 @@ async def create_service_principal( async def principal( request: Request, uuid: uuid_module.UUID, - principal=Security(get_current_principal, scopes=["read:principals"]), + _: str | None = Security(move_api_key, scopes=["read:principals"]), db=Depends(get_database_session), ): "Get information about one Principal (user or service)." @@ -926,7 +945,7 @@ async def revoke_apikey_for_principal( request: Request, uuid: uuid_module.UUID, first_eight: str, - principal=Security(get_current_principal, scopes=["admin:apikeys"]), + _: str | None = Security(move_api_key, scopes=["admin:apikeys"]), db=Depends(get_database_session), ): "Allow Tiled Admins to delete any user's apikeys e.g." @@ -1218,7 +1237,7 @@ async def whoami( async def logout( request: Request, response: Response, - principal=Security(get_current_principal, scopes=[]), + _: str | None = Security(move_api_key), ): "Deprecated. See revoke_session: POST /session/revoke." request.state.endpoint = "auth" @@ -1228,24 +1247,26 @@ async def logout( return authentication_router -def build_authentication_router( - authenticators: dict[str, Authenticator], first_provider: str, settings: Settings -) -> APIRouter: +@cache +def get_oauth2_scheme(authenticators: dict[str, Authenticator], first_provider: str): if len(authenticators) == 1 and isinstance( auth := authenticators[first_provider], ProxiedOIDCAuthenticator ): - oauth2_scheme = OAuth2AuthorizationCodeBearer( + return OAuth2AuthorizationCodeBearer( auth.authorization_endpoint, auth.token_endpoint ) - decode_access_token = auth.decode_access_token else: - oauth2_scheme = OAuth2PasswordBearer( - f"/api/v1/auth/provider/{first_provider}/token" - ) - decode_access_token = partial(decode_token, settings.secret_keys) + return OAuth2PasswordBearer(f"/api/v1/auth/provider/{first_provider}/token") + + +def build_authentication_router( + authenticators: dict[str, Authenticator], first_provider: str, settings: Settings +) -> APIRouter: + oauth2_scheme = get_oauth2_scheme(authenticators, first_provider) + decode_access_token = decode_token_for_authenticators(authenticators, settings) authentication_router = build_base_authentication_router( - oauth2_scheme, decode_access_token + oauth2_scheme, decode_access_token, authenticators ) for provider, authenticator in authenticators.items(): if isinstance(authenticator, ExternalAuthenticator): diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 09b8d8233..0527431ba 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -1,19 +1,9 @@ -from functools import cache from typing import Optional, Tuple, Union import pydantic_settings from fastapi import Depends, HTTPException, Query, Request, Security from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND -from ..media_type_registration import ( - deserialization_registry as default_deserialization_registry, -) -from ..media_type_registration import ( - serialization_registry as default_serialization_registry, -) -from ..query_registration import query_registry as default_query_registry -from ..validation_registration import validation_registry as default_validation_registry -from .authentication import get_current_principal, get_session_state from .core import NoEntry from .utils import filter_for_access, record_timing @@ -24,149 +14,122 @@ SLICE_REGEX = rf"^{DIM_REGEX}(?:,{DIM_REGEX})*$" -@cache -def get_query_registry(): - "This may be overridden via dependency_overrides." - return default_query_registry - - -@cache -def get_deserialization_registry(): - "This may be overridden via dependency_overrides." - return default_deserialization_registry - - -@cache -def get_serialization_registry(): - "This may be overridden via dependency_overrides." - return default_serialization_registry - - -@cache -def get_validation_registry(): - "This may be overridden via dependency_overrides." - return default_validation_registry - - -def get_root_tree(): - raise NotImplementedError( - "This should be overridden via dependency_overrides. " - "See tiled.server.app.build_app()." - ) - - -def SecureEntry(scopes, structure_families=None): - async def inner( - path: str, - request: Request, - principal: str = Depends(get_current_principal), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), - session_state: dict = Depends(get_session_state), - ): - """ - Obtain a node in the tree from its path. - - Walk down the path from the root tree, discover the access policy - to be used for access to the destination node, and finally filter - access by the specified scope. - - The access policy used for access to the destination node will be - the last one found while walking the tree or, in the case of a catalog adapter, - the access policy of the catalog adapter node. - - session_state is an optional dictionary passed in the session token - """ - path_parts = [segment for segment in path.split("/") if segment] - path_parts_relative = path_parts - entry = root_tree - entry_with_access_policy = ( - entry if getattr(root_tree, "access_policy", None) is not None else None - ) - - # If the entry/adapter can take a session state, pass it in. - # The entry/adapter may return itself or a different object. - if hasattr(entry, "with_session_state") and session_state: - entry = entry.with_session_state(session_state) - # start at the root - # filter and keep only what we are allowed to see from here - entry = await filter_for_access( - entry, - principal, - ["read:metadata"], - request.state.metrics, - path_parts_relative, - ) - try: - for i, segment in enumerate(path_parts): - if hasattr(entry, "lookup_adapter"): - # New catalog adapter - only has access control at the top level - # Top level means the basename of the path as defined in the config - # This adapter can jump directly to the node of interest - entry = await entry.lookup_adapter(path_parts[i:]) - if entry is None: - raise NoEntry(path_parts) - break - else: - # Old-style dict-like interface - # Traverse into sub-tree(s) to reach the desired entry, and - # to discover the access policy to use for the request - try: - entry = entry[segment] - except (KeyError, TypeError): - raise NoEntry(path_parts) - if getattr(entry, "access_policy", None) is not None: - path_parts_relative = path_parts[i + 1 :] # noqa: E203 - entry_with_access_policy = entry - # filter and keep only what we are allowed to see from here - entry = await filter_for_access( - entry, - principal, - ["read:metadata"], - request.state.metrics, - path_parts_relative, - ) +def SecureEntryBuilder(get_current_principal, get_root_tree, get_session_state): + def SecureEntry(scopes, structure_families=None): + async def inner( + path: str, + request: Request, + principal: str = Depends(get_current_principal), + root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + ): + """ + Obtain a node in the tree from its path. + + Walk down the path from the root tree, discover the access policy + to be used for access to the destination node, and finally filter + access by the specified scope. + + The access policy used for access to the destination node will be + the last one found while walking the tree or, in the case of a catalog adapter, + the access policy of the catalog adapter node. + + session_state is an optional dictionary passed in the session token + """ + path_parts = [segment for segment in path.split("/") if segment] + path_parts_relative = path_parts + entry = root_tree + entry_with_access_policy = ( + entry if getattr(root_tree, "access_policy", None) is not None else None + ) - # Now check that we have the requested scope according to the discovered access policy - access_policy = getattr(entry_with_access_policy, "access_policy", None) - if access_policy is not None: - with record_timing(request.state.metrics, "acl"): - allowed_scopes = await access_policy.allowed_scopes( - entry_with_access_policy, principal, path_parts_relative - ) - if not set(scopes).issubset(allowed_scopes): - if "read:metadata" not in allowed_scopes: - # If you can't read metadata, it does not exist for you. + # If the entry/adapter can take a session state, pass it in. + # The entry/adapter may return itself or a different object. + if hasattr(entry, "with_session_state") and session_state: + entry = entry.with_session_state(session_state) + # start at the root + # filter and keep only what we are allowed to see from here + entry = await filter_for_access( + entry, + principal, + ["read:metadata"], + request.state.metrics, + path_parts_relative, + ) + try: + for i, segment in enumerate(path_parts): + if hasattr(entry, "lookup_adapter"): + # New catalog adapter - only has access control at the top level + # Top level means the basename of the path as defined in the config + # This adapter can jump directly to the node of interest + entry = await entry.lookup_adapter(path_parts[i:]) + if entry is None: + raise NoEntry(path_parts) + break + else: + # Old-style dict-like interface + # Traverse into sub-tree(s) to reach the desired entry, and + # to discover the access policy to use for the request + try: + entry = entry[segment] + except (KeyError, TypeError): raise NoEntry(path_parts) - else: - # You can see this, but you cannot perform the requested - # operation on it. - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail=( - "Not enough permissions to perform this action on this node. " - f"Requires scopes {scopes}. " - f"Principal had scopes {list(allowed_scopes)} on this node." - ), + if getattr(entry, "access_policy", None) is not None: + path_parts_relative = path_parts[i + 1 :] # noqa: E203 + entry_with_access_policy = entry + # filter and keep only what we are allowed to see from here + entry = await filter_for_access( + entry, + principal, + ["read:metadata"], + request.state.metrics, + path_parts_relative, ) - except NoEntry: + + # Now check that we have the requested scope according to the discovered access policy + access_policy = getattr(entry_with_access_policy, "access_policy", None) + if access_policy is not None: + with record_timing(request.state.metrics, "acl"): + allowed_scopes = await access_policy.allowed_scopes( + entry_with_access_policy, principal, path_parts_relative + ) + if not set(scopes).issubset(allowed_scopes): + if "read:metadata" not in allowed_scopes: + # If you can't read metadata, it does not exist for you. + raise NoEntry(path_parts) + else: + # You can see this, but you cannot perform the requested + # operation on it. + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail=( + "Not enough permissions to perform this action on this node. " + f"Requires scopes {scopes}. " + f"Principal had scopes {list(allowed_scopes)} on this node." + ), + ) + except NoEntry: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"No such entry: {path_parts}", + ) + # Fast path for the common successful case + if (structure_families is None) or ( + entry.structure_family in structure_families + ): + return entry raise HTTPException( - status_code=HTTP_404_NOT_FOUND, detail=f"No such entry: {path_parts}" + status_code=HTTP_404_NOT_FOUND, + detail=( + f"The node at {path} has structure family {entry.structure_family} " + "and this endpoint is compatible with structure families " + f"{structure_families}" + ), ) - # Fast path for the common successful case - if (structure_families is None) or ( - entry.structure_family in structure_families - ): - return entry - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, - detail=( - f"The node at {path} has structure family {entry.structure_family} " - "and this endpoint is compatible with structure families " - f"{structure_families}" - ), - ) - - return Security(inner, scopes=scopes) + + return Security(inner, scopes=scopes) + + return SecureEntry def block( diff --git a/tiled/server/metrics.py b/tiled/server/metrics.py index d6e0d79ee..c5aec5f41 100644 --- a/tiled/server/metrics.py +++ b/tiled/server/metrics.py @@ -10,7 +10,7 @@ from fastapi import APIRouter, Request, Response, Security from prometheus_client import CONTENT_TYPE_LATEST, Histogram, generate_latest -from .authentication import get_current_principal +from tiled.server.utils import move_api_key router = APIRouter() @@ -158,7 +158,7 @@ def prometheus_registry(): @router.get("/metrics") async def metrics( - request: Request, principal=Security(get_current_principal, scopes=["metrics"]) + request: Request, _: str | None = Security(move_api_key, scopes=["metrics"]) ): """ Prometheus metrics diff --git a/tiled/server/router.py b/tiled/server/router.py index d5f613920..636013479 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -1,5 +1,3 @@ -import dataclasses -import inspect import os import re import warnings @@ -27,14 +25,17 @@ ) from tiled.schemas import About -from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator +from tiled.server.protocols import ( + Authenticator, + ExternalAuthenticator, + InternalAuthenticator, +) from .. import __version__ from ..structures.core import Spec, StructureFamily from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri from ..validation_registration import ValidationError from . import schemas -from .authentication import get_authenticators, get_current_principal from .core import ( DEFAULT_PAGE_SIZE, DEPTH_LIMIT, @@ -51,13 +52,9 @@ resolve_media_type, ) from .dependencies import ( - SecureEntry, + SecureEntryBuilder, block, expected_shape, - get_deserialization_registry, - get_query_registry, - get_serialization_registry, - get_validation_registry, offset_param, shape_param, slice_, @@ -65,359 +62,380 @@ from .file_response_with_range import FileResponseWithRange from .links import links_for_node from .settings import Settings, get_settings -from .utils import filter_for_access, get_base_url, record_timing - -router = APIRouter() - - -@router.get("/", response_model=About) -async def about( - request: Request, - settings: Settings = Depends(get_settings), - authenticators=Depends(get_authenticators), - serialization_registry=Depends(get_serialization_registry), - query_registry=Depends(get_query_registry), - # This dependency is here because it runs the code that moves - # API key from the query parameter to a cookie (if it is valid). - principal=Security(get_current_principal, scopes=[]), -): - # TODO The lazy import of entry modules and serializers means that the - # lists of formats are not populated until they are first used. Not very - # helpful for discovery! The registration can be made non-lazy, while the - # imports of the underlying I/O libraries themselves (openpyxl, pillow, - # etc.) can remain lazy. - request.state.endpoint = "about" - base_url = get_base_url(request) - authentication = { - "required": not settings.allow_anonymous_access, - } - provider_specs = [] - user_agent = request.headers.get("user-agent", "") - # The name of the "internal" mode used to be "password". - # This ensures back-compat with older Python clients. - internal_mode_name = "internal" - MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION = packaging.version.parse("0.1.0b17") - if user_agent.startswith("python-tiled/"): - agent, _, raw_version = user_agent.partition("/") - try: - parsed_version = packaging.version.parse(raw_version) - except Exception: - pass - else: - if parsed_version < MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION: - internal_mode_name = "password" - for provider, authenticator in authenticators.items(): - if isinstance(authenticator, InternalAuthenticator): - spec = { - "provider": provider, - "mode": internal_mode_name, - "links": { - "auth_endpoint": f"{base_url}/auth/provider/{provider}/token" - }, - "confirmation_message": getattr( - authenticator, "confirmation_message", None - ), - } - elif isinstance(authenticator, ExternalAuthenticator): - spec = { - "provider": provider, - "mode": "external", - "links": { - "auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize" - }, - "confirmation_message": getattr( - authenticator, "confirmation_message", None - ), - } - else: - # It should be impossible to reach here. - assert False - provider_specs.append(spec) - if provider_specs: - # If there are *any* authenticaiton providers, these - # endpoints will be added. - authentication["links"] = { - "whoami": f"{base_url}/auth/whoami", - "apikey": f"{base_url}/auth/apikey", - "refresh_session": f"{base_url}/auth/session/refresh", - "revoke_session": f"{base_url}/auth/session/revoke/{{session_id}}", - "logout": f"{base_url}/auth/logout", - } - authentication["providers"] = provider_specs - - return json_or_msgpack( - request, - About( - library_version=__version__, - api_version=0, - formats={ - structure_family: list( - serialization_registry.media_types(structure_family) - ) - for structure_family in serialization_registry.structure_families - }, - aliases={ - structure_family: serialization_registry.aliases(structure_family) - for structure_family in serialization_registry.structure_families - }, - queries=list(query_registry.name_to_query_type), - authentication=authentication, - links={ - "self": base_url, - "documentation": f"{base_url}/docs", - }, - meta={"root_path": request.scope.get("root_path") or "" + "/api"}, - ).model_dump(), - expires=datetime.now(timezone.utc) + timedelta(seconds=600), - ) +from .utils import filter_for_access, get_base_url, move_api_key, record_timing -async def search( - request: Request, - path: str, - fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), - select_metadata: Optional[str] = Query(None), - offset: Optional[int] = Query(0, alias="page[offset]", ge=0), - limit: Optional[int] = Query( - DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE - ), - sort: Optional[str] = Query(None), - max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), - omit_links: bool = Query(False), - include_data_sources: bool = Query(False), - entry: Any = SecureEntry(scopes=["read:metadata"]), - query_registry=Depends(get_query_registry), - principal: str = Depends(get_current_principal), - **filters, -): - request.state.endpoint = "search" - if entry.structure_family != StructureFamily.container: - raise WrongTypeForRoute("This is not a Node; it cannot be searched or listed.") - try: - resource, metadata_stale_at, must_revalidate = await construct_entries_response( - query_registry, - entry, - "/search", - path, - offset, - limit, - fields, - select_metadata, - omit_links, - include_data_sources, - filters, - sort, - get_base_url(request), - resolve_media_type(request), - max_depth=max_depth, - ) - # We only get one Expires header, so if different parts - # of this response become stale at different times, we - # cite the earliest one. - entries_stale_at = getattr(entry, "entries_stale_at", None) - headers = {} - if (metadata_stale_at is None) or (entries_stale_at is None): - expires = None - else: - expires = min(metadata_stale_at, entries_stale_at) - if must_revalidate: - headers["Cache-Control"] = "must-revalidate" - return json_or_msgpack( - request, - resource.model_dump(), - expires=expires, - headers=headers, - ) - except NoEntry: - raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="No such entry.") - except WrongTypeForRoute as err: - raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=err.args[0]) - except JMESPathError as err: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", - ) - +def get_router( + query_registry, + authenticators: dict[str, Authenticator], + get_current_principal, + get_root_tree, + get_session_state, + serialization_registry, + deserialization_registry, + validation_registry, +) -> APIRouter: + router = APIRouter() + SecureEntry = SecureEntryBuilder( + get_current_principal, get_root_tree, get_session_state + ) -async def distinct( - request: Request, - structure_families: bool = False, - specs: bool = False, - metadata: Optional[List[str]] = Query(default=[]), - counts: bool = False, - entry: Any = SecureEntry(scopes=["read:metadata"]), - query_registry=Depends(get_query_registry), - **filters, -): - if hasattr(entry, "get_distinct"): - filtered = await apply_search(entry, filters, query_registry) - distinct = await ensure_awaitable( - filtered.get_distinct, metadata, structure_families, specs, counts - ) + @router.get("/", response_model=About) + async def about( + request: Request, + settings: Settings = Depends(get_settings), + _: str | None = Security(move_api_key), + ): + # TODO The lazy import of entry modules and serializers means that the + # lists of formats are not populated until they are first used. Not very + # helpful for discovery! The registration can be made non-lazy, while the + # imports of the underlying I/O libraries themselves (openpyxl, pillow, + # etc.) can remain lazy. + request.state.endpoint = "about" + base_url = get_base_url(request) + authentication = { + "required": not settings.allow_anonymous_access, + } + provider_specs = [] + user_agent = request.headers.get("user-agent", "") + # The name of the "internal" mode used to be "password". + # This ensures back-compat with older Python clients. + internal_mode_name = "internal" + MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION = packaging.version.parse("0.1.0b17") + if user_agent.startswith("python-tiled/"): + agent, _, raw_version = user_agent.partition("/") + try: + parsed_version = packaging.version.parse(raw_version) + except Exception: + pass + else: + if parsed_version < MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION: + internal_mode_name = "password" + for provider, authenticator in authenticators.items(): + if isinstance(authenticator, InternalAuthenticator): + spec = { + "provider": provider, + "mode": internal_mode_name, + "links": { + "auth_endpoint": f"{base_url}/auth/provider/{provider}/token" + }, + "confirmation_message": getattr( + authenticator, "confirmation_message", None + ), + } + elif isinstance(authenticator, ExternalAuthenticator): + spec = { + "provider": provider, + "mode": "external", + "links": { + "auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize" + }, + "confirmation_message": getattr( + authenticator, "confirmation_message", None + ), + } + else: + # It should be impossible to reach here. + assert False + provider_specs.append(spec) + if provider_specs: + # If there are *any* authenticaiton providers, these + # endpoints will be added. + authentication["links"] = { + "whoami": f"{base_url}/auth/whoami", + "apikey": f"{base_url}/auth/apikey", + "refresh_session": f"{base_url}/auth/session/refresh", + "revoke_session": f"{base_url}/auth/session/revoke/{{session_id}}", + "logout": f"{base_url}/auth/logout", + } + authentication["providers"] = provider_specs return json_or_msgpack( - request, schemas.GetDistinctResponse.model_validate(distinct).model_dump() - ) - else: - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support distinct.", + request, + About( + library_version=__version__, + api_version=0, + formats={ + structure_family: list( + serialization_registry.media_types(structure_family) + ) + for structure_family in serialization_registry.structure_families + }, + aliases={ + structure_family: serialization_registry.aliases(structure_family) + for structure_family in serialization_registry.structure_families + }, + queries=list(query_registry.name_to_query_type), + authentication=authentication, + links={ + "self": base_url, + "documentation": f"{base_url}/docs", + }, + meta={"root_path": request.scope.get("root_path") or "" + "/api"}, + ).model_dump(), + expires=datetime.now(timezone.utc) + timedelta(seconds=600), ) - -def patch_route_signature(route, query_registry): - """ - This is done dynamically at router startup. - - We check the registry of known search query types, which is user - configurable, and use that to define the allowed HTTP query parameters for - this route. - - Take a route that accept unspecified search queries as **filters. - Return a wrapped version of the route that has the supported - search queries explicitly spelled out in the function signature. - - This has no change in the actual behavior of the function, - but it enables FastAPI to generate good OpenAPI documentation - showing the supported search queries. - - """ - - # Build a wrapper so that we can modify the signature - # without mutating the wrapped original. - - async def route_with_sig(*args, **kwargs): - return await route(*args, **kwargs) - - # Black magic here! FastAPI bases its validation and auto-generated swagger - # documentation on the signature of the route function. We do not know what - # that signature should be at compile-time. We only know it once we have a - # chance to check the user-configurable registry of query types. Therefore, - # we modify the signature here, at runtime, just before handing it to - # FastAPI in the usual way. - - # When FastAPI calls the function with these added parameters, they will be - # accepted via **filters. - - # Make a copy of the original parameters. - signature = inspect.signature(route) - parameters = list(signature.parameters.values()) - # Drop the **filters parameter from the signature. - del parameters[-1] - # Add a parameter for each field in each type of query. - for name, query in query_registry.name_to_query_type.items(): - for field in dataclasses.fields(query): - # The structured "alias" here is based on - # https://mglaman.dev/blog/using-json-router-query-your-search-router-indexes - if getattr(field.type, "__origin__", None) is list: - field_type = str + @router.get( + "/api/v1/search/{path:path}", + response_model=schemas.Response[ + List[schemas.Resource[schemas.NodeAttributes, dict, dict]], + schemas.PaginationLinks, + dict, + ], + ) + async def search( + request: Request, + path: str, + fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), + select_metadata: Optional[str] = Query(None), + offset: Optional[int] = Query(0, alias="page[offset]", ge=0), + limit: Optional[int] = Query( + DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE + ), + sort: Optional[str] = Query(None), + max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), + omit_links: bool = Query(False), + include_data_sources: bool = Query(False), + entry: Any = SecureEntry(scopes=["read:metadata"]), + _: str | None = Depends(move_api_key), + **filters, + ): + request.state.endpoint = "search" + if entry.structure_family != StructureFamily.container: + raise WrongTypeForRoute( + "This is not a Node; it cannot be searched or listed." + ) + try: + ( + resource, + metadata_stale_at, + must_revalidate, + ) = await construct_entries_response( + query_registry, + entry, + "/search", + path, + offset, + limit, + fields, + select_metadata, + omit_links, + include_data_sources, + filters, + sort, + get_base_url(request), + resolve_media_type(request), + max_depth=max_depth, + ) + # We only get one Expires header, so if different parts + # of this response become stale at different times, we + # cite the earliest one. + entries_stale_at = getattr(entry, "entries_stale_at", None) + headers = {} + if (metadata_stale_at is None) or (entries_stale_at is None): + expires = None else: - field_type = field.type - injected_parameter = inspect.Parameter( - name=f"filter___{name}___{field.name}", - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=Query(None, alias=f"filter[{name}][condition][{field.name}]"), - annotation=Optional[List[field_type]], + expires = min(metadata_stale_at, entries_stale_at) + if must_revalidate: + headers["Cache-Control"] = "must-revalidate" + return json_or_msgpack( + request, + resource.model_dump(), + expires=expires, + headers=headers, + ) + except NoEntry: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="No such entry.") + except WrongTypeForRoute as err: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=err.args[0]) + except JMESPathError as err: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", ) - parameters.append(injected_parameter) - route_with_sig.__signature__ = signature.replace(parameters=parameters) - # End black magic - - return route_with_sig + @router.get( + "/api/v1/distinct/{path:path}", + response_model=schemas.GetDistinctResponse, + ) + async def distinct( + request: Request, + structure_families: bool = False, + specs: bool = False, + metadata: Optional[List[str]] = Query(default=[]), + counts: bool = False, + entry: Any = SecureEntry(scopes=["read:metadata"]), + **filters, + ): + if hasattr(entry, "get_distinct"): + filtered = await apply_search(entry, filters, query_registry) + distinct = await ensure_awaitable( + filtered.get_distinct, metadata, structure_families, specs, counts + ) -@router.get( - "/metadata/{path:path}", - response_model=schemas.Response[ - schemas.Resource[schemas.NodeAttributes, dict, dict], dict, dict - ], -) -async def metadata( - request: Request, - path: str, - fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), - select_metadata: Optional[str] = Query(None), - max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), - omit_links: bool = Query(False), - include_data_sources: bool = Query(False), - entry: Any = SecureEntry(scopes=["read:metadata"]), - root_path: bool = Query(False), -): - """Fetch the metadata and structure information for one entry""" - - request.state.endpoint = "metadata" - base_url = get_base_url(request) - path_parts = [segment for segment in path.split("/") if segment] - try: - resource = await construct_resource( - base_url, - path_parts, - entry, - fields, - select_metadata, - omit_links, - include_data_sources, - resolve_media_type(request), - max_depth=max_depth, - ) - except JMESPathError as err: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", - ) - meta = {"root_path": request.scope.get("root_path") or "/"} if root_path else {} + return json_or_msgpack( + request, + schemas.GetDistinctResponse.model_validate(distinct).model_dump(), + ) + else: + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support distinct.", + ) - return json_or_msgpack( - request, - schemas.Response(data=resource, meta=meta).model_dump(), - expires=getattr(entry, "metadata_stale_at", None), + @router.get( + "/metadata/{path:path}", + response_model=schemas.Response[ + schemas.Resource[schemas.NodeAttributes, dict, dict], dict, dict + ], ) + async def metadata( + request: Request, + path: str, + fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), + select_metadata: Optional[str] = Query(None), + max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), + omit_links: bool = Query(False), + include_data_sources: bool = Query(False), + entry: Any = SecureEntry(scopes=["read:metadata"]), + root_path: bool = Query(False), + ): + """Fetch the metadata and structure information for one entry""" + request.state.endpoint = "metadata" + base_url = get_base_url(request) + path_parts = [segment for segment in path.split("/") if segment] + try: + resource = await construct_resource( + base_url, + path_parts, + entry, + fields, + select_metadata, + omit_links, + include_data_sources, + resolve_media_type(request), + max_depth=max_depth, + ) + except JMESPathError as err: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", + ) + meta = {"root_path": request.scope.get("root_path") or "/"} if root_path else {} -@router.get( - "/array/block/{path:path}", response_model=schemas.Response, name="array block" -) -async def array_block( - request: Request, - entry=SecureEntry( - scopes=["read:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - block=Depends(block), - slice=Depends(slice_), - expected_shape=Depends(expected_shape), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a chunk of array-like data. - """ - shape = entry.structure().shape - # Check that block dimensionality matches array dimensionality. - ndim = len(shape) - if len(block) != ndim: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Block parameter must have {ndim} comma-separated parameters, " - f"corresponding to the dimensions of this {ndim}-dimensional array." - ), + return json_or_msgpack( + request, + schemas.Response(data=resource, meta=meta).model_dump(), + expires=getattr(entry, "metadata_stale_at", None), ) - if block == (): - # Handle special case of numpy scalar. - if shape != (): + + @router.get( + "/array/block/{path:path}", response_model=schemas.Response, name="array block" + ) + async def array_block( + request: Request, + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + block=Depends(block), + slice=Depends(slice_), + expected_shape=Depends(expected_shape), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a chunk of array-like data. + """ + shape = entry.structure().shape + # Check that block dimensionality matches array dimensionality. + ndim = len(shape) + if len(block) != ndim: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Block parameter must have {ndim} comma-separated parameters, " + f"corresponding to the dimensions of this {ndim}-dimensional array." + ), + ) + if block == (): + # Handle special case of numpy scalar. + if shape != (): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Requested scalar but shape is {entry.structure().shape}", + ) + with record_timing(request.state.metrics, "read"): + array = await ensure_awaitable(entry.read) + else: + try: + with record_timing(request.state.metrics, "read"): + array = await ensure_awaitable(entry.read_block, block, slice) + except IndexError: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail="Block index out of range" + ) + if (expected_shape is not None) and (expected_shape != array.shape): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"The expected_shape {expected_shape} does not match the actual shape {array.shape}", + ) + if array.nbytes > settings.response_bytesize_limit: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"Requested scalar but shape is {entry.structure().shape}", + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - with record_timing(request.state.metrics, "read"): - array = await ensure_awaitable(entry.read) - else: + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + array, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + # raise HTTPException(status_code=406, detail=", ".join(err.supported)) + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + + @router.get( + "/array/full/{path:path}", response_model=schemas.Response, name="full array" + ) + async def array_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + slice=Depends(slice_), + expected_shape=Depends(expected_shape), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a slice of array-like data. + """ + structure_family = entry.structure_family + # Deferred import because this is not a required dependency of the server + # for some use cases. + import numpy + try: with record_timing(request.state.metrics, "read"): - array = await ensure_awaitable(entry.read_block, block, slice) + array = await ensure_awaitable(entry.read, slice) + if structure_family == StructureFamily.array: + array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. except IndexError: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail="Block index out of range" @@ -427,1330 +445,1227 @@ async def array_block( status_code=HTTP_400_BAD_REQUEST, detail=f"The expected_shape {expected_shape} does not match the actual shape {array.shape}", ) - if array.nbytes > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), - ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - array, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + if array.nbytes > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - except UnsupportedMediaTypes as err: - # raise HTTPException(status_code=406, detail=", ".join(err.supported)) - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + array, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/array/full/{path:path}", response_model=schemas.Response, name="full array" -) -async def array_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - slice=Depends(slice_), - expected_shape=Depends(expected_shape), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of array-like data. - """ - structure_family = entry.structure_family - # Deferred import because this is not a required dependency of the server - # for some use cases. - import numpy - - try: - with record_timing(request.state.metrics, "read"): - array = await ensure_awaitable(entry.read, slice) - if structure_family == StructureFamily.array: - array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. - except IndexError: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail="Block index out of range" - ) - if (expected_shape is not None) and (expected_shape != array.shape): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"The expected_shape {expected_shape} does not match the actual shape {array.shape}", - ) - if array.nbytes > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), - ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - structure_family, - serialization_registry, - array, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + @router.get( + "/table/partition/{path:path}", + response_model=schemas.Response, + name="table partition", + ) + async def get_table_partition( + request: Request, + partition: int, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Query(None, min_length=1), + field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a partition (continuous block of rows) from a DataFrame [GET route]. + """ + if (field is not None) and (column is not None): + redundant_field_and_column = " ".join( + ( + "Cannot accept both 'column' and 'field' query parameters", + "in the same /table/partition request.", + "Include these query values using only the 'column' parameter.", + ) ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - - -@router.get( - "/table/partition/{path:path}", - response_model=schemas.Response, - name="table partition", -) -async def get_table_partition( - request: Request, - partition: int, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Query(None, min_length=1), - field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a partition (continuous block of rows) from a DataFrame [GET route]. - """ - if (field is not None) and (column is not None): - redundant_field_and_column = " ".join( - ( - "Cannot accept both 'column' and 'field' query parameters", - "in the same /table/partition request.", - "Include these query values using only the 'column' parameter.", + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=redundant_field_and_column ) - ) - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=redundant_field_and_column - ) - elif field is not None: - field_is_deprecated = " ".join( - ( - "Query parameter 'field' is deprecated for the /table/partition route.", - "Instead use the query parameter 'column'.", + elif field is not None: + field_is_deprecated = " ".join( + ( + "Query parameter 'field' is deprecated for the /table/partition route.", + "Instead use the query parameter 'column'.", + ) ) + warnings.warn(field_is_deprecated, DeprecationWarning) + return await table_partition( + request=request, + partition=partition, + entry=entry, + column=(column or field), + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - warnings.warn(field_is_deprecated, DeprecationWarning) - return await table_partition( - request=request, - partition=partition, - entry=entry, - column=(column or field), - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, - ) - -@router.post( - "/table/partition/{path:path}", - response_model=schemas.Response, - name="table partition", -) -async def post_table_partition( - request: Request, - partition: int, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Body(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a partition (continuous block of rows) from a DataFrame [POST route]. - """ - return await table_partition( - request=request, - partition=partition, - entry=entry, - column=column, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.post( + "/table/partition/{path:path}", + response_model=schemas.Response, + name="table partition", ) - - -async def table_partition( - request: Request, - partition: int, - entry, - column: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, - settings: Settings, -): - """ - Fetch a partition (continuous block of rows) from a DataFrame. - """ - try: - # The singular/plural mismatch here of "fields" and "field" is - # due to the ?field=A&field=B&field=C... encodes in a URL. - with record_timing(request.state.metrics, "read"): - df = await ensure_awaitable(entry.read_partition, partition, column) - except IndexError: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail="Partition out of range" - ) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." - ) - if df.memory_usage().sum() > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Select a subset of the columns ('?field=...') to " - "request a smaller chunks." - ), + async def post_table_partition( + request: Request, + partition: int, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Body(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a partition (continuous block of rows) from a DataFrame [POST route]. + """ + return await table_partition( + request=request, + partition=partition, + entry=entry, + column=column, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - StructureFamily.table, - serialization_registry, - df, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + async def table_partition( + request: Request, + partition: int, + entry, + column: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, + settings: Settings, + ): + """ + Fetch a partition (continuous block of rows) from a DataFrame. + """ + try: + # The singular/plural mismatch here of "fields" and "field" is + # due to the ?field=A&field=B&field=C... encodes in a URL. + with record_timing(request.state.metrics, "read"): + df = await ensure_awaitable(entry.read_partition, partition, column) + except IndexError: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail="Partition out of range" + ) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + ) + if df.memory_usage().sum() > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Select a subset of the columns ('?field=...') to " + "request a smaller chunks." + ), + ) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + StructureFamily.table, + serialization_registry, + df, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/table/full/{path:path}", - response_model=schemas.Response, - name="full 'table' data", -) -async def get_table_full( - request: Request, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch the data for the given table [GET route]. - """ - return await table_full( - request=request, - entry=entry, - column=column, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.get( + "/table/full/{path:path}", + response_model=schemas.Response, + name="full 'table' data", ) + async def get_table_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch the data for the given table [GET route]. + """ + return await table_full( + request=request, + entry=entry, + column=column, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, + ) - -@router.post( - "/table/full/{path:path}", - response_model=schemas.Response, - name="full 'table' data", -) -async def post_table_full( - request: Request, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Body(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch the data for the given table [POST route]. - """ - return await table_full( - request=request, - entry=entry, - column=column, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.post( + "/table/full/{path:path}", + response_model=schemas.Response, + name="full 'table' data", ) - - -async def table_full( - request: Request, - entry, - column: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, - settings: Settings, -): - """ - Fetch the data for the given table. - """ - try: - with record_timing(request.state.metrics, "read"): - data = await ensure_awaitable(entry.read, column) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." - ) - if data.memory_usage().sum() > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Select a subset of the columns to " - "request a smaller chunks." - ), + async def post_table_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Body(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch the data for the given table [POST route]. + """ + return await table_full( + request=request, + entry=entry, + column=column, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - data, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - filter_for_access=None, - ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + async def table_full( + request: Request, + entry, + column: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, + settings: Settings, + ): + """ + Fetch the data for the given table. + """ + try: + with record_timing(request.state.metrics, "read"): + data = await ensure_awaitable(entry.read, column) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + ) + if data.memory_usage().sum() > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Select a subset of the columns to " + "request a smaller chunks." + ), + ) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + data, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + filter_for_access=None, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/container/full/{path:path}", - response_model=schemas.Response, - name="full 'container' metadata and data", -) -async def get_container_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.container} - ), - principal: str = Depends(get_current_principal), - field: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), -): - """ - Fetch the data for the given container via a GET request. - """ - return await container_full( - request=request, - entry=entry, - principal=principal, - field=field, - format=format, - filename=filename, - serialization_registry=serialization_registry, + @router.get( + "/container/full/{path:path}", + response_model=schemas.Response, + name="full 'container' metadata and data", ) + async def get_container_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), + principal: str = Depends(get_current_principal), + field: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + ): + """ + Fetch the data for the given container via a GET request. + """ + return await container_full( + request=request, + entry=entry, + principal=principal, + field=field, + format=format, + filename=filename, + serialization_registry=serialization_registry, + ) - -@router.post( - "/container/full/{path:path}", - response_model=schemas.Response, - name="full 'container' metadata and data", -) -async def post_container_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.container} - ), - principal: str = Depends(get_current_principal), - field: Optional[List[str]] = Body(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), -): - """ - Fetch the data for the given container via a POST request. - """ - return await container_full( - request=request, - entry=entry, - principal=principal, - field=field, - format=format, - filename=filename, - serialization_registry=serialization_registry, + @router.post( + "/container/full/{path:path}", + response_model=schemas.Response, + name="full 'container' metadata and data", ) - - -async def container_full( - request: Request, - entry, - principal: str, - field: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, -): - """ - Fetch the data for the given container. - """ - try: - with record_timing(request.state.metrics, "read"): - data = await ensure_awaitable(entry.read, fields=field) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + async def post_container_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), + principal: str = Depends(get_current_principal), + field: Optional[List[str]] = Body(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + ): + """ + Fetch the data for the given container via a POST request. + """ + return await container_full( + request=request, + entry=entry, + principal=principal, + field=field, + format=format, + filename=filename, + serialization_registry=serialization_registry, ) - curried_filter = partial( - filter_for_access, - principal=principal, - scopes=["read:data"], - metrics=request.state.metrics, - ) - # TODO Walk node to determine size before handing off to serializer. - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - data, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - filter_for_access=curried_filter, - ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - -@router.get( - "/node/full/{path:path}", - response_model=schemas.Response, - name="full 'container' or 'table'", - deprecated=True, -) -async def node_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], - structure_families={StructureFamily.table, StructureFamily.container}, - ), - principal: str = Depends(get_current_principal), - field: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch the data below the given node. - """ - try: - with record_timing(request.state.metrics, "read"): - data = await ensure_awaitable(entry.read, field) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." - ) - if (entry.structure_family == StructureFamily.table) and ( - data.memory_usage().sum() > settings.response_bytesize_limit + async def container_full( + request: Request, + entry, + principal: str, + field: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, ): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Select a subset of the columns ('?field=...') to " - "request a smaller chunks." - ), - ) - if entry.structure_family == StructureFamily.container: + """ + Fetch the data for the given container. + """ + try: + with record_timing(request.state.metrics, "read"): + data = await ensure_awaitable(entry.read, fields=field) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + ) curried_filter = partial( filter_for_access, principal=principal, scopes=["read:data"], metrics=request.state.metrics, ) - else: - curried_filter = None # TODO Walk node to determine size before handing off to serializer. - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - data, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - filter_for_access=curried_filter, + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + data, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + filter_for_access=curried_filter, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + + @router.get( + "/node/full/{path:path}", + response_model=schemas.Response, + name="full 'container' or 'table'", + deprecated=True, + ) + async def node_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.table, StructureFamily.container}, + ), + principal: str = Depends(get_current_principal), + field: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch the data below the given node. + """ + try: + with record_timing(request.state.metrics, "read"): + data = await ensure_awaitable(entry.read, field) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - + if (entry.structure_family == StructureFamily.table) and ( + data.memory_usage().sum() > settings.response_bytesize_limit + ): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Select a subset of the columns ('?field=...') to " + "request a smaller chunks." + ), + ) + if entry.structure_family == StructureFamily.container: + curried_filter = partial( + filter_for_access, + principal=principal, + scopes=["read:data"], + metrics=request.state.metrics, + ) + else: + curried_filter = None + # TODO Walk node to determine size before handing off to serializer. + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + data, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + filter_for_access=curried_filter, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/awkward/buffers/{path:path}", - response_model=schemas.Response, - name="AwkwardArray buffers", -) -async def get_awkward_buffers( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), - form_key: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of AwkwardArray data. - - Note that there is a POST route on this same path with equivalent functionality. - HTTP caches tends to engage with GET but not POST, so that GET route may be - preferred for that reason. However, HTTP clients, servers, and proxies - typically impose a length limit on URLs. (The HTTP spec does not specify - one, but this is a pragmatic measure.) For requests with large numbers of - form_key parameters, POST may be the only option. - """ - return await _awkward_buffers( - request=request, - entry=entry, - form_key=form_key, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.get( + "/awkward/buffers/{path:path}", + response_model=schemas.Response, + name="AwkwardArray buffers", ) + async def get_awkward_buffers( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), + form_key: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a slice of AwkwardArray data. + + Note that there is a POST route on this same path with equivalent functionality. + HTTP caches tends to engage with GET but not POST, so that GET route may be + preferred for that reason. However, HTTP clients, servers, and proxies + typically impose a length limit on URLs. (The HTTP spec does not specify + one, but this is a pragmatic measure.) For requests with large numbers of + form_key parameters, POST may be the only option. + """ + return await _awkward_buffers( + request=request, + entry=entry, + form_key=form_key, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, + ) - -@router.post( - "/awkward/buffers/{path:path}", - response_model=schemas.Response, - name="AwkwardArray buffers", -) -async def post_awkward_buffers( - request: Request, - body: List[str], - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of AwkwardArray data. - - Note that there is a GET route on this same path with equivalent functionality. - HTTP caches tends to engage with GET but not POST, so that GET route may be - preferred for that reason. However, HTTP clients, servers, and proxies - typically impose a length limit on URLs. (The HTTP spec does not specify - one, but this is a pragmatic measure.) For requests with large numbers of - form_key parameters, POST may be the only option. - """ - return await _awkward_buffers( - request=request, - entry=entry, - form_key=body, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.post( + "/awkward/buffers/{path:path}", + response_model=schemas.Response, + name="AwkwardArray buffers", ) - - -async def _awkward_buffers( - request: Request, - entry, - form_key: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, - settings: Settings, -): - structure_family = entry.structure_family - structure = entry.structure() - with record_timing(request.state.metrics, "read"): - # The plural vs. singular mismatch is due to the way query parameters - # are given as ?form_key=A&form_key=B&form_key=C. - container = await ensure_awaitable(entry.read_buffers, form_key) - if ( - sum(len(buffer) for buffer in container.values()) - > settings.response_bytesize_limit + async def post_awkward_buffers( + request: Request, + body: List[str], + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), ): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), + """ + Fetch a slice of AwkwardArray data. + + Note that there is a GET route on this same path with equivalent functionality. + HTTP caches tends to engage with GET but not POST, so that GET route may be + preferred for that reason. However, HTTP clients, servers, and proxies + typically impose a length limit on URLs. (The HTTP spec does not specify + one, but this is a pragmatic measure.) For requests with large numbers of + form_key parameters, POST may be the only option. + """ + return await _awkward_buffers( + request=request, + entry=entry, + form_key=body, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - components = (structure.form, structure.length, container) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - structure_family, - serialization_registry, - components, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + + async def _awkward_buffers( + request: Request, + entry, + form_key: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, + settings: Settings, + ): + structure_family = entry.structure_family + structure = entry.structure() + with record_timing(request.state.metrics, "read"): + # The plural vs. singular mismatch is due to the way query parameters + # are given as ?form_key=A&form_key=B&form_key=C. + container = await ensure_awaitable(entry.read_buffers, form_key) + if ( + sum(len(buffer) for buffer in container.values()) + > settings.response_bytesize_limit + ): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + components = (structure.form, structure.length, container) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + components, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + @router.get( + "/awkward/full/{path:path}", + response_model=schemas.Response, + name="Full AwkwardArray", + ) + async def awkward_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), + # slice=Depends(slice_), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a slice of AwkwardArray data. + """ + structure_family = entry.structure_family + # Deferred import because this is not a required dependency of the server + # for some use cases. + import awkward -@router.get( - "/awkward/full/{path:path}", - response_model=schemas.Response, - name="Full AwkwardArray", -) -async def awkward_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), - # slice=Depends(slice_), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of AwkwardArray data. - """ - structure_family = entry.structure_family - # Deferred import because this is not a required dependency of the server - # for some use cases. - import awkward - - with record_timing(request.state.metrics, "read"): - container = await ensure_awaitable(entry.read) - structure = entry.structure() - components = (structure.form, structure.length, container) - array = awkward.from_buffers(*components) - if array.nbytes > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), - ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - structure_family, - serialization_registry, - components, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + with record_timing(request.state.metrics, "read"): + container = await ensure_awaitable(entry.read) + structure = entry.structure() + components = (structure.form, structure.length, container) + array = awkward.from_buffers(*components) + if array.nbytes > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - - -@router.post("/metadata/{path:path}", response_model=schemas.PostMetadataResponse) -async def post_metadata( - request: Request, - path: str, - body: schemas.PostMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "create"]), -): - for data_source in body.data_sources: - if data_source.assets: + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + components, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + + @router.post("/metadata/{path:path}", response_model=schemas.PostMetadataResponse) + async def post_metadata( + request: Request, + path: str, + body: schemas.PostMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata", "create"]), + ): + for data_source in body.data_sources: + if data_source.assets: + raise HTTPException( + "Externally-managed assets cannot be registered " + "using POST /metadata/{path} Use POST /register/{path} instead." + ) + if body.data_sources and not getattr(entry, "writable", False): raise HTTPException( - "Externally-managed assets cannot be registered " - "using POST /metadata/{path} Use POST /register/{path} instead." + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail=f"Data cannot be written at the path {path}", ) - if body.data_sources and not getattr(entry, "writable", False): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail=f"Data cannot be written at the path {path}", + return await _create_node( + request=request, + path=path, + body=body, + validation_registry=validation_registry, + settings=settings, + entry=entry, ) - return await _create_node( - request=request, - path=path, - body=body, - validation_registry=validation_registry, - settings=settings, - entry=entry, - ) - - -@router.post("/register/{path:path}", response_model=schemas.PostMetadataResponse) -async def post_register( - request: Request, - path: str, - body: schemas.PostMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "create", "register"]), -): - return await _create_node( - request=request, - path=path, - body=body, - validation_registry=validation_registry, - settings=settings, - entry=entry, - ) - - -async def _create_node( - request: Request, - path: str, - body: schemas.PostMetadataRequest, - validation_registry, - settings: Settings, - entry, -): - metadata, structure_family, specs = ( - body.metadata, - body.structure_family, - body.specs, - ) - if structure_family == StructureFamily.container: - structure = None - else: - if len(body.data_sources) != 1: - raise NotImplementedError - structure = body.data_sources[0].structure - - metadata_modified, metadata = await validate_metadata( - metadata=metadata, - structure_family=structure_family, - structure=structure, - specs=specs, - validation_registry=validation_registry, - settings=settings, - ) - - key, node = await entry.create_node( - metadata=body.metadata, - structure_family=body.structure_family, - key=body.id, - specs=body.specs, - data_sources=body.data_sources, - ) - links = links_for_node( - structure_family, structure, get_base_url(request), path + f"/{key}" - ) - response_data = { - "id": key, - "links": links, - "data_sources": [ds.model_dump() for ds in node.data_sources], - } - if metadata_modified: - response_data["metadata"] = metadata - - return json_or_msgpack(request, response_data) - - -@router.put("/data_source/{path:path}") -async def put_data_source( - request: Request, - path: str, - data_source: int, - body: schemas.PutDataSourceRequest, - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "register"]), -): - await entry.put_data_source( - data_source=body.data_source, - ) + @router.post("/register/{path:path}", response_model=schemas.PostMetadataResponse) + async def post_register( + request: Request, + path: str, + body: schemas.PostMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata", "create", "register"]), + ): + return await _create_node( + request=request, + path=path, + body=body, + validation_registry=validation_registry, + settings=settings, + entry=entry, + ) -@router.delete("/metadata/{path:path}") -async def delete( - request: Request, - entry=SecureEntry(scopes=["write:data", "write:metadata"]), -): - if hasattr(entry, "delete"): - await entry.delete() - else: - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support deletion.", + async def _create_node( + request: Request, + path: str, + body: schemas.PostMetadataRequest, + validation_registry, + settings: Settings, + entry, + ): + metadata, structure_family, specs = ( + body.metadata, + body.structure_family, + body.specs, ) - return json_or_msgpack(request, None) - - -@router.delete("/nodes/{path:path}") -async def bulk_delete( - request: Request, - entry=SecureEntry(scopes=["write:data", "write:metadata"]), -): - if hasattr(entry, "delete_tree"): - await entry.delete_tree() - else: - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support bulk deletion.", + if structure_family == StructureFamily.container: + structure = None + else: + if len(body.data_sources) != 1: + raise NotImplementedError + structure = body.data_sources[0].structure + + metadata_modified, metadata = await validate_metadata( + metadata=metadata, + structure_family=structure_family, + structure=structure, + specs=specs, + validation_registry=validation_registry, + settings=settings, ) - return json_or_msgpack(request, None) - - -@router.put("/array/full/{path:path}") -async def put_array_full( - request: Request, - entry=SecureEntry( - scopes=["write:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - deserialization_registry=Depends(get_deserialization_registry), -): - body = await request.body() - if not hasattr(entry, "write"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot accept array data.", + + key, node = await entry.create_node( + metadata=body.metadata, + structure_family=body.structure_family, + key=body.id, + specs=body.specs, + data_sources=body.data_sources, ) - media_type = request.headers["content-type"] - if entry.structure_family == "array": - dtype = entry.structure().data_type.to_numpy_dtype() - shape = entry.structure().shape - deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) - elif entry.structure_family == "sparse": - deserializer = deserialization_registry.dispatch("sparse", media_type) - data = await ensure_awaitable(deserializer, body) - else: - raise NotImplementedError(entry.structure_family) - await ensure_awaitable(entry.write, data) - return json_or_msgpack(request, None) - - -@router.put("/array/block/{path:path}") -async def put_array_block( - request: Request, - entry=SecureEntry( - scopes=["write:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - deserialization_registry=Depends(get_deserialization_registry), - block=Depends(block), -): - if not hasattr(entry, "write_block"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot accept array data.", + links = links_for_node( + structure_family, structure, get_base_url(request), path + f"/{key}" ) - from tiled.adapters.array import slice_and_shape_from_block_and_chunks + response_data = { + "id": key, + "links": links, + "data_sources": [ds.model_dump() for ds in node.data_sources], + } + if metadata_modified: + response_data["metadata"] = metadata + + return json_or_msgpack(request, response_data) + + @router.put("/data_source/{path:path}") + async def put_data_source( + request: Request, + path: str, + data_source: int, + body: schemas.PutDataSourceRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata", "register"]), + ): + await entry.put_data_source( + data_source=body.data_source, + ) + + @router.delete("/metadata/{path:path}") + async def delete( + request: Request, + entry=SecureEntry(scopes=["write:data", "write:metadata"]), + ): + if hasattr(entry, "delete"): + await entry.delete() + else: + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support deletion.", + ) + return json_or_msgpack(request, None) + + @router.delete("/nodes/{path:path}") + async def bulk_delete( + request: Request, + entry=SecureEntry(scopes=["write:data", "write:metadata"]), + ): + if hasattr(entry, "delete_tree"): + await entry.delete_tree() + else: + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support bulk deletion.", + ) + return json_or_msgpack(request, None) + + @router.put("/array/full/{path:path}") + async def put_array_full( + request: Request, + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + ): + body = await request.body() + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + media_type = request.headers["content-type"] + if entry.structure_family == "array": + dtype = entry.structure().data_type.to_numpy_dtype() + shape = entry.structure().shape + deserializer = deserialization_registry.dispatch("array", media_type) + data = await ensure_awaitable(deserializer, body, dtype, shape) + elif entry.structure_family == "sparse": + deserializer = deserialization_registry.dispatch("sparse", media_type) + data = await ensure_awaitable(deserializer, body) + else: + raise NotImplementedError(entry.structure_family) + await ensure_awaitable(entry.write, data) + return json_or_msgpack(request, None) + + @router.put("/array/block/{path:path}") + async def put_array_block( + request: Request, + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + block=Depends(block), + ): + if not hasattr(entry, "write_block"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + from tiled.adapters.array import slice_and_shape_from_block_and_chunks + + body = await request.body() + media_type = request.headers["content-type"] + if entry.structure_family == "array": + dtype = entry.structure().data_type.to_numpy_dtype() + _, shape = slice_and_shape_from_block_and_chunks( + block, entry.structure().chunks + ) + deserializer = deserialization_registry.dispatch("array", media_type) + data = await ensure_awaitable(deserializer, body, dtype, shape) + elif entry.structure_family == "sparse": + deserializer = deserialization_registry.dispatch("sparse", media_type) + data = await ensure_awaitable(deserializer, body) + else: + raise NotImplementedError(entry.structure_family) + await ensure_awaitable(entry.write_block, data, block) + return json_or_msgpack(request, None) + + @router.patch("/array/full/{path:path}") + async def patch_array_full( + request: Request, + offset=Depends(offset_param), + shape=Depends(shape_param), + extend: bool = False, + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array}, + ), + ): + if not hasattr(entry, "patch"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) - body = await request.body() - media_type = request.headers["content-type"] - if entry.structure_family == "array": dtype = entry.structure().data_type.to_numpy_dtype() - _, shape = slice_and_shape_from_block_and_chunks( - block, entry.structure().chunks - ) + body = await request.body() + media_type = request.headers["content-type"] deserializer = deserialization_registry.dispatch("array", media_type) data = await ensure_awaitable(deserializer, body, dtype, shape) - elif entry.structure_family == "sparse": - deserializer = deserialization_registry.dispatch("sparse", media_type) - data = await ensure_awaitable(deserializer, body) - else: - raise NotImplementedError(entry.structure_family) - await ensure_awaitable(entry.write_block, data, block) - return json_or_msgpack(request, None) - - -@router.patch("/array/full/{path:path}") -async def patch_array_full( - request: Request, - offset=Depends(offset_param), - shape=Depends(shape_param), - extend: bool = False, - entry=SecureEntry( - scopes=["write:data"], - structure_families={StructureFamily.array}, - ), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "patch"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot accept array data.", - ) - - dtype = entry.structure().data_type.to_numpy_dtype() - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) - structure = await ensure_awaitable(entry.patch, data, offset, extend) - return json_or_msgpack(request, structure) - - -@router.put("/table/full/{path:path}") -@router.put("/node/full/{path:path}", deprecated=True) -async def put_node_full( - request: Request, - entry=SecureEntry( - scopes=["write:data"], structure_families={StructureFamily.table} - ), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "write"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support writing.", - ) - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch(StructureFamily.table, media_type) - data = await ensure_awaitable(deserializer, body) - await ensure_awaitable(entry.write, data) - return json_or_msgpack(request, None) - - -@router.put("/table/partition/{path:path}") -async def put_table_partition( - partition: int, - request: Request, - entry=SecureEntry(scopes=["write:data"]), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "write_partition"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not supporting writing a partition.", - ) - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch(StructureFamily.table, media_type) - data = await ensure_awaitable(deserializer, body) - await ensure_awaitable(entry.write_partition, data, partition) - return json_or_msgpack(request, None) - - -@router.patch("/table/partition/{path:path}") -async def patch_table_partition( - partition: int, - request: Request, - entry=SecureEntry(scopes=["write:data"]), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "write_partition"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not supporting writing a partition.", - ) - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch(StructureFamily.table, media_type) - data = await ensure_awaitable(deserializer, body) - await ensure_awaitable(entry.append_partition, data, partition) - return json_or_msgpack(request, None) - - -@router.put("/awkward/full/{path:path}") -async def put_awkward_full( - request: Request, - entry=SecureEntry( - scopes=["write:data"], structure_families={StructureFamily.awkward} - ), - deserialization_registry=Depends(get_deserialization_registry), -): - body = await request.body() - if not hasattr(entry, "write"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot be written to.", + structure = await ensure_awaitable(entry.patch, data, offset, extend) + return json_or_msgpack(request, structure) + + @router.put("/table/full/{path:path}") + @router.put("/node/full/{path:path}", deprecated=True) + async def put_node_full( + request: Request, + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.table} + ), + ): + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support writing.", + ) + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.table, media_type ) - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch( - StructureFamily.awkward, media_type - ) - structure = entry.structure() - data = await ensure_awaitable(deserializer, body, structure.form, structure.length) - await ensure_awaitable(entry.write, data) - return json_or_msgpack(request, None) - - -@router.patch("/metadata/{path:path}", response_model=schemas.PatchMetadataResponse) -async def patch_metadata( - request: Request, - body: schemas.PatchMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata"]), -): - if not hasattr(entry, "replace_metadata"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support update of metadata.", + data = await ensure_awaitable(deserializer, body) + await ensure_awaitable(entry.write, data) + return json_or_msgpack(request, None) + + @router.put("/table/partition/{path:path}") + async def put_table_partition( + partition: int, + request: Request, + entry=SecureEntry(scopes=["write:data"]), + ): + if not hasattr(entry, "write_partition"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not supporting writing a partition.", + ) + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.table, media_type ) - if body.content_type == patch_mimetypes.JSON_PATCH: - metadata = apply_json_patch(entry.metadata(), (body.metadata or [])) - specs = apply_json_patch((entry.specs or []), (body.specs or [])) - elif body.content_type == patch_mimetypes.MERGE_PATCH: - metadata = apply_merge_patch(entry.metadata(), (body.metadata or {})) - # body.specs = [] clears specs, as per json merge patch specification - # but we treat body.specs = None as "no modifications" - current_specs = entry.specs or [] - target_specs = current_specs if body.specs is None else body.specs - specs = apply_merge_patch(current_specs, target_specs) - else: - raise HTTPException( - status_code=HTTP_406_NOT_ACCEPTABLE, - detail=f"valid content types: {', '.join(patch_mimetypes)}", + data = await ensure_awaitable(deserializer, body) + await ensure_awaitable(entry.write_partition, data, partition) + return json_or_msgpack(request, None) + + @router.patch("/table/partition/{path:path}") + async def patch_table_partition( + partition: int, + request: Request, + entry=SecureEntry(scopes=["write:data"]), + ): + if not hasattr(entry, "write_partition"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not supporting writing a partition.", + ) + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.table, media_type ) - - # Manually validate limits that bypass pydantic validation via patch - if len(specs) > schemas.MAX_ALLOWED_SPECS: - raise HTTPException( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Update cannot result in more than {schemas.MAX_ALLOWED_SPECS} specs", + data = await ensure_awaitable(deserializer, body) + await ensure_awaitable(entry.append_partition, data, partition) + return json_or_msgpack(request, None) + + @router.put("/awkward/full/{path:path}") + async def put_awkward_full( + request: Request, + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.awkward} + ), + ): + body = await request.body() + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot be written to.", + ) + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.awkward, media_type ) - if len(specs) != len(set(specs)): - raise HTTPException( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - detail="Update cannot result in non-unique specs", + structure = entry.structure() + data = await ensure_awaitable( + deserializer, body, structure.form, structure.length ) + await ensure_awaitable(entry.write, data) + return json_or_msgpack(request, None) + + @router.patch("/metadata/{path:path}", response_model=schemas.PatchMetadataResponse) + async def patch_metadata( + request: Request, + body: schemas.PatchMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata"]), + ): + if not hasattr(entry, "replace_metadata"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support update of metadata.", + ) + if body.content_type == patch_mimetypes.JSON_PATCH: + metadata = apply_json_patch(entry.metadata(), (body.metadata or [])) + specs = apply_json_patch((entry.specs or []), (body.specs or [])) + elif body.content_type == patch_mimetypes.MERGE_PATCH: + metadata = apply_merge_patch(entry.metadata(), (body.metadata or {})) + # body.specs = [] clears specs, as per json merge patch specification + # but we treat body.specs = None as "no modifications" + current_specs = entry.specs or [] + target_specs = current_specs if body.specs is None else body.specs + specs = apply_merge_patch(current_specs, target_specs) + else: + raise HTTPException( + status_code=HTTP_406_NOT_ACCEPTABLE, + detail=f"valid content types: {', '.join(patch_mimetypes)}", + ) - structure_family, structure = ( - entry.structure_family, - entry.structure(), - ) + # Manually validate limits that bypass pydantic validation via patch + if len(specs) > schemas.MAX_ALLOWED_SPECS: + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Update cannot result in more than {schemas.MAX_ALLOWED_SPECS} specs", + ) + if len(specs) != len(set(specs)): + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + detail="Update cannot result in non-unique specs", + ) - metadata_modified, metadata = await validate_metadata( - metadata=metadata, - structure_family=structure_family, - structure=structure, - specs=[Spec(x) for x in specs], - validation_registry=validation_registry, - settings=settings, - ) + structure_family, structure = ( + entry.structure_family, + entry.structure(), + ) - await entry.replace_metadata(metadata=metadata, specs=specs) - - response_data = {"id": entry.key} - if metadata_modified: - response_data["metadata"] = metadata - return json_or_msgpack(request, response_data) - - -@router.put("/metadata/{path:path}", response_model=schemas.PutMetadataResponse) -async def put_metadata( - request: Request, - body: schemas.PutMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata"]), -): - if not hasattr(entry, "replace_metadata"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support update of metadata.", + metadata_modified, metadata = await validate_metadata( + metadata=metadata, + structure_family=structure_family, + structure=structure, + specs=[Spec(x) for x in specs], + validation_registry=validation_registry, + settings=settings, ) - metadata, structure_family, structure, specs = ( - body.metadata if body.metadata is not None else entry.metadata(), - entry.structure_family, - entry.structure(), - body.specs if body.specs is not None else entry.specs, - ) + await entry.replace_metadata(metadata=metadata, specs=specs) - metadata_modified, metadata = await validate_metadata( - metadata=metadata, - structure_family=structure_family, - structure=structure, - specs=specs, - validation_registry=validation_registry, - settings=settings, - ) + response_data = {"id": entry.key} + if metadata_modified: + response_data["metadata"] = metadata + return json_or_msgpack(request, response_data) - await entry.replace_metadata(metadata=metadata, specs=specs) - - response_data = {"id": entry.key} - if metadata_modified: - response_data["metadata"] = metadata - return json_or_msgpack(request, response_data) - - -@router.get("/revisions/{path:path}") -async def get_revisions( - request: Request, - path: str, - offset: Optional[int] = Query(0, alias="page[offset]", ge=0), - limit: Optional[int] = Query( - DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE - ), - entry=SecureEntry(scopes=["read:metadata"]), -): - if not hasattr(entry, "revisions"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support revisions.", - ) + @router.put("/metadata/{path:path}", response_model=schemas.PutMetadataResponse) + async def put_metadata( + request: Request, + body: schemas.PutMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata"]), + ): + if not hasattr(entry, "replace_metadata"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support update of metadata.", + ) - base_url = get_base_url(request) - resource = await construct_revisions_response( - entry, - base_url, - "/revisions", - path, - offset, - limit, - resolve_media_type(request), - ) - return json_or_msgpack(request, resource.model_dump()) - - -@router.delete("/revisions/{path:path}") -async def delete_revision( - request: Request, - number: int, - entry=SecureEntry(scopes=["write:metadata"]), -): - if not hasattr(entry, "revisions"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support a del request for revisions.", + metadata, structure_family, structure, specs = ( + body.metadata if body.metadata is not None else entry.metadata(), + entry.structure_family, + entry.structure(), + body.specs if body.specs is not None else entry.specs, ) - await entry.delete_revision(number) - return json_or_msgpack(request, None) - - -# For simplicity of implementation, we support a restricted subset of the full -# Range spec. This could be extended if the need arises. -# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range -RANGE_HEADER_PATTERN = re.compile(r"^bytes=(\d+)-(\d+)$") - - -@router.get("/asset/bytes/{path:path}") -async def get_asset( - request: Request, - id: int, - relative_path: Optional[Path] = None, - entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: Settings = Depends(get_settings), -): - if not settings.expose_raw_assets: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail=( - "This Tiled server is configured not to allow " - "downloading raw assets." - ), - ) - if not hasattr(entry, "asset_by_id"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support downloading assets.", + metadata_modified, metadata = await validate_metadata( + metadata=metadata, + structure_family=structure_family, + structure=structure, + specs=specs, + validation_registry=validation_registry, + settings=settings, ) - asset = await entry.asset_by_id(id) - if asset is None: - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, - detail=f"This node exists but it does not have an Asset with id {id}", - ) - if asset.is_directory: - if relative_path is None: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - "This asset is a directory. Must specify relative path, " - f"from manifest provided by /asset/manifest/...?id={id}" - ), - ) - if relative_path.is_absolute(): + await entry.replace_metadata(metadata=metadata, specs=specs) + + response_data = {"id": entry.key} + if metadata_modified: + response_data["metadata"] = metadata + return json_or_msgpack(request, response_data) + + @router.get("/revisions/{path:path}") + async def get_revisions( + request: Request, + path: str, + offset: Optional[int] = Query(0, alias="page[offset]", ge=0), + limit: Optional[int] = Query( + DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE + ), + entry=SecureEntry(scopes=["read:metadata"]), + ): + if not hasattr(entry, "revisions"): raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="relative_path query parameter must be a *relative* path", + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support revisions.", ) - else: - if relative_path is not None: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="This asset is not a directory. The relative_path query parameter must not be set.", - ) - if not asset.data_uri.startswith("file:"): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Only download assets stored as file:// is currently supported.", + + base_url = get_base_url(request) + resource = await construct_revisions_response( + entry, + base_url, + "/revisions", + path, + offset, + limit, + resolve_media_type(request), ) - path = path_from_uri(asset.data_uri) - if relative_path is not None: - # Be doubly sure that this is under the Asset's data_uri directory - # and not sneakily escaping it. - if not os.path.commonpath([path, path / relative_path]) != path: - # This should not be possible. - raise RuntimeError( - f"Refusing to serve {path / relative_path} because it is outside " - "of the Asset's directory" + return json_or_msgpack(request, resource.model_dump()) + + @router.delete("/revisions/{path:path}") + async def delete_revision( + request: Request, + number: int, + entry=SecureEntry(scopes=["write:metadata"]), + ): + if not hasattr(entry, "revisions"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support a del request for revisions.", ) - full_path = path / relative_path - else: - full_path = path - stat_result = await anyio.to_thread.run_sync(os.stat, full_path) - filename = full_path.name - if "range" in request.headers: - range_header = request.headers["range"] - match = RANGE_HEADER_PATTERN.match(range_header) - if match is None: + + await entry.delete_revision(number) + return json_or_msgpack(request, None) + + # For simplicity of implementation, we support a restricted subset of the full + # Range spec. This could be extended if the need arises. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range + RANGE_HEADER_PATTERN = re.compile(r"^bytes=(\d+)-(\d+)$") + + @router.get("/asset/bytes/{path:path}") + async def get_asset( + request: Request, + id: int, + relative_path: Optional[Path] = None, + entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? + settings: Settings = Depends(get_settings), + ): + if not settings.expose_raw_assets: raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, + status_code=HTTP_403_FORBIDDEN, detail=( - "Only a Range headers of the form 'bytes=start-end' are supported. " - f"Could not parse Range header: {range_header}", + "This Tiled server is configured not to allow " + "downloading raw assets." ), ) - range = start, _ = (int(match.group(1)), int(match.group(2))) - if start > stat_result.st_size: + if not hasattr(entry, "asset_by_id"): raise HTTPException( - status_code=HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, - headers={"content-range": f"bytes */{stat_result.st_size}"}, + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support downloading assets.", ) - status_code = HTTP_206_PARTIAL_CONTENT - else: - range = None - status_code = HTTP_200_OK - return FileResponseWithRange( - full_path, - stat_result=stat_result, - status_code=status_code, - headers={"Content-Disposition": f'attachment; filename="{filename}"'}, - range=range, - ) - -@router.get("/asset/manifest/{path:path}") -async def get_asset_manifest( - request: Request, - id: int, - entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: Settings = Depends(get_settings), -): - if not settings.expose_raw_assets: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail=( - "This Tiled server is configured not to allow " - "downloading raw assets." - ), - ) - if not hasattr(entry, "asset_by_id"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support downloading assets.", - ) - - asset = await entry.asset_by_id(id) - if asset is None: - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, - detail=f"This node exists but it does not have an Asset with id {id}", - ) - if not asset.is_directory: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="This asset is not a directory. There is no manifest.", - ) - if not asset.data_uri.startswith("file:"): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Only download assets stored as file:// is currently supported.", - ) - path = path_from_uri(asset.data_uri) - manifest = [] - # Walk the directory and any subdirectories. Aggregate a list of all the - # files, given as paths relative to the directory root. - for root, _directories, files in os.walk(path): - manifest.extend(Path(root, file) for file in files) - return json_or_msgpack(request, {"manifest": manifest}) - - -async def validate_metadata( - metadata: dict, - structure_family: StructureFamily, - structure, - specs: List[Spec], - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), -): - metadata_modified = False - - # Specs should be ordered from most specific/constrained to least. - # Validate them in reverse order, with the least constrained spec first, - # because it may do normalization that helps pass the more constrained one. - # Known Issue: - # When there is more than one spec, it's possible for the validator for - # Spec 2 to make a modification that breaks the validation for Spec 1. - # For now we leave it to the server maintainer to ensure that validators - # won't step on each other in this way, but this may need revisiting. - for spec in reversed(specs): - if spec.name not in validation_registry: - if settings.reject_undeclared_specs: + asset = await entry.asset_by_id(id) + if asset is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"This node exists but it does not have an Asset with id {id}", + ) + if asset.is_directory: + if relative_path is None: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + "This asset is a directory. Must specify relative path, " + f"from manifest provided by /asset/manifest/...?id={id}" + ), + ) + if relative_path.is_absolute(): raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"Unrecognized spec: {spec.name}", + detail="relative_path query parameter must be a *relative* path", ) else: - validator = validation_registry(spec.name) - try: - result = validator(metadata, structure_family, structure, spec) - except ValidationError as e: + if relative_path is not None: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"failed validation for spec {spec.name}:\n{e}", + detail="This asset is not a directory. The relative_path query parameter must not be set.", + ) + if not asset.data_uri.startswith("file:"): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Only download assets stored as file:// is currently supported.", + ) + path = path_from_uri(asset.data_uri) + if relative_path is not None: + # Be doubly sure that this is under the Asset's data_uri directory + # and not sneakily escaping it. + if not os.path.commonpath([path, path / relative_path]) != path: + # This should not be possible. + raise RuntimeError( + f"Refusing to serve {path / relative_path} because it is outside " + "of the Asset's directory" ) - if result is not None: - metadata_modified = True - metadata = result + full_path = path / relative_path + else: + full_path = path + stat_result = await anyio.to_thread.run_sync(os.stat, full_path) + filename = full_path.name + if "range" in request.headers: + range_header = request.headers["range"] + match = RANGE_HEADER_PATTERN.match(range_header) + if match is None: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + "Only a Range headers of the form 'bytes=start-end' are supported. " + f"Could not parse Range header: {range_header}", + ), + ) + range = start, _ = (int(match.group(1)), int(match.group(2))) + if start > stat_result.st_size: + raise HTTPException( + status_code=HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + headers={"content-range": f"bytes */{stat_result.st_size}"}, + ) + status_code = HTTP_206_PARTIAL_CONTENT + else: + range = None + status_code = HTTP_200_OK + return FileResponseWithRange( + full_path, + stat_result=stat_result, + status_code=status_code, + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + range=range, + ) + + @router.get("/asset/manifest/{path:path}") + async def get_asset_manifest( + request: Request, + id: int, + entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? + settings: Settings = Depends(get_settings), + ): + if not settings.expose_raw_assets: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail=( + "This Tiled server is configured not to allow " + "downloading raw assets." + ), + ) + if not hasattr(entry, "asset_by_id"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support downloading assets.", + ) - return metadata_modified, metadata + asset = await entry.asset_by_id(id) + if asset is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"This node exists but it does not have an Asset with id {id}", + ) + if not asset.is_directory: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="This asset is not a directory. There is no manifest.", + ) + if not asset.data_uri.startswith("file:"): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Only download assets stored as file:// is currently supported.", + ) + path = path_from_uri(asset.data_uri) + manifest = [] + # Walk the directory and any subdirectories. Aggregate a list of all the + # files, given as paths relative to the directory root. + for root, _directories, files in os.walk(path): + manifest.extend(Path(root, file) for file in files) + return json_or_msgpack(request, {"manifest": manifest}) + + async def validate_metadata( + metadata: dict, + structure_family: StructureFamily, + structure, + specs: List[Spec], + settings: Settings = Depends(get_settings), + ): + metadata_modified = False + + # Specs should be ordered from most specific/constrained to least. + # Validate them in reverse order, with the least constrained spec first, + # because it may do normalization that helps pass the more constrained one. + # Known Issue: + # When there is more than one spec, it's possible for the validator for + # Spec 2 to make a modification that breaks the validation for Spec 1. + # For now we leave it to the server maintainer to ensure that validators + # won't step on each other in this way, but this may need revisiting. + for spec in reversed(specs): + if spec.name not in validation_registry: + if settings.reject_undeclared_specs: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Unrecognized spec: {spec.name}", + ) + else: + validator = validation_registry(spec.name) + try: + result = validator(metadata, structure_family, structure, spec) + except ValidationError as e: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"failed validation for spec {spec.name}:\n{e}", + ) + if result is not None: + metadata_modified = True + metadata = result + + return metadata_modified, metadata + + return router diff --git a/tiled/server/utils.py b/tiled/server/utils.py index d28611471..94a4b74ae 100644 --- a/tiled/server/utils.py +++ b/tiled/server/utils.py @@ -1,5 +1,14 @@ import contextlib import time +from datetime import datetime, timezone +from typing import Optional + +from fastapi import Depends, HTTPException, Request, Security +from fastapi.openapi.models import APIKey, APIKeyIn +from fastapi.security import SecurityScopes +from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery +from fastapi.security.utils import get_authorization_scheme_param +from starlette.status import HTTP_400_BAD_REQUEST from ..access_policies import NO_ACCESS from ..adapters.mapping import MapAdapter @@ -10,11 +19,78 @@ CSRF_COOKIE_NAME = "tiled_csrf" -def get_authenticators(): - raise NotImplementedError( - "This should be overridden via dependency_overrides. " - "See tiled.server.app.build_app()." - ) +def utcnow(): + "UTC now with second resolution" + return datetime.now(timezone.utc).replace(microsecond=0) + + +def headers_for_401(request: Request, security_scopes: SecurityScopes): + # call directly from methods, rather than as a dependency, to avoid calling + # when not needed. + if security_scopes.scopes: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + else: + authenticate_value = "Bearer" + return { + "WWW-Authenticate": authenticate_value, + "X-Tiled-Root": get_base_url(request), + } + + +class APIKeyAuthorizationHeader(APIKeyBase): + """ + Expect a header like + + Authorization: Apikey SECRET + + where Apikey is case-insensitive. + """ + + def __init__( + self, + *, + name: str, + scheme_name: Optional[str] = None, + description: Optional[str] = None, + ): + self.model: APIKey = APIKey( + **{"in": APIKeyIn.header}, name=name, description=description + ) + self.scheme_name = scheme_name or self.__class__.__name__ + + async def __call__(self, request: Request) -> Optional[str]: + authorization: str = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() == "bearer": + return None + if scheme.lower() != "apikey": + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + "Authorization header must include the authorization type " + "followed by a space and then the secret, as in " + "'Bearer SECRET' or 'Apikey SECRET'. " + ), + ) + return param + + +async def get_api_key( + api_key_query: str = Security(APIKeyQuery(name="api_key", auto_error=False)), + api_key_header: str = Security( + APIKeyAuthorizationHeader( + name="Authorization", + description="Prefix value with 'Apikey ' as in, 'Apikey SECRET'", + ) + ), + api_key_cookie: str = Security( + APIKeyCookie(name=API_KEY_COOKIE_NAME, auto_error=False) + ), +): + for api_key in [api_key_query, api_key_header, api_key_cookie]: + if api_key is not None: + return api_key + return None @contextlib.contextmanager @@ -82,3 +158,22 @@ async def filter_for_access(entry, principal, scopes, metrics, path_parts): for query in queries: entry = entry.search(query) return entry + + +async def move_api_key( + request: Request, + api_key: str | None = Depends(get_api_key), +): + """ + Moves API key if given as a query parameter into a cookie + """ + + if ( + api_key is not None + and "api_key" in request.query_params + and request.cookies.get(API_KEY_COOKIE_NAME) != api_key + ): + request.state.cookies_to_set.append( + {"key": API_KEY_COOKIE_NAME, "value": api_key} + ) + return api_key