-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph_datatypes.py
149 lines (122 loc) · 4.63 KB
/
graph_datatypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import time
from numba import int32, jitclass
@jitclass([
('_parent', int32[:]),
('_size', int32[:]),
('n_clusters', int32),
('_track_cc', int32),
('cc1_root', int32),
('cc2_root', int32)
])
class DisjointSetPlus(object):
def __init__(self, size: int, track_cc: int = 1):
self._parent = np.empty(size, dtype=np.int32)
self._size = np.empty(size, dtype=np.int32)
self.n_clusters = -1
self._track_cc = -1
self.cc1_root = -1
self.cc2_root = -1
self.reset(track_cc)
def reset(self, track_cc):
self._parent.fill(0)
self._size.fill(1)
self._track_cc = track_cc
if track_cc >= 1:
self.cc1_root = 0
if track_cc >= 2: # TODO: > 2?
self.cc2_root = 1
self.n_clusters = len(self._parent)
@property
def track_cc(self):
return self._track_cc
@property
def cc1_size(self) -> int:
return self._size[self.cc1_root]
@property
def cc2_size(self) -> int:
return 0 if self.n_clusters == 1 else self._size[self.cc2_root] # TODO: fix?
def find(self, x: int) -> int:
parent = self._parent
parent_x = parent[x]
while parent_x != x:
grandpa_x = parent[parent_x]
parent[x] = grandpa_x
x, parent_x = parent_x, grandpa_x
return x
def _update_cc1_and_cc2(self, root_x: int, root_y: int, new_size: int) -> None: # size of root_x >= root_y
_size = self._size
if new_size >= _size[self.cc1_root]: # cc1 is replaced
if root_x == self.cc1_root: # cc1 is one of the merged clusters
if root_y == self.cc2_root: # cc1 and cc2 merged, need to find replacement for cc2
_size[root_x] = -1 # temp. remove cc1 from search
self.cc2_root = _size.argmax()
_size[root_x] = new_size
# else: cc1 grows, but cc2 remains
elif new_size > _size[self.cc1_root]: # cc1 is replaced, and it becomes new cc2
self.cc2_root = self.cc1_root
self.cc1_root = root_x
else: # new_size == _size[self.cc1_root]:
# for stability in this case let's keep the old cc1, and the new cluster becomes cc2
self.cc2_root = root_x
elif new_size > _size[self.cc2_root]: # cc2 is replaced
self.cc2_root = root_x
def union(self, x: int, y: int) -> bool:
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
_size = self._size
if _size[root_x] < _size[root_y]:
root_x, root_y = root_y, root_x
new_size = _size[root_x] + _size[root_y]
if self._track_cc == 2:
self._update_cc1_and_cc2(root_x, root_y, new_size)
elif self._track_cc == 1 and new_size > self._size[self.cc1_root]: # cc1 is replaced
self.cc1_root = root_x
self._parent[root_y] = root_x
_size[root_x] = new_size
self.n_clusters -= 1
return True
def get_cluster(self, x):
root = self.find(x)
size = self._size[root]
return root, size
def get_size(self, x):
root = self.find(x)
return self._size[root]
class TimingMessage(object):
def __init__(self, name: str, print_fun=print, print_on_start: str = False):
self.name = name
self._t0 = None
self._print_fun = print_fun
self._print_on_start = print_on_start
self._already_printed = False
def print_raw(self, verb, extra, timed=True):
verb = verb if not timed else f'{verb} ({self.elapsed():.2f}s)'
sep = ': ' if verb and extra else ''
self._print_fun(f'[{self.name}] {verb}{sep}{extra}')
def started(self, msg=''):
self.print_raw('Started', msg, timed=False)
def finished(self, msg=''):
self.print_raw('Finished', msg)
self._already_printed = True
def exception(self, msg):
self.print_raw('Interrupted', msg)
def message(self, msg, verb='Running', timed=True):
self.print_raw(verb, msg, timed=timed)
def print(self, msg, verb='', timed=False):
self.print_raw(verb, msg, timed=timed)
def elapsed(self):
return time.time() - self._t0
def __enter__(self):
if self._print_on_start:
self.started()
self._t0 = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
if not self._already_printed:
self.finished()
else:
self.exception(str(exc_val))