Skip to content

Commit

Permalink
modif of train_model
Browse files Browse the repository at this point in the history
  • Loading branch information
LouiseDurandJanin committed Aug 21, 2023
1 parent 1b72094 commit b37c6cc
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions src/models/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@
import sklearn
import pandas as pd
from sklearn import ensemble
from sklearn.metrics import f1_score, make_scorer
from sklearn.model_selection import train_test_split
import joblib
from sklearn.model_selection import cross_val_score, GridSearchCV

df = pd.read_csv('data/preprocessed/preprocessed.csv')

target = df['grav']
feats = df.drop(['grav'], axis = 1)

X_train, X_test, y_train, y_test = train_test_split(feats, target, test_size=0.3, random_state = 42)

X_train = pd.read_csv('data/preprocessed/X_train.csv')
X_test = pd.read_csv('data/preprocessed/X_test.csv')
y_train = pd.read_csv('data/preprocessed/y_train.csv')
y_test = pd.read_csv('data/preprocessed/y_test.csv')

rf_classifier = ensemble.RandomForestClassifier(n_jobs = -1, n_estimators= 100)

# Perform the grid search on the data
#--Train the model
rf_classifier.fit(X_train, y_train)

# Save the trained model to a file
model_filename = 'trained_model.joblib'
#--Save the trained model to a file
model_filename = 'C:/Users/lenov/Documents/Template_MLOps_accidents/src/models/trained_model.joblib'
joblib.dump(rf_classifier, model_filename)
print("Model trained and saved successfully.")

0 comments on commit b37c6cc

Please sign in to comment.