-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsom.py
75 lines (59 loc) · 2.46 KB
/
som.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import math
import random
class SOM(object):
def __init__(self, dataset, nodes, sigma_0, learn_rate_0, sigma_timeconst, learn_timeconst, total_iterations, plot_interval, graphics):
self.total_iterations = total_iterations
self.iteration = 1
self.dataset = dataset
self.nodes = nodes
self.sigma_0 = sigma_0
self.learn_rate_0 = learn_rate_0
self.sigma_timeconst = sigma_timeconst
self.learn_timeconst = learn_timeconst
self.plot_interval = plot_interval
self.graphics = graphics
def train(self):
while self.iteration <= self.total_iterations:
#print("Iteration {}".format(self.iteration))
datapoint = random.choice(self.dataset)
closest_node = self.closest_node(datapoint)
self.update_nodes(closest_node, datapoint)
self.iteration += 1
if(self.iteration % self.plot_interval == 0):
self.graphics.draw_frame(self, self.iteration)
print("Initial neighbourhood {}".format(self.sigma_0))
print("Initial learning rate {}".format(self.learn_rate_0))
print("Learn rate {}".format(self.learn_rate()))
print("Sigma {}".format(self.sigma()))
def sigma(self):
return self.sigma_0 * math.exp(-self.iteration / self.sigma_timeconst)
def learn_rate(self):
return self.learn_rate_0 * math.exp(-self.iteration / self.learn_timeconst)
def closest_node(self, datapoint):
mindist = float("inf")
winning_node = None
for node in self.nodes:
dist = node.dist(datapoint)
if dist < mindist:
mindist = dist
winning_node = node
return winning_node
def save(self, filename, type):
import json
nodes_dict = []
for node in self.nodes:
nodes_dict.append(node.serialize())
output = {}
output['nodes'] = nodes_dict
output['type'] = type
output['classes'] = self.nodes[0].nodes_per_dim
with open(filename, 'w') as file:
file.write(json.dumps(output))
def update_nodes(self, center_node, datapoint):
learn_rate = self.learn_rate()
sigma = self.sigma()
#print("Learn rate: {}\nSigma: {}".format(learn_rate,sigma))
for node in self.nodes:
node.update_weight(center_node, datapoint, learn_rate, sigma)
def report(self):
raise NotImplementedError