-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathplotting.py
30 lines (25 loc) · 799 Bytes
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
def plot(embeds, labels, fig_path='./example.pdf'):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')
# Create a sphere
r = 1
pi = np.pi
cos = np.cos
sin = np.sin
phi, theta = np.mgrid[0.0:pi:100j, 0.0:2.0*pi:100j]
x = r*sin(phi)*cos(theta)
y = r*sin(phi)*sin(theta)
z = r*cos(phi)
ax.plot_surface(
x, y, z, rstride=1, cstride=1, color='w', alpha=0.3, linewidth=0)
ax.scatter(embeds[:,0], embeds[:,1], embeds[:,2], c=labels, s=20)
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
ax.set_aspect("equal")
plt.tight_layout()
plt.savefig(fig_path)