• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_requires.h"
17 #define EIGEN_USE_THREADS
18 
19 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
20     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include "tensorflow/core/kernels/quantize_and_dequantize_op.h"
25 
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/type_traits.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 
38 // Simulate quantization precision loss in a float tensor by:
39 // 1. Quantize the tensor to fixed point numbers, which should match the target
40 //    quantization method when it is used in inference.
41 // 2. Dequantize it back to floating point numbers for the following ops, most
42 //    likely matmul.
43 template <typename Device, typename T>
44 class QuantizeAndDequantizeV2Op : public OpKernel {
45  public:
QuantizeAndDequantizeV2Op(OpKernelConstruction * ctx)46   explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
47       : OpKernel(ctx) {
48     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
49     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
50     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
51     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
52                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
53                                         " with signed_input_ ", signed_input_));
54     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
55 
56     string round_mode_string;
57     OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
58     OP_REQUIRES(
59         ctx,
60         (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"),
61         errors::InvalidArgument("Round mode string must be "
62                                 "'HALF_UP' or "
63                                 "'HALF_TO_EVEN', is '" +
64                                 round_mode_string + "'"));
65     if (round_mode_string == "HALF_UP") {
66       round_mode_ = ROUND_HALF_UP;
67     } else if (round_mode_string == "HALF_TO_EVEN") {
68       round_mode_ = ROUND_HALF_TO_EVEN;
69     }
70     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
71   }
72 
Compute(OpKernelContext * ctx)73   void Compute(OpKernelContext* ctx) override {
74     const Tensor& input = ctx->input(0);
75     OP_REQUIRES(
76         ctx, axis_ >= -1,
77         errors::InvalidArgument("Axis must be at least -1. Found ", axis_));
78     OP_REQUIRES(
79         ctx, (axis_ == -1 || axis_ < input.shape().dims()),
80         errors::InvalidArgument("Shape must be at least rank ", axis_ + 1,
81                                 " but is rank ", input.shape().dims()));
82     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
83     Tensor input_min_tensor;
84     Tensor input_max_tensor;
85     Tensor* output = nullptr;
86     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
87     if (range_given_) {
88       input_min_tensor = ctx->input(1);
89       input_max_tensor = ctx->input(2);
90       if (axis_ == -1) {
91         auto min_val = input_min_tensor.scalar<T>()();
92         auto max_val = input_max_tensor.scalar<T>()();
93         OP_REQUIRES(ctx, min_val <= max_val,
94                     errors::InvalidArgument("Invalid range: input_min ",
95                                             min_val, " > input_max ", max_val));
96       } else {
97         OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth,
98                     errors::InvalidArgument(
99                         "input_min_tensor has incorrect size, was ",
100                         input_min_tensor.dim_size(0), " expected ", depth,
101                         " to match dim ", axis_, " of the input ",
102                         input_min_tensor.shape()));
103         OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth,
104                     errors::InvalidArgument(
105                         "input_max_tensor has incorrect size, was ",
106                         input_max_tensor.dim_size(0), " expected ", depth,
107                         " to match dim ", axis_, " of the input ",
108                         input_max_tensor.shape()));
109       }
110     } else {
111       auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
112       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
113                                              range_shape, &input_min_tensor));
114       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
115                                              range_shape, &input_max_tensor));
116     }
117 
118     if (axis_ == -1) {
119       functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
120       f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
121         range_given_, &input_min_tensor, &input_max_tensor, round_mode_,
122         narrow_range_, output->flat<T>());
123     } else {
124       functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
125       f(ctx->eigen_device<Device>(),
126         input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
127         num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
128         round_mode_, narrow_range_,
129         output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
130     }
131   }
132 
133  private:
134   int num_bits_;
135   int axis_;
136   QuantizerRoundMode round_mode_;
137   bool signed_input_;
138   bool range_given_;
139   bool narrow_range_;
140 };
141 
142 // Implementation of QuantizeAndDequantizeV4GradientOp.
143 // When back-propagating the error through a quantized layer, the following
144 // paper gives evidence that clipped-ReLU is better than non-clipped:
145 // "Deep Learning with Low Precision by Half-wave Gaussian Quantization"
146 // http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf
147 template <typename Device, typename T>
148 class QuantizeAndDequantizeV4GradientOp : public OpKernel {
149  public:
QuantizeAndDequantizeV4GradientOp(OpKernelConstruction * ctx)150   explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx)
151       : OpKernel::OpKernel(ctx) {
152     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
153   }
154 
Compute(OpKernelContext * ctx)155   void Compute(OpKernelContext* ctx) override {
156     const Tensor& gradient = ctx->input(0);
157     const Tensor& input = ctx->input(1);
158     Tensor* input_backprop = nullptr;
159     OP_REQUIRES_OK(ctx,
160                    ctx->allocate_output(0, input.shape(), &input_backprop));
161     OP_REQUIRES(
162         ctx, axis_ >= -1,
163         errors::InvalidArgument("Axis must be at least -1. Found ", axis_));
164     OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()),
165                 errors::InvalidArgument(
166                     "Axis should be -1 or 0 or a positive value less than ",
167                     input.shape().dims(), "but given axis value was ", axis_));
168 
169     OP_REQUIRES(
170         ctx, input.IsSameSize(gradient),
171         errors::InvalidArgument("gradient and input must be the same size"));
172     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
173     const Tensor& input_min_tensor = ctx->input(2);
174     OP_REQUIRES(ctx,
175                 input_min_tensor.dims() == 0 || input_min_tensor.dims() == 1,
176                 errors::InvalidArgument(
177                     "Input min tensor must have dimension 1. Recieved ",
178                     input_min_tensor.dims(), "."));
179     const Tensor& input_max_tensor = ctx->input(3);
180     OP_REQUIRES(ctx,
181                 input_max_tensor.dims() == 0 || input_max_tensor.dims() == 1,
182                 errors::InvalidArgument(
183                     "Input max tensor must have dimension 1. Recieved ",
184                     input_max_tensor.dims(), "."));
185     if (axis_ != -1) {
186       OP_REQUIRES(
187           ctx, input_min_tensor.dim_size(0) == depth,
188           errors::InvalidArgument("min has incorrect size, expected ", depth,
189                                   " was ", input_min_tensor.dim_size(0)));
190       OP_REQUIRES(
191           ctx, input_max_tensor.dim_size(0) == depth,
192           errors::InvalidArgument("max has incorrect size, expected ", depth,
193                                   " was ", input_max_tensor.dim_size(0)));
194     }
195 
196     TensorShape min_max_shape(input_min_tensor.shape());
197     Tensor* input_min_backprop;
198     OP_REQUIRES_OK(ctx,
199                    ctx->allocate_output(1, min_max_shape, &input_min_backprop));
200 
201     Tensor* input_max_backprop;
202     OP_REQUIRES_OK(ctx,
203                    ctx->allocate_output(2, min_max_shape, &input_max_backprop));
204 
205     if (axis_ == -1) {
206       functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f;
207       f(ctx->eigen_device<Device>(), gradient.template flat<T>(),
208         input.template flat<T>(), input_min_tensor.scalar<T>(),
209         input_max_tensor.scalar<T>(), input_backprop->template flat<T>(),
210         input_min_backprop->template scalar<T>(),
211         input_max_backprop->template scalar<T>());
212     } else {
213       functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f;
214       f(ctx->eigen_device<Device>(),
215         gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1),
216         input.template flat_inner_outer_dims<T, 3>(axis_ - 1),
217         &input_min_tensor, &input_max_tensor,
218         input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1),
219         input_min_backprop->template flat<T>(),
220         input_max_backprop->template flat<T>());
221     }
222   }
223 
224  private:
225   int axis_;
226 };
227 
228 // Simulate quantization precision loss in a float tensor by:
229 // 1. Quantize the tensor to fixed point numbers, which should match the target
230 //    quantization method when it is used in inference.
231 // 2. Dequantize it back to floating point numbers for the following ops, most
232 //    likely matmul.
233 // Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a
234 // tensor.
235 template <typename Device, typename T>
236 class QuantizeAndDequantizeV3Op : public OpKernel {
237  public:
QuantizeAndDequantizeV3Op(OpKernelConstruction * ctx)238   explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx)
239       : OpKernel(ctx) {
240     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
241     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
242     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
243     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
244   }
245 
Compute(OpKernelContext * ctx)246   void Compute(OpKernelContext* ctx) override {
247     const Tensor& input = ctx->input(0);
248     OP_REQUIRES(ctx, axis_ < input.dims(),
249                 errors::InvalidArgument(
250                     "Axis requested is larger than input dimensions. Axis: ",
251                     axis_, " Input Dimensions: ", input.dims()));
252     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
253     Tensor* output = nullptr;
254     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
255 
256     Tensor num_bits_tensor;
257     num_bits_tensor = ctx->input(3);
258     int num_bits_val = num_bits_tensor.scalar<int32>()();
259 
260     OP_REQUIRES(
261         ctx, num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63),
262         errors::InvalidArgument("num_bits is out of range: ", num_bits_val,
263                                 " with signed_input_ ", signed_input_));
264 
265     Tensor input_min_tensor;
266     Tensor input_max_tensor;
267     if (range_given_) {
268       input_min_tensor = ctx->input(1);
269       input_max_tensor = ctx->input(2);
270       if (axis_ == -1) {
271         auto min_val = input_min_tensor.scalar<T>()();
272         auto max_val = input_max_tensor.scalar<T>()();
273         OP_REQUIRES(ctx, min_val <= max_val,
274                     errors::InvalidArgument("Invalid range: input_min ",
275                                             min_val, " > input_max ", max_val));
276       } else {
277         OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth,
278                     errors::InvalidArgument(
279                         "input_min_tensor has incorrect size, was ",
280                         input_min_tensor.dim_size(0), " expected ", depth,
281                         " to match dim ", axis_, " of the input ",
282                         input_min_tensor.shape()));
283         OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth,
284                     errors::InvalidArgument(
285                         "input_max_tensor has incorrect size, was ",
286                         input_max_tensor.dim_size(0), " expected ", depth,
287                         " to match dim ", axis_, " of the input ",
288                         input_max_tensor.shape()));
289       }
290     } else {
291       auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
292       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
293                                              range_shape, &input_min_tensor));
294       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
295                                              range_shape, &input_max_tensor));
296     }
297 
298     if (axis_ == -1) {
299       functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
300       f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
301         num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
302         ROUND_HALF_TO_EVEN, narrow_range_, output->flat<T>());
303     } else {
304       functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
305       f(ctx->eigen_device<Device>(),
306         input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
307         num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
308         ROUND_HALF_TO_EVEN, narrow_range_,
309         output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
310     }
311   }
312 
313  private:
314   int axis_;
315   bool signed_input_;
316   bool range_given_;
317   bool narrow_range_;
318 };
319 
320 // DEPRECATED: Use QuantizeAndDequantizeV2Op.
321 template <typename Device, typename T>
322 class QuantizeAndDequantizeOp : public OpKernel {
323  public:
QuantizeAndDequantizeOp(OpKernelConstruction * ctx)324   explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
325     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
326     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
327     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
328                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
329                                         " with signed_input_ ", signed_input_));
330     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
331     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_));
332     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_));
333     if (range_given_) {
334       OP_REQUIRES(
335           ctx, input_min_ <= input_max_,
336           errors::InvalidArgument("Invalid range: input_min ", input_min_,
337                                   " > input_max ", input_max_));
338     }
339   }
340 
Compute(OpKernelContext * ctx)341   void Compute(OpKernelContext* ctx) override {
342     const Tensor& input = ctx->input(0);
343 
344     Tensor* output = nullptr;
345     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
346 
347     // One global scale.
348     Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape());
349     Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape());
350     // Initialize the tensors with the values in the Attrs.
351     input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_);
352     input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_);
353 
354     functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
355     functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
356             num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
357             ROUND_HALF_TO_EVEN, /*narrow_range=*/false, output->flat<T>());
358   }
359 
360  private:
361   bool signed_input_;
362   int num_bits_;
363   bool range_given_;
364   float input_min_;
365   float input_max_;
366 };
367 
368 // Specializations for CPUDevice.
369 
370 namespace functor {
371 template <typename T>
372 struct QuantizeAndDequantizeOneScaleFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleFunctor373   void operator()(const CPUDevice& d, typename TTypes<T>::ConstVec input,
374                   const bool signed_input, const int num_bits,
375                   const bool range_given, Tensor* input_min_tensor,
376                   Tensor* input_max_tensor, QuantizerRoundMode round_mode,
377                   bool narrow_range, typename TTypes<T>::Vec out) {
378     QuantizeAndDequantizeOneScaleImpl<CPUDevice, T>::Compute(
379         d, input, signed_input, num_bits, range_given, input_min_tensor,
380         input_max_tensor, round_mode, narrow_range, out);
381   }
382 };
383 
384 template <typename T>
385 struct QuantizeAndDequantizePerChannelFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizePerChannelFunctor386   void operator()(const CPUDevice& d, typename TTypes<T, 3>::ConstTensor input,
387                   bool signed_input, int num_bits, bool range_given,
388                   Tensor* input_min_tensor, Tensor* input_max_tensor,
389                   QuantizerRoundMode round_mode, bool narrow_range,
390                   typename TTypes<T, 3>::Tensor out) {
391     QuantizeAndDequantizePerChannelImpl<CPUDevice, T>::Compute(
392         d, input, signed_input, num_bits, range_given, input_min_tensor,
393         input_max_tensor, round_mode, narrow_range, out);
394   }
395 };
396 
397 template <typename T>
398 struct QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleGradientFunctor399   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat gradient,
400                   typename TTypes<T>::ConstFlat input,
401                   typename TTypes<T>::ConstScalar input_min_tensor,
402                   typename TTypes<T>::ConstScalar input_max_tensor,
403                   typename TTypes<T>::Flat input_backprop,
404                   typename TTypes<T>::Scalar input_min_backprop,
405                   typename TTypes<T>::Scalar input_max_backprop) {
406     QuantizeAndDequantizeOneScaleGradientImpl<CPUDevice, T>::Compute(
407         d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
408         input_min_backprop, input_max_backprop);
409   }
410 };
411 
412 template <typename T>
413 struct QuantizeAndDequantizePerChannelGradientFunctor<CPUDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizePerChannelGradientFunctor414   void operator()(const CPUDevice& d,
415                   typename TTypes<T, 3>::ConstTensor gradient,
416                   typename TTypes<T, 3>::ConstTensor input,
417                   const Tensor* input_min_tensor,
418                   const Tensor* input_max_tensor,
419                   typename TTypes<T, 3>::Tensor input_backprop,
420                   typename TTypes<T>::Flat input_min_backprop,
421                   typename TTypes<T>::Flat input_max_backprop) {
422     QuantizeAndDequantizePerChannelGradientImpl<CPUDevice, T>::Compute(
423         d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
424         input_min_backprop, input_max_backprop);
425   }
426 };
427 
428 template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice,
429                                                                       float>;
430 template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
431     CPUDevice, double>;
432 
433 }  // namespace functor
434 
435 #define REGISTER_CPU_KERNEL(T)                                                 \
436   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
437                               .Device(DEVICE_CPU)                              \
438                               .TypeConstraint<T>("T"),                         \
439                           QuantizeAndDequantizeV2Op<CPUDevice, T>);            \
440   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
441                               .Device(DEVICE_CPU)                              \
442                               .TypeConstraint<T>("T"),                         \
443                           QuantizeAndDequantizeV3Op<CPUDevice, T>);            \
444   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4")                      \
445                               .Device(DEVICE_CPU)                              \
446                               .TypeConstraint<T>("T"),                         \
447                           QuantizeAndDequantizeV2Op<CPUDevice, T>);            \
448   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad")                  \
449                               .Device(DEVICE_CPU)                              \
450                               .TypeConstraint<T>("T"),                         \
451                           QuantizeAndDequantizeV4GradientOp<CPUDevice, T>);    \
452   REGISTER_KERNEL_BUILDER(                                                     \
453       Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
454       QuantizeAndDequantizeOp<CPUDevice, T>);
455 TF_CALL_float(REGISTER_CPU_KERNEL);
456 TF_CALL_double(REGISTER_CPU_KERNEL);
457 #undef REGISTER_CPU_KERNEL
458 
459 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
460     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
461 #define REGISTER_GPU_KERNEL(T)                                                 \
462   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
463                               .Device(DEVICE_GPU)                              \
464                               .HostMemory("input_min")                         \
465                               .HostMemory("input_max")                         \
466                               .TypeConstraint<T>("T"),                         \
467                           QuantizeAndDequantizeV2Op<GPUDevice, T>);            \
468   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
469                               .Device(DEVICE_GPU)                              \
470                               .HostMemory("input_min")                         \
471                               .HostMemory("input_max")                         \
472                               .HostMemory("num_bits")                          \
473                               .TypeConstraint<T>("T"),                         \
474                           QuantizeAndDequantizeV3Op<GPUDevice, T>);            \
475   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4")                      \
476                               .Device(DEVICE_GPU)                              \
477                               .HostMemory("input_min")                         \
478                               .HostMemory("input_max")                         \
479                               .TypeConstraint<T>("T"),                         \
480                           QuantizeAndDequantizeV2Op<GPUDevice, T>);            \
481   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad")                  \
482                               .Device(DEVICE_GPU)                              \
483                               .HostMemory("input_min")                         \
484                               .HostMemory("input_max")                         \
485                               .TypeConstraint<T>("T"),                         \
486                           QuantizeAndDequantizeV4GradientOp<GPUDevice, T>);    \
487   REGISTER_KERNEL_BUILDER(                                                     \
488       Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
489       QuantizeAndDequantizeOp<GPUDevice, T>);
490 TF_CALL_float(REGISTER_GPU_KERNEL);
491 TF_CALL_double(REGISTER_GPU_KERNEL);
492 #undef REGISTER_GPU_KERNEL
493 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
494 }  // namespace tensorflow
495