Skip to content

Commit

Permalink
Merge pull request #1217 from lsst-sqre:tickets/DM-48432
Browse files Browse the repository at this point in the history
DM-48432: Move quota calculation to QuotaConfig model
  • Loading branch information
rra authored Jan 14, 2025
2 parents 27d25e0 + ec5d3bd commit 30d7ed3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 98 deletions.
87 changes: 43 additions & 44 deletions src/gafaelfawr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .exceptions import InvalidTokenError
from .keypair import RSAKeyPair
from .models.token import Token
from .models.userinfo import Quota
from .util import group_name_for_github_team

HttpsUrl = Annotated[
Expand All @@ -79,12 +80,10 @@
"GitHubGroupTeam",
"HttpsUrl",
"LDAPConfig",
"NotebookQuota",
"OIDCClient",
"OIDCConfig",
"OIDCServerConfig",
"QuotaConfig",
"QuotaGrant",
]


Expand Down Expand Up @@ -657,56 +656,16 @@ def keypair(self) -> RSAKeyPair:
return self._keypair


class NotebookQuota(BaseModel):
"""Quota settings for the Notebook Aspect."""

model_config = ConfigDict(extra="forbid")

cpu: float = Field(
..., title="CPU limit", description="Maximum number of CPU equivalents"
)

memory: float = Field(
...,
title="Memory limit (GiB)",
description="Maximum memory usage in GiB",
)


class QuotaGrant(BaseModel):
"""One grant of quotas.
There may be one of these per group, as well as a default one, in the
overall quota configuration.
"""

model_config = ConfigDict(extra="forbid")

api: dict[str, int] = Field(
{},
title="Service quotas",
description=(
"Mapping of service names to quota of requests per 15 minutes"
),
)

notebook: NotebookQuota | None = Field(
None,
title="Notebook quota",
description="Quota settings for the Notebook Aspect",
)


class QuotaConfig(BaseModel):
"""Quota configuration."""

model_config = ConfigDict(extra="forbid")

default: QuotaGrant = Field(
default: Quota = Field(
..., title="Default quota", description="Default quotas for all users"
)

groups: dict[str, QuotaGrant] = Field(
groups: dict[str, Quota] = Field(
{},
title="Quota grants by group",
description="Additional quota grants by group name",
Expand All @@ -718,6 +677,46 @@ class QuotaConfig(BaseModel):
description="Groups whose members bypass all quota restrictions",
)

def calculate_quota(self, groups: set[str]) -> Quota | None:
"""Calculate user's quota given their group membership.
Parameters
----------
groups
Group membership of the user.
Returns
-------
Quota or None
Quota information for that user or `None` if no quotas apply.
"""
if groups & self.bypass:
return None

# Start with the defaults.
api = dict(self.default.api)
notebook = None
if self.default.notebook:
notebook = self.default.notebook.model_copy()

# Look for group-specific rules.
for group in groups & set(self.groups.keys()):
extra = self.groups[group]
if extra.notebook:
if notebook:
notebook.cpu += extra.notebook.cpu
notebook.memory += extra.notebook.memory
else:
notebook = extra.notebook.model_copy()
for service, quota in extra.api.items():
if service in api:
api[service] += quota
else:
api[service] = quota

# Return the results.
return Quota(api=api, notebook=notebook)


class GitHubGroupTeam(BaseModel):
"""Specification for a GitHub team."""
Expand Down
6 changes: 5 additions & 1 deletion src/gafaelfawr/models/userinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from datetime import datetime

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from ..constants import GROUPNAME_REGEX
from ..pydantic import Timestamp
Expand Down Expand Up @@ -77,6 +77,8 @@ class Group(BaseModel):
class NotebookQuota(BaseModel):
"""Notebook Aspect quota information for a user."""

model_config = ConfigDict(extra="forbid")

cpu: float = Field(..., title="CPU equivalents", examples=[4.0])

memory: float = Field(
Expand All @@ -87,6 +89,8 @@ class NotebookQuota(BaseModel):
class Quota(BaseModel):
"""Quota information for a user."""

model_config = ConfigDict(extra="forbid")

api: dict[str, int] = Field(
{},
title="API quotas",
Expand Down
61 changes: 8 additions & 53 deletions src/gafaelfawr/services/userinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..exceptions import FirestoreError
from ..models.ldap import LDAPUserData
from ..models.token import TokenData, TokenUserInfo
from ..models.userinfo import Group, NotebookQuota, Quota, UserInfo
from ..models.userinfo import Group, UserInfo
from .firestore import FirestoreService
from .ldap import LDAPService

Expand Down Expand Up @@ -112,6 +112,12 @@ async def get_user_info_from_token(
if not gid and not ldap_data.gid and self._config.add_user_group:
gid = uid or ldap_data.uid

# Calculate the quota.
quota = None
if self._config.quota:
group_names = {g.name for g in groups}
quota = self._config.quota.calculate_quota(group_names)

# Return the results.
return UserInfo(
username=username,
Expand All @@ -120,7 +126,7 @@ async def get_user_info_from_token(
gid=gid or ldap_data.gid,
email=token_data.email or ldap_data.email,
groups=sorted(groups, key=lambda g: g.name),
quota=self._calculate_quota(groups),
quota=quota,
)

async def get_scopes(self, user_info: TokenUserInfo) -> set[str] | None:
Expand Down Expand Up @@ -210,57 +216,6 @@ async def invalidate_cache(self, username: str) -> None:
if self._ldap:
await self._ldap.invalidate_cache(username)

def _calculate_quota(self, groups: list[Group]) -> Quota | None:
"""Calculate the quota for a user.
Parameters
----------
groups
The user's group membership.
Returns
-------
gafaelfawr.models.token.Quota
Quota information for that user.
"""
if not self._config.quota:
return None
group_names = {g.name for g in groups}
if group_names & self._config.quota.bypass:
return Quota()

# Start with the defaults.
api = dict(self._config.quota.default.api)
notebook = None
if self._config.quota.default.notebook:
notebook = NotebookQuota(
cpu=self._config.quota.default.notebook.cpu,
memory=self._config.quota.default.notebook.memory,
)

# Look for group-specific rules.
for group in group_names:
if group not in self._config.quota.groups:
continue
extra = self._config.quota.groups[group]
if extra.notebook:
if notebook:
notebook.cpu += extra.notebook.cpu
notebook.memory += extra.notebook.memory
else:
notebook = NotebookQuota(
cpu=extra.notebook.cpu,
memory=extra.notebook.memory,
)
for service in extra.api:
if service in api:
api[service] += extra.api[service]
else:
api[service] = extra.api[service]

# Return the results.
return Quota(api=api, notebook=notebook)

async def _get_groups_from_ldap(
self,
username: str,
Expand Down

0 comments on commit 30d7ed3

Please sign in to comment.