• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/platform/macros.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 // Gymnastics with nudged zero point is to ensure that the real zero maps to
27 // an integer, which is required for e.g. zero-padding in convolutional layers.
CpuNudge(const float min,const float max,const float quant_min,const float quant_max,float * nudged_min,float * nudged_max,float * scale)28 void CpuNudge(const float min, const float max, const float quant_min,
29               const float quant_max, float* nudged_min, float* nudged_max,
30               float* scale) {
31   *scale = (max - min) / (quant_max - quant_min);
32 
33   const float zero_point_from_min = quant_min - min / *scale;
34   float nudged_zero_point;
35   if (zero_point_from_min <= quant_min) {
36     nudged_zero_point = quant_min;
37   } else if (zero_point_from_min >= quant_max) {
38     nudged_zero_point = quant_max;
39   } else {
40     nudged_zero_point = std::round(zero_point_from_min);
41   }
42 
43   *nudged_min = (quant_min - nudged_zero_point) * (*scale);
44   *nudged_max = (quant_max - nudged_zero_point) * (*scale);
45 }
46 
47 // An XLA version of CpuNudge().
XlaNudge(xla::XlaBuilder * b,const DataType data_type,const xla::XlaOp & min,const xla::XlaOp & max,const float quant_min_value,const float quant_max_value,xla::XlaOp * nudged_min,xla::XlaOp * nudged_max,xla::XlaOp * scale)48 void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
49               const xla::XlaOp& min, const xla::XlaOp& max,
50               const float quant_min_value, const float quant_max_value,
51               xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
52               xla::XlaOp* scale) {
53   *scale = xla::Div(xla::Sub(max, min),
54                     XlaHelpers::FloatLiteral(
55                         b, data_type, quant_max_value - quant_min_value));
56   xla::XlaOp quant_min =
57       XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
58   xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale));
59   xla::XlaOp quant_max =
60       XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
61   xla::XlaOp nudged_zero_point =
62       xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min,
63                   xla::Select(xla::Ge(zero_point_from_min, quant_max),
64                               quant_max, xla::Round(zero_point_from_min)));
65   *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale);
66   *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale);
67 }
68 
Quantize(xla::XlaBuilder * b,const xla::XlaOp & input,const DataType data_type,const xla::XlaOp & nudged_input_min,const xla::XlaOp & nudged_input_max,const xla::XlaOp & input_scale)69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
70                     const DataType data_type,
71                     const xla::XlaOp& nudged_input_min,
72                     const xla::XlaOp& nudged_input_max,
73                     const xla::XlaOp& input_scale) {
74   xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
75   xla::XlaOp inv_scale = xla::Div(one, input_scale);
76   xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
77 
78   xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max);
79   xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min);
80   xla::XlaOp rounded =
81       xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half));
82   return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
83 }
84 
85 // Builds a custom_call to a method named 'fake_quant_with_min_max_vars'.
86 // The method will be provided the input, the min/max range from the original
87 // TensorFlow op, and the num_bits and narrow_range attributes.
BuildFakeQuantCustomCall(xla::XlaBuilder * b,xla::XlaOp input,xla::XlaOp input_min,xla::XlaOp input_max,int num_bits,bool narrow_range)88 xla::StatusOr<xla::XlaOp> BuildFakeQuantCustomCall(
89     xla::XlaBuilder* b, xla::XlaOp input, xla::XlaOp input_min,
90     xla::XlaOp input_max, int num_bits, bool narrow_range) {
91   xla::XlaOp num_bits_arg =
92       XlaHelpers::IntegerLiteral(b, DataType::DT_INT32, num_bits);
93   xla::XlaOp narrow_range_arg = narrow_range
94                                     ? XlaHelpers::One(b, DataType::DT_BOOL)
95                                     : XlaHelpers::Zero(b, DataType::DT_BOOL);
96 
97   std::vector<xla::XlaOp> args = {input, input_min, input_max, num_bits_arg,
98                                   narrow_range_arg};
99   std::vector<xla::Shape> arg_shapes;
100   for (const xla::XlaOp& arg : args) {
101     TF_ASSIGN_OR_RETURN(xla::Shape arg_shape, b->GetShape(arg));
102     *arg_shape.mutable_layout() =
103         xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
104     arg_shapes.push_back(std::move(arg_shape));
105   }
106 
107   // Input and output shapes match exactly.
108   TF_ASSIGN_OR_RETURN(xla::Shape output_shape, b->GetShape(input));
109 
110   return xla::CustomCallWithLayout(b, "fake_quant_with_min_max_vars", args,
111                                    output_shape, arg_shapes);
112 }
113 
114 class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
115  public:
FakeQuantWithMinMaxArgsOp(OpKernelConstruction * ctx)116   explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
117       : XlaOpKernel(ctx) {
118     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
119     OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
120                 errors::InvalidArgument("num_bits is out of range, expected "
121                                         "between 2 and 16, was: ",
122                                         num_bits_));
123     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
124     quant_min_ = narrow_range_ ? 1 : 0;
125     quant_max_ = (1 << num_bits_) - 1;
126 
127     OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min_));
128     OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max_));
129     CpuNudge(input_min_, input_max_, quant_min_, quant_max_, &nudged_input_min_,
130              &nudged_input_max_, &input_scale_);
131   }
132 
Compile(XlaOpKernelContext * ctx)133   void Compile(XlaOpKernelContext* ctx) override {
134     xla::XlaBuilder* b = ctx->builder();
135     xla::XlaOp input = ctx->Input(0);
136     const DataType data_type = ctx->input_type(0);
137 
138     if (ctx->compiler()->options().allow_cpu_custom_calls &&
139         ctx->compiler()->options().custom_fake_quant_op_calls) {
140       xla::XlaOp custom_call_output =
141           b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
142               b, input,
143               XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_min_),
144               XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_max_),
145               num_bits_, narrow_range_));
146       ctx->SetOutput(0, custom_call_output);
147       return;
148     }
149 
150     xla::XlaOp nudged_input_min =
151         XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
152     xla::XlaOp nudged_input_max =
153         XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
154     xla::XlaOp input_scale =
155         XlaHelpers::FloatLiteral(b, data_type, input_scale_);
156     xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
157                                  nudged_input_max, input_scale);
158     ctx->SetOutput(0, output);
159   }
160 
161  private:
162   int num_bits_;
163   bool narrow_range_;
164   float input_min_;
165   float input_max_;
166   float quant_min_;
167   float quant_max_;
168   float nudged_input_min_;
169   float nudged_input_max_;
170   float input_scale_;
171 };
172 
173 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
174 
175 class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
176  public:
FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction * ctx)177   explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
178       : XlaOpKernel(ctx) {
179     int num_bits;
180     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
181     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
182                 errors::InvalidArgument("num_bits is out of range, expected "
183                                         "between 2 and 16, was: ",
184                                         num_bits));
185     bool narrow_range;
186     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
187     const float quant_min = narrow_range ? 1 : 0;
188     const float quant_max = (1 << num_bits) - 1;
189 
190     float input_min, input_max, scale;
191     OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
192     OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
193     CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
194              &nudged_input_max_, &scale);
195   }
196 
Compile(XlaOpKernelContext * ctx)197   void Compile(XlaOpKernelContext* ctx) override {
198     xla::XlaOp gradient = ctx->Input(0);
199     const TensorShape gradient_shape = ctx->InputShape(0);
200     xla::XlaOp input = ctx->Input(1);
201     const DataType data_type = ctx->input_type(1);
202 
203     xla::XlaBuilder* b = ctx->builder();
204     xla::XlaOp nudged_input_min =
205         XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
206     xla::XlaOp nudged_input_max =
207         XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
208 
209     xla::XlaOp between_nudged_min_max = xla::And(
210         xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
211     xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type),
212                                        gradient_shape.dim_sizes());
213     xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes);
214     ctx->SetOutput(0, output);
215   }
216 
217  private:
218   float nudged_input_min_;
219   float nudged_input_max_;
220 };
221 
222 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
223                 FakeQuantWithMinMaxArgsGradOp);
224 
225 class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
226  public:
FakeQuantWithMinMaxVarsOp(OpKernelConstruction * ctx)227   explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
228       : XlaOpKernel(ctx) {
229     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
230     OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
231                 errors::InvalidArgument("num_bits is out of range, expected "
232                                         "between 2 and 16, was: ",
233                                         num_bits_));
234     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
235     quant_min_ = narrow_range_ ? 1 : 0;
236     quant_max_ = (1 << num_bits_) - 1;
237   }
238 
Compile(XlaOpKernelContext * ctx)239   void Compile(XlaOpKernelContext* ctx) override {
240     xla::XlaBuilder* b = ctx->builder();
241     xla::XlaOp input = ctx->Input(0);
242     const DataType data_type = ctx->input_type(0);
243     xla::XlaOp input_min = ctx->Input(1);
244     xla::XlaOp input_max = ctx->Input(2);
245 
246     if (ctx->compiler()->options().allow_cpu_custom_calls &&
247         ctx->compiler()->options().custom_fake_quant_op_calls) {
248       xla::XlaOp custom_call_output =
249           b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
250               b, input, input_min, input_max, num_bits_, narrow_range_));
251       ctx->SetOutput(0, custom_call_output);
252       return;
253     }
254 
255     xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
256     XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
257              &nudged_input_min, &nudged_input_max, &input_scale);
258 
259     xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
260                                  nudged_input_max, input_scale);
261     ctx->SetOutput(0, output);
262   }
263 
264  private:
265   int num_bits_;
266   bool narrow_range_;
267   float quant_min_;
268   float quant_max_;
269 };
270 
271 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
272 
273 class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
274  public:
FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction * ctx)275   explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
276       : XlaOpKernel(ctx) {
277     int num_bits;
278     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
279     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
280                 errors::InvalidArgument("num_bits is out of range, expected "
281                                         "between 2 and 16, was: ",
282                                         num_bits));
283     bool narrow_range;
284     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
285     quant_min_ = narrow_range ? 1 : 0;
286     quant_max_ = (1 << num_bits) - 1;
287   }
288 
Compile(XlaOpKernelContext * ctx)289   void Compile(XlaOpKernelContext* ctx) override {
290     xla::XlaOp gradient = ctx->Input(0);
291     const TensorShape gradient_shape = ctx->InputShape(0);
292     xla::XlaOp input = ctx->Input(1);
293     const DataType data_type = ctx->input_type(1);
294     const DataType accumulation_type =
295         XlaHelpers::SumAccumulationType(data_type);
296     xla::XlaOp input_min = ctx->Input(2);
297     xla::XlaOp input_max = ctx->Input(3);
298 
299     xla::XlaBuilder* b = ctx->builder();
300     xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
301     XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
302              &nudged_input_min, &nudged_input_max, &input_scale);
303 
304     xla::XlaOp between_nudged_min_max = xla::And(
305         xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
306     xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
307     xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes());
308     xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes);
309     ctx->SetOutput(0, output0);
310 
311     xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
312     xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
313     xla::XlaOp reduce1 = xla::ReduceAll(
314         XlaHelpers::ConvertElementType(select1, accumulation_type),
315         XlaHelpers::Zero(b, accumulation_type),
316         *ctx->GetOrCreateAdd(accumulation_type));
317     xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type);
318     ctx->SetOutput(1, output1);
319 
320     xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
321     xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
322     xla::XlaOp reduce2 = xla::ReduceAll(
323         XlaHelpers::ConvertElementType(select2, accumulation_type),
324         XlaHelpers::Zero(b, accumulation_type),
325         *ctx->GetOrCreateAdd(accumulation_type));
326     xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type);
327     ctx->SetOutput(2, output2);
328   }
329 
330  private:
331   float quant_min_;
332   float quant_max_;
333 };
334 
335 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
336                 FakeQuantWithMinMaxVarsGradOp);
337 
338 }  // namespace
339 }  // namespace tensorflow
340