• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/quantization_util.h"
21 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
22 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace dequantize {
31 
32 struct OpData {
33   tflite::DequantizationParams quantization_params;
34   // The scaling factor from input to output (aka the 'real multiplier') can
35   // be represented as a fixed point multiplier plus a left shift.
36   int32_t output_multiplier;
37   int output_shift;
38   int32_t output_zero_point;
39 };
40 
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
43   return context->AllocatePersistentBuffer(context, sizeof(OpData));
44 }
45 
Prepare(TfLiteContext * context,TfLiteNode * node)46 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
47   TFLITE_DCHECK(node->user_data != nullptr);
48   OpData* data = static_cast<OpData*>(node->user_data);
49 
50   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
51   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
52 
53   // TODO(b/140515557): Add cached dequant to improve hybrid model performance.
54   const TfLiteTensor* input = GetInput(context, node, 0);
55   TF_LITE_ENSURE(context, input != nullptr);
56   TfLiteTensor* output = GetOutput(context, node, 0);
57   TF_LITE_ENSURE(context, output != nullptr);
58 
59   TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
60                               input->type == kTfLiteInt8 ||
61                               input->type == kTfLiteInt16);
62   TF_LITE_ENSURE(
63       context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt32);
64 
65   if (output->type == kTfLiteInt32) {
66     const double effective_output_scale =
67         static_cast<double>(input->params.scale) /
68         static_cast<double>(output->params.scale);
69     QuantizeMultiplier(effective_output_scale, &data->output_multiplier,
70                        &data->output_shift);
71   }
72 
73   data->quantization_params.zero_point = input->params.zero_point;
74   data->quantization_params.scale = static_cast<double>(input->params.scale);
75   data->output_zero_point = output->params.zero_point;
76   return kTfLiteOk;
77 }
78 
Eval(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
80   TFLITE_DCHECK(node->user_data != nullptr);
81   OpData* data = static_cast<OpData*>(node->user_data);
82 
83   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
84   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
85 
86   if (output->type == kTfLiteFloat32) {
87     switch (input->type) {
88       case kTfLiteUInt8:
89         reference_ops::Dequantize(data->quantization_params,
90                                   tflite::micro::GetTensorShape(input),
91                                   tflite::micro::GetTensorData<uint8_t>(input),
92                                   tflite::micro::GetTensorShape(output),
93                                   tflite::micro::GetTensorData<float>(output));
94         break;
95       case kTfLiteInt8:
96         reference_ops::Dequantize(data->quantization_params,
97                                   tflite::micro::GetTensorShape(input),
98                                   tflite::micro::GetTensorData<int8_t>(input),
99                                   tflite::micro::GetTensorShape(output),
100                                   tflite::micro::GetTensorData<float>(output));
101         break;
102       case kTfLiteInt16:
103         reference_ops::Dequantize(data->quantization_params,
104                                   tflite::micro::GetTensorShape(input),
105                                   tflite::micro::GetTensorData<int16_t>(input),
106                                   tflite::micro::GetTensorShape(output),
107                                   tflite::micro::GetTensorData<float>(output));
108         break;
109       default:
110         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
111                            TfLiteTypeGetName(input->type),
112                            TfLiteTypeGetName(output->type));
113         return kTfLiteError;
114     }
115   } else if (output->type == kTfLiteInt32) {
116     int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
117                                      tflite::micro::GetTensorShape(output));
118     switch (input->type) {
119       case kTfLiteInt8: {
120         reference_ops::Requantize(
121             tflite::micro::GetTensorData<int8_t>(input), flat_size,
122             data->output_multiplier, data->output_shift,
123             data->quantization_params.zero_point, data->output_zero_point,
124             tflite::micro::GetTensorData<int32_t>(output));
125         break;
126       }
127       default:
128         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
129                            TfLiteTypeGetName(input->type),
130                            TfLiteTypeGetName(output->type));
131         return kTfLiteError;
132     }
133   } else {
134     TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
135                        TfLiteTypeGetName(input->type),
136                        TfLiteTypeGetName(output->type));
137     return kTfLiteError;
138   }
139 
140   return kTfLiteOk;
141 }
142 
143 }  // namespace dequantize
144 
Register_DEQUANTIZE()145 TfLiteRegistration Register_DEQUANTIZE() {
146   return {/*init=*/dequantize::Init,
147           /*free=*/nullptr,
148           /*prepare=*/dequantize::Prepare,
149           /*invoke=*/dequantize::Eval,
150           /*profiling_string=*/nullptr,
151           /*builtin_code=*/0,
152           /*custom_name=*/nullptr,
153           /*version=*/0};
154 }
155 
156 }  // namespace micro
157 }  // namespace ops
158 }  // namespace tflite
159