Skip to content

Commit

Permalink
Merge pull request #414 from LACMTA:dev
Browse files Browse the repository at this point in the history
Update Redis connection logic and refactor Redis calls in main.py
  • Loading branch information
albertkun authored Nov 22, 2023
2 parents 474d0e2 + a5bddb1 commit 8c9799d
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 116 deletions.
2 changes: 1 addition & 1 deletion data-loading-service/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import crython
# import schedule
@crython.job(second='*/'+str(main_helper.set_interval_time())+'')
@crython.job(second='*/15')
def gtfs_rt_scheduler():
try:
gtfs_rt_helper.update_gtfs_realtime_data()
Expand Down
63 changes: 42 additions & 21 deletions fastapi/app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import aioredis
import pickle
import time

from sqlalchemy import select

Expand All @@ -51,7 +52,26 @@
from sqlalchemy.orm import Session
from sqlalchemy.orm.decl_api import DeclarativeMeta

redis = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5)

redis_connection = None

def initialize_redis(retries=5, delay=5):
global redis_connection
for i in range(retries):
try:
redis_connection = aioredis.from_url(Config.REDIS_URL, socket_connect_timeout=5)
# If connection is successful, break the loop
if redis_connection.ping():
break
except Exception as e:
print(f"Failed to connect to Redis: {e}")
redis_connection = None
if i < retries - 1: # no delay on the last attempt
time.sleep(delay)
else:
raise Exception("Failed to connect to Redis after several attempts")

initialize_redis()
# import sqlalchemy

def asdict(obj):
Expand All @@ -75,7 +95,7 @@ async def get_data_redis(db, model, id_field, id_value):
key = f'{model.__tablename__}:{id_value}'

# Try to get data from Redis
data = await redis.get(key)
data = await redis_connection.get(key)

