-
Notifications
You must be signed in to change notification settings - Fork 284
/
Copy pathvisualization_metrics.py
102 lines (77 loc) · 3.32 KB
/
visualization_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Time-series Generative Adversarial Networks (TimeGAN) Codebase.
Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
"Time-series Generative Adversarial Networks,"
Neural Information Processing Systems (NeurIPS), 2019.
Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks
Last updated Date: April 24th 2020
Code author: Jinsung Yoon ([email protected])
-----------------------------
visualization_metrics.py
Note: Use PCA or tSNE for generated and original data visualization
"""
# Necessary packages
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
def visualization (ori_data, generated_data, analysis):
"""Using PCA or tSNE for generated and original data visualization.
Args:
- ori_data: original data
- generated_data: generated synthetic data
- analysis: tsne or pca
"""
# Analysis sample size (for faster computation)
anal_sample_no = min([1000, len(ori_data)])
idx = np.random.permutation(len(ori_data))[:anal_sample_no]
# Data preprocessing
ori_data = np.asarray(ori_data)
generated_data = np.asarray(generated_data)
ori_data = ori_data[idx]
generated_data = generated_data[idx]
no, seq_len, dim = ori_data.shape
for i in range(anal_sample_no):
if (i == 0):
prep_data = np.reshape(np.mean(ori_data[0,:,:], 1), [1,seq_len])
prep_data_hat = np.reshape(np.mean(generated_data[0,:,:],1), [1,seq_len])
else:
prep_data = np.concatenate((prep_data,
np.reshape(np.mean(ori_data[i,:,:],1), [1,seq_len])))
prep_data_hat = np.concatenate((prep_data_hat,
np.reshape(np.mean(generated_data[i,:,:],1), [1,seq_len])))
# Visualization parameter
colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)]
if analysis == 'pca':
# PCA Analysis
pca = PCA(n_components = 2)
pca.fit(prep_data)
pca_results = pca.transform(prep_data)
pca_hat_results = pca.transform(prep_data_hat)
# Plotting
f, ax = plt.subplots(1)
plt.scatter(pca_results[:,0], pca_results[:,1],
c = colors[:anal_sample_no], alpha = 0.2, label = "Original")
plt.scatter(pca_hat_results[:,0], pca_hat_results[:,1],
c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic")
ax.legend()
plt.title('PCA plot')
plt.xlabel('x-pca')
plt.ylabel('y_pca')
plt.show()
elif analysis == 'tsne':
# Do t-SNE Analysis together
prep_data_final = np.concatenate((prep_data, prep_data_hat), axis = 0)
# TSNE anlaysis
tsne = TSNE(n_components = 2, verbose = 1, perplexity = 40, n_iter = 300)
tsne_results = tsne.fit_transform(prep_data_final)
# Plotting
f, ax = plt.subplots(1)
plt.scatter(tsne_results[:anal_sample_no,0], tsne_results[:anal_sample_no,1],
c = colors[:anal_sample_no], alpha = 0.2, label = "Original")
plt.scatter(tsne_results[anal_sample_no:,0], tsne_results[anal_sample_no:,1],
c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic")
ax.legend()
plt.title('t-SNE plot')
plt.xlabel('x-tsne')
plt.ylabel('y_tsne')
plt.show()