Skip to content

Commit

Permalink
feat: new columns to StopTimes and TripDepartureTimes models
Browse files Browse the repository at this point in the history
  • Loading branch information
albertkun committed Mar 1, 2024
1 parent 8a2dff5 commit 24c2d86
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 57 deletions.
146 changes: 132 additions & 14 deletions fastapi/app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.encoders import jsonable_encoder
from sqlalchemy.future import select

from sqlalchemy import and_, inspect, cast, Integer,or_
from sqlalchemy import and_, inspect, cast, Integer,or_, any_
from sqlalchemy.orm import joinedload
from sqlalchemy import exists
from sqlalchemy.sql import text
Expand Down Expand Up @@ -593,19 +593,137 @@ def get_stops_id(db, stop_code: str,agency_id: str):
# user_dict = models.User[username]
# return schemas.UserInDB(**user_dict)

def get_trips_data(db,trip_id: str,agency_id: str):
if trip_id == 'list':
the_query = db.query(models.Trips).filter(models.Trips.agency_id == agency_id).all()
result = []
for row in the_query:
result.append(row.trip_id)
return result
elif trip_id == 'all':
the_query = db.query(models.Trips).filter(models.Trips.agency_id == agency_id).all()
return the_query
else:
the_query = db.query(models.Trips).filter(models.Trips.trip_id == trip_id,models.Trips.agency_id == agency_id).all()
return the_query
async def get_geometry_by_shape_id_async(
async_session: Session,
model: Type[DeclarativeMeta],
agency_id: str,
shape_id: str,
cache_expiration: int = None
):
logging.info(f"Executing query for model={model}, agency_id={agency_id}, shape_id={shape_id}")

key = f"{model.__name__}:{agency_id}:{shape_id}"

redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5)

result = await redis.get(key)
if result is not None:
try:
data = pickle.loads(result)
except (pickle.UnpicklingError, AttributeError, EOFError, ImportError, IndexError) as e:
logging.error(f"Error unpickling data from Redis: {e}")
data = None
if data is not None and isinstance(data, list):
return data

with async_session.no_autoflush:
conditions = [
getattr(model, 'shape_id') == shape_id,
getattr(model, 'agency_id') == agency_id
]
stmt = select(model.geometry).where(and_(*conditions))
result = await async_session.execute(stmt)
data = result.scalars().all()

try:
await redis.set(key, pickle.dumps(data), ex=cache_expiration)
except pickle.PicklingError as e:
logging.error(f"Error pickling data for Redis: {e}")

await redis.close()

return data

async def get_stops_by_shape_id_async(
async_session: Session,
model: Type[DeclarativeMeta],
agency_id: str,
shape_id: str,
cache_expiration: int = None
):
logging.info(f"Executing query for model={model}, agency_id={agency_id}, shape_id={shape_id}")

key = f"{model.__name__}:{agency_id}:{shape_id}"

redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5)

result = await redis.get(key)
if result is not None:
try:
data = pickle.loads(result)
except (pickle.UnpicklingError, AttributeError, EOFError, ImportError, IndexError) as e:
logging.error(f"Error unpickling data from Redis: {e}")
data = None
if data is not None and isinstance(data, list):
return data

with async_session.no_autoflush:
conditions = [
getattr(model, 'shape_id') == shape_id,
getattr(model, 'agency_id') == agency_id
]
stmt = select(model.stop_ids).where(and_(*conditions))
result = await async_session.execute(stmt)
data = result.scalars().all()

try:
await redis.set(key, pickle.dumps(data), ex=cache_expiration)
except pickle.PicklingError as e:
logging.error(f"Error pickling data for Redis: {e}")

await redis.close()

return data

def time_to_minutes_past_midnight(time_obj: datetime.time) -> int:
"""Convert a datetime.time object to minutes past midnight."""
return time_obj.hour * 60 + time_obj.minute

from sqlalchemy import func
from sqlalchemy import select, and_, Integer, func
from sqlalchemy.orm import Session
from typing import Type
from sqlalchemy.orm.decl_api import DeclarativeMeta

async def get_trips_by_shape_id_async(
async_session: Session,
model: Type[DeclarativeMeta],
shape_id: str,
agency_id: str
):
conditions = [
getattr(model, 'shape_id') == shape_id,
getattr(model, 'agency_id') == agency_id
]

stmt = select(model).where(and_(*conditions))
result = await async_session.execute(stmt)
return result.scalars().all()

from sqlalchemy import text

async def get_stop_times_by_trip_id_and_time_range_async(
async_session: Session,
model: Type[DeclarativeMeta],
trip_id: str,
time: datetime.time,
agency_id: str
):
conditions = [
getattr(model, 'trip_id') == trip_id,
getattr(model, 'agency_id') == agency_id,
or_(
and_(
getattr(model, 'departure_time_clean') >= time,
getattr(model, 'is_next_day') == False
),
getattr(model, 'is_next_day') == True
)
]

stmt = select(model).where(and_(*conditions))
result = await async_session.execute(stmt)
return result.scalars().all()

