From 194a5aa415f5a910c9946d5e905ed73cd768e147 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Thu, 31 Oct 2024 19:11:17 +0000 Subject: [PATCH] modified: app/main.py modified: app/utils.py modified: setup.sh --- app/main.py | 6 +++++- app/utils.py | 17 +++++++++++------ setup.sh | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/app/main.py b/app/main.py index db9cc45..78d3153 100644 --- a/app/main.py +++ b/app/main.py @@ -224,6 +224,7 @@ async def predict_and_explain( surface: str | None = None, event: str | None = None, prob_min: float = 0.01, + top_k: int = 3, ): """ Predict code APE. @@ -247,7 +248,9 @@ async def predict_and_explain( text = [text_description] # model needs a list of strings params = {"additional_var": [1] * len(text)} # TBR - pred, confidence, all_scores = model.predict_and_explain(text, params) + pred, confidence, all_scores, all_scores_letters = model.predict_and_explain( + text, params, top_k=top_k + ) response = process_response_explain( text=text, @@ -257,6 +260,7 @@ async def predict_and_explain( all_scores=all_scores, prob_min=prob_min, libs=libs, + all_scores_letters=all_scores_letters, ) return response diff --git a/app/utils.py b/app/utils.py index 2f89c16..572ffc7 100644 --- a/app/utils.py +++ b/app/utils.py @@ -312,6 +312,7 @@ def process_response_explain( all_scores: list[dict[str, float]], prob_min: float, libs: dict, + all_scores_letters, ): """ Processes model predictions and generates response. @@ -335,16 +336,20 @@ def process_response_explain( detail="The model is not confident enough to make a prediction (and explain it).", ) + k = int(all_scores_letters.shape[0]) # top_k output_dict = { - str(1): { - "code": predictions[liasse_nb][-1].replace("__label__", ""), - "probabilite": float(confidence[liasse_nb][-1]), - "libelle": libs[predictions[liasse_nb][-1].replace("__label__", "")], + str(rank_pred + 1): { + "code": predictions[liasse_nb][-rank_pred - 1].replace("__label__", ""), + "probabilite": float(confidence[liasse_nb][-rank_pred - 1]), + "libelle": libs[predictions[liasse_nb][-rank_pred - 1].replace("__label__", "")], + "letter_attr": np.array(all_scores_letters).tolist()[ + -rank_pred - 1 + ], # Converts numpy arrays to lists + "word_attr": np.array(all_scores[liasse_nb]).tolist()[-rank_pred - 1], } + for rank_pred in range(k) } - output_dict[text[liasse_nb]] = all_scores[liasse_nb] - try: response = output_dict return response diff --git a/setup.sh b/setup.sh index 8f17e16..2be3610 100755 --- a/setup.sh +++ b/setup.sh @@ -8,7 +8,7 @@ pre-commit install export MLFLOW_S3_ENDPOINT_URL="https://$AWS_S3_ENDPOINT" export MLFLOW_TRACKING_URI=https://user-meilametayebjee-mlflow.user.lab.sspcloud.fr export MLFLOW_MODEL_NAME=fasttext-pytorch -export MLFLOW_MODEL_VERSION=7 +export MLFLOW_MODEL_VERSION=8 export API_USERNAME=username export API_PASSWORD=password export AUTH_API=False