Skip to content

Commit

Permalink
modif of predict_model
Browse files Browse the repository at this point in the history
  • Loading branch information
LouiseDurandJanin committed Aug 21, 2023
1 parent 1b86827 commit 1b72094
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions src/models/predict_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@

import click
import joblib
from imblearn.metrics import classification_report_imbalanced
import pandas as pd

# Load your saved model
loaded_model = joblib.load("src/models/trained_model.joblib")

@click.command()
def main():
# Load your training dataset (replace 'train_data.csv' with your file)
X_train = pd.read_csv("data/preprocessed/X_train.csv")

# Get feature names from X_train columns
feature_names = X_train.columns.tolist()

features = {}

# Load the trained model from the file
model_filename = 'trained_model.joblib'
loaded_model = joblib.load(model_filename)
# Get user input for each feature
for feature_name in feature_names:
feature_value = click.prompt(f"Enter value for {feature_name}", type=float)
features[feature_name] = feature_value

y_pred = loaded_model.predict(X_test)
# Predict using the model
result = predict_model(features)
print("Prediction:", result)

def predict_model(features):
input_df = pd.DataFrame([features])
prediction = loaded_model.predict(input_df)
return prediction

if __name__ == "__main__":
main()

print(f1_score(y_test, y_pred_rf))
pd.crosstab(y_test, y_pred, rownames=['Classe réelle'], colnames=['Classe prédite'])
print(classification_report_imbalanced(y_test, y_pred))

0 comments on commit 1b72094

Please sign in to comment.