diff --git a/src/gafaelfawr/config.py b/src/gafaelfawr/config.py index 4136aaea..634df1dd 100644 --- a/src/gafaelfawr/config.py +++ b/src/gafaelfawr/config.py @@ -54,6 +54,7 @@ from .exceptions import InvalidTokenError from .keypair import RSAKeyPair from .models.token import Token +from .models.userinfo import Quota from .util import group_name_for_github_team HttpsUrl = Annotated[ @@ -79,12 +80,10 @@ "GitHubGroupTeam", "HttpsUrl", "LDAPConfig", - "NotebookQuota", "OIDCClient", "OIDCConfig", "OIDCServerConfig", "QuotaConfig", - "QuotaGrant", ] @@ -657,56 +656,16 @@ def keypair(self) -> RSAKeyPair: return self._keypair -class NotebookQuota(BaseModel): - """Quota settings for the Notebook Aspect.""" - - model_config = ConfigDict(extra="forbid") - - cpu: float = Field( - ..., title="CPU limit", description="Maximum number of CPU equivalents" - ) - - memory: float = Field( - ..., - title="Memory limit (GiB)", - description="Maximum memory usage in GiB", - ) - - -class QuotaGrant(BaseModel): - """One grant of quotas. - - There may be one of these per group, as well as a default one, in the - overall quota configuration. - """ - - model_config = ConfigDict(extra="forbid") - - api: dict[str, int] = Field( - {}, - title="Service quotas", - description=( - "Mapping of service names to quota of requests per 15 minutes" - ), - ) - - notebook: NotebookQuota | None = Field( - None, - title="Notebook quota", - description="Quota settings for the Notebook Aspect", - ) - - class QuotaConfig(BaseModel): """Quota configuration.""" model_config = ConfigDict(extra="forbid") - default: QuotaGrant = Field( + default: Quota = Field( ..., title="Default quota", description="Default quotas for all users" ) - groups: dict[str, QuotaGrant] = Field( + groups: dict[str, Quota] = Field( {}, title="Quota grants by group", description="Additional quota grants by group name", @@ -718,6 +677,46 @@ class QuotaConfig(BaseModel): description="Groups whose members bypass all quota restrictions", ) + def calculate_quota(self, groups: set[str]) -> Quota | None: + """Calculate user's quota given their group membership. + + Parameters + ---------- + groups + Group membership of the user. + + Returns + ------- + Quota or None + Quota information for that user or `None` if no quotas apply. + """ + if groups & self.bypass: + return None + + # Start with the defaults. + api = dict(self.default.api) + notebook = None + if self.default.notebook: + notebook = self.default.notebook.model_copy() + + # Look for group-specific rules. + for group in groups & set(self.groups.keys()): + extra = self.groups[group] + if extra.notebook: + if notebook: + notebook.cpu += extra.notebook.cpu + notebook.memory += extra.notebook.memory + else: + notebook = extra.notebook.model_copy() + for service, quota in extra.api.items(): + if service in api: + api[service] += quota + else: + api[service] = quota + + # Return the results. + return Quota(api=api, notebook=notebook) + class GitHubGroupTeam(BaseModel): """Specification for a GitHub team.""" diff --git a/src/gafaelfawr/models/userinfo.py b/src/gafaelfawr/models/userinfo.py index 61850bce..b0f08a3f 100644 --- a/src/gafaelfawr/models/userinfo.py +++ b/src/gafaelfawr/models/userinfo.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from ..constants import GROUPNAME_REGEX from ..pydantic import Timestamp @@ -77,6 +77,8 @@ class Group(BaseModel): class NotebookQuota(BaseModel): """Notebook Aspect quota information for a user.""" + model_config = ConfigDict(extra="forbid") + cpu: float = Field(..., title="CPU equivalents", examples=[4.0]) memory: float = Field( @@ -87,6 +89,8 @@ class NotebookQuota(BaseModel): class Quota(BaseModel): """Quota information for a user.""" + model_config = ConfigDict(extra="forbid") + api: dict[str, int] = Field( {}, title="API quotas", diff --git a/src/gafaelfawr/services/userinfo.py b/src/gafaelfawr/services/userinfo.py index 7d1e23e8..b1040062 100644 --- a/src/gafaelfawr/services/userinfo.py +++ b/src/gafaelfawr/services/userinfo.py @@ -8,7 +8,7 @@ from ..exceptions import FirestoreError from ..models.ldap import LDAPUserData from ..models.token import TokenData, TokenUserInfo -from ..models.userinfo import Group, NotebookQuota, Quota, UserInfo +from ..models.userinfo import Group, UserInfo from .firestore import FirestoreService from .ldap import LDAPService @@ -112,6 +112,12 @@ async def get_user_info_from_token( if not gid and not ldap_data.gid and self._config.add_user_group: gid = uid or ldap_data.uid + # Calculate the quota. + quota = None + if self._config.quota: + group_names = {g.name for g in groups} + quota = self._config.quota.calculate_quota(group_names) + # Return the results. return UserInfo( username=username, @@ -120,7 +126,7 @@ async def get_user_info_from_token( gid=gid or ldap_data.gid, email=token_data.email or ldap_data.email, groups=sorted(groups, key=lambda g: g.name), - quota=self._calculate_quota(groups), + quota=quota, ) async def get_scopes(self, user_info: TokenUserInfo) -> set[str] | None: @@ -210,57 +216,6 @@ async def invalidate_cache(self, username: str) -> None: if self._ldap: await self._ldap.invalidate_cache(username) - def _calculate_quota(self, groups: list[Group]) -> Quota | None: - """Calculate the quota for a user. - - Parameters - ---------- - groups - The user's group membership. - - Returns - ------- - gafaelfawr.models.token.Quota - Quota information for that user. - """ - if not self._config.quota: - return None - group_names = {g.name for g in groups} - if group_names & self._config.quota.bypass: - return Quota() - - # Start with the defaults. - api = dict(self._config.quota.default.api) - notebook = None - if self._config.quota.default.notebook: - notebook = NotebookQuota( - cpu=self._config.quota.default.notebook.cpu, - memory=self._config.quota.default.notebook.memory, - ) - - # Look for group-specific rules. - for group in group_names: - if group not in self._config.quota.groups: - continue - extra = self._config.quota.groups[group] - if extra.notebook: - if notebook: - notebook.cpu += extra.notebook.cpu - notebook.memory += extra.notebook.memory - else: - notebook = NotebookQuota( - cpu=extra.notebook.cpu, - memory=extra.notebook.memory, - ) - for service in extra.api: - if service in api: - api[service] += extra.api[service] - else: - api[service] = extra.api[service] - - # Return the results. - return Quota(api=api, notebook=notebook) - async def _get_groups_from_ldap( self, username: str,