Skip to content

Commit

Permalink
Merge pull request #513 from LACMTA:dev
Browse files Browse the repository at this point in the history
Refactor endpoints and remove unnecessary endpoints
  • Loading branch information
albertkun authored Mar 19, 2024
2 parents 851ba8e + fac6fd2 commit a7ded39
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 95 deletions.
10 changes: 5 additions & 5 deletions data-loading-service/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def retry_on_failure(task, retries=5, delay=15):
time.sleep(delay)
raise Exception('Task failed after all retries') # If all retries fail, raise an exception

# @crython.job(second='*/15')
# def gtfs_rt_scheduler():
# if not lock.locked():
# with lock:
# asyncio.run(retry_on_failure(gtfs_rt_helper.update_gtfs_realtime_data))
@crython.job(second='*/15')
def gtfs_rt_scheduler():
if not lock.locked():
with lock:
asyncio.run(retry_on_failure(gtfs_rt_helper.update_gtfs_realtime_data))

@crython.job(expr='@daily')
def go_pass_data_scheduler():
Expand Down
17 changes: 11 additions & 6 deletions data-loading-service/app/utils/gtfs_rt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# from ..gtfs_rt import *
# from ..models import *

import json
import re
import requests
import pandas as pd
import geopandas as gpd
Expand Down Expand Up @@ -136,11 +136,15 @@ async def update_gtfs_realtime_data():
]

for endpoint in websocket_endpoints:
match = re.search(r'/ws/(.*?)/', endpoint)
if match:
agency = match.group(1)
async with websockets.connect(endpoint) as websocket:
feed = FeedMessage()
response_data = await websocket.recv()
if not response_data:
break
print(f"No response data from {endpoint}. Skipping this endpoint.")
continue
feed.ParseFromString(response_data)

trip_update_array = []
Expand Down Expand Up @@ -218,7 +222,9 @@ async def update_gtfs_realtime_data():

vehicle_position_updates_gdf = gpd.GeoDataFrame(vehicle_position_updates, geometry=gpd.points_from_xy(vehicle_position_updates.position_longitude, vehicle_position_updates.position_latitude))
combined_vehicle_position_dataframes.append(vehicle_position_updates_gdf)
# logging('vehicle_position_updates Data Frame: ' + str(vehicle_position_updates))
process_dataframes_and_update_db(combined_trip_update_dataframes, combined_stop_time_dataframes, combined_vehicle_position_dataframes)

def process_dataframes_and_update_db(combined_trip_update_dataframes, combined_stop_time_dataframes, combined_vehicle_position_dataframes):
combined_trip_update_df = pd.concat(combined_trip_update_dataframes)
combined_stop_time_df = pd.concat(combined_stop_time_dataframes)
combined_vehicle_position_df = gpd.GeoDataFrame(pd.concat(combined_vehicle_position_dataframes, ignore_index=True), geometry='geometry')
Expand Down Expand Up @@ -246,10 +252,9 @@ async def update_gtfs_realtime_data():
del combined_stop_time_dataframes
del combined_vehicle_position_dataframes


if __name__ == "__main__":
process_start = timeit.default_timer()
# update_gtfs_realtime_data()
# process_start = timeit.default_timer()
update_gtfs_realtime_data()
process_end = timeit.default_timer()
session.close()
print('Process took {} seconds'.format(process_end - process_start))
69 changes: 9 additions & 60 deletions fastapi/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@

from enum import Enum

from shapely import wkt
from geojson import LineString

# for OAuth2
from fastapi.security import OAuth2PasswordBearer,OAuth2PasswordRequestForm

