From 24c2d86526fc82689213efd9a8481e5eca1f7cee Mon Sep 17 00:00:00 2001 From: albertkun Date: Thu, 29 Feb 2024 23:49:21 -0800 Subject: [PATCH] feat: new columns to StopTimes and TripDepartureTimes models --- fastapi/app/crud.py | 146 ++++++++++++++++++++++++++++++++++++++---- fastapi/app/main.py | 100 +++++++++++++++++------------ fastapi/app/models.py | 8 ++- 3 files changed, 197 insertions(+), 57 deletions(-) diff --git a/fastapi/app/crud.py b/fastapi/app/crud.py index 09d2f70..a5b9402 100644 --- a/fastapi/app/crud.py +++ b/fastapi/app/crud.py @@ -7,7 +7,7 @@ from fastapi.encoders import jsonable_encoder from sqlalchemy.future import select -from sqlalchemy import and_, inspect, cast, Integer,or_ +from sqlalchemy import and_, inspect, cast, Integer,or_, any_ from sqlalchemy.orm import joinedload from sqlalchemy import exists from sqlalchemy.sql import text @@ -593,19 +593,137 @@ def get_stops_id(db, stop_code: str,agency_id: str): # user_dict = models.User[username] # return schemas.UserInDB(**user_dict) -def get_trips_data(db,trip_id: str,agency_id: str): - if trip_id == 'list': - the_query = db.query(models.Trips).filter(models.Trips.agency_id == agency_id).all() - result = [] - for row in the_query: - result.append(row.trip_id) - return result - elif trip_id == 'all': - the_query = db.query(models.Trips).filter(models.Trips.agency_id == agency_id).all() - return the_query - else: - the_query = db.query(models.Trips).filter(models.Trips.trip_id == trip_id,models.Trips.agency_id == agency_id).all() - return the_query +async def get_geometry_by_shape_id_async( + async_session: Session, + model: Type[DeclarativeMeta], + agency_id: str, + shape_id: str, + cache_expiration: int = None +): + logging.info(f"Executing query for model={model}, agency_id={agency_id}, shape_id={shape_id}") + + key = f"{model.__name__}:{agency_id}:{shape_id}" + + redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5) + + result = await redis.get(key) + if result is not None: + try: + data = pickle.loads(result) + except (pickle.UnpicklingError, AttributeError, EOFError, ImportError, IndexError) as e: + logging.error(f"Error unpickling data from Redis: {e}") + data = None + if data is not None and isinstance(data, list): + return data + + with async_session.no_autoflush: + conditions = [ + getattr(model, 'shape_id') == shape_id, + getattr(model, 'agency_id') == agency_id + ] + stmt = select(model.geometry).where(and_(*conditions)) + result = await async_session.execute(stmt) + data = result.scalars().all() + + try: + await redis.set(key, pickle.dumps(data), ex=cache_expiration) + except pickle.PicklingError as e: + logging.error(f"Error pickling data for Redis: {e}") + + await redis.close() + + return data + +async def get_stops_by_shape_id_async( + async_session: Session, + model: Type[DeclarativeMeta], + agency_id: str, + shape_id: str, + cache_expiration: int = None +): + logging.info(f"Executing query for model={model}, agency_id={agency_id}, shape_id={shape_id}") + + key = f"{model.__name__}:{agency_id}:{shape_id}" + + redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5) + + result = await redis.get(key) + if result is not None: + try: + data = pickle.loads(result) + except (pickle.UnpicklingError, AttributeError, EOFError, ImportError, IndexError) as e: + logging.error(f"Error unpickling data from Redis: {e}") + data = None + if data is not None and isinstance(data, list): + return data + + with async_session.no_autoflush: + conditions = [ + getattr(model, 'shape_id') == shape_id, + getattr(model, 'agency_id') == agency_id + ] + stmt = select(model.stop_ids).where(and_(*conditions)) + result = await async_session.execute(stmt) + data = result.scalars().all() + + try: + await redis.set(key, pickle.dumps(data), ex=cache_expiration) + except pickle.PicklingError as e: + logging.error(f"Error pickling data for Redis: {e}") + + await redis.close() + + return data + +def time_to_minutes_past_midnight(time_obj: datetime.time) -> int: + """Convert a datetime.time object to minutes past midnight.""" + return time_obj.hour * 60 + time_obj.minute + +from sqlalchemy import func +from sqlalchemy import select, and_, Integer, func +from sqlalchemy.orm import Session +from typing import Type +from sqlalchemy.orm.decl_api import DeclarativeMeta + +async def get_trips_by_shape_id_async( + async_session: Session, + model: Type[DeclarativeMeta], + shape_id: str, + agency_id: str +): + conditions = [ + getattr(model, 'shape_id') == shape_id, + getattr(model, 'agency_id') == agency_id + ] + + stmt = select(model).where(and_(*conditions)) + result = await async_session.execute(stmt) + return result.scalars().all() + +from sqlalchemy import text + +async def get_stop_times_by_trip_id_and_time_range_async( + async_session: Session, + model: Type[DeclarativeMeta], + trip_id: str, + time: datetime.time, + agency_id: str +): + conditions = [ + getattr(model, 'trip_id') == trip_id, + getattr(model, 'agency_id') == agency_id, + or_( + and_( + getattr(model, 'departure_time_clean') >= time, + getattr(model, 'is_next_day') == False + ), + getattr(model, 'is_next_day') == True + ) + ] + + stmt = select(model).where(and_(*conditions)) + result = await async_session.execute(stmt) + return result.scalars().all() def get_agency_data(db, tablename,agency_id): aliased_table = aliased(tablename) diff --git a/fastapi/app/main.py b/fastapi/app/main.py index e137b25..8b4ef4e 100644 --- a/fastapi/app/main.py +++ b/fastapi/app/main.py @@ -883,53 +883,69 @@ async def get_trip_departure_times( fields = {'route_code': route_code, 'direction_id': direction_id, 'day_type': day_type} result = await crud.get_data_from_many_fields_async(async_db, model, agency_id, fields) - # Group the results by trip_id - trips = defaultdict(list) + # Convert the current time string to a datetime.time object + current_time = current_time or datetime.now().strftime("%H:%M:%S") + current_time_obj = datetime.strptime(current_time, "%H:%M:%S").time() + + # Group the results by shape_id and filter based on current_time + # Group the results by shape_id and filter based on current_time + shapes = defaultdict(set) for record in result: - trips[record['trip_id']].append(record) - # Iterate over each trip - # Iterate over each trip - filtered_trips = {} - for trip_id, records in trips.items(): - if current_time is not None: - # Convert the current time string to a datetime.time object - current_time_obj = datetime.strptime(current_time, "%H:%M:%S").time() - - # Filter the records based on the current_time - filtered_records = [record for record in records if record['start_time'] <= current_time_obj <= record['end_time']] - - if filtered_records: # Only proceed if there are records within the time range - for record in filtered_records: - # Convert each time in the departure_times list to a datetime.time object - departure_times = [t if isinstance(t, time) else datetime.strptime(t, "%H:%M:%S").time() for t in record['departure_times']] - - # If current_time is in departure_times, set it as the closest_time - if current_time_obj in departure_times: - record['closest_time'] = current_time - else: - # Calculate the difference between the current time and each time in the departure_times list - time_diffs = [abs(datetime.combine(date.today(), t) - datetime.combine(date.today(), current_time_obj)) for t in departure_times] - - # Find the minimum difference and the corresponding time in the departure_times list - min_diff, closest_time = min(zip(time_diffs, departure_times), key=lambda x: x[0]) - - # Set the closest_time for the record - record['closest_time'] = closest_time.strftime("%H:%M:%S") - - # Update the filtered_trips dictionary with the filtered records - filtered_trips[trip_id] = filtered_records - - trips = filtered_trips - result = dict(trips) - - # If no trips are found, find the trip with the closest departure time - if not result: - closest_trip = min(trips.items(), key=lambda x: min(abs(datetime.combine(date.today(), t) - datetime.combine(date.today(), current_time_obj)) for t in [t if isinstance(t, time) else datetime.strptime(t, "%H:%M:%S").time() for record in x[1] for t in record['departure_times']])) - result = {closest_trip[0]: closest_trip[1]} + start_time = record['start_time'] + end_time = record['end_time'] + + # Check if the trip goes into the next day + if start_time > end_time: + # The trip goes into the next day + if current_time_obj >= start_time or current_time_obj <= end_time: + shapes[record['shape_id']].update(record['stops']) + else: + # The trip does not go into the next day + if start_time <= current_time_obj <= end_time: + shapes[record['shape_id']].update(record['stops']) + + # Convert defaultdict to dict and assign it to result + # Also convert sets to lists + result = {shape_id: list(stops) for shape_id, stops in shapes.items()} if result is None: raise HTTPException(status_code=404, detail=f"Data not found for route code {route_code}, day type {day_type}, and direction id {direction_id}") return result +@app.get("/{agency_id}/shape_info/{shape_id}", tags=["Static data"]) +async def get_shape_info( + agency_id: str, + shape_id: str, + time: str, + async_db: AsyncSession = Depends(get_async_db) +): + """ + Get shape info by shape_id and time. + """ + # Convert the time string to a datetime.time object + time_obj = datetime.strptime(time, "%H:%M:%S").time() + + # Get all trips for the given shape_id + trips = await crud.get_trips_by_shape_id_async(async_db, models.Trips, shape_id, agency_id) + print(f"Trips: {trips}") + + stops = [] + # For each trip, get the stop times + for trip in trips: + stop_times = await crud.get_stop_times_by_trip_id_and_time_range_async(async_db, models.StopTimes, trip.trip_id, time_obj, agency_id) + print(f"Stop times for trip {trip.trip_id}: {stop_times}") + + # For each stop time, get the stop details + for stop_time in stop_times: + stop = await crud.get_stop_by_id_async(async_db, models.Stops, stop_time.stop_id, agency_id) + print(f"Stop details for stop {stop_time.stop_id}: {stop}") + # Add the stop to the list of stops + stops.append(stop) + + # Return the result + return { + 'stops': stops + } + @app.get("/calendar_dates",tags=["Static data"]) async def get_calendar_dates_from_db(db: Session = Depends(get_db)): result = crud.get_calendar_dates(db) diff --git a/fastapi/app/models.py b/fastapi/app/models.py index ea8421f..c6ed004 100644 --- a/fastapi/app/models.py +++ b/fastapi/app/models.py @@ -1,7 +1,8 @@ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float,PrimaryKeyConstraint,JSON, join, inspect, Time, TIMESTAMP +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float,PrimaryKeyConstraint,JSON, join, inspect, Time, TIMESTAMP, UniqueConstraint from sqlalchemy.orm import class_mapper from sqlalchemy.dialects.postgresql import ARRAY + from geoalchemy2 import * from geoalchemy2.shape import to_shape from geoalchemy2.elements import WKBElement @@ -61,6 +62,9 @@ class StopTimes(BaseModel): __tablename__ = "stop_times" arrival_time = Column(String) departure_time = Column(String) + arrival_time_clean = Column(Time) + departure_time_clean = Column(Time) + is_next_day = Column(Boolean) stop_id = Column(Integer, index=True) stop_sequence = Column(Integer,primary_key=True, index=True) stop_headsign = Column(String) @@ -180,7 +184,9 @@ class TripDepartureTimes(BaseModel): end_time = Column(Time) stops = Column(ARRAY(String)) departure_times = Column(ARRAY(String)) + stop_ids = Column(ARRAY(String)) is_next_day = Column(Boolean) + __table_args__ = (UniqueConstraint('trip_id', 'route_code', name='trip_id_route_code_key'),) class RouteStopsGrouped(BaseModel): __tablename__ = "route_stops_grouped"