forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net_async_base.h
237 lines (196 loc) · 7.22 KB
/
net_async_base.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
#define CAFFE2_CORE_NET_ASYNC_BASE_H_
#include <c10/macros/Macros.h>
#include "c10/core/thread_pool.h"
#include "c10/util/Registry.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/net_dag_utils.h"
#include "caffe2/core/prof_dag_counters.h"
#include "caffe2/core/stats.h"
#include "caffe2/core/timer.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/prof_dag.pb.h"
#include "caffe2/utils/proto_utils.h"
C10_DECLARE_int(caffe2_streams_per_gpu);
C10_DECLARE_int(caffe2_net_async_max_gpus);
C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
C10_DECLARE_int(caffe2_net_async_thread_pool_size);
C10_DECLARE_bool(caffe2_net_async_check_stream_status);
C10_DECLARE_bool(caffe2_net_async_use_single_pool);
C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
C10_DECLARE_bool(caffe2_net_async_run_root_tasks_inline);
C10_DECLARE_bool(caffe2_net_async_profile_operators);
namespace caffe2 {
class AsyncNetExecutorHelper;
namespace tracing {
class Tracer;
}
struct ExecutionOptions {
explicit ExecutionOptions(const std::shared_ptr<const NetDef>& net_def);
// number of gpu streams per gpu per cpu thread
int streams_per_gpu_ = 1;
// ops synchronization options
bool finish_chain_ = false;
bool always_schedule_child_ = false;
// try to pick gpu stream that is not busy
bool check_stream_status_ = false;
// use single thread pool for all devices
bool use_single_pool_ = false;
// use per net instances thread pools instead of global ones
bool use_per_net_pools_ = false;
// whether RunAsync is blocking
bool is_blocking_ = false;
// prof_dag counters reporting
bool report_stats_ = false;
// immediately run children tasks inline whenever possible
bool use_dfs_scheduling_ = false;
// run net's root tasks in RunAsync thread instead of in thread pool
bool run_root_tasks_inline_ = false;
};
struct TORCH_API AsyncNetCancelled : public std::exception {
const char* what() const noexcept override {
return "Cancelled";
}
};
class TORCH_API AsyncNetBase : public NetBase {
public:
AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
~AsyncNetBase() override;
bool SupportsAsync() override {
return true;
}
vector<OperatorBase*> GetOperators() const override {
return operators_;
}
bool RunAsync() override;
const dag_utils::ExecutionChains& TEST_execution_chains() const {
return execution_chains_;
}
ProfDAGProtos GetOperatorStats() const;
ProfDAGProtos GetPerOperatorCost() const;
ProfDAGReport GetProfReport() const;
protected:
bool canSchedule(
int chain_id,
const std::vector<EventStatus>* status = nullptr,
bool* parent_failed = nullptr);
bool canSchedule(int parent_id, int child_id);
int tasksNum() const;
Event& event(int task_id) const;
EventStatus query(int task_id) const;
const std::vector<int>& children(int task_id) const;
const std::vector<int>& parents(int task_id) const;
int updateParentCount(int child_id);
int getParentCount(int child_id);
bool testAndSetScheduled(int task_id);
int numOps(int task_id) const;
int firstTaskOpId(int task_id) const;
int lastTaskOpId(int task_id) const;
const OperatorBase* firstTaskOp(int task_id) const;
const OperatorBase* lastTaskOp(int task_id) const;
OperatorBase* firstTaskOp(int task_id);
OperatorBase* lastTaskOp(int task_id);
void asyncWait(
int task_id,
int stream_id,
const std::vector<int>& wait_task_ids) const;
bool run(int task_id, int stream_id) noexcept;
int stream(int task_id);
TaskThreadPoolBase* pool(const DeviceOption& device_option);
TaskThreadPoolBase* pool();
void finishTasks(const std::unordered_set<int>& task_ids);
void finalizeEvents();
bool isStreamFree(int task_id, int stream_id) const;
virtual void reset();
bool handleRunError() override;
// Operator/task graph
std::vector<OperatorBase*> operators_;
std::vector<dag_utils::OperatorNode> operator_nodes_;
std::vector<std::vector<int>> chains_;
std::vector<dag_utils::OpGraphNode> chain_nodes_; // chains' parents/children
dag_utils::ExecutionChains execution_chains_; // for testing
// Pools and streams
std::mutex pools_mutex_;
// first int key - device id, second - pool size, one pool per (device, size)
typedef std::unordered_map<
int,
std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
PoolsMap;
PoolsMap cpu_pools_;
PoolsMap gpu_pools_;
static std::vector<int>& getStreamCounters();
int num_workers_;
// Exception/error handling
void handleChainError(
int task_id,
OperatorBase* op,
const char* err_msg,
bool save_exception = false) noexcept;
std::atomic<bool> success_;
// Tracing
std::shared_ptr<tracing::Tracer> tracer_;
// execution mode flags
ExecutionOptions options_;
ProfDAGCounters counters_;
C10_DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
private:
TaskThreadPoolBase*
poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
std::unique_ptr<AsyncNetExecutorHelper> helper_;
friend class AsyncNetExecutorHelper;
friend class tracing::Tracer;
};
class AsyncNetExecutorHelper : public ExecutorHelper {
public:
explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}
TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
return net_->pool(option);
}
private:
AsyncNetBase* net_;
};
template <class TaskThreadPoolImpl, int device_type>
std::shared_ptr<TaskThreadPoolBase>
GetAsyncNetThreadPool(int device_id, int pool_size, bool create_new) {
static std::unordered_map<
int,
std::unordered_map<int, std::weak_ptr<TaskThreadPoolBase>>>
pools;
static std::mutex pool_mutex;
const auto& device_type_name = DeviceTypeName(device_type);
if (pool_size <= 0) {
if (FLAGS_caffe2_net_async_thread_pool_size > 0) {
pool_size = FLAGS_caffe2_net_async_thread_pool_size;
LOG(INFO) << "Using default " << device_type_name
<< " pool size: " << pool_size << "; device id: " << device_id;
} else {
auto num_cores = std::thread::hardware_concurrency();
CAFFE_ENFORCE(num_cores > 0, "Failed to get number of CPU cores");
LOG(INFO) << "Using estimated " << device_type_name
<< " pool size: " << num_cores << "; device id: " << device_id;
pool_size = num_cores;
}
} else {
LOG(INFO) << "Using specified " << device_type_name
<< " pool size: " << pool_size << "; device id: " << device_id;
}
if (create_new) {
LOG(INFO) << "Created new " << device_type_name
<< " pool, size: " << pool_size << "; device id: " << device_id;
return std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
} else {
std::lock_guard<std::mutex> lock(pool_mutex);
auto shared_pool = pools[device_id][pool_size].lock();
if (!shared_pool) {
LOG(INFO) << "Created shared " << device_type_name
<< " pool, size: " << pool_size << "; device id: " << device_id;
shared_pool = std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
pools[device_id][pool_size] = shared_pool;
}
return shared_pool;
}
}
} // namespace caffe2
#endif // CAFFE2_CORE_NET_ASYNC_BASE_H_