1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/platform/macros.h"
22
23 namespace tensorflow {
24 namespace {
25
26 // Gymnastics with nudged zero point is to ensure that the real zero maps to
27 // an integer, which is required for e.g. zero-padding in convolutional layers.
CpuNudge(const float min,const float max,const float quant_min,const float quant_max,float * nudged_min,float * nudged_max,float * scale)28 void CpuNudge(const float min, const float max, const float quant_min,
29 const float quant_max, float* nudged_min, float* nudged_max,
30 float* scale) {
31 *scale = (max - min) / (quant_max - quant_min);
32
33 const float zero_point_from_min = quant_min - min / *scale;
34 float nudged_zero_point;
35 if (zero_point_from_min <= quant_min) {
36 nudged_zero_point = quant_min;
37 } else if (zero_point_from_min >= quant_max) {
38 nudged_zero_point = quant_max;
39 } else {
40 nudged_zero_point = std::round(zero_point_from_min);
41 }
42
43 *nudged_min = (quant_min - nudged_zero_point) * (*scale);
44 *nudged_max = (quant_max - nudged_zero_point) * (*scale);
45 }
46
47 // An XLA version of CpuNudge().
XlaNudge(xla::XlaBuilder * b,const DataType data_type,const xla::XlaOp & min,const xla::XlaOp & max,const float quant_min_value,const float quant_max_value,xla::XlaOp * nudged_min,xla::XlaOp * nudged_max,xla::XlaOp * scale)48 void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
49 const xla::XlaOp& min, const xla::XlaOp& max,
50 const float quant_min_value, const float quant_max_value,
51 xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
52 xla::XlaOp* scale) {
53 *scale = xla::Div(xla::Sub(max, min),
54 XlaHelpers::FloatLiteral(
55 b, data_type, quant_max_value - quant_min_value));
56 xla::XlaOp quant_min =
57 XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
58 xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale));
59 xla::XlaOp quant_max =
60 XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
61 xla::XlaOp nudged_zero_point =
62 xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min,
63 xla::Select(xla::Ge(zero_point_from_min, quant_max),
64 quant_max, xla::Round(zero_point_from_min)));
65 *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale);
66 *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale);
67 }
68
Quantize(xla::XlaBuilder * b,const xla::XlaOp & input,const DataType data_type,const xla::XlaOp & nudged_input_min,const xla::XlaOp & nudged_input_max,const xla::XlaOp & input_scale)69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
70 const DataType data_type,
71 const xla::XlaOp& nudged_input_min,
72 const xla::XlaOp& nudged_input_max,
73 const xla::XlaOp& input_scale) {
74 xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
75 xla::XlaOp inv_scale = xla::Div(one, input_scale);
76 xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
77
78 xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max);
79 xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min);
80 xla::XlaOp rounded =
81 xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half));
82 return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
83 }
84
85 class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
86 public:
FakeQuantWithMinMaxArgsOp(OpKernelConstruction * ctx)87 explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
88 : XlaOpKernel(ctx) {
89 int num_bits;
90 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
91 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
92 errors::InvalidArgument("num_bits is out of range, expected "
93 "between 2 and 16, was: ",
94 num_bits));
95 bool narrow_range;
96 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
97 quant_min_ = narrow_range ? 1 : 0;
98 quant_max_ = (1 << num_bits) - 1;
99
100 float input_min, input_max;
101 OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
102 OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
103 CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
104 &nudged_input_max_, &input_scale_);
105 }
106
Compile(XlaOpKernelContext * ctx)107 void Compile(XlaOpKernelContext* ctx) override {
108 xla::XlaOp input = ctx->Input(0);
109 const DataType data_type = ctx->input_type(0);
110
111 xla::XlaBuilder* b = ctx->builder();
112 xla::XlaOp nudged_input_min =
113 XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
114 xla::XlaOp nudged_input_max =
115 XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
116 xla::XlaOp input_scale =
117 XlaHelpers::FloatLiteral(b, data_type, input_scale_);
118 xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
119 nudged_input_max, input_scale);
120 ctx->SetOutput(0, output);
121 }
122
123 private:
124 float quant_min_;
125 float quant_max_;
126 float nudged_input_min_;
127 float nudged_input_max_;
128 float input_scale_;
129 };
130
131 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
132
133 class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
134 public:
FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction * ctx)135 explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
136 : XlaOpKernel(ctx) {
137 int num_bits;
138 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
139 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
140 errors::InvalidArgument("num_bits is out of range, expected "
141 "between 2 and 16, was: ",
142 num_bits));
143 bool narrow_range;
144 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
145 const float quant_min = narrow_range ? 1 : 0;
146 const float quant_max = (1 << num_bits) - 1;
147
148 float input_min, input_max, scale;
149 OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
150 OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
151 CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
152 &nudged_input_max_, &scale);
153 }
154
Compile(XlaOpKernelContext * ctx)155 void Compile(XlaOpKernelContext* ctx) override {
156 xla::XlaOp gradient = ctx->Input(0);
157 const TensorShape gradient_shape = ctx->InputShape(0);
158 xla::XlaOp input = ctx->Input(1);
159 const DataType data_type = ctx->input_type(1);
160
161 xla::XlaBuilder* b = ctx->builder();
162 xla::XlaOp nudged_input_min =
163 XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
164 xla::XlaOp nudged_input_max =
165 XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
166
167 xla::XlaOp between_nudged_min_max = xla::And(
168 xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
169 xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type),
170 gradient_shape.dim_sizes());
171 xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes);
172 ctx->SetOutput(0, output);
173 }
174
175 private:
176 float nudged_input_min_;
177 float nudged_input_max_;
178 };
179
180 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
181 FakeQuantWithMinMaxArgsGradOp);
182
183 class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
184 public:
FakeQuantWithMinMaxVarsOp(OpKernelConstruction * ctx)185 explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
186 : XlaOpKernel(ctx) {
187 int num_bits;
188 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
189 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
190 errors::InvalidArgument("num_bits is out of range, expected "
191 "between 2 and 16, was: ",
192 num_bits));
193 bool narrow_range;
194 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
195 quant_min_ = narrow_range ? 1 : 0;
196 quant_max_ = (1 << num_bits) - 1;
197 }
198
Compile(XlaOpKernelContext * ctx)199 void Compile(XlaOpKernelContext* ctx) override {
200 xla::XlaOp input = ctx->Input(0);
201 const DataType data_type = ctx->input_type(0);
202 xla::XlaOp input_min = ctx->Input(1);
203 xla::XlaOp input_max = ctx->Input(2);
204
205 xla::XlaBuilder* b = ctx->builder();
206 xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
207 XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
208 &nudged_input_min, &nudged_input_max, &input_scale);
209
210 xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
211 nudged_input_max, input_scale);
212 ctx->SetOutput(0, output);
213 }
214
215 private:
216 float quant_min_;
217 float quant_max_;
218 };
219
220 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
221
222 class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
223 public:
FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction * ctx)224 explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
225 : XlaOpKernel(ctx) {
226 int num_bits;
227 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
228 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
229 errors::InvalidArgument("num_bits is out of range, expected "
230 "between 2 and 16, was: ",
231 num_bits));
232 bool narrow_range;
233 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
234 quant_min_ = narrow_range ? 1 : 0;
235 quant_max_ = (1 << num_bits) - 1;
236 }
237
Compile(XlaOpKernelContext * ctx)238 void Compile(XlaOpKernelContext* ctx) override {
239 xla::XlaOp gradient = ctx->Input(0);
240 const TensorShape gradient_shape = ctx->InputShape(0);
241 xla::XlaOp input = ctx->Input(1);
242 const DataType data_type = ctx->input_type(1);
243 const DataType accumulation_type =
244 XlaHelpers::SumAccumulationType(data_type);
245 xla::XlaOp input_min = ctx->Input(2);
246 xla::XlaOp input_max = ctx->Input(3);
247
248 xla::XlaBuilder* b = ctx->builder();
249 xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
250 XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
251 &nudged_input_min, &nudged_input_max, &input_scale);
252
253 xla::XlaOp between_nudged_min_max = xla::And(
254 xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
255 xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
256 xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes());
257 xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes);
258 ctx->SetOutput(0, output0);
259
260 xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
261 xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
262 xla::XlaOp reduce1 = xla::ReduceAll(
263 XlaHelpers::ConvertElementType(select1, accumulation_type),
264 XlaHelpers::Zero(b, accumulation_type),
265 *ctx->GetOrCreateAdd(accumulation_type));
266 xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type);
267 ctx->SetOutput(1, output1);
268
269 xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
270 xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
271 xla::XlaOp reduce2 = xla::ReduceAll(
272 XlaHelpers::ConvertElementType(select2, accumulation_type),
273 XlaHelpers::Zero(b, accumulation_type),
274 *ctx->GetOrCreateAdd(accumulation_type));
275 xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type);
276 ctx->SetOutput(2, output2);
277 }
278
279 private:
280 float quant_min_;
281 float quant_max_;
282 };
283
284 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
285 FakeQuantWithMinMaxVarsGradOp);
286
287 } // namespace
288 } // namespace tensorflow
289