Skip to content

Commit

Permalink
Merge pull request #503 from LACMTA:dev
Browse files Browse the repository at this point in the history
Refactor get_data_from_many_fields_async function and add new API endpoint for trip departure times
  • Loading branch information
albertkun authored Feb 29, 2024
2 parents 94800b5 + 0bfd5fe commit da02425
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 28 deletions.
12 changes: 9 additions & 3 deletions fastapi/app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,18 @@ async def get_data_async(async_session: Session, model: Type[DeclarativeMeta], a

return [item.to_dict() for item in data]


async def get_data_from_many_fields_async(async_session: Session, model: Type[DeclarativeMeta], agency_id: str, fields: Optional[Dict[str, Any]] = None, cache_expiration: int = None):
async def get_data_from_many_fields_async(
async_session: Session,
model: Type[DeclarativeMeta],
agency_id: str,
fields: Optional[Dict[str, Any]] = None,
cache_expiration: int = None,
geometry: bool = False
):
# Create a unique key for this query
logging.info(f"Executing query for model={model}, agency_id={agency_id}, fields={fields}")

key = f"{model.__name__}:{agency_id}:{fields}"
key = f"{model.__name__}:{agency_id}:{fields}:{geometry}"

# Create a new Redis connection for each function call
redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5)
Expand Down
75 changes: 56 additions & 19 deletions fastapi/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import pytz
import time

from datetime import timedelta, date, datetime
from datetime import timedelta, date, datetime,time


from fastapi import FastAPI, Request, Response, Depends, HTTPException, status, Query, WebSocket, WebSocketDisconnect
from fastapi import Path as FastAPIPath
Expand Down Expand Up @@ -825,39 +826,75 @@ async def get_trip_shape_stop_times(
raise HTTPException(status_code=404, detail=f"Data not found for route code {route_code}, day type {day_type}, direction id {direction_id}, and time {time}")
return result

@app.get("/{agency_id}/route_stops_grouped/{route_code}", tags=["Static data"])
async def get_route_stops_grouped_by_route_code(
agency_id: AgencyIdEnum,
# @app.get("/{agency_id}/route_stops_grouped/{route_code}", tags=["Static data"])
# async def get_route_stops_grouped_by_route_code(
# agency_id: AgencyIdEnum,
# route_code: str,
# day_type: Optional[str] = None,
# direction_id: Optional[int] = None,
# async_db: AsyncSession = Depends(get_async_db)
# ):
# """
# Get route stops grouped data by route code, day type, and direction id.
# """
# model = models.RouteStopsGrouped
# if route_code.lower() == 'all':
# # Return all routes
# result = await crud.get_all_data_async(async_db, model, agency_id.value)
# elif route_code.lower() == 'list':
# # Return a list of route codes
# result = await crud.get_list_of_unique_values_async(async_db, model, 'route_code', agency_id.value)
# else:
# # Return data for a specific route code, and optionally day type and direction id
# if day_type is None and direction_id is None:
# result = await crud.get_data_async(async_db, model, agency_id.value, 'route_code', route_code)
# else:
# fields = {'route_code': route_code}
# if day_type is not None:
# fields['day_type'] = day_type
# if direction_id is not None:
# fields['direction_id'] = direction_id
# result = await crud.get_data_from_many_fields_async(async_db, model, agency_id.value, fields)
# if result is None:
# raise HTTPException(status_code=404, detail=f"Data not found for route code {route_code}, day type {day_type}, and direction id {direction_id}")
# return result

@app.get("/{agency_id}/trip_departure_times/{route_code}/{direction_id}/{day_type}", tags=["Static data"])
async def get_trip_departure_times(
agency_id: str,
route_code: str,
day_type: Optional[str] = None,
direction_id: Optional[int] = None,
direction_id: int,
day_type: str,
time: Optional[str] = None,
async_db: AsyncSession = Depends(get_async_db)
):
"""
Get route stops grouped data by route code, day type, and direction id.
Get trip departure times data by route code, day type, direction id.
"""
model = models.RouteStopsGrouped
model = models.TripDepartureTimes
if route_code.lower() == 'all':
# Return all routes
result = await crud.get_all_data_async(async_db, model, agency_id.value)
result = await crud.get_all_data_async(async_db, model, agency_id)
elif route_code.lower() == 'list':
# Return a list of route codes
result = await crud.get_list_of_unique_values_async(async_db, model, 'route_code', agency_id.value)
result = await crud.get_list_of_unique_values_async(async_db, model, 'route_code', agency_id)
else:
# Return data for a specific route code, and optionally day type and direction id
if day_type is None and direction_id is None:
result = await crud.get_data_async(async_db, model, agency_id.value, 'route_code', route_code)
fields = {'route_code': route_code, 'direction_id': direction_id, 'day_type': day_type}
result = await crud.get_data_from_many_fields_async(async_db, model, agency_id, fields)

if time is not None:
# Convert the time string to a datetime.time object
time_obj = datetime.strptime(time, "%H:%M:%S").time()
# Filter the result to get the closest time
result = min(result, key=lambda r: abs(datetime.combine(date.today(), r['start_time']) - datetime.combine(date.today(), time_obj)))
else:
fields = {'route_code': route_code}
if day_type is not None:
fields['day_type'] = day_type
if direction_id is not None:
fields['direction_id'] = direction_id
result = await crud.get_data_from_many_fields_async(async_db, model, agency_id.value, fields)
# If no time is provided, return the first record
result = result[0] if result else None

if result is None:
raise HTTPException(status_code=404, detail=f"Data not found for route code {route_code}, day type {day_type}, and direction id {direction_id}")
return result

@app.get("/calendar_dates",tags=["Static data"])
async def get_calendar_dates_from_db(db: Session = Depends(get_db)):
result = crud.get_calendar_dates(db)
Expand Down
27 changes: 21 additions & 6 deletions fastapi/app/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float,PrimaryKeyConstraint,JSON, join, ARRAY, inspect, Time, TIMESTAMP
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float,PrimaryKeyConstraint,JSON, join, inspect, Time, TIMESTAMP
from sqlalchemy.orm import class_mapper
from sqlalchemy.dialects.postgresql import ARRAY

Expand Down Expand Up @@ -167,14 +167,29 @@ class TripShapeStopTimes(BaseModel):
is_next_day = Column(Boolean)
payload = Column(String)

class TripDepartureTimes(BaseModel):
__tablename__ = "trip_departure_times"

trip_id = Column(String, primary_key=True, index=True)
route_code = Column(String)
agency_id = Column(String)
day_type = Column(String)
direction_id = Column(Integer)
shape_id = Column(String)
start_time = Column(Time)
end_time = Column(Time)
stops = Column(ARRAY(String))
departure_times = Column(ARRAY(String))
is_next_day = Column(Boolean)

class RouteStopsGrouped(BaseModel):
__tablename__ = "route_stops_grouped"
route_code = Column(String,primary_key=True, index=True)
payload = Column(JSON)
agency_id = Column(String)
direction_id = Column(Integer)
day_type = Column(String)
shape_direction = Column(Geometry('LINESTRING', srid=4326))
# payload = Column(JSON)
# agency_id = Column(String)
# # direction_id = Column(Integer)
# day_type = Column(String)
# shape_direction = Column(Geometry('LINESTRING', srid=4326))
# shape_direction_0 = Column(Geometry('LINESTRING', srid=4326))
# shape_direction_1 = Column(Geometry('LINESTRING', srid=4326))

Expand Down

0 comments on commit da02425

Please sign in to comment.