Skip to content

Commit

Permalink
enforce atomic onboarding, return progress tracking url
Browse files Browse the repository at this point in the history
  • Loading branch information
jaismith committed May 21, 2024
1 parent eb06210 commit ad60f1e
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 16 deletions.
15 changes: 12 additions & 3 deletions backend/src/handlers/access.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEventV2
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig
import pandas as pd
from botocore.exceptions import ClientError

from utils import usgs, forecast, claude, db

Expand All @@ -25,13 +27,20 @@ def get_site():

@app.post('/site/register')
def register_site():
WEBSOCKET_API_ENDPOINT = os.environ['WEBSOCKET_API_ENDPOINT']

query_params = app.current_event.query_string_parameters
usgs_site = query_params.get('usgs_site')

site_info = db.register_new_site(usgs_site)
# TODO: get subscription url, return
try:
site_info = db.register_new_site(usgs_site)
except ClientError as e:
if e.response['Error']['Code'] == 'ConditionalCheckFailedException':
return { 'message': 'Site is already (or currently being) onboarded.' }
else:
raise

return { 'site': site_info }
return { 'site': site_info, 'progress_url': WEBSOCKET_API_ENDPOINT }

@app.get('/report')
def get_report():
Expand Down
2 changes: 1 addition & 1 deletion backend/src/handlers/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def handler(event, _context):
db.push_fcst_entries(fcst_rows)
if (is_onboarding):
db.push_site_onboarding_log(usgs_site, f'\tfinished forecasting at {utils.get_current_local_time()}')
db.update_site_status(usgs_site, db.SiteStatus.READY)
db.update_site_status(usgs_site, db.SiteStatus.ACTIVE)

return { 'statusCode': 200 }

Expand Down
7 changes: 7 additions & 0 deletions backend/src/handlers/onboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ def process_stream(event, _context):
db.remove_site_subscription(usgs_site, subscription_id)

return { 'statusCode': 200 }

def register_failure(event, _context):
usgs_site = event['usgs_site']

db.update_site_status(usgs_site, db.SiteStatus.FAILED)

return { 'statusCode': 200 }
28 changes: 25 additions & 3 deletions backend/src/utils/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import json
from datetime import datetime
import boto3
from boto3.dynamodb.conditions import Key, Attr
Expand All @@ -14,6 +16,8 @@
report_table = dynamodb.Table('flowcast-reports')
site_table = dynamodb.Table('flowcast-sites')

stepfunctions = boto3.client('stepfunctions')

def get_latest_hist_entry(usgs_site):
res = data_table.query(
KeyConditionExpression=Key('usgs_site#type')
Expand Down Expand Up @@ -123,10 +127,14 @@ class SiteStatus(Enum):
''' Site feature models are being trained. '''
FORECASTING = 'FORECASTING'
''' Future datapoints are being forecast. '''
READY = 'READY'
ACTIVE = 'ACTIVE'
''' Site is onboarded and ready for usage. '''
FAILED = 'FAILED'
''' Site failed to onboard. '''

def register_new_site(usgs_site: str, registration_date=datetime.now(), status=SiteStatus.SCHEDULED):
UPDATE_AND_FORECAST_STATE_MACHINE_ARN = os.environ['UPDATE_AND_FORECAST_STATE_MACHINE_ARN']

usgs_site_data = usgs.get_site_info(usgs_site)
item = {
'usgs_site': usgs_site,
Expand All @@ -142,10 +150,24 @@ def register_new_site(usgs_site: str, registration_date=datetime.now(), status=S
}

site_table.put_item(
Item=item
Item=item,
ConditionExpression="attribute_not_exists(usgs_site) OR #status = :failed_status",
ExpressionAttributeNames={
'#status': 'status'
},
ExpressionAttributeValues={
':failed_status': SiteStatus.FAILED.value
}
)

stepfunctions.start_execution(
stateMachineArn=UPDATE_AND_FORECAST_STATE_MACHINE_ARN,
input=json.dumps({
'usgs_site': usgs_site,
'is_onboarding': True
})
)

item['onboarding_logs'] = list(item['onboarding_logs'])
item['subscription_ids'] = list(item['subscription_ids'])
return item

Expand Down
44 changes: 35 additions & 9 deletions infra/lib/flowcast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ export class FlowcastStack extends Stack {
memorySize: 256,
logRetention: DEFAULT_LOG_RETENTION
});
const onboardFailed = new lambda.DockerImageFunction(this, 'onboard_failed_function', {
code: lambda.DockerImageCode.fromEcr(sharedLambdaImage.repository, {
tagOrDigest: sharedLambdaImage.imageTag,
entrypoint: ['python', '-m', 'awslambdaric'],
cmd: ['index.handle_onboard_failed']
}),
environment: env,
architecture: lambda.Architecture.X86_64,
timeout: cdk.Duration.seconds(30),
memorySize: 256,
logRetention: DEFAULT_LOG_RETENTION
});

const websocketApi = new apigatewayv2.WebSocketApi(this, 'flowcast-websocket-api', {
connectRouteOptions: { integration: new integrations.WebSocketLambdaIntegration('connect-integration', onboardConnect) },
Expand Down Expand Up @@ -208,6 +220,8 @@ export class FlowcastStack extends Stack {
tumblingWindow: cdk.Duration.seconds(30)
}));

access.addEnvironment('WEBSOCKET_API_ENDPOINT', websocketApiStage.url)

// * fargate

const trainVpc = new ec2.Vpc(this, 'flowcast-vpc', {
Expand Down Expand Up @@ -276,7 +290,7 @@ export class FlowcastStack extends Stack {

// * permissions

[update, forecast, access, exportFunc, onboardConnect, onboardDisconnect, onboardProcessStream].forEach(func => {
[update, forecast, access, exportFunc, onboardConnect, onboardDisconnect, onboardProcessStream, onboardFailed].forEach(func => {
db.grantFullAccess(func);
reportsDb.grantFullAccess(func);
sitesDb.grantFullAccess(func);
Expand Down Expand Up @@ -307,18 +321,22 @@ export class FlowcastStack extends Stack {

// * sfn

const failTask = new sfnTasks.LambdaInvoke(this, `fail_task_${id}`, {
lambdaFunction: onboardFailed,
resultPath: '$.Result'
});
const updateTask = new sfnTasks.LambdaInvoke(this, 'update_task', {
lambdaFunction: update,
resultPath: '$.Result'
});
}).addCatch(failTask);
const forecastTask = new sfnTasks.LambdaInvoke(this, 'forecast_task', {
lambdaFunction: forecast,
resultPath: '$.Result'
});
}).addCatch(failTask);
const exportTask = new sfnTasks.LambdaInvoke(this, 'export_task', {
lambdaFunction: exportFunc,
resultPath: '$.Result'
});
}).addCatch(failTask);
const trainTask = new sfnTasks.BatchSubmitJob(this, 'train_task', {
jobQueueArn: trainJobQueue.jobQueueArn,
jobDefinitionArn: trainJobDefinition.jobDefinitionArn,
Expand All @@ -327,18 +345,23 @@ export class FlowcastStack extends Stack {
'usgs_site.$': '$.usgs_site'
}),
resultPath: '$.Result'
});
}).addCatch(failTask);

