From 9db4679579da6f11eb32e84c4ef223209ca51eaf Mon Sep 17 00:00:00 2001 From: albertkun Date: Tue, 21 Nov 2023 14:36:38 -0800 Subject: [PATCH 1/4] fix: redundant code and update Redis connection logic --- fastapi/app/main.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/fastapi/app/main.py b/fastapi/app/main.py index 80da2cd..5850d8c 100644 --- a/fastapi/app/main.py +++ b/fastapi/app/main.py @@ -86,23 +86,6 @@ size=Query(100, ge=1, le=500), ) -# Define the redis variable at the top level -redis = None - -async def initialize_redis(): - global redis - logging.info(f"Connecting to Redis at {Config.REDIS_URL}") - for i in range(5): # Retry up to 5 times - try: - redis = await aioredis.from_url(Config.REDIS_URL) - break # If the connection is successful, break out of the loop - except aioredis.exceptions.ConnectionError as e: - logging.error(f"Failed to connect to Redis: {e}") - if i < 4: # If this was not the last attempt, wait a bit before retrying - await asyncio.sleep(5) # Wait for 5 seconds - else: # If this was the last attempt, re-raise the exception - raise - async def get_data(db: Session, key: str, fetch_func): # Get data from Redis data = await redis.get(key) From 085a1973511a48645b358619883992093ce2ab4e Mon Sep 17 00:00:00 2001 From: albertkun Date: Tue, 21 Nov 2023 14:39:18 -0800 Subject: [PATCH 2/4] fefactor: Redis calls in main.py --- fastapi/app/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fastapi/app/main.py b/fastapi/app/main.py index 5850d8c..cb89cb1 100644 --- a/fastapi/app/main.py +++ b/fastapi/app/main.py @@ -88,14 +88,14 @@ async def get_data(db: Session, key: str, fetch_func): # Get data from Redis - data = await redis.get(key) + data = await crud.redis.get(key) if data is None: # If data is not in Redis, get it from the database data = fetch_func(db, key) if data is None: return None # Set data in Redis - await redis.set(key, data) + await crud.redis.set(key, data) return data @@ -460,10 +460,10 @@ async def websocket_endpoint(websocket: WebSocket, agency_id: str, async_db: Asy data = await asyncio.wait_for(crud.get_all_data_async(async_db, models.VehiclePositions, agency_id), timeout=120) if data is not None: # Publish the received data to a Redis channel - await redis.publish('live_vehicle_positions', data) + await crud.redis.publish('live_vehicle_positions', data) await asyncio.sleep(10) # Subscribe to the Redis channel and send any received messages to the WebSocket client - ch = await redis.subscribe('live_vehicle_positions') + ch = await crud.redis.subscribe('live_vehicle_positions') while await ch.wait_message(): message = await ch.get() await websocket.send_json(message) @@ -489,13 +489,13 @@ async def websocket_vehicle_positions_by_ids(websocket: WebSocket, agency_id: Ag if result is not None: data[id] = result # Publish the received data to a Redis channel - await redis.publish('live_vehicle_positions_by_ids', data) + await crud.redis.publish('live_vehicle_positions_by_ids', data) except asyncio.TimeoutError: raise HTTPException(status_code=408, detail="Request timed out") if data: await asyncio.sleep(5) # Subscribe to the Redis channel and send any received messages to the WebSocket client - ch = await redis.subscribe('live_vehicle_positions_by_ids') + ch = await crud.redis.subscribe('live_vehicle_positions_by_ids') while await ch.wait_message(): message = await ch.get() await websocket.send_json(message) From ce57d4c9d5decc4d6b09f0b98a7b737c49b1f965 Mon Sep 17 00:00:00 2001 From: albertkun Date: Wed, 22 Nov 2023 06:07:36 -0800 Subject: [PATCH 3/4] feat: requirements.txt files and add websocket test --- fastapi/app/crud.py | 63 ++++++++----- fastapi/app/main.py | 124 +++++++++++--------------- fastapi/app/requirements.txt | 1 - fastapi/requirements.txt | 2 - fastapi/tests/requirements.txt | 4 +- fastapi/tests/test_endpoints_local.py | 17 ++++ 6 files changed, 113 insertions(+), 98 deletions(-) diff --git a/fastapi/app/crud.py b/fastapi/app/crud.py index db9e146..2eceed0 100644 --- a/fastapi/app/crud.py +++ b/fastapi/app/crud.py @@ -39,6 +39,7 @@ import aioredis import pickle +import time from sqlalchemy import select @@ -51,7 +52,26 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.decl_api import DeclarativeMeta -redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5) + +redis_connection = None + +def initialize_redis(retries=5, delay=5): + global redis_connection + for i in range(retries): + try: + redis_connection = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5) + # If connection is successful, break the loop + if redis_connection.ping(): + break + except Exception as e: + print(f"Failed to connect to Redis: {e}") + redis_connection = None + if i < retries - 1: # no delay on the last attempt + time.sleep(delay) + else: + raise Exception("Failed to connect to Redis after several attempts") + +initialize_redis() # import sqlalchemy def asdict(obj): @@ -75,7 +95,7 @@ async def get_data_redis(db, model, id_field, id_value): key = f'{model.__tablename__}:{id_value}' # Try to get data from Redis - data = await redis.get(key) + data = await redis_connection.get(key) if data is None: # If data is not in Redis, get it from the database @@ -90,7 +110,7 @@ async def get_data_redis(db, model, id_field, id_value): for key, value in row.__dict__.items() if not key.startswith('_sa_instance_state') } for row in result]) - await redis.set(key, data) + await redis_connection.set(key, data) else: # Parse the JSON-formatted string back into a Python data structure data = json.loads(data) @@ -119,13 +139,14 @@ async def get_vehicle_data_async(db: AsyncSession, agency_id: str, vehicle_id: s return data import pickle - async def get_data_async(async_session: Session, model: Type[DeclarativeMeta], agency_id: str, field_name: Optional[str] = None, field_value: Optional[str] = None): # Create a unique key for this query key = f"{model.__name__}:{agency_id}:{field_name}:{field_value}" # Try to get the result from Redis - result = await redis.get(key) + if redis_connection is None: + initialize_redis() + result = await redis_connection.get(key) if result is not None: data = pickle.loads(result) if isinstance(data, model): @@ -163,7 +184,7 @@ async def get_list_of_unique_values_async(session: AsyncSession, model, agency_i logging.info(f"Generated key: {key}") # Try to get the result from Redis - result = await redis.get(key) + result = await redis_connection.get(key) if result is not None: logging.info("Found result in Redis") return pickle.loads(result) @@ -186,7 +207,7 @@ async def get_list_of_unique_values_async(session: AsyncSession, model, agency_i logging.info(f"Unique values from database: {unique_values}") # Store the result in Redis - await redis.set(key, pickle.dumps(unique_values)) + await redis_connection.set(key, pickle.dumps(unique_values)) return unique_values @@ -224,7 +245,7 @@ def get_stop_times_by_route_code(db, route_code: str,agency_id: str): async def get_stop_times_by_trip_id(db, trip_id: str, agency_id: str): # Try to get the result from Redis first cache_key = f'stop_times:{trip_id}:{agency_id}' - cached_result = await redis.get(cache_key) + cached_result = await redis_connection.get(cache_key) if cached_result is not None: return pickle.loads(cached_result) @@ -242,7 +263,7 @@ async def get_stop_times_by_trip_id(db, trip_id: str, agency_id: str): # If result is not empty, store it in Redis for future use if result: - await redis.set(cache_key, pickle.dumps(result)) + await redis_connection.set(cache_key, pickle.dumps(result)) return result @@ -276,7 +297,7 @@ def list_gtfs_rt_vehicle_positions_by_field_name(db, field_name: str,agency_id: async def get_gtfs_rt_trips_by_field_name(db, field_name: str, field_value: str, agency_id: str): # Try to get the result from Redis first cache_key = f'trips:{field_name}:{field_value}:{agency_id}' - cached_result = await redis.get(cache_key) + cached_result = await redis_connection.get(cache_key) if cached_result is not None: return pickle.loads(cached_result) @@ -296,14 +317,14 @@ async def get_gtfs_rt_trips_by_field_name(db, field_name: str, field_value: str, # If result is not empty, store it in Redis for future use if result: - await redis.set(cache_key, pickle.dumps(result)) + await redis_connection.set(cache_key, pickle.dumps(result)) return result async def get_all_gtfs_rt_trips(db, agency_id: str): # Try to get the result from Redis first cache_key = f'trips:{agency_id}' - cached_result = await redis.get(cache_key) + cached_result = await redis_connection.get(cache_key) if cached_result is not None: return pickle.loads(cached_result) @@ -315,7 +336,7 @@ async def get_all_gtfs_rt_trips(db, agency_id: str): # If result is not empty, store it in Redis for future use if result: - await redis.set(cache_key, pickle.dumps(result)) + await redis_connection.set(cache_key, pickle.dumps(result)) return result @@ -323,7 +344,7 @@ async def get_all_gtfs_rt_vehicle_positions(db, agency_id: str, geojson: bool): try: # Try to get the result from Redis first cache_key = f'vehicle_positions:{agency_id}:{geojson}' - cached_result = await redis.get(cache_key) + cached_result = await redis_connection.get(cache_key) if cached_result is not None: return pickle.loads(cached_result) @@ -350,7 +371,7 @@ async def get_all_gtfs_rt_vehicle_positions(db, agency_id: str, geojson: bool): # If result is not empty, store it in Redis for future use if result: - await redis.set(cache_key, pickle.dumps(result)) + await redis_connection.set(cache_key, pickle.dumps(result)) return result except Exception as e: @@ -432,7 +453,7 @@ def _async(db, agency_id: str, geojson: bool): from sqlalchemy.orm import joinedload async def get_gtfs_rt_vehicle_positions_trip_data_by_route_code(session: AsyncSession, route_code: str, geojson:bool, agency_id:str): cache_key = f'trip_data:{route_code}:{agency_id}' - cached_data = await redis.get(cache_key) + cached_data = await redis_connection.get(cache_key) if cached_data is not None: return pickle.loads(cached_data) stmt = ( @@ -577,7 +598,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data_redis(db, vehicle_id: str): key = f'vehicle:{vehicle_id}' # Try to get data from Redis - data = await redis.get(key) + data = await redis_connection.get(key) if data is None: # If data is not in Redis, get it from the database @@ -588,7 +609,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data_redis(db, vehicle_id: str): # Convert the result to JSON and store it in Redis data = json.dumps([dict(row) for row in result]) - await redis.set(key, data) + await redis_connection.set(key, data) return data @@ -596,7 +617,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data_redis(db, vehicle_id: str): async def get_gtfs_rt_vehicle_positions_trip_data(db, vehicle_id: str, geojson: bool, agency_id: str): # Try to get the result from Redis first cache_key = f'vehicle_positions:{vehicle_id}:{geojson}:{agency_id}' - result = await redis.get(cache_key) + result = await redis_connection.get(cache_key) if result is not None: return pickle.loads(result) @@ -617,7 +638,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data(db, vehicle_id: str, geojson: this_json['type'] = "FeatureCollection" this_json['features'] = features if this_json: - await redis.set(cache_key, pickle.dumps(this_json)) + await redis_connection.set(cache_key, pickle.dumps(this_json)) return this_json for row in the_query: @@ -641,7 +662,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data(db, vehicle_id: str, geojson: return message_object else: if result: - await redis.set(cache_key, pickle.dumps(result)) + await redis_connection.set(cache_key, pickle.dumps(result)) return result diff --git a/fastapi/app/main.py b/fastapi/app/main.py index cb89cb1..4a549d4 100644 --- a/fastapi/app/main.py +++ b/fastapi/app/main.py @@ -48,9 +48,7 @@ from starlette.responses import Response -from fastapi_cache import FastAPICache -from fastapi_cache.backends.redis import RedisBackend -from fastapi_cache.decorator import cache + # from redis import asyncio as aioredis @@ -88,14 +86,14 @@ async def get_data(db: Session, key: str, fetch_func): # Get data from Redis - data = await crud.redis.get(key) + data = await crud.redis_connection.get(key) if data is None: # If data is not in Redis, get it from the database data = fetch_func(db, key) if data is None: return None # Set data in Redis - await crud.redis.set(key, data) + await crud.redis_connection.set(key, data) return data @@ -355,8 +353,7 @@ def standardize_string(input_string): #### Begin GTFS-RT Routes #### -@app.get("/{agency_id}/trip_updates", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) +@app.get("/{agency_id}/trip_updates", tags=["Real-Time data"]) async def get_all_trip_updates(agency_id: AgencyIdEnum, async_db: AsyncSession = Depends(get_async_db)): """ Get all trip updates. @@ -367,8 +364,7 @@ async def get_all_trip_updates(agency_id: AgencyIdEnum, async_db: AsyncSession = raise HTTPException(status_code=404, detail="Data not found") return data -@app.get("/{agency_id}/trip_updates/{field}/{ids}", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) +@app.get("/{agency_id}/trip_updates/{field}/{ids}", tags=["Real-Time data"]) async def get_trip_updates_by_ids(agency_id: AgencyIdEnum, field: TripUpdatesFieldsEnum, ids: str, format: FormatEnum = Query(FormatEnum.json), async_db: AsyncSession = Depends(get_async_db)): """ Get specific trip updates by IDs dependant on the `field` selected. IDs can be provided as a comma-separated list. @@ -397,8 +393,7 @@ async def get_list_of_field_values(agency_id: AgencyIdEnum, field: TripUpdatesFi raise HTTPException(status_code=404, detail="Data not found") return data -@app.get("/{agency_id}/vehicle_positions", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) +@app.get("/{agency_id}/vehicle_positions", tags=["Real-Time data"]) async def get_all_vehicle_positions(agency_id: AgencyIdEnum, format: FormatEnum = Query(FormatEnum.json), async_db: AsyncSession = Depends(get_async_db)): """ Get all vehicle positions updates. @@ -412,8 +407,7 @@ async def get_all_vehicle_positions(agency_id: AgencyIdEnum, format: FormatEnum data = to_geojson(data) return data -@app.get("/{agency_id}/vehicle_positions/{field}/{ids}", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) +@app.get("/{agency_id}/vehicle_positions/{field}/{ids}", tags=["Real-Time data"]) async def get_vehicle_positions_by_ids(agency_id: AgencyIdEnum, field: VehiclePositionsFieldsEnum, ids: str, format: FormatEnum = Query(FormatEnum.json), async_db: AsyncSession = Depends(get_async_db)): """ Get specific vehicle position updates by IDs dependant on the `field` selected. IDs can be provided as a comma-separated list. @@ -432,7 +426,6 @@ async def get_vehicle_positions_by_ids(agency_id: AgencyIdEnum, field: VehiclePo return data @app.get("/{agency_id}/vehicle_positions/{field}", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) async def get_list_of_field_values(agency_id: AgencyIdEnum, field: VehiclePositionsFieldsEnum, async_db: AsyncSession = Depends(get_async_db)): """ Get a list of all values for a specific field in the vehicle positions updates. @@ -443,38 +436,40 @@ async def get_list_of_field_values(agency_id: AgencyIdEnum, field: VehiclePositi raise HTTPException(status_code=404, detail="Data not found") return data +import json +import asyncio +import aioredis @app.websocket("/ws/{agency_id}/vehicle_positions") async def websocket_endpoint(websocket: WebSocket, agency_id: str, async_db: AsyncSession = Depends(get_async_db)): await websocket.accept() + + channel = (await crud.redis_connection.subscribe('live_vehicle_positions'))[0] + + async def listen_to_redis(): + async for message in channel.iter(encoding='utf-8'): + # Unserialize message with json before sending + message_data = json.loads(message) + await websocket.send_json(message_data) + + listen_task = asyncio.create_task(listen_to_redis()) + try: - while True: - # Check server load - load1, load5, load15 = os.getloadavg() - if load1 > SERVER_OVERLOAD_THRESHOLD: - await websocket.send_json({"type": "server_overload"}) - await websocket.close() - return - - try: - data = await asyncio.wait_for(crud.get_all_data_async(async_db, models.VehiclePositions, agency_id), timeout=120) - if data is not None: - # Publish the received data to a Redis channel - await crud.redis.publish('live_vehicle_positions', data) - await asyncio.sleep(10) - # Subscribe to the Redis channel and send any received messages to the WebSocket client - ch = await crud.redis.subscribe('live_vehicle_positions') - while await ch.wait_message(): - message = await ch.get() - await websocket.send_json(message) - # Send a ping every 10 seconds - await websocket.send_json({"type": "ping"}) - except asyncio.TimeoutError: - raise HTTPException(status_code=408, detail="Request timed out") + data = await asyncio.wait_for(crud.get_all_data_async(async_db, models.VehiclePositions, agency_id), timeout=120) + if data is not None: + # Serialize data with json before publishing + data_json = json.dumps(data) + await crud.redis_connection.publish('live_vehicle_positions', data_json) + except asyncio.TimeoutError: + raise HTTPException(status_code=408, detail="Request timed out") except WebSocketDisconnect: # Handle the WebSocket disconnect event print("WebSocket disconnected") - + finally: + listen_task.cancel() + crud.redis_connection.unsubscribe('live_vehicle_positions') + crud.redis_connection.close() + await crud.redis_connection.wait_closed() @app.websocket("/ws/{agency_id}/vehicle_positions/{field}/{ids}") async def websocket_vehicle_positions_by_ids(websocket: WebSocket, agency_id: AgencyIdEnum, field: VehiclePositionsFieldsEnum, ids: str, async_db: AsyncSession = Depends(get_async_db)): await websocket.accept() @@ -489,13 +484,13 @@ async def websocket_vehicle_positions_by_ids(websocket: WebSocket, agency_id: Ag if result is not None: data[id] = result # Publish the received data to a Redis channel - await crud.redis.publish('live_vehicle_positions_by_ids', data) + await crud.redis_connection.publish('live_vehicle_positions_by_ids', data) except asyncio.TimeoutError: raise HTTPException(status_code=408, detail="Request timed out") if data: await asyncio.sleep(5) # Subscribe to the Redis channel and send any received messages to the WebSocket client - ch = await crud.redis.subscribe('live_vehicle_positions_by_ids') + ch = await crud.redis_connection.subscribe('live_vehicle_positions_by_ids') while await ch.wait_message(): message = await ch.get() await websocket.send_json(message) @@ -508,13 +503,11 @@ async def websocket_vehicle_positions_by_ids(websocket: WebSocket, agency_id: Ag ##### todo: Needs to be tested @app.get("/{agency_id}/trip_detail/route_code/{route_code}",tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) async def get_trip_detail_by_route_code(agency_id: AgencyIdEnum, route_code: str, geojson:bool=False, db: AsyncSession = Depends(get_db)): - result = await crud.get_gtfs_rt_vehicle_positions_trip_data_by_route_code_for_async(session=db, route_code=route_code, geojson=geojson, agency_id=agency_id.value) + result = await crud.get_gtfs_rt_vehicle_positions_trip_data_by_route_code(session=db, route_code=route_code, geojson=geojson, agency_id=agency_id.value) return result @app.get("/{agency_id}/trip_detail/vehicle/{vehicle_id?}", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) async def get_trip_detail_by_vehicle(agency_id: AgencyIdEnum, vehicle_id: Optional[str] = None, operation: OperationEnum = Depends(), geojson: bool = False, async_db: AsyncSession = Depends(get_async_db)): if operation == OperationEnum.ALL: result = await crud.get_all_data_async(async_db, models.VehiclePositions, operation.value) @@ -536,7 +529,6 @@ async def get_trip_detail_by_vehicle(agency_id: AgencyIdEnum, vehicle_id: Option @app.get("/{agency_id}/trip_detail/route/{route_code?}", tags=["Real-Time data"]) -@cache(expire=REALTIME_UDPATE_INTERVAL) async def get_trip_detail_by_route(agency_id: AgencyIdEnum, route_code: Optional[OperationEnum] = None, geojson: bool = False, async_db: AsyncSession = Depends(get_async_db)): if route_code == OperationEnum.ALL: result = await crud.get_all_data_async(async_db, models.VehiclePositions, route_code.value) @@ -552,7 +544,6 @@ async def get_trip_detail_by_route(agency_id: AgencyIdEnum, route_code: Optional @app.get("/canceled_service_summary",tags=["Canceled Service Data"]) -@cache(expire=CANCELED_UDPATE_INTERVAL) async def get_canceled_trip_summary(db: AsyncSession = Depends(get_async_db)): result = await crud.get_canceled_trips(db,'all') canceled_trips_summary = {} @@ -577,14 +568,12 @@ async def get_canceled_trip_summary(db: AsyncSession = Depends(get_async_db)): "last_updated":update_time} @app.get("/canceled_service/line/{line}",tags=["Canceled Service Data"]) -@cache(expire=CANCELED_UDPATE_INTERVAL) async def get_canceled_trip(db: Session = Depends(get_db),line: str = None): result = crud.get_canceled_trips(db,line) json_compatible_item_data = jsonable_encoder(result) return JSONResponse(content=json_compatible_item_data) @app.get("/canceled_service/all",tags=["Canceled Service Data"]) -@cache(expire=CANCELED_UDPATE_INTERVAL) async def get_canceled_trip(db: Session = Depends(get_db)): result = crud.get_canceled_trips(db,'all') json_compatible_item_data = jsonable_encoder(result) @@ -592,53 +581,45 @@ async def get_canceled_trip(db: Session = Depends(get_db)): ### GTFS Static data ### @app.get("/{agency_id}/route_stops/{route_code}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def populate_route_stops(agency_id: AgencyIdEnum,route_code:str, daytype: DayTypesEnum = DayTypesEnum.all, db: Session = Depends(get_db)): result = crud.get_gtfs_route_stops(db,route_code,daytype.value,agency_id.value) json_compatible_item_data = jsonable_encoder(result) return JSONResponse(content=json_compatible_item_data) @app.get("/{agency_id}/route_stops_grouped/{route_code}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def populate_route_stops_grouped(agency_id: AgencyIdEnum,route_code:str, db: Session = Depends(get_db)): result = crud.get_gtfs_route_stops_grouped(db,route_code,agency_id.value) json_compatible_item_data = jsonable_encoder(result[0]) return JSONResponse(content=json_compatible_item_data) @app.get("/calendar_dates",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_calendar_dates_from_db(db: Session = Depends(get_db)): result = crud.get_calendar_dates(db) calendar_dates = jsonable_encoder(result) return JSONResponse(content={"calendar_dates":calendar_dates}) @app.get("/{agency_id}/stop_times/route_code/{route_code}",tags=["Static data"],response_model=Page[schemas.StopTimes]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_stop_times_by_route_code_and_agency(agency_id: AgencyIdEnum,route_code, db: Session = Depends(get_db)): result = crud.get_stop_times_by_route_code(db,route_code,agency_id.value) return result @app.get("/{agency_id}/stop_times/trip_id/{trip_id}",tags=["Static data"],response_model=Page[schemas.StopTimes]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_stop_times_by_trip_id_and_agency(agency_id: AgencyIdEnum,trip_id, db: Session = Depends(get_db)): result = crud.get_stop_times_by_trip_id(db,trip_id,agency_id.value) return result @app.get("/{agency_id}/stops/{stop_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_stops(agency_id: AgencyIdEnum,stop_id, db: Session = Depends(get_db)): result = crud.get_stops_id(db,stop_id,agency_id.value) return result @app.get("/{agency_id}/trips/{trip_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_bus_trips(agency_id: AgencyIdEnum,trip_id, db: Session = Depends(get_db)): result = crud.get_trips_data(db,trip_id,agency_id.value) return result @app.get("/{agency_id}/shapes/{shape_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_shapes(agency_id: AgencyIdEnum,shape_id, geojson: bool = False,db: Session = Depends(get_db)): if shape_id == "list": result = crud.get_trip_shapes_list(db,agency_id.value) @@ -647,7 +628,6 @@ async def get_shapes(agency_id: AgencyIdEnum,shape_id, geojson: bool = False,db: return result @app.get("/{agency_id}/trip_shapes/{shape_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_trip_shapes(agency_id: AgencyIdEnum,shape_id, db: Session = Depends(get_db)): if shape_id == "all": result = crud.get_trip_shapes_all(db,agency_id.value) @@ -658,7 +638,6 @@ async def get_trip_shapes(agency_id: AgencyIdEnum,shape_id, db: Session = Depend return result @app.get("/{agency_id}/calendar/{service_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_calendar_list(agency_id: AgencyIdEnum,service_id, db: Session = Depends(get_db)): if service_id == "list": result = crud.get_calendar_list(db,agency_id.value) @@ -668,19 +647,16 @@ async def get_calendar_list(agency_id: AgencyIdEnum,service_id, db: Session = De @app.get("/{agency_id}/calendar/{service_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_calendar(agency_id: AgencyIdEnum,service_id, db: Session = Depends(get_db)): result = crud.get_calendar_data_by_id(db,models.Calendar,service_id,agency_id.value) return result @app.get("/{agency_id}/routes/{route_id}",tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_routes(agency_id: AgencyIdEnum,route_id, db: Session = Depends(get_db)): result = crud.get_routes_by_route_id(db,route_id,agency_id.value) return result @app.get("/{agency_id}/route_overview", tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_route_overview(agency_id: AllAgencyIdEnum, async_db: AsyncSession = Depends(get_async_db)): """ Get route overview data for all routes. @@ -702,7 +678,6 @@ async def get_route_overview(agency_id: AllAgencyIdEnum, async_db: AsyncSession return result @app.get("/{agency_id}/route_overview/{route_code}", tags=["Static data"]) -@cache(expire=STATIC_UDPATE_INTERVAL) async def get_route_overview_by_route_code(agency_id: AgencyIdEnum, route_code: str, async_db: AsyncSession = Depends(get_async_db)): """ Get route overview data by route code. @@ -733,7 +708,6 @@ async def get_agency(agency_id: AgencyIdEnum, db: Session = Depends(get_db)): #### Begin Other data endpoints #### @app.get("/get_gopass_schools",tags=["Other data"]) -@cache(expire=GO_PASS_UPDATE_INTERVAL) async def get_gopass_schools(db: AsyncSession = Depends(get_async_db), show_missing: bool = False, combine_phone:bool = False, groupby_column:GoPassGroupEnum = None): if combine_phone == True: result = await crud.get_gopass_schools_combined_phone(db, groupby_column.value) @@ -829,11 +803,8 @@ def read_user(username: str, db: Session = Depends(get_db),token: str = Depends( async def get_all_routes(): return [route.path for route in app.routes] -@app.on_event("startup") -async def startup_event(): +def setup_logging(): try: - crud.redis = await aioredis.from_url(Config.REDIS_URL) - FastAPICache.init(backend=crud.redis, prefix="fastapi-cache") uvicorn_access_logger = logging.getLogger("uvicorn.access") uvicorn_error_logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.app") @@ -857,9 +828,20 @@ async def startup_event(): uvicorn_access_logger.addFilter(LogFilter()) uvicorn_error_logger.addFilter(LogFilter()) logger.addFilter(LogFilter()) - except: - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_exception(exc_type, exc_value, exc_traceback, file=sys.stderr) + except Exception as e: + print(f"Failed to set up logging: {e}") + +class RedisNotConnected(Exception): + """Raised when Redis is not connected""" + pass +@app.on_event("startup") +async def startup_event(): + try: + crud.redis_connection = crud.initialize_redis() + setup_logging() + except Exception as e: + print(f"Failed to connect to Redis: {e}") + raise e app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -869,7 +851,3 @@ async def startup_event(): expose_headers=["*"], ) add_pagination(app) -# @app.on_event("startup") -# async def startup_redis(): - # redis = aioredis.from_url("redis://localhost", encoding="utf8", decode_responses=True,port=6379) -# FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache") diff --git a/fastapi/app/requirements.txt b/fastapi/app/requirements.txt index 41a94d1..aa9554a 100644 --- a/fastapi/app/requirements.txt +++ b/fastapi/app/requirements.txt @@ -23,7 +23,6 @@ shapely pandas polyline websockets -fastapi-cache2[redis] pymemcache asyncio asyncpg diff --git a/fastapi/requirements.txt b/fastapi/requirements.txt index 926abaf..15ddf75 100644 --- a/fastapi/requirements.txt +++ b/fastapi/requirements.txt @@ -24,8 +24,6 @@ shapely polyline websockets redis_om -fastapi-redis-cache -fastapi-cache2[redis] asyncio aioredis sqlalchemy[asyncio] diff --git a/fastapi/tests/requirements.txt b/fastapi/tests/requirements.txt index b904197..0eedb66 100644 --- a/fastapi/tests/requirements.txt +++ b/fastapi/tests/requirements.txt @@ -3,4 +3,6 @@ pytest fastapi httpx sqlalchemy -GeoAlchemy2 \ No newline at end of file +GeoAlchemy2 +websockets +pytest-asyncio \ No newline at end of file diff --git a/fastapi/tests/test_endpoints_local.py b/fastapi/tests/test_endpoints_local.py index 0de0b4e..37c4eee 100644 --- a/fastapi/tests/test_endpoints_local.py +++ b/fastapi/tests/test_endpoints_local.py @@ -1,9 +1,13 @@ import os import requests import pytest +import websockets +import asyncio +import json # Set the URL url = 'http://localhost:80' +websocket_url = 'ws://localhost:80' agency_ids = ["LACMTA", "LACMTA_Rail"] @@ -37,3 +41,16 @@ def test_get_vehicle_positions_route_code_geojson(): def test_get_vehicle_positions_route_code_geojson(): response = requests.get(f"{url}/LACMTA_Rail/vehicle_positions/route_code/801?format=geojson") assert response.status_code == 200 +@pytest.mark.asyncio +async def test_websocket_endpoint(): + # Include the agency_id in the URL + websocket_url_with_agency_id = f"{websocket_url}/ws/LACMTA_Rail/vehicle_positions" + async with websockets.connect(websocket_url_with_agency_id) as websocket: + # Wait for a response + response = await websocket.recv() + + # Parse the response + response_data = json.loads(response) + + # Assert that the response is what you expected + assert response_data["type"] == "ping" \ No newline at end of file From 86082fca7844bafaceb392237a7d720754f0af35 Mon Sep 17 00:00:00 2001 From: albertkun Date: Wed, 22 Nov 2023 06:11:28 -0800 Subject: [PATCH 4/4] Update GTFS-RT scheduler interval --- data-loading-service/app/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data-loading-service/app/main.py b/data-loading-service/app/main.py index ace5f13..5cbd331 100644 --- a/data-loading-service/app/main.py +++ b/data-loading-service/app/main.py @@ -10,7 +10,7 @@ import crython # import schedule -@crython.job(second='*/'+str(main_helper.set_interval_time())+'') +@crython.job(second='*/15') def gtfs_rt_scheduler(): try: gtfs_rt_helper.update_gtfs_realtime_data()