Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor websocket endpoint and add Redis caching for vehicle positions #458

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions fastapi/app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,29 +346,28 @@ def get_unique_stop_ids(the_query):
return stop_id_list

### websocket endpoint handling
def get_vehicle_positions_for_websocket(db: Session, agency_id: str):
async def get_vehicle_positions_for_websocket(db: Session, agency_id: str):
# Try to get data from Redis cache first
cache_key = f'realtime_vehicle_websocket:{str(filters)}:{agency_id}:{include_stop_time_updates}'
cache_key = f'realtime_vehicle_websocket:{agency_id}'
if redis_connection is None:
initialize_redis()
cached_result = redis_connection.get(cache_key)
if cached_result is not None:
return pickle.loads(cached_result)
return json.loads(cached_result)

# If not in cache, query the database
data = db.query(models.VehiclePositions).filter_by(agency_id=agency_id).all()
data = [item.to_dict() for item in data]
for item in data:
if 'geometry' in item and isinstance(item['geometry'], WKBElement):
item['geometry'] = mapping(to_shape(item['geometry']))
item['geometry'] = mapping(item['geometry'])

# Store the result in Redis cache
redis_connection.set(agency_id, json.dumps(data))
redis_connection.set(cache_key, json.dumps(data), ex=60) # Set an expiration time of 60 seconds

return data



async def get_gtfs_rt_line_detail_updates_for_route_code(session,route_code: str, geojson:bool,agency_id:str):
the_query = await session.execute(select(models.StopTimeUpdates).where(models.StopTimeUpdates.route_code == route_code,models.StopTimeUpdates.agency_id == agency_id))

Expand Down
5 changes: 3 additions & 2 deletions fastapi/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import create_engine,MetaData
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker,scoped_session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.pool import NullPool

Expand All @@ -18,7 +18,8 @@ def create_async_uri(uri):
async_engine = create_async_engine(create_async_uri(Config.API_DB_URI), echo=False, poolclass=NullPool)
async_session = sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)

session_factory = sessionmaker(bind=engine)
scoped_session = scoped_session(session_factory)
# LocalSession = sessionmaker(autocommit=False, autoflush=True, bind=async_engine, expire_on_commit=True)

session = Session()
Expand Down
56 changes: 39 additions & 17 deletions fastapi/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

from .utils.log_helper import *

from .database import Session, AsyncSession, engine, session, get_db,get_async_db
from .database import Session, AsyncSession, engine, session,get_db,get_async_db
from . import crud, models, security, schemas
from .config import Config
from pathlib import Path
Expand Down Expand Up @@ -510,11 +510,11 @@ async def get_trip_detail_by_vehicle(agency_id: AgencyIdEnum, vehicle_id: Option
from shapely.geometry import mapping
from geoalchemy2 import WKBElement


@app.websocket("/ws/{agency_id}/vehicle_positions")
async def websocket_endpoint(websocket: WebSocket, agency_id: str, db: Session = Depends(get_db)):
async def websocket_endpoint(websocket: WebSocket, agency_id: str):
await websocket.accept()
redis = aioredis.Redis.from_url(Config.REDIS_URL, decode_responses=True)

redis = app.state.redis_pool
psub = redis.pubsub()

async def reader(channel: aioredis.client.PubSub):
Expand All @@ -529,25 +529,44 @@ async def reader(channel: aioredis.client.PubSub):
await websocket.send_text(json.dumps(item))
except Exception as e:
await websocket.send_text(f"Error: {str(e)}")
await asyncio.sleep(0.01)
await asyncio.sleep(0.1)
except asyncio.TimeoutError:
pass

async def publisher():
last_data = None
while True:
try:
# Query the database directly
data = db.query(models.VehiclePositions).filter_by(agency_id=agency_id).all()
# Convert the data to dictionaries
data = [item.to_dict() for item in data]
for item in data:
# Convert WKBElement to GeoJSON
if 'geometry' in item and isinstance(item['geometry'], WKBElement):
item['geometry'] = mapping(to_shape(item['geometry']))
# Only publish items that have a trip_id
if item.get('trip_id') is not None:
await redis.publish(f'vehicle_positions_{agency_id}', json.dumps(item))
# Get a new session from the pool
db = Session()
# Try to get data from Redis cache first
cache_key = f'vehicle_positions_cache:{agency_id}'
cached_data = await redis.get(cache_key)
if cached_data is not None:
data = json.loads(cached_data)
else:
# If not in cache, query the database directly
data = db.query(models.VehiclePositions).filter_by(agency_id=agency_id).all()
# Convert the data to dictionaries
data = [item.to_dict() for item in data]
for item in data:
# Convert WKBElement to GeoJSON
if 'geometry' in item and isinstance(item['geometry'], WKBElement):
item['geometry'] = mapping(to_shape(item['geometry']))
# Store the result in Redis cache
await redis.set(cache_key, json.dumps(data), ex=60) # Set an expiration time of 60 seconds

# Only publish the data if it has changed
if data != last_data:
for item in data:
# Only publish items that have a trip_id
if item.get('trip_id') is not None:
await redis.publish(f'vehicle_positions_{agency_id}', json.dumps(item))
last_data = data

await asyncio.sleep(1) # Sleep for a bit to prevent flooding the client with messages
# Close the session
Session.remove()
except Exception as e:
print(f"Error: {str(e)}")

Expand All @@ -561,7 +580,9 @@ async def publisher():

# closing all open connections
await psub.close()

redis.close()
await redis.wait_closed()

@app.websocket("/ws/{agency_id}/trip_detail/route_code/{route_code}")
async def websocket_endpoint(websocket: WebSocket, agency_id: str, route_code: str, db: AsyncSession = Depends(get_db)):
await websocket.accept()
Expand Down Expand Up @@ -991,6 +1012,7 @@ class RedisNotConnected(Exception):
async def startup_event():
try:
crud.redis_connection = crud.initialize_redis()
app.state.redis_pool = aioredis.from_url(Config.REDIS_URL, decode_responses=True)
setup_logging()
except Exception as e:
print(f"Failed to connect to Redis: {e}")
Expand Down
Loading