const wait = new sfn.Wait(this, 'wait', {
time: sfn.WaitTime.duration(cdk.Duration.seconds(30))
});
const onboardCondition = sfn.Condition.booleanEquals('$.is_onboarding', true);
const failCondition = sfn.Condition.not(sfn.Condition.numberEquals('$.Result.Payload.statusCode', 200));

const failState = (id: string) => new sfnTasks.LambdaInvoke(this, `fail_task_${id}`, {
lambdaFunction: onboardFailed,
resultPath: '$.Result'
}).next(new sfn.Fail(this, id));

const exportCompleteChoice = new sfn.Choice(this, 'check_export_complete');
exportCompleteChoice
.when(sfn.Condition.numberEquals('$.Result.Payload.statusCode', 200), new sfn.Pass(this, 'export_complete'))
.when(sfn.Condition.numberGreaterThanEquals('$.Result.Payload.statusCode', 400), new sfn.Fail(this, 'export_failed'))
.when(sfn.Condition.numberGreaterThanEquals('$.Result.Payload.statusCode', 400), failState('export_failed'))
.otherwise(wait
.next(new sfnTasks.LambdaInvoke(this, 'poll_export_task', {
lambdaFunction: exportFunc,
Expand All @@ -349,28 +372,31 @@ export class FlowcastStack extends Stack {
const updateAndForecastSfn = new sfn.StateMachine(this, 'update_and_forecast', {
definitionBody: sfn.DefinitionBody.fromChainable(sfn.Chain.start(updateTask)
.next(new sfn.Choice(this, 'verify_update')
.when(failCondition, new sfn.Fail(this, 'update_failed'))
.when(failCondition, failState('update_failed'))
.otherwise(new sfn.Pass(this, 'update_successful'))
.afterwards())
.next(new sfn.Choice(this, 'check_onboarding')
.when(onboardCondition, exportTask
.next(exportCompleteChoice.afterwards())
.next(trainTask)
.next(new sfn.Choice(this, 'verify_train')
.when(failCondition, new sfn.Fail(this, 'train_failed'))
.when(failCondition, failState('train_failed'))
.otherwise(new sfn.Pass(this, 'train_successful'))
.afterwards()))
.otherwise(new sfn.Pass(this, 'not_onboarding'))
.afterwards())
.next(forecastTask)
.next(new sfn.Choice(this, 'verify_forecast')
.when(failCondition, new sfn.Fail(this, 'forecast_failed'))
.when(failCondition, failState('forecast_failed'))
.otherwise(new sfn.Pass(this, 'forecast_successful'))
.afterwards())
.next(new sfn.Succeed(this, 'update_and_forecast_successful'))),
timeout: cdk.Duration.minutes(25)
});

access.addEnvironment('UPDATE_AND_FORECAST_STATE_MACHINE_ARN', updateAndForecastSfn.stateMachineArn);
updateAndForecastSfn.grantStartExecution(access);

// * public access url

new cdk.CfnResource(this, 'public_access_url', {
Expand Down

0 comments on commit ad60f1e

Please sign in to comment.