-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.py
29 lines (24 loc) · 948 Bytes
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
def load_data():
data_root = './dataset'
train = pd.read_csv(data_root + '/train.csv')
test = pd.read_csv(data_root + '/test.csv')
train['Target'] = train['Target'].replace({'Graduate': 1, 'Dropout': 0, 'Enrolled': 2})
# train_df = pd.DataFrame(train)
# test_df = pd.DataFrame(test)
# train = train_df.values
# test = test_df.values
return train, test
# generate heat map of each element
def show_heat_map(data_train, show_image=True):
corr_matrix = data_train.corr()
plt.figure(figsize=(15, 12))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', linewidths=0.5)
plt.title('Heatmap of Correlation Matrix')
if show_image is True:
plt.show()
# If you want to generate heat_map, just uncomment the following code
# train, test = load_data()
# show_heat_map(data_train=train)