Skip to content

Commit

Permalink
Merge pull request #536 from LACMTA:dev
Browse files Browse the repository at this point in the history
Refactor get_route_details function for performance and caching
  • Loading branch information
albertkun authored May 22, 2024
2 parents 00a4473 + 84a55ab commit 8365825
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions fastapi/app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,73 @@ async def get_route_details(db: AsyncSession, route_code: str, direction_id: int
return cached_data


query = text("SELECT * FROM metro_api.get_route_details_with_shape_ids(:p_route_code, :p_direction_id, :p_day_type, :p_input_time, :p_num_results)")
result = db.execute(query, {'p_route_code': route_code, 'p_direction_id': direction_id, 'p_day_type': day_type.value, 'p_input_time': p_time.strftime("%H:%M:%S"), 'p_num_results': num_results})
raw_data = result.fetchall()

stop_times = defaultdict(list)
shape_ids = set()
for row in raw_data:
stop_name, departure_times, shape_id = row
shape_ids.add(shape_id)
for time in departure_times:
if time not in stop_times[stop_name]:
stop_times[stop_name].append(time)
stop_times[stop_name].sort()

# Prepare the list of stop times and shape_ids
stop_times_list = [(stop_name, times, shape_id) for stop_name, times in stop_times.items()]

# Query the trip_shapes table for the geometries of the distinct shape_ids
query = text("SELECT shape_id, ST_AsGeoJSON(geometry) FROM metro_api.trip_shapes WHERE shape_id IN :shape_ids")
result = db.execute(query, {'shape_ids': tuple(shape_ids)})

# Fetch all rows from the result
geometries_result = result.fetchall()

# Process the geometries
geometries = {shape_id: geometry for shape_id, geometry in geometries_result}

# Prepare the debugging information
debug_info = {
'queried_shape_ids': list(shape_ids),
'num_queried_shape_ids': len(shape_ids),
'num_returned_geometries': len(geometries_result),
}

# Prepare the final data
final_data = {
'stop_times': stop_times_list,
'geometries': geometries,
'debug_info': debug_info,
}
try:
await redis.set(key, pickle.dumps(final_data), ex=cache_expiration)
except pickle.PicklingError as e:
logging.error(f"Error pickling data for Redis: {e}")

await redis.close()

return final_data

async def get_route_details_dev(db: AsyncSession, route_code: str, direction_id: int, day_type: str, p_time: str, num_results: int, cache_expiration: Optional[int] = None):
logging.info(f"Executing query for route_code={route_code}, direction_id={direction_id}, day_type={day_type}, time={p_time}, num_results={num_results}")

key = f"get_route_details:{route_code}:{direction_id}:{day_type}:{p_time}:{num_results}"

redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5)

cached_result = await redis.get(key)
if cached_result is not None:
try:
cached_data = pickle.loads(cached_result)
except (pickle.UnpicklingError, AttributeError, EOFError, ImportError, IndexError) as e:
logging.error(f"Error unpickling data from Redis: {e}")
cached_data = None
if cached_data is not None:
return cached_data


query = text("SELECT * FROM metro_api.get_route_details_with_shape_ids(:p_route_code, :p_direction_id, :p_day_type, :p_input_time, :p_num_results)")
result = db.execute(query, {'p_route_code': route_code, 'p_direction_id': direction_id, 'p_day_type': day_type.value, 'p_input_time': p_time.strftime("%H:%M:%S"), 'p_num_results': num_results})
raw_data = result.fetchall()
Expand Down

0 comments on commit 8365825

Please sign in to comment.