Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CAI-291] security improvements #1352

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/chatbot/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ AUTH_COGNITO_ALLOW_ACCOUNT_LINKING=true
AUTH_COGNITO_CLIENT_ID=...
AUTH_COGNITO_CLIENT_SECRET=
AUTH_COGNITO_ISSUER=https://cognito-idp.eu-south-1.amazonaws.com/eu-south-1_xxxxxxxx
AUTH_COGNITO_USERPOOL_ID=...
AUTH_DISABLE_SIGNUP=false
AUTH_DISABLE_USERNAME_PASSWORD=true
AWS_ENDPOINT_URL_DYNAMODB=http://localhost:8000
CHB_AWS_ACCESS_KEY_ID=...
CHB_AWS_BEDROCK_EMBED_REGION=eu-central-1
CHB_AWS_BEDROCK_LLM_REGION=eu-west-3
Expand Down
2 changes: 2 additions & 0 deletions apps/chatbot/docker/app.local.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ ARG DEBIAN_FRONTEND=noninteractive

RUN apt-get update && \
apt-get install -y \
gcc \
curl \
wget \
jq \
zip

RUN wget https://github.com/rphrp1985/selenium_support/raw/main/chrome_114_amd64.deb && \
Expand Down
27 changes: 14 additions & 13 deletions apps/chatbot/docker/compose.test.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
services:
api-test:
build:
Expand All @@ -16,31 +17,31 @@ services:
condition: service_started
dynamodb:
condition: service_started
motoserver:
condition: service_started
networks:
- ntw
- ntwtest

dynamodb:
image: amazon/dynamodb-local:2.5.2
environment:
- AWS_ACCESS_KEY_ID=dummy
- AWS_SECRET_ACCESS_KEY=dummy
- AWS_DEFAULT_REGION=local
healthcheck:
test:
[
"CMD-SHELL",
'[ "$(curl -s -o /dev/null -I -w ''%{http_code}'' http://localhost:8000)" == "400" ]',
]
interval: 10s
timeout: 10s
retries: 10
networks:
- ntw
- ntwtest

redis:
image: redis/redis-stack:7.2.0-v13
networks:
- ntw
- ntwtest

motoserver:
image: motoserver/moto:5.1.0
environment:
- MOTO_PORT=3001
networks:
- ntwtest

networks:
ntw:
ntwtest:
2 changes: 1 addition & 1 deletion apps/chatbot/docker/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ services:
context: ..
dockerfile: docker/app.local.Dockerfile
command: "./scripts/run.local.sh"
env_file: ../.env.local
ports:
- "8080:8080"
volumes:
Expand All @@ -18,7 +19,6 @@ services:
condition: service_started
langfuse:
condition: service_started
env_file: ../.env.local
networks:
- ntw

Expand Down
3,885 changes: 2,200 additions & 1,685 deletions apps/chatbot/poetry.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions apps/chatbot/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ llama-index-embeddings-gemini = "^0.2.0"
llama-index-llms-bedrock-converse = "^0.3.0"
llama-index-postprocessor-presidio = "^0.2.0"
langfuse = "^2.53.9"
nh3 = "^0.2.20"
pyproject-hooks = "^1.2.0"
python-jose = "^3.3.0"
requests-auth-aws-sigv4 = "^0.7"
moto = {extras = ["cognitoidp"], version = "^5.1.0"}

[tool.poetry.group.test.dependencies]
httpx = "^0.27.2"
Expand Down
2 changes: 2 additions & 0 deletions apps/chatbot/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
addopts = -vv -p no:warnings
2 changes: 2 additions & 0 deletions apps/chatbot/scripts/cognito_id.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
aws cognito-idp create-user-pool --pool-name test-pool --endpoint-url http://motoserver:3001 --region eu-south-1|jq -r '.UserPool.Id'
4 changes: 2 additions & 2 deletions apps/chatbot/scripts/run.local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
echo '-=-=-=-=-=-=-=-=-= init DynamoDB -==-=-=-=-=-=-=-=-'
./scripts/dynamodb-init.sh

