Skip to content

Commit

Permalink
Merge pull request #443 from LACMTA/dev
Browse files Browse the repository at this point in the history
Add websocket endpoint for vehicle positions
  • Loading branch information
albertkun authored Jan 24, 2024
2 parents be2e563 + 86b6788 commit 4d41505
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 30 deletions.
2 changes: 1 addition & 1 deletion fastapi/app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def get_data_async(async_session: Session, model: Type[DeclarativeMeta], a
item.shape_direction_1 = mapping(load_wkb(item.shape_direction_1.desc))
else:
if hasattr(item, 'geometry') and item.geometry is not None:
item.geometry = mapping(load_wkb(bytes(item.geometry)))
item.geometry = mapping(load_wkb(bytes(item.geometry.desc, 'utf-8')))

# Cache the result in Redis with the specified expiration time
try:
Expand Down
80 changes: 51 additions & 29 deletions fastapi/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,26 @@
import os
import aioredis
import asyncio
import async_timeout
import pytz
import time


from datetime import timedelta, date, datetime

from fastapi import FastAPI, Request, Response, Depends, HTTPException, status, Query, WebSocket,WebSocketDisconnect
from fastapi import FastAPI, Request, Response, Depends, HTTPException, status, Query, WebSocket, WebSocketDisconnect
from fastapi import Path as FastAPIPath
# from fastapi import FastAPI, Request, Response, Depends, HTTPException, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, RedirectResponse, HTMLResponse,PlainTextResponse
from fastapi.responses import JSONResponse, RedirectResponse, HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates


from sqlalchemy import false, distinct, inspect
from sqlalchemy.orm import aliased
from sqlalchemy.future import select
from sqlalchemy.exc import SQLAlchemyError

from collections import defaultdict


from pydantic import BaseModel, Json, ValidationError
import functools
import io
Expand All @@ -46,20 +43,11 @@
from starlette.requests import Request
from starlette.responses import Response





# from redi as aioredis
from enum import Enum


# for OAuth2
from fastapi.security import OAuth2PasswordBearer,OAuth2PasswordRequestForm

# from app.models import *
# from app.security import *

# Pagination
from fastapi_pagination import Page, add_pagination, paginate
from fastapi_pagination.ext.sqlalchemy import paginate as paginate_sqlalchemy
Expand Down Expand Up @@ -522,23 +510,57 @@ async def get_trip_detail_by_vehicle(agency_id: AgencyIdEnum, vehicle_id: Option
from shapely.geometry import mapping
from geoalchemy2 import WKBElement


@app.websocket("/ws/{agency_id}/vehicle_positions")
async def websocket_endpoint(websocket: WebSocket, agency_id: str, db: Session = Depends(get_db)):
await websocket.accept()
while True:
try:
# Query the database directly
data = db.query(models.VehiclePositions).filter_by(agency_id=agency_id).all()
# Convert the data to dictionaries
data = [item.to_dict() for item in data]
for item in data:
# Convert WKBElement to GeoJSON
if 'geometry' in item and isinstance(item['geometry'], WKBElement):
item['geometry'] = mapping(to_shape(item['geometry']))
await websocket.send_text(json.dumps(item))
except Exception as e:
await websocket.send_text(f"Error: {str(e)}")
await asyncio.sleep(1) # Sleep for a bit to prevent flooding the client with messages
redis = aioredis.Redis.from_url(Config.REDIS_URL, decode_responses=True)
psub = redis.pubsub()

async def reader(channel: aioredis.client.PubSub):
while True:
try:
async with async_timeout.timeout(1):
message = await channel.get_message(ignore_subscribe_messages=True)
if message is not None:
if message["type"] == "message":
try:
item = json.loads(message['data'])
await websocket.send_text(json.dumps(item))
except Exception as e:
await websocket.send_text(f"Error: {str(e)}")
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
pass

async def publisher():
while True:
try:
# Query the database directly
data = db.query(models.VehiclePositions).filter_by(agency_id=agency_id).all()
# Convert the data to dictionaries
data = [item.to_dict() for item in data]
for item in data:
# Convert WKBElement to GeoJSON
if 'geometry' in item and isinstance(item['geometry'], WKBElement):
item['geometry'] = mapping(to_shape(item['geometry']))
# Only publish items that have a trip_id
if item.get('trip_id') is not None:
await redis.publish(f'vehicle_positions_{agency_id}', json.dumps(item))
await asyncio.sleep(1) # Sleep for a bit to prevent flooding the client with messages
except Exception as e:
print(f"Error: {str(e)}")

# Start the publisher and reader as separate tasks
asyncio.create_task(publisher())

async with psub as p:
await p.subscribe(f'vehicle_positions_{agency_id}')
await reader(p) # wait for reader to complete
await p.unsubscribe(f'vehicle_positions_{agency_id}')

# closing all open connections
await psub.close()

@app.websocket("/ws/{agency_id}/trip_detail/route_code/{route_code}")
async def websocket_endpoint(websocket: WebSocket, agency_id: str, route_code: str, db: AsyncSession = Depends(get_db)):
Expand Down

0 comments on commit 4d41505

Please sign in to comment.