Skip to content

Commit

Permalink
modified: app/main.py
Browse files Browse the repository at this point in the history
	modified:   app/utils.py
	modified:   setup.sh
  • Loading branch information
meilame-tayebjee committed Oct 31, 2024
1 parent 4bff8f0 commit 194a5aa
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
6 changes: 5 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ async def predict_and_explain(
surface: str | None = None,
event: str | None = None,
prob_min: float = 0.01,
top_k: int = 3,
):
"""
Predict code APE.
Expand All @@ -247,7 +248,9 @@ async def predict_and_explain(
text = [text_description] # model needs a list of strings
params = {"additional_var": [1] * len(text)} # TBR

pred, confidence, all_scores = model.predict_and_explain(text, params)
pred, confidence, all_scores, all_scores_letters = model.predict_and_explain(
text, params, top_k=top_k
)

response = process_response_explain(
text=text,
Expand All @@ -257,6 +260,7 @@ async def predict_and_explain(
all_scores=all_scores,
prob_min=prob_min,
libs=libs,
all_scores_letters=all_scores_letters,
)

return response
Expand Down
17 changes: 11 additions & 6 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def process_response_explain(
all_scores: list[dict[str, float]],
prob_min: float,
libs: dict,
all_scores_letters,
):
"""
Processes model predictions and generates response.
Expand All @@ -335,16 +336,20 @@ def process_response_explain(
detail="The model is not confident enough to make a prediction (and explain it).",
)

k = int(all_scores_letters.shape[0]) # top_k
output_dict = {
str(1): {
"code": predictions[liasse_nb][-1].replace("__label__", ""),
"probabilite": float(confidence[liasse_nb][-1]),
"libelle": libs[predictions[liasse_nb][-1].replace("__label__", "")],
str(rank_pred + 1): {
"code": predictions[liasse_nb][-rank_pred - 1].replace("__label__", ""),
"probabilite": float(confidence[liasse_nb][-rank_pred - 1]),
"libelle": libs[predictions[liasse_nb][-rank_pred - 1].replace("__label__", "")],
"letter_attr": np.array(all_scores_letters).tolist()[
-rank_pred - 1
], # Converts numpy arrays to lists
"word_attr": np.array(all_scores[liasse_nb]).tolist()[-rank_pred - 1],
}
for rank_pred in range(k)
}

output_dict[text[liasse_nb]] = all_scores[liasse_nb]

try:
response = output_dict
return response
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pre-commit install
export MLFLOW_S3_ENDPOINT_URL="https://$AWS_S3_ENDPOINT"
export MLFLOW_TRACKING_URI=https://user-meilametayebjee-mlflow.user.lab.sspcloud.fr
export MLFLOW_MODEL_NAME=fasttext-pytorch
export MLFLOW_MODEL_VERSION=7
export MLFLOW_MODEL_VERSION=8
export API_USERNAME=username
export API_PASSWORD=password
export AUTH_API=False
Expand Down

0 comments on commit 194a5aa

Please sign in to comment.