#echo '-=-=-=-=-=-=-= create redis index =-=-=-=-=-=-=-=-'
#./scripts/create_redis_index.sh
echo '-=-=-=-=-=-=-= create redis index =-=-=-=-=-=-=-=-'
./scripts/create_redis_index.sh

echo '-=-=-=-=-=-=-=-=-= run FastAPI =-==-=-=-=-=-=-=-=-'
fastapi dev src/app/main.py --port 8080 --host 0.0.0.0
2 changes: 1 addition & 1 deletion apps/chatbot/scripts/run.test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ echo '-=-=-=-=-=-=-= create redis index =-=-=-=-=-=-=-=-'
./scripts/create_redis_index.sh

echo '-=-=-=-=-=-=-=-=-=- run pytest -=-==-=-=-=-=-=-=-=-'
pytest -p no:warnings
pytest -vv -p no:warnings
42 changes: 42 additions & 0 deletions apps/chatbot/src/app/jwt_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import requests
from jose import jwk, jwt
from jose.utils import base64url_decode
from fastapi import HTTPException

AWS_DEFAULT_REGION = os.getenv('CHB_AWS_DEFAULT_REGION', os.getenv('AWS_DEFAULT_REGION', None))
AUTH_COGNITO_USERPOOL_ID = os.getenv('AUTH_COGNITO_USERPOOL_ID')

def get_jwks():
KEYS_URL = f"https://cognito-idp.{AWS_DEFAULT_REGION}.amazonaws.com/{AWS_DEFAULT_REGION}_{AUTH_COGNITO_USERPOOL_ID}/.well-known/jwks.json"
response = requests.get(KEYS_URL)
if response.status_code == 200:
return response.json()
else:
raise HTTPException(status_code=401, detail="Auth error")

def verify_jwt(token: str):
jwks = get_jwks()
public_keys = {key["kid"]: key for key in jwks["keys"]}

try:
headers = jwt.get_unverified_header(token)
kid = headers["kid"]
if kid not in public_keys:
raise HTTPException(status_code=401, detail="Invalid token key")

public_key = jwk.construct(public_keys[kid])

message, encoded_signature = str(token).rsplit('.', 1)
decoded_signature = base64url_decode(encoded_signature.encode('utf-8'))
if not public_key.verify(message.encode("utf8"), decoded_signature):
raise HTTPException(status_code=401, detail="error in public_key.verify")

# since we passed the verification, we can now safely use the unverified claims
claims = jwt.get_unverified_claims(token)
return claims
except jwt_exceptions.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except jwt_exceptions.JWTError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {e}")

45 changes: 22 additions & 23 deletions apps/chatbot/src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import uuid
import boto3
import datetime
import jwt
import nh3
import time
from typing import Annotated, List
from boto3.dynamodb.conditions import Key
from botocore.exceptions import BotoCoreError, ClientError
Expand All @@ -17,11 +18,14 @@
from pydantic import BaseModel, Field

from src.modules.chatbot import Chatbot
from src.app.jwt_check import get_jwks, verify_jwt

logging.basicConfig(level=logging.INFO)
params = yaml.safe_load(open("config/params.yaml", "r"))
prompts = yaml.safe_load(open("config/prompts.yaml", "r"))
AWS_DEFAULT_REGION = os.getenv('CHB_AWS_DEFAULT_REGION', os.getenv('AWS_DEFAULT_REGION', None))
AUTH_COGNITO_USERPOOL_ID = os.getenv('AUTH_COGNITO_USERPOOL_ID')
ENVIRONMENT = os.getenv('environment', 'dev')

chatbot = Chatbot(params, prompts)

Expand All @@ -43,17 +47,11 @@ class QueryFeedback(BaseModel):
region_name=AWS_DEFAULT_REGION
)

