Skip to content

Commit

Permalink
Moved model loading step out of predict_labels_with_n
Browse files Browse the repository at this point in the history
Refactored code to pass in the trip model directly to predict_labels_with_n() in eamur.

Moved the load model step to eacil.inferrers by using load_model() of eamur.

Modified TestRunGreedyModel to use this refactored function.
  • Loading branch information
Mahadik, Mukul Chandrakant authored and Mahadik, Mukul Chandrakant committed Dec 1, 2023
1 parent 425966a commit 64232ab
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
18 changes: 17 additions & 1 deletion emission/analysis/classification/inference/labels/inferrers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import random
import copy
import time
import arrow

import emission.analysis.modelling.tour_model_first_only.load_predict as lp
import emission.analysis.modelling.trip_model.run_model as eamur
Expand Down Expand Up @@ -164,7 +166,21 @@ def predict_cluster_confidence_discounting(trip_list, max_confidence=None, first
# load application config
model_type = eamtc.get_model_type()
model_storage = eamtc.get_model_storage()
labels_n_list = eamur.predict_labels_with_n(trip_list, model_type, model_storage)

# assert and fetch unique user id for trip_list
user_id_list = []
for trip in trip_list:
user_id_list.append(trip['user_id'])
assert user_id_list.count(user_id_list[0]) == len(user_id_list), "Multiple user_ids found for trip_list, expected unique user_id for all trips"
# Assertion successful, use unique user_id
user_id = user_id_list[0]

# load model
start_model_load_time = time.process_time()
model = eamur._load_stored_trip_model(user_id, model_type, model_storage)
print(f"{arrow.now()} Inside predict_labels_n: Model load time = {time.process_time() - start_model_load_time}")

labels_n_list = eamur.predict_labels_with_n(trip_list, model)
predictions_list = []
for labels, n in labels_n_list:
if n <= 0: # No model data or trip didn't match a cluster
Expand Down
18 changes: 2 additions & 16 deletions emission/analysis/modelling/trip_model/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,15 @@ def update_trip_model(

def predict_labels_with_n(
trip_list: List[ecwc.Confirmedtrip],
model_type = eamumt.ModelType.GREEDY_SIMILARITY_BINNING,
model_storage = eamums.ModelStorage.DOCUMENT_DATABASE,
model_config = None):
model: eamuu.TripModel):
"""
invoke the user label prediction model to predict labels for a trip.
:param trip_list: the list of trips to predict labels for
:param model_type: type of prediction model to run
:param model_storage: location to read/write models
:param model_config: optional configuration for model, for debugging purposes
:param model: trip model used for predictions
:return: a list of predictions
"""

user_id_list = []
for trip in trip_list:
user_id_list.append(trip['user_id'])
assert user_id_list.count(user_id_list[0]) == len(user_id_list), "Multiple user_ids found for trip_list, expected unique user_id for all trips"
# Assertion successful, use unique user_id
user_id = user_id_list[0]

start_model_load_time = time.process_time()
model = _load_stored_trip_model(user_id, model_type, model_storage, model_config)
print(f"{arrow.now()} Inside predict_labels_n: Model load time = {time.process_time() - start_model_load_time}")
predictions_list = []
print(f"{arrow.now()} Inside predict_labels_n: Predicting...")
start_predict_time = time.process_time()
Expand Down
11 changes: 8 additions & 3 deletions emission/tests/modellingTests/TestRunGreedyModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,18 @@ def test1RoundTripGreedySimilarityBinning(self):
origin=self.origin,
destination=self.destination
)
predictions_list = eamur.predict_labels_with_n(
trip_list = [test],

model = eamur._load_stored_trip_model(
user_id,
model_type=eamumt.ModelType.GREEDY_SIMILARITY_BINNING,
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
model_config=greedy_model_config
)

predictions_list = eamur.predict_labels_with_n(
trip_list = [test],
model = model
)

for prediction, n in predictions_list:
[logging.debug(p) for p in sorted(prediction, key=lambda r: r['p'], reverse=True)]
Expand Down

0 comments on commit 64232ab

Please sign in to comment.