forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LegacyBatchedFallback.cpp
378 lines (349 loc) · 17.5 KB
/
LegacyBatchedFallback.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
#include <ATen/Context.h>
#include <ATen/LegacyBatchedFallback.h>
#include <ATen/MatrixRef.h>
#include <ATen/LegacyVmapTransforms.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/accumulate.h>
#include <c10/util/llvmMathExtras.h>
#include <c10/util/irange.h>
namespace at {
// Given a linear index, return the actual index.
// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
static SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
computeIndex(int64_t linear_idx, IntArrayRef sizes) {
SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
result.reserve(sizes.size());
for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
auto remainder = linear_idx % *it;
result.push_back(remainder);
linear_idx -= remainder;
linear_idx /= *it;
}
std::reverse(std::begin(result), std::end(result));
return result;
}
static bool areAllReturnsTensors(const FunctionSchema& schema) {
return std::all_of(
schema.returns().begin(),
schema.returns().end(),
[] (const Argument& arg) { return arg.type() == TensorType::get(); });
}
static bool areAnyArgumentsTensorList(const FunctionSchema& schema) {
return std::any_of(
schema.arguments().begin(),
schema.arguments().end(),
[] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); });
}
// Returns if an operator is in-place. An operator is inplace if:
// 1. The first argument is a Tensor and it is being written to
// 2. The first argument is being returned
// 3. No other arguments are aliased
// Here is an example of an in-place operator:
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
static bool isInplaceOp(const c10::FunctionSchema& schema) {
if (!schema.is_mutable() || schema.returns().size() != 1) {
return false;
}
// Check that the first argument is being written to
const AliasInfo* first_arg_alias_info = schema.arguments().begin()->alias_info();
if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) {
return false;
}
// Check that none of the other args are being aliased
for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
const AliasInfo* alias_info = it->alias_info();
if (alias_info) {
return false;
}
}
// Check that the first tensor is being returned (i.e., output has a (a!))
const AliasInfo* return_alias_info = schema.returns()[0].alias_info();
return return_alias_info && return_alias_info->isWrite();
}
static void warnFallback(const c10::FunctionSchema& schema) {
if (!globalContext().areVmapFallbackWarningsEnabled()) {
return;
}
TORCH_WARN("There is a performance drop because we have not yet implemented ",
"the batching rule for ", schema.operator_name(), ". ",
"You are using the legacy vmap prototype (torch._vmap_internals.vmap). ",
"If you are using torch.autograd.functional.{jacobian, hessian} ",
"or torch._vmap_internals.vmap: please switch to using ",
"torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ",
"for better operator coverage and performance improvements .");
}
// The general flow of the algorithm is as follows.
// - First, we figure out which arguments are BatchedTensors and save them
// to a vector. We also store a vector of which index of the arguments list
// each BatchedTensor appears in. This will be useful for bookkeeping later.
// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
// This returns a vector of VmapPhysicalView that hold tensors that contain
// all of the collective batch dimensions at the front of the tensors.
// - Then, we attempt to call `op` once per slice of the inputs. To do this,
// we repeatedly we slice the input arguments (if they are BatchedTensors),
// put the sliced (or a not-sliced) version of the input onto the stack, invoke
// the operator, and then pop the results off the stack.
static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
warnFallback(schema);
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
const auto arguments = torch::jit::last(stack, num_arguments);
const auto arguments_begin = stack->size() - num_arguments;
// `self` is the Tensor being modified in-place
Tensor self = arguments[0].toTensor();
const auto* self_impl = maybeGetBatchedImpl(self);
std::bitset<kVmapMaxTensorDims> self_vmap_levels;
if (self_impl) {
self_vmap_levels = createVmapLevelsBitset(self_impl->bdims());
}
// Figure out which arguments are BatchedTensor. Save them to a vector.
// For each BatchedTensor, also record what position of `arguments` they came from.
SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
VmapDimVector batched_tensor_inputs_position;
for (const auto idx : c10::irange(arguments.size())) {
const auto& ivalue = arguments[idx];
if (!ivalue.isTensor()) {
continue;
}
const auto& tensor = ivalue.toTensor();
if (!tensor.defined()) {
continue;
}
const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
continue;
}
// NOTE: [vmap-incompatible in-place operations]
// In-place operations on `self` are not possible if there exists some vmap
// level `l` such that `self` is not being vmapped on that level but another
// argument is. For example, let B0 be a batch dim inside vmap and consider
// vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
// - self is torch.ones(3) and does not participate in this vmap
// - other is BatchedTensor(torch.ones(B0, 3))
// There's no way to do self.add_(other) because `other` has more elements
// elements than `self` due to being vmapped over.
//
// In the vmap fallback, we should error out when we detect this.
auto other_vmap_levels = createVmapLevelsBitset(batched->bdims());
if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
// Find one vmap level to complain about
auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
[[maybe_unused]] auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
// The following prints out "vmap: aten::add_(tensor, ...) is not possible",
// but it would be better to print out "tensor.add_(...) is not possible".
// Afaict there's no official way to get the add_ and there is no way to
// tell if an operator has method or function variants.
TORCH_CHECK(false,
"vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
"there exists a Tensor `other` in extra_args that has more elements ",
"than `self`. This happened due to `other` being vmapped over but ",
"`self` not being vmapped over at level ", offending_level, ". ",
"Please try to use out-of-place operators instead of ", schema.name(), ". ",
"If said operator is being called inside the PyTorch framework, ",
"please file a bug report instead.");
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
}
TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
// VmapPhysicalViews that contain all of the batch dimensions.
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
batched_tensor_inputs);
// Compute the total number of batches
auto num_batch_dims = input_physical_views.front().numBatchDims();
auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(
first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
const auto num_batches = c10::multiply_integers(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support vmap over dims of size 0.");
// Strategy: For each batch, we are going to push slices (where applicable)
// of the arguments onto `stack`, and call `op`.
for (const auto linear_idx : c10::irange(num_batches)) {
auto index = computeIndex(linear_idx, batch_sizes);
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
auto input_physical_views_iter = input_physical_views.begin();
for (const auto arg_idx : c10::irange(num_arguments)) {
// We assume that torch::jit::Stack is backed by vector<IValue> for
// simplicity. When that is not the case, this code should be updated.
const auto& argument = (*stack)[arguments_begin + arg_idx];
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|| arg_idx != *batched_tensor_inputs_pos_iter) {
// argument isn't a BatchedTensor
torch::jit::push(stack, argument);
continue;
}
// argument is a BatchedTensor
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
const auto& physical_view_for_argument = *input_physical_views_iter;
torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
batched_tensor_inputs_pos_iter++;
input_physical_views_iter++;
}
op.callBoxed(stack);
torch::jit::drop(stack, 1);
}
// Return the tensor that was written to in-place
torch::jit::drop(stack, num_arguments);
torch::jit::push(stack, self);
}
static Tensor safeStack(TensorList tensors) {
auto is_defined = [](const Tensor& t) { return t.defined(); };
if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
return at::stack(tensors);
}
// NOTE [vmap through backward and undefined grad]
// While vmapping through backward functions (to compute batched grad), it
// is possible for the backward function to return an undefined grad for some
// grad_input for each example. In that case, we return an undefined grad.
//
// It is theoretically posssible for *some* of the examples to produce an
// undefined grad (a kernel could peek at the gradient values and return an
// undefined tensor if it determines the gradient is full of zeros). We
// could handle this by treating the undefined grad as a zero-filled tensor
// of the correct shape while stacking the tensors together. However I expect
// this to happen very rarely (I have not been able to find an example in our
// codebase) so we just error out in this case.
if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
return Tensor();
}
TORCH_CHECK(false,
"vmap: slow fallback received a mix of undefined and defined tensors ",
"as the result of an operation. This is not supported, please file us ",
"an issue on github.");
}
// The general flow of the algorithm is as follows.
// - First, we figure out which arguments are BatchedTensors and save them
// to a vector. We also store a vector of which index of the arguments list
// each BatchedTensor appears in. This will be useful for bookkeeping later.
// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
// This returns a vector of VmapPhysicalView that hold tensors that contain
// all of the collective batch dimensions at the front of the tensors.
// - Then, we attempt to call `op` once per slice of the inputs. To do this,
// we repeatedly we slice the input arguments (if they are BatchedTensors),
// put the sliced (or a not-sliced) version of the input onto the stack, invoke
// the operator, and then pop the results off the stack.
// - Each result obtained from the previous step is a slice of the total result,
// so we stack those tensors together to form the final result.
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
if (isInplaceOp(schema)) {
batchedTensorInplaceForLoopFallback(op, stack);
return;
}
TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
"Batching rule not implemented for ", schema.operator_name(), "; ",
"the fallback path doesn't work on out= or view ops.");
TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
"Batching rule not implemented for ", schema.operator_name(), ". ",
"We could not generate a fallback.");
TORCH_CHECK(num_returns >= 1,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support operations with no returns.");
warnFallback(schema);
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
const auto arguments = torch::jit::last(stack, num_arguments);
const auto arguments_begin = stack->size() - num_arguments;
// Figure out which arguments are BatchedTensor. Save them to a vector.
// For each BatchedTensor, also record what position of `arguments` they came from.
SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
VmapDimVector batched_tensor_inputs_position;
for (const auto idx : c10::irange(arguments.size())) {
const auto& ivalue = arguments[idx];
if (!ivalue.isTensor()) {
continue;
}
const auto& tensor = ivalue.toTensor();
if (!tensor.defined()) {
continue;
}
const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
continue;
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
}
TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
// VmapPhysicalViews that contain all of the batch dimensions.
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
batched_tensor_inputs);
// Compute the total number of batches
auto num_batch_dims = input_physical_views.front().numBatchDims();
auto some_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
const auto num_batches = c10::multiply_integers(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support vmap over dims of size 0.");
// Strategy: For each batch, we are going to push slices (where applicable)
// of the arguments onto `stack`, call `op`, and store the result in
// `output_shards`.
//
// NOTE: [Output shards layout]
// Assume that the operator has three outputs: a, b, c.
// The layout of output_shards is as follows:
// [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
// This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
// more easily in the next step.
std::vector<Tensor> output_shards(num_batches * num_returns);
for (const auto linear_idx : c10::irange(num_batches)) {
auto index = computeIndex(linear_idx, batch_sizes);
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
auto input_physical_views_iter = input_physical_views.begin();
for (const auto arg_idx : c10::irange(num_arguments)) {
// We assume that torch::jit::Stack is backed by vector<IValue> for
// simplicity. When that is not the case, this code should be updated.
const auto& argument = (*stack)[arguments_begin + arg_idx];
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|| arg_idx != *batched_tensor_inputs_pos_iter) {
// argument isn't a BatchedTensor
torch::jit::push(stack, argument);
continue;
}
// argument is a BatchedTensor
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
const auto& physical_view_for_argument = *input_physical_views_iter;
torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
batched_tensor_inputs_pos_iter++;
input_physical_views_iter++;
}
op.callBoxed(stack);
// Store the result into `output_shards`. See NOTE: [Output shards layout]
// to learn about the details of how we store the shards.
const auto returns = torch::jit::last(stack, num_returns);
for (const auto return_idx : c10::irange(returns.size())) {
output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
}
torch::jit::drop(stack, num_returns);
}
// For each output Tensor, stack the shards of the tensor together to form a return
torch::jit::drop(stack, num_arguments);
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
for (const auto return_idx : c10::irange(num_returns)) {
auto shards = output_shards_chunks[return_idx];
auto flat_output = safeStack(shards);
// See NOTE [vmap through backward and undefined grad]
if (!flat_output.defined()) {
torch::jit::push(stack, flat_output);
continue;
}
VmapDimVector output_sizes(batch_sizes);
output_sizes.insert(
output_sizes.end(),
flat_output.sizes().begin() + 1,
flat_output.sizes().end());
torch::jit::push(
stack,
input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
}
}
} // namespace at