if (os.getenv('environment', 'dev') == 'local'):
dynamodb = boto3_session.resource(
'dynamodb',
endpoint_url=os.getenv('CHB_DYNAMODB_URL', 'http://localhost:8000'),
region_name=AWS_DEFAULT_REGION
)
else:
dynamodb = boto3_session.resource(
'dynamodb',
region_name=AWS_DEFAULT_REGION
)
# endpoint_url is set by AWS_ENDPOINT_URL_DYNAMODB
dynamodb = boto3_session.resource(
'dynamodb',
region_name=AWS_DEFAULT_REGION
)

table_queries = dynamodb.Table(
f"{os.getenv('CHB_QUERY_TABLE_PREFIX', 'chatbot')}-queries"
Expand Down Expand Up @@ -94,14 +92,13 @@ async def query_creation (
salt = session_salt(session['id'])

answer = chatbot.chat_generate(
query_str = query.question,
query_str = nh3.clean(query.question),
messages = [item.dict() for item in query.history] if query.history else None,
trace_id = trace_id,
user_id = hash_func(userId, salt),
session_id = session["id"]
)


if query.queriedAt is None:
queriedAt = now.isoformat()
else:
Expand All @@ -126,19 +123,20 @@ async def query_creation (
raise HTTPException(status_code=422, detail=f"[POST /queries] error: {e}")
return bodyToReturn


def current_user_id(authorization: str) -> str:
if authorization is None:
return None
raise HTTPException(status_code=401, detail="Unauthorized")

else:
token = authorization.split(' ')[1]
decoded = jwt.decode(
token,
algorithms=["RS256"],
options={"verify_signature": False}
)
return decoded['cognito:username']

decoded = verify_jwt(token)
if decoded is False:
raise HTTPException(status_code=401, detail="Unauthorized")
else:
if "cognito:username" in decoded:
return decoded['cognito:username']
else:
return decoded['username']

def find_or_create_session(userId: str, now: datetime.datetime):
if userId is None:
Expand Down Expand Up @@ -210,6 +208,7 @@ async def queries_fetching(
authorization: Annotated[str | None, Header()] = None
):
userId = current_user_id(authorization)

if sessionId is None:
sessionId = last_session_id(userId)
else:
Expand Down
63 changes: 63 additions & 0 deletions apps/chatbot/src/app/mock_cognito.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import boto3
import os
from moto import mock_aws

@mock_aws
def mock_client():
return boto3.client('cognito-idp', region_name=os.getenv('AWS_DEFAULT_REGION'))

client = mock_client()

@mock_aws
def mock_user_pool_id():
# Create a user pool
user_pool_response = client.create_user_pool(PoolName='test_pool')
user_pool_id = user_pool_response['UserPool']['Id']
return user_pool_id

@mock_aws
def mock_signup():
user_pool_id = mock_user_pool_id()
# Create a user pool client
client_response = client.create_user_pool_client(
UserPoolId=user_pool_id,
ClientName='test_client'
)
client_id = client_response['UserPoolClient']['ClientId']

# Sign up a new user
client.sign_up(
ClientId=client_id,
Username='test_user',
Password='TestPassword123!'
)

# Admin confirm the user (bypassing the confirmation step)
client.admin_confirm_sign_up(
UserPoolId=user_pool_id,
Username='test_user'
)

# Initiate auth to obtain JWT tokens
response = client.initiate_auth(
ClientId=client_id,
AuthFlow='USER_PASSWORD_AUTH',
AuthParameters={
'USERNAME': 'test_user',
'PASSWORD': 'TestPassword123!'
}
)

access_token = response['AuthenticationResult']['AccessToken']

print('Access Token:', access_token)
print('User Pool ID:', user_pool_id)

return {
"access_token": access_token,
"user_pool_id": user_pool_id
}

if __name__ == "__main__":
mock_signup()

Loading
Loading