Expand Down Expand Up @@ -524,7 +527,7 @@ async def get_trip_detail_by_vehicle(agency_id: AgencyIdEnum, vehicle_id: Option

connected_clients = 0

@app.router.get("/ws/{agency_id}/{endpoint}/{route_codes}")
@app.router.get("/ws/{agency_id}/{endpoint}/{route_codes}" ,tags=["Real-Time data"])
async def dummy_websocket_endpoint(agency_id: str, endpoint: str, route_codes: Optional[str] = None):
"""
Dummy HTTP endpoint for WebSocket documentation.
Expand Down Expand Up @@ -642,7 +645,8 @@ async def reader(channel: aioredis.client.PubSub):
await psub.close()
redis.close()
await redis.wait_closed()
@app.router.get("/ws/{agency_id}/{endpoint}/{route_codes}")

@app.router.get("/ws/{agency_id}/{endpoint}/{route_codes}" ,tags=["Real-Time data"])
async def dummy_websocket_endpoint(agency_id: str, endpoint: str, route_codes: Optional[str] = None):
"""
Dummy HTTP endpoint for WebSocket documentation.
Expand Down Expand Up @@ -813,61 +817,6 @@ async def populate_route_stops(agency_id: AgencyIdEnum,route_code:str, daytype:
json_compatible_item_data = jsonable_encoder(result)
return JSONResponse(content=json_compatible_item_data)

@app.get("/{agency_id}/trip_departure_times/{route_code}/{direction_id}/{day_type}", tags=["Static data"])
async def get_trip_departure_times(
agency_id: str,
route_code: str,
direction_id: int,
day_type: str,
current_time: Optional[str] = None,
async_db: AsyncSession = Depends(get_async_db)
):
"""
Get trip departure times data by route code, day type, direction id.
"""
model = models.TripDepartureTimes
if route_code.lower() == 'all':
# Return all routes
result = await crud.get_all_data_async(async_db, model, agency_id)
elif route_code.lower() == 'list':
# Return a list of route codes
result = await crud.get_list_of_unique_values_async(async_db, model, 'route_code', agency_id)
else:
# Return data for a specific route code, and optionally day type and direction id
fields = {'route_code': route_code, 'direction_id': direction_id, 'day_type': day_type}
result = await crud.get_data_from_many_fields_async(async_db, model, agency_id, fields)

# Convert the current time string to a datetime.time object
current_time = current_time or datetime.now().strftime("%H:%M:%S")
current_time_obj = datetime.strptime(current_time, "%H:%M:%S").time()

# Group the results by shape_id and filter based on current_time
# Group the results by shape_id and filter based on current_time
shapes = defaultdict(set)
for record in result:
start_time = record['start_time']
end_time = record['end_time']

# Check if the trip goes into the next day
if start_time > end_time:
# The trip goes into the next day
if current_time_obj >= start_time or current_time_obj <= end_time:
shapes[record['shape_id']].update(record['stops'])
else:
# The trip does not go into the next day
if start_time <= current_time_obj <= end_time:
shapes[record['shape_id']].update(record['stops'])

# Convert defaultdict to dict and assign it to result
# Also convert sets to lists
result = {shape_id: list(stops) for shape_id, stops in shapes.items()}

if result is None:
raise HTTPException(status_code=404, detail=f"Data not found for route code {route_code}, day type {day_type}, and direction id {direction_id}")
return result
from shapely import wkt
from geojson import LineString

@app.get("/{agency_id}/route_details/{route_code}", tags=["Static data"])
async def route_details_endpoint(
agency_id: str,
Expand Down Expand Up @@ -1083,7 +1032,7 @@ async def get_gopass_schools(db: AsyncSession = Depends(get_async_db), show_miss
json_compatible_item_data = jsonable_encoder(result)
return JSONResponse(content=json_compatible_item_data)

@app.get("/time")
@app.get("/time", tags=["User Methods"])
async def get_time():
current_time = datetime.now()
return {current_time}
Expand All @@ -1092,7 +1041,7 @@ async def get_time():
# async def root():


@app.get("/",response_class=HTMLResponse)
@app.get("/",response_class=HTMLResponse, tags=["User Methods"])
def index(request:Request):
# return templates.TemplateResponse("index.html",context={"request":request})
human_readable_default_update = None
Expand Down Expand Up @@ -1165,7 +1114,7 @@ def read_user(username: str, db: Session = Depends(get_db),token: str = Depends(
return db_user


@app.get("/routes", response_model=List[str])
@app.get("/routes", response_model=List[str], tags=["Static Data"])
async def get_all_routes():
return [route.path for route in app.routes]

Expand Down
31 changes: 7 additions & 24 deletions fastapi/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,33 +172,16 @@ class TripShapeStopTimes(BaseModel):
is_next_day = Column(Boolean)
payload = Column(String)

class TripDepartureTimes(BaseModel):
__tablename__ = "trip_departure_times"

trip_id = Column(String, primary_key=True, index=True)
route_code = Column(String)
agency_id = Column(String)
day_type = Column(String)
direction_id = Column(Integer)
shape_id = Column(String)
start_time = Column(Time)
end_time = Column(Time)
stops = Column(ARRAY(String))
departure_times = Column(ARRAY(String))
stop_ids = Column(ARRAY(String))
is_next_day = Column(Boolean)
__table_args__ = (UniqueConstraint('trip_id', 'route_code', name='trip_id_route_code_key'),)

class RouteStopsGrouped(BaseModel):
__tablename__ = "route_stops_grouped"
route_code = Column(String,primary_key=True, index=True)
# payload = Column(JSON)
# agency_id = Column(String)
# # direction_id = Column(Integer)
# day_type = Column(String)
# shape_direction = Column(Geometry('LINESTRING', srid=4326))
# shape_direction_0 = Column(Geometry('LINESTRING', srid=4326))
# shape_direction_1 = Column(Geometry('LINESTRING', srid=4326))
payload = Column(JSON)
agency_id = Column(String)
# direction_id = Column(Integer)
day_type = Column(String)
shape_direction = Column(Geometry('LINESTRING', srid=4326))
shape_direction_0 = Column(Geometry('LINESTRING', srid=4326))
shape_direction_1 = Column(Geometry('LINESTRING', srid=4326))

class TripShapes(Base):
__tablename__ = "trip_shapes"
Expand Down

0 comments on commit a7ded39

Please sign in to comment.