-
Notifications
You must be signed in to change notification settings - Fork 173
/
signal.go
220 lines (181 loc) · 5.25 KB
/
signal.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package goka
import (
"fmt"
"sync"
)
// State types a state of the Signal
type State int
type waiter struct {
done chan struct{}
state State
minState bool
}
// StateReader is a read only abstraction of a Signal to expose the current state.
type StateReader interface {
State() State
IsState(State) bool
WaitForStateMin(state State) <-chan struct{}
WaitForState(state State) <-chan struct{}
ObserveStateChange() *StateChangeObserver
}
// Signal allows synchronization on a state, waiting for that state and checking
// the current state
type Signal struct {
m sync.RWMutex
state State
waiters map[*waiter]struct{}
stateChangeObservers []*StateChangeObserver
allowedStates map[State]bool
}
var emptyValue struct{}
// NewSignal creates a new Signal based on the states
func NewSignal(states ...State) *Signal {
s := &Signal{
allowedStates: make(map[State]bool),
waiters: make(map[*waiter]struct{}),
}
for _, state := range states {
s.allowedStates[state] = true
}
return s
}
// SetState changes the state of the signal
// and notifies all goroutines waiting for the new state
func (s *Signal) SetState(state State) *Signal {
s.m.Lock()
defer s.m.Unlock()
if !s.allowedStates[state] {
panic(fmt.Errorf("trying to set illegal state %v", state))
}
// if we're already in the state, do not notify anyone
if s.state == state {
return s
}
// set the state and notify all channels waiting for it.
s.state = state
for w, _ := range s.waiters {
if w.state == state || (w.minState && state >= w.state) {
delete(s.waiters, w)
close(w.done)
continue
}
s.waiters[w] = emptyValue
}
// notify the state change observers
for _, obs := range s.stateChangeObservers {
obs.notify(state)
}
return s
}
// IsState returns if the signal is in the requested state
func (s *Signal) IsState(state State) bool {
s.m.RLock()
defer s.m.RUnlock()
return s.state == state
}
// State returns the current state
func (s *Signal) State() State {
s.m.RLock()
defer s.m.RUnlock()
return s.state
}
// WaitForStateMin returns a channel that will be closed, when the signal enters passed
// state or higher (states are ints, so we're just comparing ints here)
func (s *Signal) WaitForStateMin(state State) <-chan struct{} {
w := &waiter{
done: make(chan struct{}),
state: state,
minState: true,
}
return s.waitForWaiter(state, w)
}
// WaitForStateMinWithCleanup functions identically to WaitForStateMin, but returns a cleanup function in addition
// so that the caller can cleanup resources if it no longer wants to wait
func (s *Signal) WaitForStateMinWithCleanup(state State) (<-chan struct{}, func()) {
w := &waiter{
done: make(chan struct{}),
state: state,
minState: true,
}
cleanup := func() {
s.m.Lock()
defer s.m.Unlock()
delete(s.waiters, w)
}
return s.waitForWaiter(state, w), cleanup
}
// WaitForState returns a channel that closes when the signal reaches passed
// state.
func (s *Signal) WaitForState(state State) <-chan struct{} {
w := &waiter{
done: make(chan struct{}),
state: state,
}
return s.waitForWaiter(state, w)
}
func (s *Signal) waitForWaiter(state State, w *waiter) chan struct{} {
// if the signal is currently in that state (or in a higher state if minState is set)
// then close the waiter immediately
s.m.Lock()
defer s.m.Unlock()
if curState := s.state; state == curState || (w.minState && curState >= state) {
close(w.done)
} else {
s.waiters[w] = emptyValue
}
return w.done
}
// StateChangeObserver wraps a channel that triggers when the signal's state changes
type StateChangeObserver struct {
// state notifier channel
c chan State
// closed is closed when the observer is closed to avoid sending to a closed channel
closed chan struct{}
// stop is a callback to stop the observer
stop func()
}
// Stop stops the observer. Its update channel will be closed and
func (s *StateChangeObserver) Stop() {
s.stop()
}
// C returns the channel to observer state changes
func (s *StateChangeObserver) C() <-chan State {
return s.c
}
func (s *StateChangeObserver) notify(state State) {
select {
case <-s.closed:
case s.c <- state:
}
}
// ObserveStateChange returns a channel that receives state changes.
// Note that the caller must take care of consuming that channel, otherwise the Signal
// will block upon state changes.
func (s *Signal) ObserveStateChange() *StateChangeObserver {
s.m.Lock()
defer s.m.Unlock()
observer := &StateChangeObserver{
c: make(chan State, 1),
closed: make(chan struct{}),
}
// initialize the observer with the current state
observer.notify(s.state)
// the stop funtion stops the observer by closing its channel
// and removing it from the list of observers
observer.stop = func() {
close(observer.closed)
s.m.Lock()
defer s.m.Unlock()
// iterate over all observers and close *this* one
for idx, obs := range s.stateChangeObservers {
if obs == observer {
copy(s.stateChangeObservers[idx:], s.stateChangeObservers[idx+1:])
s.stateChangeObservers[len(s.stateChangeObservers)-1] = nil
s.stateChangeObservers = s.stateChangeObservers[:len(s.stateChangeObservers)-1]
}
}
close(observer.c)
}
s.stateChangeObservers = append(s.stateChangeObservers, observer)
return observer
}