if data is None:
# If data is not in Redis, get it from the database
Expand All @@ -90,7 +110,7 @@ async def get_data_redis(db, model, id_field, id_value):
for key, value in row.__dict__.items()
if not key.startswith('_sa_instance_state')
} for row in result])
await redis.set(key, data)
await redis_connection.set(key, data)
else:
# Parse the JSON-formatted string back into a Python data structure
data = json.loads(data)
Expand Down Expand Up @@ -119,13 +139,14 @@ async def get_vehicle_data_async(db: AsyncSession, agency_id: str, vehicle_id: s
return data

import pickle

async def get_data_async(async_session: Session, model: Type[DeclarativeMeta], agency_id: str, field_name: Optional[str] = None, field_value: Optional[str] = None):
# Create a unique key for this query
key = f"{model.__name__}:{agency_id}:{field_name}:{field_value}"

# Try to get the result from Redis
result = await redis.get(key)
if redis_connection is None:
initialize_redis()
result = await redis_connection.get(key)
if result is not None:
data = pickle.loads(result)
if isinstance(data, model):
Expand Down Expand Up @@ -163,7 +184,7 @@ async def get_list_of_unique_values_async(session: AsyncSession, model, agency_i
logging.info(f"Generated key: {key}")

# Try to get the result from Redis
result = await redis.get(key)
result = await redis_connection.get(key)
if result is not None:
logging.info("Found result in Redis")
return pickle.loads(result)
Expand All @@ -186,7 +207,7 @@ async def get_list_of_unique_values_async(session: AsyncSession, model, agency_i
logging.info(f"Unique values from database: {unique_values}")

# Store the result in Redis
await redis.set(key, pickle.dumps(unique_values))
await redis_connection.set(key, pickle.dumps(unique_values))

return unique_values

Expand Down Expand Up @@ -224,7 +245,7 @@ def get_stop_times_by_route_code(db, route_code: str,agency_id: str):
async def get_stop_times_by_trip_id(db, trip_id: str, agency_id: str):
# Try to get the result from Redis first
cache_key = f'stop_times:{trip_id}:{agency_id}'
cached_result = await redis.get(cache_key)
cached_result = await redis_connection.get(cache_key)
if cached_result is not None:
return pickle.loads(cached_result)

Expand All @@ -242,7 +263,7 @@ async def get_stop_times_by_trip_id(db, trip_id: str, agency_id: str):

# If result is not empty, store it in Redis for future use
if result:
await redis.set(cache_key, pickle.dumps(result))
await redis_connection.set(cache_key, pickle.dumps(result))

return result

Expand Down Expand Up @@ -276,7 +297,7 @@ def list_gtfs_rt_vehicle_positions_by_field_name(db, field_name: str,agency_id:
async def get_gtfs_rt_trips_by_field_name(db, field_name: str, field_value: str, agency_id: str):
# Try to get the result from Redis first
cache_key = f'trips:{field_name}:{field_value}:{agency_id}'
cached_result = await redis.get(cache_key)
cached_result = await redis_connection.get(cache_key)
if cached_result is not None:
return pickle.loads(cached_result)

Expand All @@ -296,14 +317,14 @@ async def get_gtfs_rt_trips_by_field_name(db, field_name: str, field_value: str,

# If result is not empty, store it in Redis for future use
if result:
await redis.set(cache_key, pickle.dumps(result))
await redis_connection.set(cache_key, pickle.dumps(result))

return result

async def get_all_gtfs_rt_trips(db, agency_id: str):
# Try to get the result from Redis first
cache_key = f'trips:{agency_id}'
cached_result = await redis.get(cache_key)
cached_result = await redis_connection.get(cache_key)
if cached_result is not None:
return pickle.loads(cached_result)

Expand All @@ -315,15 +336,15 @@ async def get_all_gtfs_rt_trips(db, agency_id: str):

# If result is not empty, store it in Redis for future use
if result:
await redis.set(cache_key, pickle.dumps(result))
await redis_connection.set(cache_key, pickle.dumps(result))

return result

async def get_all_gtfs_rt_vehicle_positions(db, agency_id: str, geojson: bool):
try:
# Try to get the result from Redis first
cache_key = f'vehicle_positions:{agency_id}:{geojson}'
cached_result = await redis.get(cache_key)
cached_result = await redis_connection.get(cache_key)
if cached_result is not None:
return pickle.loads(cached_result)

Expand All @@ -350,7 +371,7 @@ async def get_all_gtfs_rt_vehicle_positions(db, agency_id: str, geojson: bool):

# If result is not empty, store it in Redis for future use
if result:
await redis.set(cache_key, pickle.dumps(result))
await redis_connection.set(cache_key, pickle.dumps(result))

return result
except Exception as e:
Expand Down Expand Up @@ -432,7 +453,7 @@ def _async(db, agency_id: str, geojson: bool):
from sqlalchemy.orm import joinedload
async def get_gtfs_rt_vehicle_positions_trip_data_by_route_code(session: AsyncSession, route_code: str, geojson:bool, agency_id:str):
cache_key = f'trip_data:{route_code}:{agency_id}'
cached_data = await redis.get(cache_key)
cached_data = await redis_connection.get(cache_key)
if cached_data is not None:
return pickle.loads(cached_data)
stmt = (
Expand Down Expand Up @@ -577,7 +598,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data_redis(db, vehicle_id: str):
key = f'vehicle:{vehicle_id}'

# Try to get data from Redis
data = await redis.get(key)
data = await redis_connection.get(key)

if data is None:
# If data is not in Redis, get it from the database
Expand All @@ -588,15 +609,15 @@ async def get_gtfs_rt_vehicle_positions_trip_data_redis(db, vehicle_id: str):

# Convert the result to JSON and store it in Redis
data = json.dumps([dict(row) for row in result])
await redis.set(key, data)
await redis_connection.set(key, data)

return data


async def get_gtfs_rt_vehicle_positions_trip_data(db, vehicle_id: str, geojson: bool, agency_id: str):
# Try to get the result from Redis first
cache_key = f'vehicle_positions:{vehicle_id}:{geojson}:{agency_id}'
result = await redis.get(cache_key)
result = await redis_connection.get(cache_key)
if result is not None:
return pickle.loads(result)

Expand All @@ -617,7 +638,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data(db, vehicle_id: str, geojson:
this_json['type'] = "FeatureCollection"
this_json['features'] = features
if this_json:
await redis.set(cache_key, pickle.dumps(this_json))
await redis_connection.set(cache_key, pickle.dumps(this_json))

return this_json
for row in the_query:
Expand All @@ -641,7 +662,7 @@ async def get_gtfs_rt_vehicle_positions_trip_data(db, vehicle_id: str, geojson:
return message_object
else:
if result:
await redis.set(cache_key, pickle.dumps(result))
await redis_connection.set(cache_key, pickle.dumps(result))

return result

Expand Down
Loading

0 comments on commit 8c9799d

Please sign in to comment.