1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/platform/macros.h"
22
23 namespace tensorflow {
24 namespace {
25
26 // Gymnastics with nudged zero point is to ensure that the real zero maps to
27 // an integer, which is required for e.g. zero-padding in convolutional layers.
CpuNudge(const float min,const float max,const float quant_min,const float quant_max,float * nudged_min,float * nudged_max,float * scale)28 void CpuNudge(const float min, const float max, const float quant_min,
29 const float quant_max, float* nudged_min, float* nudged_max,
30 float* scale) {
31 *scale = (max - min) / (quant_max - quant_min);
32
33 const float zero_point_from_min = quant_min - min / *scale;
34 float nudged_zero_point;
35 if (zero_point_from_min <= quant_min) {
36 nudged_zero_point = quant_min;
37 } else if (zero_point_from_min >= quant_max) {
38 nudged_zero_point = quant_max;
39 } else {
40 nudged_zero_point = std::round(zero_point_from_min);
41 }
42
43 *nudged_min = (quant_min - nudged_zero_point) * (*scale);
44 *nudged_max = (quant_max - nudged_zero_point) * (*scale);
45 }
46
47 // An XLA version of CpuNudge().
XlaNudge(xla::XlaBuilder * b,const DataType data_type,const xla::XlaOp & min,const xla::XlaOp & max,const float quant_min_value,const float quant_max_value,xla::XlaOp * nudged_min,xla::XlaOp * nudged_max,xla::XlaOp * scale)48 void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
49 const xla::XlaOp& min, const xla::XlaOp& max,
50 const float quant_min_value, const float quant_max_value,
51 xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
52 xla::XlaOp* scale) {
53 *scale = xla::Div(xla::Sub(max, min),
54 XlaHelpers::FloatLiteral(
55 b, data_type, quant_max_value - quant_min_value));
56 xla::XlaOp quant_min =
57 XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
58 xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale));
59 xla::XlaOp quant_max =
60 XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
61 xla::XlaOp nudged_zero_point =
62 xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min,
63 xla::Select(xla::Ge(zero_point_from_min, quant_max),
64 quant_max, xla::Round(zero_point_from_min)));
65 *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale);
66 *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale);
67 }
68
Quantize(xla::XlaBuilder * b,const xla::XlaOp & input,const DataType data_type,const xla::XlaOp & nudged_input_min,const xla::XlaOp & nudged_input_max,const xla::XlaOp & input_scale)69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
70 const DataType data_type,
71 const xla::XlaOp& nudged_input_min,
72 const xla::XlaOp& nudged_input_max,
73 const xla::XlaOp& input_scale) {
74 xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
75 xla::XlaOp inv_scale = xla::Div(one, input_scale);
76 xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
77
78 xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max);
79 xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min);
80 xla::XlaOp rounded =
81 xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half));
82 return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
83 }
84
85 // Builds a custom_call to a method named 'fake_quant_with_min_max_vars'.
86 // The method will be provided the input, the min/max range from the original
87 // TensorFlow op, and the num_bits and narrow_range attributes.
BuildFakeQuantCustomCall(xla::XlaBuilder * b,xla::XlaOp input,xla::XlaOp input_min,xla::XlaOp input_max,int num_bits,bool narrow_range)88 xla::StatusOr<xla::XlaOp> BuildFakeQuantCustomCall(
89 xla::XlaBuilder* b, xla::XlaOp input, xla::XlaOp input_min,
90 xla::XlaOp input_max, int num_bits, bool narrow_range) {
91 xla::XlaOp num_bits_arg =
92 XlaHelpers::IntegerLiteral(b, DataType::DT_INT32, num_bits);
93 xla::XlaOp narrow_range_arg = narrow_range
94 ? XlaHelpers::One(b, DataType::DT_BOOL)
95 : XlaHelpers::Zero(b, DataType::DT_BOOL);
96
97 std::vector<xla::XlaOp> args = {input, input_min, input_max, num_bits_arg,
98 narrow_range_arg};
99 std::vector<xla::Shape> arg_shapes;
100 for (const xla::XlaOp& arg : args) {
101 TF_ASSIGN_OR_RETURN(xla::Shape arg_shape, b->GetShape(arg));
102 *arg_shape.mutable_layout() =
103 xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
104 arg_shapes.push_back(std::move(arg_shape));
105 }
106
107 // Input and output shapes match exactly.
108 TF_ASSIGN_OR_RETURN(xla::Shape output_shape, b->GetShape(input));
109
110 return xla::CustomCallWithLayout(b, "fake_quant_with_min_max_vars", args,
111 output_shape, arg_shapes);
112 }
113
114 class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
115 public:
FakeQuantWithMinMaxArgsOp(OpKernelConstruction * ctx)116 explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
117 : XlaOpKernel(ctx) {
118 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
119 OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
120 errors::InvalidArgument("num_bits is out of range, expected "
121 "between 2 and 16, was: ",
122 num_bits_));
123 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
124 quant_min_ = narrow_range_ ? 1 : 0;
125 quant_max_ = (1 << num_bits_) - 1;
126
127 OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min_));
128 OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max_));
129 CpuNudge(input_min_, input_max_, quant_min_, quant_max_, &nudged_input_min_,
130 &nudged_input_max_, &input_scale_);
131 }
132
Compile(XlaOpKernelContext * ctx)133 void Compile(XlaOpKernelContext* ctx) override {
134 xla::XlaBuilder* b = ctx->builder();
135 xla::XlaOp input = ctx->Input(0);
136 const DataType data_type = ctx->input_type(0);
137
138 if (ctx->compiler()->options().allow_cpu_custom_calls &&
139 ctx->compiler()->options().custom_fake_quant_op_calls) {
140 xla::XlaOp custom_call_output =
141 b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
142 b, input,
143 XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_min_),
144 XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_max_),
145 num_bits_, narrow_range_));
146 ctx->SetOutput(0, custom_call_output);
147 return;
148 }
149
150 xla::XlaOp nudged_input_min =
151 XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
152 xla::XlaOp nudged_input_max =
153 XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
154 xla::XlaOp input_scale =
155 XlaHelpers::FloatLiteral(b, data_type, input_scale_);
156 xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
157 nudged_input_max, input_scale);
158 ctx->SetOutput(0, output);
159 }
160
161 private:
162 int num_bits_;
163 bool narrow_range_;
164 float input_min_;
165 float input_max_;
166 float quant_min_;
167 float quant_max_;
168 float nudged_input_min_;
169 float nudged_input_max_;
170 float input_scale_;
171 };
172
173 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
174
175 class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
176 public:
FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction * ctx)177 explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
178 : XlaOpKernel(ctx) {
179 int num_bits;
180 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
181 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
182 errors::InvalidArgument("num_bits is out of range, expected "
183 "between 2 and 16, was: ",
184 num_bits));
185 bool narrow_range;
186 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
187 const float quant_min = narrow_range ? 1 : 0;
188 const float quant_max = (1 << num_bits) - 1;
189
190 float input_min, input_max, scale;
191 OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
192 OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
193 CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
194 &nudged_input_max_, &scale);
195 }
196
Compile(XlaOpKernelContext * ctx)197 void Compile(XlaOpKernelContext* ctx) override {
198 xla::XlaOp gradient = ctx->Input(0);
199 const TensorShape gradient_shape = ctx->InputShape(0);
200 xla::XlaOp input = ctx->Input(1);
201 const DataType data_type = ctx->input_type(1);
202
203 xla::XlaBuilder* b = ctx->builder();
204 xla::XlaOp nudged_input_min =
205 XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
206 xla::XlaOp nudged_input_max =
207 XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
208
209 xla::XlaOp between_nudged_min_max = xla::And(
210 xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
211 xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type),
212 gradient_shape.dim_sizes());
213 xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes);
214 ctx->SetOutput(0, output);
215 }
216
217 private:
218 float nudged_input_min_;
219 float nudged_input_max_;
220 };
221
222 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
223 FakeQuantWithMinMaxArgsGradOp);
224
225 class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
226 public:
FakeQuantWithMinMaxVarsOp(OpKernelConstruction * ctx)227 explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
228 : XlaOpKernel(ctx) {
229 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
230 OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
231 errors::InvalidArgument("num_bits is out of range, expected "
232 "between 2 and 16, was: ",
233 num_bits_));
234 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
235 quant_min_ = narrow_range_ ? 1 : 0;
236 quant_max_ = (1 << num_bits_) - 1;
237 }
238
Compile(XlaOpKernelContext * ctx)239 void Compile(XlaOpKernelContext* ctx) override {
240 xla::XlaBuilder* b = ctx->builder();
241 xla::XlaOp input = ctx->Input(0);
242 const DataType data_type = ctx->input_type(0);
243 xla::XlaOp input_min = ctx->Input(1);
244 xla::XlaOp input_max = ctx->Input(2);
245
246 if (ctx->compiler()->options().allow_cpu_custom_calls &&
247 ctx->compiler()->options().custom_fake_quant_op_calls) {
248 xla::XlaOp custom_call_output =
249 b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
250 b, input, input_min, input_max, num_bits_, narrow_range_));
251 ctx->SetOutput(0, custom_call_output);
252 return;
253 }
254
255 xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
256 XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
257 &nudged_input_min, &nudged_input_max, &input_scale);
258
259 xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
260 nudged_input_max, input_scale);
261 ctx->SetOutput(0, output);
262 }
263
264 private:
265 int num_bits_;
266 bool narrow_range_;
267 float quant_min_;
268 float quant_max_;
269 };
270
271 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
272
273 class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
274 public:
FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction * ctx)275 explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
276 : XlaOpKernel(ctx) {
277 int num_bits;
278 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
279 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
280 errors::InvalidArgument("num_bits is out of range, expected "
281 "between 2 and 16, was: ",
282 num_bits));
283 bool narrow_range;
284 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
285 quant_min_ = narrow_range ? 1 : 0;
286 quant_max_ = (1 << num_bits) - 1;
287 }
288
Compile(XlaOpKernelContext * ctx)289 void Compile(XlaOpKernelContext* ctx) override {
290 xla::XlaOp gradient = ctx->Input(0);
291 const TensorShape gradient_shape = ctx->InputShape(0);
292 xla::XlaOp input = ctx->Input(1);
293 const DataType data_type = ctx->input_type(1);
294 const DataType accumulation_type =
295 XlaHelpers::SumAccumulationType(data_type);
296 xla::XlaOp input_min = ctx->Input(2);
297 xla::XlaOp input_max = ctx->Input(3);
298
299 xla::XlaBuilder* b = ctx->builder();
300 xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
301 XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
302 &nudged_input_min, &nudged_input_max, &input_scale);
303
304 xla::XlaOp between_nudged_min_max = xla::And(
305 xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
306 xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
307 xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes());
308 xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes);
309 ctx->SetOutput(0, output0);
310
311 xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
312 xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
313 xla::XlaOp reduce1 = xla::ReduceAll(
314 XlaHelpers::ConvertElementType(select1, accumulation_type),
315 XlaHelpers::Zero(b, accumulation_type),
316 *ctx->GetOrCreateAdd(accumulation_type));
317 xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type);
318 ctx->SetOutput(1, output1);
319
320 xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
321 xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
322 xla::XlaOp reduce2 = xla::ReduceAll(
323 XlaHelpers::ConvertElementType(select2, accumulation_type),
324 XlaHelpers::Zero(b, accumulation_type),
325 *ctx->GetOrCreateAdd(accumulation_type));
326 xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type);
327 ctx->SetOutput(2, output2);
328 }
329
330 private:
331 float quant_min_;
332 float quant_max_;
333 };
334
335 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
336 FakeQuantWithMinMaxVarsGradOp);
337
338 } // namespace
339 } // namespace tensorflow
340