Skip to content

Commit

Permalink
Merge pull request #515 from LACMTA:2023-api-optimization
Browse files Browse the repository at this point in the history
refactor: route_stops and trip_shapes endpoints
  • Loading branch information
albertkun authored Apr 3, 2024
2 parents fac6fd2 + f1b49ed commit 16af494
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 40 deletions.
51 changes: 12 additions & 39 deletions fastapi/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from .config import Config
from pathlib import Path

from logzio.handler import LogzioHandler
import logging
import typing as t

Expand Down Expand Up @@ -812,8 +811,12 @@ async def get_canceled_trip(db: Session = Depends(get_db)):

### GTFS Static data ###
@app.get("/{agency_id}/route_stops/{route_code}",tags=["Static data"])
async def populate_route_stops(agency_id: AgencyIdEnum,route_code:str, daytype: DayTypesEnum = DayTypesEnum.all, db: Session = Depends(get_db)):
result = crud.get_gtfs_route_stops(db,route_code,daytype.value,agency_id.value)
async def populate_route_stops(agency_id: AgencyIdEnum, route_code:int, daytype: DayTypesEnum = DayTypesEnum.all, async_session: Session = Depends(get_async_db)):
if daytype.value != 'all':
result = await crud.get_data_async(async_session, models.RouteStops, agency_id.value, 'route_code', route_code, cache_expiration=60*60*24*7)
result = [item for item in result if item['day_type'] == daytype.value]
else:
result = await crud.get_data_async(async_session, models.RouteStops, agency_id.value, 'route_code', route_code, cache_expiration=60*60*24*7)
json_compatible_item_data = jsonable_encoder(result)
return JSONResponse(content=json_compatible_item_data)

Expand Down Expand Up @@ -917,8 +920,8 @@ async def get_stops(agency_id: AgencyIdEnum,stop_id, db: Session = Depends(get_d
return result

@app.get("/{agency_id}/trips/{trip_id}",tags=["Static data"])
async def get_bus_trips(agency_id: AgencyIdEnum,trip_id, db: Session = Depends(get_db)):
result = crud.get_trips_data(db,trip_id,agency_id.value)
async def get_bus_trips(agency_id: AgencyIdEnum, trip_id, async_session: Session = Depends(get_async_db)):
result = await crud.get_data_async(async_session, models.Trips, agency_id.value, 'trip_id', trip_id)
return result

@app.get("/{agency_id}/shapes/{shape_id}",tags=["Static data"])
Expand All @@ -930,28 +933,15 @@ async def get_shapes(agency_id: AgencyIdEnum,shape_id, geojson: bool = False,db:
return result

@app.get("/{agency_id}/trip_shapes/{shape_id}",tags=["Static data"])
async def get_trip_shapes(agency_id: AgencyIdEnum,shape_id, db: Session = Depends(get_db)):
async def get_trip_shapes(agency_id: AgencyIdEnum, shape_id, async_session: Session = Depends(get_async_db)):
if shape_id == "all":
result = crud.get_trip_shapes_all(db,agency_id.value)
result = crud.get_all_data_async(async_session, models.TripShapes, agency_id.value)
elif shape_id == "list":
result = crud.get_trip_shapes_list(db,agency_id.value)
result = crud.get_list_of_unique_values_async(async_session, models.TripShapes, 'shape_id', agency_id.value)
else:
result = crud.get_trip_shape(db,shape_id,agency_id.value)
result = await crud.get_data_async(async_session, models.TripShapes, agency_id.value, 'shape_id', shape_id)
return result

@app.get("/{agency_id}/trip_shapes/routes/{route_code}", tags=["Dynamic data"])
async def get_route_details(agency_id: AgencyIdEnum, route_code: str, db: Session = Depends(get_db)):
# Query the database for all trip_shapes associated with the route
trip_shapes = crud.get_trip_shapes_for_route(db, route_code, agency_id.value)

# For each trip_shape, query the database for all associated stops
for shape in trip_shapes:
shape_dict = {key: value for key, value in shape.__dict__.items() if not key.startswith('_')}
shape_dict['stops'] = crud.get_stops_for_trip_shape(db, shape_dict['shape_id'], agency_id.value)

# Return the trip_shapes and their associated stops
return trip_shapes

@app.get("/{agency_id}/calendar/{service_id}",tags=["Static data"])
async def get_calendar_list(agency_id: AgencyIdEnum,service_id, db: Session = Depends(get_db)):
if service_id == "list":
Expand Down Expand Up @@ -1123,23 +1113,6 @@ def setup_logging():
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_error_logger = logging.getLogger("uvicorn.error")
logger = logging.getLogger("uvicorn.app")
logzio_formatter = logging.Formatter("%(message)s")

logzio_uvicorn_access_handler = LogzioHandler(Config.LOGZIO_TOKEN, 'uvicorn.access', 5, Config.LOGZIO_URL)
logzio_uvicorn_access_handler.setLevel(logging.ERROR) # Set level to ERROR
logzio_uvicorn_access_handler.setFormatter(logzio_formatter)

logzio_uvicorn_error_handler = LogzioHandler(Config.LOGZIO_TOKEN, 'uvicorn.error', 5, Config.LOGZIO_URL)
logzio_uvicorn_error_handler.setLevel(logging.ERROR) # Set level to ERROR
logzio_uvicorn_error_handler.setFormatter(logzio_formatter)

logzio_app_handler = LogzioHandler(Config.LOGZIO_TOKEN, 'fastapi.app', 5, Config.LOGZIO_URL)
logzio_app_handler.setLevel(logging.ERROR) # Set level to ERROR
logzio_app_handler.setFormatter(logzio_formatter)

uvicorn_access_logger.addHandler(logzio_uvicorn_access_handler)
uvicorn_error_logger.addHandler(logzio_uvicorn_error_handler)
logger.addHandler(logzio_app_handler)

uvicorn_access_logger.addFilter(LogFilter())
uvicorn_error_logger.addFilter(LogFilter())
Expand Down
9 changes: 8 additions & 1 deletion fastapi/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class RouteStops(Base):
# latitude = Column(Float)
# longitude = Column(Float)
agency_id = Column(String)
def to_dict(self):
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}

class TripShapeStopTimes(BaseModel):
__tablename__ = "trip_shape_stop_times"
Expand Down Expand Up @@ -188,6 +190,8 @@ class TripShapes(Base):
shape_id = Column(String, primary_key=True, index=True)
geometry = Column(Geometry('LINESTRING', srid=4326))
agency_id = Column(String)
def to_dict(self):
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}

class Shapes(Base):
__tablename__ = "shapes"
Expand All @@ -198,7 +202,8 @@ class Shapes(Base):
geometry = Column(Geometry('POINT', srid=4326))
shape_pt_sequence = Column(Integer)
agency_id = Column(String)

def to_dict(self):
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}

class Trips(Base):
__tablename__ = "trips"
Expand All @@ -211,6 +216,8 @@ class Trips(Base):
shape_id = Column(String)
trip_id_event = Column(String)
agency_id = Column(String)
def to_dict(self):
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}

class TripShapeStops(Base):
__tablename__ = "trip_shape_stops"
Expand Down

0 comments on commit 16af494

Please sign in to comment.