1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/framework/op.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/type_traits.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/kernels/meta_support.h" 25 #include "tensorflow/core/kernels/quantization_utils.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 28 namespace { 29 enum { 30 QUANTIZE_MODE_MIN_COMBINED, 31 QUANTIZE_MODE_MIN_FIRST, 32 QUANTIZE_MODE_SCALED, 33 }; 34 } // namespace 35 36 namespace tensorflow { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 template <typename Device, typename T> 41 class DequantizeOp : public OpKernel { 42 public: DequantizeOp(OpKernelConstruction * ctx)43 explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 44 half_range_ = !std::is_signed<T>::value 45 ? 0.0f 46 : (static_cast<float>(std::numeric_limits<T>::max()) - 47 std::numeric_limits<T>::min() + 1) / 48 2.0f; 49 string mode_string; 50 OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); 51 OP_REQUIRES(ctx, 52 (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" || 53 mode_string == "SCALED"), 54 errors::InvalidArgument("Mode string must be 'MIN_COMBINED'," 55 " 'MIN_FIRST', or 'SCALED', is '" + 56 mode_string + "'")); 57 if (mode_string == "MIN_COMBINED") { 58 mode_ = QUANTIZE_MODE_MIN_COMBINED; 59 } else if (mode_string == "MIN_FIRST") { 60 mode_ = QUANTIZE_MODE_MIN_FIRST; 61 } else if (mode_string == "SCALED") { 62 mode_ = QUANTIZE_MODE_SCALED; 63 } 64 } 65 Compute(OpKernelContext * ctx)66 void Compute(OpKernelContext* ctx) override { 67 const Tensor& input = ctx->input(0); 68 const float min_range = ctx->input(1).flat<float>()(0); 69 const float max_range = ctx->input(2).flat<float>()(0); 70 71 Tensor* output = nullptr; 72 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); 73 if (mode_ == QUANTIZE_MODE_MIN_COMBINED) { 74 const float scale_factor = 75 (max_range - min_range) / 76 (static_cast<float>(std::numeric_limits<T>::max()) - 77 std::numeric_limits<T>::min()); 78 79 float* out_ptr = output->flat<float>().data(); 80 const T* in_ptr = input.flat<T>().data(); 81 82 const int64 num_elements = input.NumElements(); 83 for (int i = 0; i < num_elements; ++i) { 84 out_ptr[i] = 85 ((static_cast<int>(in_ptr[i]) + half_range_) * scale_factor) + 86 min_range; 87 } 88 } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) { 89 if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) { 90 auto input_ui8_array = input.flat<quint8>(); 91 meta::Dequantize(ctx, input_ui8_array.data(), input_ui8_array.size(), 92 min_range, max_range, output->flat<float>().data()); 93 } else { 94 QuantizedTensorToFloatInPlaceUsingEigen<T>( 95 ctx->template eigen_device<Device>(), input, min_range, max_range, 96 output); 97 } 98 } else if (mode_ == QUANTIZE_MODE_SCALED) { 99 const float scale_factor = 100 std::numeric_limits<T>::min() == 0 101 ? (max_range / std::numeric_limits<T>::max()) 102 : std::max(min_range / std::numeric_limits<T>::min(), 103 max_range / std::numeric_limits<T>::max()); 104 float* out_ptr = output->flat<float>().data(); 105 const T* in_ptr = input.flat<T>().data(); 106 const int64 num_elements = input.NumElements(); 107 for (int64 i = 0; i < num_elements; ++i) { 108 out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor; 109 } 110 } 111 } 112 113 private: 114 float half_range_; 115 int mode_; 116 }; 117 118 REGISTER_KERNEL_BUILDER( 119 Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint8>("T"), 120 DequantizeOp<CPUDevice, quint8>); 121 REGISTER_KERNEL_BUILDER( 122 Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint8>("T"), 123 DequantizeOp<CPUDevice, qint8>); 124 REGISTER_KERNEL_BUILDER( 125 Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint16>("T"), 126 DequantizeOp<CPUDevice, quint16>); 127 REGISTER_KERNEL_BUILDER( 128 Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint16>("T"), 129 DequantizeOp<CPUDevice, qint16>); 130 131 REGISTER_KERNEL_BUILDER( 132 Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint32>("T"), 133 DequantizeOp<CPUDevice, qint32>); 134 135 } // namespace tensorflow 136