forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FakeQuantPerChannelAffine.cpp
259 lines (224 loc) · 9.24 KB
/
FakeQuantPerChannelAffine.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/FakeQuantAffine.h>
#include <c10/util/irange.h>
// FakeQuantize Op for PerChannelAffine quantization scheme.
namespace at::native {
// Use REGISTER_DISPATCH to run CPU and CUDA backend.
DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub);
DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub);
/* Per channel fake-quantizes the 'inputs' tensor.
Args:
X: Forward input tensor.
dY: Backward input tensor (_backward op only).
scale: scale of per channel affine quantization
zero_point: zero_point of per channel affine quantization
axis: int specifying the axis to be quantized
quant_min: minimum quantized value
quant_max: maximum quantized value
Returns:
Fake quantized tensor (double dtype).
*/
Tensor fake_quantize_per_channel_affine(
const Tensor& self,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
const auto res = at::fake_quantize_per_channel_affine_cachemask(
self, scale, zero_point, axis, quant_min, quant_max);
return std::get<0>(res);
}
std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
const Tensor& self,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
"Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == self.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")
TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");
if(!at::isFloatingType(zero_point.scalar_type())){
TORCH_CHECK(
at::min(zero_point).item().toInt() >= quant_min &&
at::max(zero_point).item().toInt() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");
}
TORCH_CHECK(
axis >= 0 && axis <= self.dim(),
"`axis` must be between 0 and number of dimensions of input");
auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
c10::DimVector expected_shape(self.dim(), 1);
expected_shape[axis] = self.size(axis);
TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(Y)
.add_input(self)
.add_owned_input(native::_unsafe_view(scale, expected_shape))
.add_owned_input(native::_unsafe_view(zero_point, expected_shape))
.build();
// TODO(future, optional): read once, write twice. Not done at the moment
// for simplicity, as we do not expect this to be a bottleneck.
TensorIterator iter_mask = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(mask)
.add_input(self)
.add_owned_input(native::_unsafe_view(scale, expected_shape))
.add_owned_input(native::_unsafe_view(zero_point, expected_shape))
.build();
// TODO(future, optional): look into packing the mask further (BoolTensor uses
// 1 byte per element, we only need 1 bit per element).
fake_quant_per_channel_cachemask_stub(iter.device_type(), iter, iter_mask, quant_min, quant_max);
return std::make_tuple(Y, mask);
}
/* Backward path to fake-quantize the 'inputs' tensor per channel, with mask.
Args:
dY: output grad.
mask: mask tensor from the forward pass.
Returns:
dX (input grad).
*/
Tensor fake_quantize_per_channel_affine_cachemask_backward(
const Tensor& dY,
const Tensor& mask) {
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
TORCH_CHECK(mask.numel() == dY.numel(),
"`mask` and `dY` are not the same size: ",
"`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
if (dY.numel() <= 0) {
return dY;
}
// Note: no additional kernels needed, since mask is pre-computed
// and we can use the existing tensor multiplication kernels.
return dY * mask;
}
static Tensor _get_rounded_zero_point(
const Tensor& zero_point,
int64_t quant_min,
int64_t quant_max) {
// This assumes the per channel zero point vector is single-dimensioned.
return zero_point.round().clamp_(quant_min, quant_max);
}
Tensor _fake_quantize_learnable_per_channel_affine(
const Tensor& self,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
double grad_factor) {
Tensor zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max).to(at::kInt);
return native::fake_quantize_per_channel_affine(
self, scale, zero_point_rounded, axis, quant_min, quant_max);
}
std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_backward(
const Tensor& dY,
const Tensor& X,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
double grad_factor) {
/* The gradients for scale and zero point are calculated as below:
Let Xfq be the fake quantized version of X.
Let Xq be the quantized version of X (clamped at qmin and qmax).
Let Delta and z be the scale and the zero point.
:math:
\frac{d\Delta }{dx} =
\begin{cases}
q_{\min} - z& \text{ if } X_q= q_{\min} \\
q_{\max} - z& \text{ if } X_q= q_{\max} \\
(X_{fq} - X) / \Delta & \text{ else }
\end{cases}
\frac{dz }{dx} =
\begin{cases}
-\Delta& \text{ if } X_q= q_{\min} \text{ or } X_q = q_{\max} \\
0 & \text{ else }
\end{cases}
*/
auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
TORCH_CHECK(
quant_min <= 0 && quant_max >= 0,
"Expecting `quant_min` <= 0 and `quant_max` >= 0");
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == X.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")
TORCH_CHECK(
at::min(zero_point_rounded).item().toLong() >= quant_min &&
at::max(zero_point_rounded).item().toLong() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");
TORCH_CHECK(
axis >= 0 && axis < X.dim(),
"`axis` must be between 0 and number of dimensions of input");
if (X.numel() <= 0) {
return std::make_tuple(X, scale, zero_point);
}
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto numDimensions = X.ndimension();
// Create an axis mask for vectorizing and reshaping the scale and zero point tensors
// into the same shapes as X along the channel axis.
c10::DimVector axis_mask(numDimensions);
for (const auto i : c10::irange(numDimensions)) {
axis_mask[i] = (i == axis) ? X.size(axis) : 1;
}
auto X_shape = X.sizes();
auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
auto iter = TensorIteratorConfig()
.add_output(dX)
.add_output(dScale_vec)
.add_output(dZeroPoint_vec)
.add_input(X)
.add_input(dY)
.add_input(scale_vectorized)
.add_input(zero_point_vectorized)
.build();
fake_quant_grad_learnable_channel_stub(
X.device().type(), iter, quant_min, quant_max, grad_factor);
auto numElements = X.ndimension() - 1;
// Create a collection of axes that include all but the channel axis for
// reduction when summing over the dScale and dZeroPoint tensors.
c10::DimVector axis_for_reduction(numElements);
for (const auto i : c10::irange(axis)) {
axis_for_reduction[i] = i;
}
for (const auto i : c10::irange(axis, numElements)) {
axis_for_reduction[i] = i + 1;
}
auto dScale = dScale_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements));
auto dZeroPoint = dZeroPoint_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements));
return std::make_tuple(dX, dScale, dZeroPoint);
}
} // namespace at::native