• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/platform/macros.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 // Gymnastics with nudged zero point is to ensure that the real zero maps to
27 // an integer, which is required for e.g. zero-padding in convolutional layers.
CpuNudge(const float min,const float max,const float quant_min,const float quant_max,float * nudged_min,float * nudged_max,float * scale)28 void CpuNudge(const float min, const float max, const float quant_min,
29               const float quant_max, float* nudged_min, float* nudged_max,
30               float* scale) {
31   *scale = (max - min) / (quant_max - quant_min);
32 
33   const float zero_point_from_min = quant_min - min / *scale;
34   float nudged_zero_point;
35   if (zero_point_from_min <= quant_min) {
36     nudged_zero_point = quant_min;
37   } else if (zero_point_from_min >= quant_max) {
38     nudged_zero_point = quant_max;
39   } else {
40     nudged_zero_point = std::round(zero_point_from_min);
41   }
42 
43   *nudged_min = (quant_min - nudged_zero_point) * (*scale);
44   *nudged_max = (quant_max - nudged_zero_point) * (*scale);
45 }
46 
47 // An XLA version of CpuNudge().
XlaNudge(xla::XlaBuilder * b,const DataType data_type,const xla::XlaOp & min,const xla::XlaOp & max,const float quant_min_value,const float quant_max_value,xla::XlaOp * nudged_min,xla::XlaOp * nudged_max,xla::XlaOp * scale)48 void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
49               const xla::XlaOp& min, const xla::XlaOp& max,
50               const float quant_min_value, const float quant_max_value,
51               xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
52               xla::XlaOp* scale) {
53   *scale = xla::Div(xla::Sub(max, min),
54                     XlaHelpers::FloatLiteral(
55                         b, data_type, quant_max_value - quant_min_value));
56   xla::XlaOp quant_min =
57       XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
58   xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale));
59   xla::XlaOp quant_max =
60       XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
61   xla::XlaOp nudged_zero_point =
62       xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min,
63                   xla::Select(xla::Ge(zero_point_from_min, quant_max),
64                               quant_max, xla::Round(zero_point_from_min)));
65   *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale);
66   *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale);
67 }
68 
Quantize(xla::XlaBuilder * b,const xla::XlaOp & input,const DataType data_type,const xla::XlaOp & nudged_input_min,const xla::XlaOp & nudged_input_max,const xla::XlaOp & input_scale)69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
70                     const DataType data_type,
71                     const xla::XlaOp& nudged_input_min,
72                     const xla::XlaOp& nudged_input_max,
73                     const xla::XlaOp& input_scale) {
74   xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
75   xla::XlaOp inv_scale = xla::Div(one, input_scale);
76   xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
77 
78   xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max);
79   xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min);
80   xla::XlaOp rounded =
81       xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half));
82   return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
83 }
84 
85 class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
86  public:
FakeQuantWithMinMaxArgsOp(OpKernelConstruction * ctx)87   explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
88       : XlaOpKernel(ctx) {
89     int num_bits;
90     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
91     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
92                 errors::InvalidArgument("num_bits is out of range, expected "
93                                         "between 2 and 16, was: ",
94                                         num_bits));
95     bool narrow_range;
96     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
97     quant_min_ = narrow_range ? 1 : 0;
98     quant_max_ = (1 << num_bits) - 1;
99 
100     float input_min, input_max;
101     OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
102     OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
103     CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
104              &nudged_input_max_, &input_scale_);
105   }
106 
Compile(XlaOpKernelContext * ctx)107   void Compile(XlaOpKernelContext* ctx) override {
108     xla::XlaOp input = ctx->Input(0);
109     const DataType data_type = ctx->input_type(0);
110 
111     xla::XlaBuilder* b = ctx->builder();
112     xla::XlaOp nudged_input_min =
113         XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
114     xla::XlaOp nudged_input_max =
115         XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
116     xla::XlaOp input_scale =
117         XlaHelpers::FloatLiteral(b, data_type, input_scale_);
118     xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
119                                  nudged_input_max, input_scale);
120     ctx->SetOutput(0, output);
121   }
122 
123  private:
124   float quant_min_;
125   float quant_max_;
126   float nudged_input_min_;
127   float nudged_input_max_;
128   float input_scale_;
129 };
130 
131 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
132 
133 class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
134  public:
FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction * ctx)135   explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
136       : XlaOpKernel(ctx) {
137     int num_bits;
138     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
139     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
140                 errors::InvalidArgument("num_bits is out of range, expected "
141                                         "between 2 and 16, was: ",
142                                         num_bits));
143     bool narrow_range;
144     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
145     const float quant_min = narrow_range ? 1 : 0;
146     const float quant_max = (1 << num_bits) - 1;
147 
148     float input_min, input_max, scale;
149     OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
150     OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
151     CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
152              &nudged_input_max_, &scale);
153   }
154 
Compile(XlaOpKernelContext * ctx)155   void Compile(XlaOpKernelContext* ctx) override {
156     xla::XlaOp gradient = ctx->Input(0);
157     const TensorShape gradient_shape = ctx->InputShape(0);
158     xla::XlaOp input = ctx->Input(1);
159     const DataType data_type = ctx->input_type(1);
160 
161     xla::XlaBuilder* b = ctx->builder();
162     xla::XlaOp nudged_input_min =
163         XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
164     xla::XlaOp nudged_input_max =
165         XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
166 
167     xla::XlaOp between_nudged_min_max = xla::And(
168         xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
169     xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type),
170                                        gradient_shape.dim_sizes());
171     xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes);
172     ctx->SetOutput(0, output);
173   }
174 
175  private:
176   float nudged_input_min_;
177   float nudged_input_max_;
178 };
179 
180 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
181                 FakeQuantWithMinMaxArgsGradOp);
182 
183 class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
184  public:
FakeQuantWithMinMaxVarsOp(OpKernelConstruction * ctx)185   explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
186       : XlaOpKernel(ctx) {
187     int num_bits;
188     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
189     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
190                 errors::InvalidArgument("num_bits is out of range, expected "
191                                         "between 2 and 16, was: ",
192                                         num_bits));
193     bool narrow_range;
194     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
195     quant_min_ = narrow_range ? 1 : 0;
196     quant_max_ = (1 << num_bits) - 1;
197   }
198 
Compile(XlaOpKernelContext * ctx)199   void Compile(XlaOpKernelContext* ctx) override {
200     xla::XlaOp input = ctx->Input(0);
201     const DataType data_type = ctx->input_type(0);
202     xla::XlaOp input_min = ctx->Input(1);
203     xla::XlaOp input_max = ctx->Input(2);
204 
205     xla::XlaBuilder* b = ctx->builder();
206     xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
207     XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
208              &nudged_input_min, &nudged_input_max, &input_scale);
209 
210     xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
211                                  nudged_input_max, input_scale);
212     ctx->SetOutput(0, output);
213   }
214 
215  private:
216   float quant_min_;
217   float quant_max_;
218 };
219 
220 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
221 
222 class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
223  public:
FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction * ctx)224   explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
225       : XlaOpKernel(ctx) {
226     int num_bits;
227     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
228     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
229                 errors::InvalidArgument("num_bits is out of range, expected "
230                                         "between 2 and 16, was: ",
231                                         num_bits));
232     bool narrow_range;
233     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
234     quant_min_ = narrow_range ? 1 : 0;
235     quant_max_ = (1 << num_bits) - 1;
236   }
237 
Compile(XlaOpKernelContext * ctx)238   void Compile(XlaOpKernelContext* ctx) override {
239     xla::XlaOp gradient = ctx->Input(0);
240     const TensorShape gradient_shape = ctx->InputShape(0);
241     xla::XlaOp input = ctx->Input(1);
242     const DataType data_type = ctx->input_type(1);
243     const DataType accumulation_type =
244         XlaHelpers::SumAccumulationType(data_type);
245     xla::XlaOp input_min = ctx->Input(2);
246     xla::XlaOp input_max = ctx->Input(3);
247 
248     xla::XlaBuilder* b = ctx->builder();
249     xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
250     XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
251              &nudged_input_min, &nudged_input_max, &input_scale);
252 
253     xla::XlaOp between_nudged_min_max = xla::And(
254         xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
255     xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
256     xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes());
257     xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes);
258     ctx->SetOutput(0, output0);
259 
260     xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
261     xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
262     xla::XlaOp reduce1 = xla::ReduceAll(
263         XlaHelpers::ConvertElementType(select1, accumulation_type),
264         XlaHelpers::Zero(b, accumulation_type),
265         *ctx->GetOrCreateAdd(accumulation_type));
266     xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type);
267     ctx->SetOutput(1, output1);
268 
269     xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
270     xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
271     xla::XlaOp reduce2 = xla::ReduceAll(
272         XlaHelpers::ConvertElementType(select2, accumulation_type),
273         XlaHelpers::Zero(b, accumulation_type),
274         *ctx->GetOrCreateAdd(accumulation_type));
275     xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type);
276     ctx->SetOutput(2, output2);
277   }
278 
279  private:
280   float quant_min_;
281   float quant_max_;
282 };
283 
284 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
285                 FakeQuantWithMinMaxVarsGradOp);
286 
287 }  // namespace
288 }  // namespace tensorflow
289