Skip to content

Commit

Permalink
auth: add retry logic and unique JTI to JWT claims
Browse files Browse the repository at this point in the history
- retry mechanism aims to handle race conditions
- JTI claim in JWT to ensure token is unique
- create device record before token creation to prevent foreign key issues
  • Loading branch information
joshschmelzle committed Feb 1, 2025
1 parent d37e751 commit 4f9536b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 49 deletions.
2 changes: 1 addition & 1 deletion debian/changelog
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
wlanpi-core (2.0.0~dev20250131.3) unstable; urgency=high
wlanpi-core (2.0.0~dev20250201.1) unstable; urgency=high

* Development build towards 2.0.0
* Breaking auth changes
Expand Down
121 changes: 73 additions & 48 deletions wlanpi_core/core/tokenmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from wlanpi_core.core.config import settings
from wlanpi_core.core.logging import get_logger
from wlanpi_core.core.models import SigningKey, Token
from wlanpi_core.core.repositories import TokenRepository
from wlanpi_core.core.repositories import DeviceRepository, TokenRepository
from wlanpi_core.services import system_service

log = get_logger(__name__)
Expand Down Expand Up @@ -244,58 +244,83 @@ async def create_token(
Returns:
JWT token string
"""
async with self.app_state.db_manager.session() as session:
try:
signing_key = await self._get_or_create_signing_key(session)

now = datetime.now(timezone.utc)
expires = now + (
expires_delta or timedelta(days=settings.ACCESS_TOKEN_EXPIRE_DAYS)
)

claims = {
"sub": system_service.get_hostname(),
"iss": "wlanpi-core",
"did": device_id,
"exp": int(expires.timestamp()),
"iat": int(now.timestamp()),
"kid": str(signing_key.id),
}
max_retries = 3
retry_count = 0

jwt_token = jwt.encode(
header={"alg": "HS256", "kid": str(signing_key.id)},
payload=claims,
key=signing_key.key,
).decode("utf-8")

token_model = Token(
token=jwt_token,
device_id=device_id,
key_id=signing_key.id,
expires_at=expires,
)
session.add(token_model)
await session.commit()
while retry_count < max_retries:
async with self.app_state.db_manager.session() as session:
try:
device_repo = DeviceRepository(session)
await device_repo.get_or_create_device(device_id)

log.debug(token_model)
log.debug(vars(token_model))
signing_key = await self._get_or_create_signing_key(session)

self.token_cache.cache_token(jwt_token, claims)
now = datetime.now(timezone.utc)
expires = now + (
expires_delta
or timedelta(days=settings.ACCESS_TOKEN_EXPIRE_DAYS)
)

return jwt_token
claims = {
"sub": system_service.get_hostname(),
"iss": "wlanpi-core",
"did": device_id,
"exp": int(expires.timestamp()),
"iat": int(now.timestamp()),
"kid": str(signing_key.id),
"jti": secrets.token_hex(8),
}

except Exception as e:
await session.rollback()
log.exception(
"Token creation failed",
extra={
"component": "auth",
"action": "create_token_error",
"device_id": device_id,
"error": str(e),
},
)
raise HTTPException(status_code=500, detail=str(e))
jwt_token = jwt.encode(
header={"alg": "HS256", "kid": str(signing_key.id)},
payload=claims,
key=signing_key.key,
).decode("utf-8")

token_model = Token(
token=jwt_token,
device_id=device_id,
key_id=signing_key.id,
expires_at=expires,
)
session.add(token_model)
await session.commit()

log.debug(token_model)
log.debug(vars(token_model))

self.token_cache.cache_token(jwt_token, claims)

return jwt_token

except sqlalchemy.exc.IntegrityError as e:
await session.rollback()
retry_count += 1
if retry_count >= max_retries:
log.exception(
"Token creation failed after retries",
extra={
"component": "auth",
"action": "create_token_error",
"device_id": device_id,
"error": str(e),
"retries": retry_count,
},
)
raise HTTPException(status_code=500, detail=str(e))
continue
except Exception as e:
await session.rollback()
log.exception(
"Token creation failed",
extra={
"component": "auth",
"action": "create_token_error",
"device_id": device_id,
"error": str(e),
},
)
raise HTTPException(status_code=500, detail=str(e))

async def verify_token(self, token: str) -> TokenValidationResult:
"""Verify JWT token and return validation result"""
Expand Down

0 comments on commit 4f9536b

Please sign in to comment.