Skip to content

Commit

Permalink
Merge pull request #1184 from lsst-sqre/tickets/DM-48088
Browse files Browse the repository at this point in the history
DM-48088: Convert internal scopes representation to set
  • Loading branch information
rra authored Dec 11, 2024
2 parents 57dd7f4 + c8f6c5f commit 1676ccc
Show file tree
Hide file tree
Showing 28 changed files with 315 additions and 299 deletions.
29 changes: 16 additions & 13 deletions src/gafaelfawr/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,18 @@ class InternalTokenCache(TokenCache):
"""Cache for internal tokens."""

def get(
self, token_data: TokenData, service: str, scopes: list[str]
self, token_data: TokenData, service: str, scopes: set[str]
) -> Token | None:
"""Retrieve an internal token from the cache.
Parameters
----------
token_data
The authentication data for the parent token.
Authentication data for the parent token.
service
The service of the internal token.
Service of the internal token.
scopes
The scopes the internal token should have.
Scopes the internal token should have.
Returns
-------
Expand All @@ -367,7 +367,7 @@ def store(
self,
token_data: TokenData,
service: str,
scopes: list[str],
scopes: set[str],
token: Token,
) -> None:
"""Store an internal token in the cache.
Expand All @@ -377,19 +377,19 @@ def store(
Parameters
----------
token_data
The authentication data for the parent token.
Authentication data for the parent token.
service
The service of the internal token.
Service of the internal token.
scopes
The scopes the internal token should have.
Scopes the internal token should have.
token
The token to cache.
Token to cache.
"""
key = self._build_key(token_data, service, scopes)
self._cache[key] = token

def _build_key(
self, token_data: TokenData, service: str, scopes: list[str]
self, token_data: TokenData, service: str, scopes: set[str]
) -> tuple[str, ...]:
"""Build the cache key for an internal token.
Expand All @@ -407,9 +407,12 @@ def _build_key(
tuple
An object suitable for use as a hash key for this internal token.
"""
expires = str(token_data.expires) if token_data.expires else "None"
scope = ",".join(sorted(scopes))
return (token_data.token.key, expires, service, scope)
return (
token_data.token.key,
str(token_data.expires) if token_data.expires else "None",
service,
" ".join(sorted(scopes)),
)


class NotebookTokenCache(TokenCache):
Expand Down
8 changes: 4 additions & 4 deletions src/gafaelfawr/handlers/ingress.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,18 +597,18 @@ async def build_delegated_token(
return str(token)
elif auth_config.delegate_to:
# Delegated scopes are optional; if the authenticating token doesn't
# have the scope, it's omitted from the delegated token. (To make it
# have the scope, it's omitted from the delegated token. (To make it
# mandatory, require that scope via the scope parameter as well, and
# then the authenticating token will always have it.) Therefore,
# then the authenticating token will always have it.) Therefore,
# reduce the scopes of the internal token to the intersection between
# the requested delegated scopes and the scopes of the authenticating
# token.
delegate_scopes = auth_config.delegate_scopes & set(token_data.scopes)
delegate_scopes = auth_config.delegate_scopes & token_data.scopes
token_service = context.factory.create_token_service()
token = await token_service.get_internal_token(
token_data,
service=auth_config.delegate_to,
scopes=sorted(delegate_scopes),
scopes=delegate_scopes,
ip_address=context.ip_address,
minimum_lifetime=auth_config.minimum_lifetime,
)
Expand Down
6 changes: 4 additions & 2 deletions src/gafaelfawr/models/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from pydantic import BaseModel, Field

from ..pydantic import Scopes

__all__ = [
"APIConfig",
"APILoginResponse",
Expand Down Expand Up @@ -165,11 +167,11 @@ class APILoginResponse(BaseModel):
examples=["someuser"],
)

scopes: list[str] = Field(
scopes: Scopes = Field(
...,
title="Access scopes",
description="Access scopes for this authenticated user",
examples=["read:all", "user:token"],
examples=[["read:all", "user:token"]],
)

config: APIConfig = Field(
Expand Down
8 changes: 4 additions & 4 deletions src/gafaelfawr/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def bootstrap_token(cls) -> Self:
token=Token(),
username="<bootstrap>",
token_type=TokenType.service,
scopes=["admin:token"],
scopes={"admin:token"},
)

@classmethod
Expand All @@ -332,7 +332,7 @@ def internal_token(cls) -> Self:
token=Token(),
username="<internal>",
token_type=TokenType.service,
scopes=["admin:token"],
scopes={"admin:token"},
)


Expand Down Expand Up @@ -380,7 +380,7 @@ class AdminTokenRequest(BaseModel):
)

scopes: Scopes = Field(
[],
set(),
title="Token scopes",
examples=[["read:all"]],
)
Expand Down Expand Up @@ -486,7 +486,7 @@ class UserTokenRequest(BaseModel):
)

scopes: Scopes = Field(
[],
set(),
title="Token scope",
examples=[["read:all"]],
)
Expand Down
16 changes: 11 additions & 5 deletions src/gafaelfawr/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _normalize_ip_address(v: str | IPv4Address | IPv6Address) -> str:
"""


def _normalize_scopes(v: str | Iterable[str]) -> list[str]:
def _normalize_scopes(v: str | Iterable[str]) -> set[str]:
"""Pydantic validator for scope fields.
Scopes are stored in the database as a comma-delimited, sorted list.
Expand All @@ -65,15 +65,21 @@ def _normalize_scopes(v: str | Iterable[str]) -> list[str]:
Scopes as a set.
"""
if isinstance(v, str):
return [] if not v else sorted(v.split(","))
return set() if not v else set(v.split(","))
else:
return sorted(v)
return set(v)


Scopes: TypeAlias = Annotated[list[str], PlainValidator(_normalize_scopes)]
Scopes: TypeAlias = Annotated[
set[str],
PlainValidator(_normalize_scopes),
PlainSerializer(
lambda s: sorted(s), return_type=list[str], when_used="json"
),
]
"""Type for a list of scopes.
The scopes will be forced to sorted order by validation.
The scopes will be forced to sorted order on serialization.
"""


Expand Down
4 changes: 2 additions & 2 deletions src/gafaelfawr/services/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async def _create_token(self, parent: GafaelfawrServiceToken) -> Token:
request = AdminTokenRequest(
username=parent.spec.service,
token_type=TokenType.service,
scopes=parent.spec.scopes,
scopes=set(parent.spec.scopes),
)
return await self._token_service.create_token_from_admin_request(
request, TokenData.internal_token(), ip_address=None
Expand All @@ -342,7 +342,7 @@ async def _is_token_valid(
return False
if token_data.username != parent.spec.service:
return False
return sorted(token_data.scopes) == sorted(parent.spec.scopes)
return token_data.scopes == set(parent.spec.scopes)

async def _secret_needs_update(
self, parent: GafaelfawrServiceToken, secret: V1Secret | None
Expand Down
62 changes: 28 additions & 34 deletions src/gafaelfawr/services/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def audit(self, *, fix: bool = False) -> list[str]:
return alerts

async def create_session_token(
self, user_info: TokenUserInfo, *, scopes: list[str], ip_address: str
self, user_info: TokenUserInfo, *, scopes: set[str], ip_address: str
) -> Token:
"""Create a new session token.
Expand Down Expand Up @@ -217,9 +217,7 @@ async def create_session_token(
async with self._session.begin():
admins = await self._admin_store.list()
if any(user_info.username == a.username for a in admins):
scopes = sorted({*scopes, "admin:token"})
else:
scopes = sorted(scopes)
scopes.add("admin:token")

data = TokenData(
token=token,
Expand Down Expand Up @@ -256,7 +254,7 @@ async def create_session_token(
token_key=token.key,
token_username=data.username,
token_expires=format_datetime_for_logging(expires),
token_scopes=scopes,
token_scopes=sorted(scopes),
token_userinfo=data.to_userinfo_dict(),
)

Expand Down Expand Up @@ -294,7 +292,7 @@ async def create_oidc_token(
token=token,
username=auth_data.username,
token_type=TokenType.oidc,
scopes=[],
scopes=set(),
created=created,
expires=expires,
name=auth_data.name,
Expand Down Expand Up @@ -342,7 +340,7 @@ async def create_user_token(
username: str,
*,
token_name: str,
scopes: list[str],
scopes: set[str],
expires: datetime | None = None,
ip_address: str,
) -> Token:
Expand Down Expand Up @@ -393,7 +391,6 @@ async def create_user_token(
self._validate_scopes(scopes, auth_data)
if expires:
expires = expires.replace(microsecond=0)
scopes = sorted(scopes)

token = Token()
created = current_datetime()
Expand Down Expand Up @@ -498,7 +495,7 @@ async def create_token_from_admin_request(
token=token,
username=request.username,
token_type=request.token_type,
scopes=sorted(request.scopes),
scopes=request.scopes,
created=created,
expires=expires,
name=request.name,
Expand Down Expand Up @@ -537,7 +534,7 @@ async def create_token_from_admin_request(
token_username=request.username,
token_expires=format_datetime_for_logging(expires),
token_name=request.token_name,
token_scopes=data.scopes,
token_scopes=sorted(data.scopes),
token_userinfo=data.to_userinfo_dict(),
)
else:
Expand All @@ -546,7 +543,7 @@ async def create_token_from_admin_request(
token_key=token.key,
token_username=request.username,
token_expires=format_datetime_for_logging(expires),
token_scopes=data.scopes,
token_scopes=sorted(data.scopes),
token_userinfo=data.to_userinfo_dict(),
)
return token
Expand Down Expand Up @@ -751,7 +748,7 @@ async def get_internal_token(
self,
token_data: TokenData,
service: str,
scopes: list[str],
scopes: set[str],
*,
ip_address: str,
minimum_lifetime: timedelta | None = None,
Expand Down Expand Up @@ -783,7 +780,6 @@ async def get_internal_token(
"""
self._validate_scopes(scopes, token_data)
self._validate_username(token_data.username)
scopes = sorted(scopes)
return await self._token_cache.get_internal_token(
token_data,
service,
Expand Down Expand Up @@ -919,7 +915,7 @@ async def modify_token(
*,
ip_address: str,
token_name: str | None = None,
scopes: list[str] | None = None,
scopes: set[str] | None = None,
expires: datetime | None = None,
no_expire: bool = False,
) -> TokenInfo | None:
Expand Down Expand Up @@ -992,7 +988,7 @@ async def modify_token(
username=info.username,
token_type=TokenType.user,
token_name=token_name if token_name else info.token_name,
scopes=sorted(scopes) if scopes is not None else info.scopes,
scopes=scopes if scopes is not None else info.scopes,
expires=info.expires if not (expires or no_expire) else expires,
actor=auth_data.username,
action=TokenChange.edit,
Expand All @@ -1006,7 +1002,7 @@ async def modify_token(
info = await self._token_db_store.modify(
key,
token_name=token_name,
scopes=sorted(scopes) if scopes else scopes,
scopes=scopes,
expires=expires,
no_expire=no_expire,
)
Expand Down Expand Up @@ -1111,7 +1107,7 @@ async def _audit_token(
mismatches.append("username")
if db.token_type != redis.token_type:
mismatches.append("type")
if db.scopes != sorted(redis.scopes):
if db.scopes != redis.scopes:
# There was a bug where Redis wasn't updated when the scopes were
# changed but the database was. Redis is canonical, so set the
# database scopes to match.
Expand Down Expand Up @@ -1168,19 +1164,18 @@ def _audit_unknown_scopes(self, tokens: Iterable[TokenData]) -> list[str]:
alerts = []
for token_data in tokens:
known_scopes = set(self._config.known_scopes.keys())
for scope in token_data.scopes:
if scope not in known_scopes:
self._logger.warning(
"Token has unknown scope",
token=token_data.token.key,
user=token_data.username,
scope=scope,
)
alerts.append(
f"Token `{token_data.token.key}` for"
f" `{token_data.username}` has unknown scope"
f" (`{scope}`)"
)
for scope in token_data.scopes - known_scopes:
self._logger.warning(
"Token has unknown scope",
token=token_data.token.key,
user=token_data.username,
scope=scope,
)
alerts.append(
f"Token `{token_data.token.key}` for"
f" `{token_data.username}` has unknown scope"
f" (`{scope}`)"
)
return alerts

def _check_authorization(
Expand Down Expand Up @@ -1419,7 +1414,7 @@ def _validate_expires(self, expires: datetime | None) -> None:

def _validate_scopes(
self,
scopes: list[str],
scopes: set[str],
auth_data: TokenData | None = None,
) -> None:
"""Check that the requested scopes are valid.
Expand All @@ -1439,12 +1434,11 @@ def _validate_scopes(
"""
if not scopes:
return
scopes_set = set(scopes)
if auth_data and "admin:token" not in auth_data.scopes:
if not (scopes_set <= set(auth_data.scopes)):
if not (scopes <= auth_data.scopes):
msg = "Requested scopes are broader than your current scopes"
raise InvalidScopesError(msg)
if not (scopes_set <= self._config.known_scopes.keys()):
if not (scopes <= set(self._config.known_scopes.keys())):
msg = "Unknown scopes requested"
raise InvalidScopesError(msg)

Expand Down
Loading

0 comments on commit 1676ccc

Please sign in to comment.