def get_agency_data(db, tablename,agency_id):
aliased_table = aliased(tablename)
Expand Down
100 changes: 58 additions & 42 deletions fastapi/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,53 +883,69 @@ async def get_trip_departure_times(
fields = {'route_code': route_code, 'direction_id': direction_id, 'day_type': day_type}
result = await crud.get_data_from_many_fields_async(async_db, model, agency_id, fields)

# Group the results by trip_id
trips = defaultdict(list)
# Convert the current time string to a datetime.time object
current_time = current_time or datetime.now().strftime("%H:%M:%S")
current_time_obj = datetime.strptime(current_time, "%H:%M:%S").time()

# Group the results by shape_id and filter based on current_time
# Group the results by shape_id and filter based on current_time
shapes = defaultdict(set)
for record in result:
trips[record['trip_id']].append(record)
# Iterate over each trip
# Iterate over each trip
filtered_trips = {}
for trip_id, records in trips.items():
if current_time is not None:
# Convert the current time string to a datetime.time object
current_time_obj = datetime.strptime(current_time, "%H:%M:%S").time()

# Filter the records based on the current_time
filtered_records = [record for record in records if record['start_time'] <= current_time_obj <= record['end_time']]

if filtered_records: # Only proceed if there are records within the time range
for record in filtered_records:
# Convert each time in the departure_times list to a datetime.time object
departure_times = [t if isinstance(t, time) else datetime.strptime(t, "%H:%M:%S").time() for t in record['departure_times']]

# If current_time is in departure_times, set it as the closest_time
if current_time_obj in departure_times:
record['closest_time'] = current_time
else:
# Calculate the difference between the current time and each time in the departure_times list
time_diffs = [abs(datetime.combine(date.today(), t) - datetime.combine(date.today(), current_time_obj)) for t in departure_times]

# Find the minimum difference and the corresponding time in the departure_times list
min_diff, closest_time = min(zip(time_diffs, departure_times), key=lambda x: x[0])

# Set the closest_time for the record
record['closest_time'] = closest_time.strftime("%H:%M:%S")

# Update the filtered_trips dictionary with the filtered records
filtered_trips[trip_id] = filtered_records

trips = filtered_trips
result = dict(trips)

# If no trips are found, find the trip with the closest departure time
if not result:
closest_trip = min(trips.items(), key=lambda x: min(abs(datetime.combine(date.today(), t) - datetime.combine(date.today(), current_time_obj)) for t in [t if isinstance(t, time) else datetime.strptime(t, "%H:%M:%S").time() for record in x[1] for t in record['departure_times']]))
result = {closest_trip[0]: closest_trip[1]}
start_time = record['start_time']
end_time = record['end_time']

# Check if the trip goes into the next day
if start_time > end_time:
# The trip goes into the next day
if current_time_obj >= start_time or current_time_obj <= end_time:
shapes[record['shape_id']].update(record['stops'])
else:
# The trip does not go into the next day
if start_time <= current_time_obj <= end_time:
shapes[record['shape_id']].update(record['stops'])

# Convert defaultdict to dict and assign it to result
# Also convert sets to lists
result = {shape_id: list(stops) for shape_id, stops in shapes.items()}

if result is None:
raise HTTPException(status_code=404, detail=f"Data not found for route code {route_code}, day type {day_type}, and direction id {direction_id}")
return result
@app.get("/{agency_id}/shape_info/{shape_id}", tags=["Static data"])
async def get_shape_info(
agency_id: str,
shape_id: str,
time: str,
async_db: AsyncSession = Depends(get_async_db)
):
"""
Get shape info by shape_id and time.
"""
# Convert the time string to a datetime.time object
time_obj = datetime.strptime(time, "%H:%M:%S").time()

# Get all trips for the given shape_id
trips = await crud.get_trips_by_shape_id_async(async_db, models.Trips, shape_id, agency_id)
print(f"Trips: {trips}")

stops = []
# For each trip, get the stop times
for trip in trips:
stop_times = await crud.get_stop_times_by_trip_id_and_time_range_async(async_db, models.StopTimes, trip.trip_id, time_obj, agency_id)
print(f"Stop times for trip {trip.trip_id}: {stop_times}")

# For each stop time, get the stop details
for stop_time in stop_times:
stop = await crud.get_stop_by_id_async(async_db, models.Stops, stop_time.stop_id, agency_id)
print(f"Stop details for stop {stop_time.stop_id}: {stop}")
# Add the stop to the list of stops
stops.append(stop)

# Return the result
return {
'stops': stops
}

@app.get("/calendar_dates",tags=["Static data"])
async def get_calendar_dates_from_db(db: Session = Depends(get_db)):
result = crud.get_calendar_dates(db)
Expand Down
8 changes: 7 additions & 1 deletion fastapi/app/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float,PrimaryKeyConstraint,JSON, join, inspect, Time, TIMESTAMP
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float,PrimaryKeyConstraint,JSON, join, inspect, Time, TIMESTAMP, UniqueConstraint
from sqlalchemy.orm import class_mapper
from sqlalchemy.dialects.postgresql import ARRAY


from geoalchemy2 import *
from geoalchemy2.shape import to_shape
from geoalchemy2.elements import WKBElement
Expand Down Expand Up @@ -61,6 +62,9 @@ class StopTimes(BaseModel):
__tablename__ = "stop_times"
arrival_time = Column(String)
departure_time = Column(String)
arrival_time_clean = Column(Time)
departure_time_clean = Column(Time)
is_next_day = Column(Boolean)
stop_id = Column(Integer, index=True)
stop_sequence = Column(Integer,primary_key=True, index=True)
stop_headsign = Column(String)
Expand Down Expand Up @@ -180,7 +184,9 @@ class TripDepartureTimes(BaseModel):
end_time = Column(Time)
stops = Column(ARRAY(String))
departure_times = Column(ARRAY(String))
stop_ids = Column(ARRAY(String))
is_next_day = Column(Boolean)
__table_args__ = (UniqueConstraint('trip_id', 'route_code', name='trip_id_route_code_key'),)

class RouteStopsGrouped(BaseModel):
__tablename__ = "route_stops_grouped"
Expand Down

0 comments on commit 24c2d86

Please sign in to comment.