forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
observer.h
164 lines (135 loc) · 3.72 KB
/
observer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#pragma once
#include <memory>
#include <unordered_set>
#include "caffe2/core/logging.h"
namespace caffe2 {
/**
* Use this to implement a Observer using the Observer Pattern template.
*/
template <class T>
class ObserverBase {
public:
explicit ObserverBase(T* subject) : subject_(subject) {}
virtual void Start() {}
virtual void Stop() {}
virtual std::string debugInfo() {
return "Not implemented.";
}
virtual ~ObserverBase() noexcept {};
T* subject() const {
return subject_;
}
virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
const {
return nullptr;
};
protected:
T* subject_;
};
/**
* Inherit to make your class observable.
*/
template <class T>
class Observable {
public:
Observable() = default;
Observable(Observable&&) = default;
Observable& operator =(Observable&&) = default;
virtual ~Observable() = default;
C10_DISABLE_COPY_AND_ASSIGN(Observable);
using Observer = ObserverBase<T>;
/* Returns a reference to the observer after addition. */
const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
std::unordered_set<const Observer*> observers;
for (auto& ob : observers_list_) {
observers.insert(ob.get());
}
const auto* observer_ptr = observer.get();
if (observers.count(observer_ptr)) {
return observer_ptr;
}
observers_list_.push_back(std::move(observer));
UpdateCache();
return observer_ptr;
}
/**
* Returns a unique_ptr to the removed observer. If not found, return a
* nullptr
*/
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
if (it->get() == observer_ptr) {
auto res = std::move(*it);
observers_list_.erase(it);
UpdateCache();
return res;
}
}
return nullptr;
}
virtual size_t NumObservers() {
return num_observers_;
}
private:
inline static void StartObserver(Observer* observer) {
try {
observer->Start();
} catch (const std::exception& e) {
LOG(ERROR) << "Exception from observer: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception from observer: unknown";
}
}
inline static void StopObserver(Observer* observer) {
try {
observer->Stop();
} catch (const std::exception& e) {
LOG(ERROR) << "Exception from observer: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception from observer: unknown";
}
}
void UpdateCache() {
num_observers_ = observers_list_.size();
if (num_observers_ != 1) {
// we cannot take advantage of the cache
return;
}
observer_cache_ = observers_list_[0].get();
}
public:
void StartAllObservers() {
// do not access observers_list_ unless necessary
if (num_observers_ == 0) {
return;
} else if (num_observers_ == 1) {
StartObserver(observer_cache_);
} else {
for (auto& observer : observers_list_) {
StartObserver(observer.get());
}
}
}
void StopAllObservers() {
// do not access observers_list_ unless necessary
if (num_observers_ == 0) {
return;
} else if (num_observers_ == 1) {
StopObserver(observer_cache_);
} else {
for (auto& observer : observers_list_) {
StopObserver(observer.get());
}
}
}
private:
// an on-stack cache for fast iteration;
// ideally, inside StartAllObservers and StopAllObservers,
// we should never access observers_list_
Observer* observer_cache_;
size_t num_observers_ = 0;
protected:
std::vector<std::unique_ptr<Observer>> observers_list_;
};
} // namespace caffe2