1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/quantization_util.h"
21 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
22 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace dequantize {
31
32 struct OpData {
33 tflite::DequantizationParams quantization_params;
34 // The scaling factor from input to output (aka the 'real multiplier') can
35 // be represented as a fixed point multiplier plus a left shift.
36 int32_t output_multiplier;
37 int output_shift;
38 int32_t output_zero_point;
39 };
40
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
43 return context->AllocatePersistentBuffer(context, sizeof(OpData));
44 }
45
Prepare(TfLiteContext * context,TfLiteNode * node)46 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
47 TFLITE_DCHECK(node->user_data != nullptr);
48 OpData* data = static_cast<OpData*>(node->user_data);
49
50 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
51 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
52
53 // TODO(b/140515557): Add cached dequant to improve hybrid model performance.
54 const TfLiteTensor* input = GetInput(context, node, 0);
55 TF_LITE_ENSURE(context, input != nullptr);
56 TfLiteTensor* output = GetOutput(context, node, 0);
57 TF_LITE_ENSURE(context, output != nullptr);
58
59 TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
60 input->type == kTfLiteInt8 ||
61 input->type == kTfLiteInt16);
62 TF_LITE_ENSURE(
63 context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt32);
64
65 if (output->type == kTfLiteInt32) {
66 const double effective_output_scale =
67 static_cast<double>(input->params.scale) /
68 static_cast<double>(output->params.scale);
69 QuantizeMultiplier(effective_output_scale, &data->output_multiplier,
70 &data->output_shift);
71 }
72
73 data->quantization_params.zero_point = input->params.zero_point;
74 data->quantization_params.scale = static_cast<double>(input->params.scale);
75 data->output_zero_point = output->params.zero_point;
76 return kTfLiteOk;
77 }
78
Eval(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
80 TFLITE_DCHECK(node->user_data != nullptr);
81 OpData* data = static_cast<OpData*>(node->user_data);
82
83 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
84 TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
85
86 if (output->type == kTfLiteFloat32) {
87 switch (input->type) {
88 case kTfLiteUInt8:
89 reference_ops::Dequantize(data->quantization_params,
90 tflite::micro::GetTensorShape(input),
91 tflite::micro::GetTensorData<uint8_t>(input),
92 tflite::micro::GetTensorShape(output),
93 tflite::micro::GetTensorData<float>(output));
94 break;
95 case kTfLiteInt8:
96 reference_ops::Dequantize(data->quantization_params,
97 tflite::micro::GetTensorShape(input),
98 tflite::micro::GetTensorData<int8_t>(input),
99 tflite::micro::GetTensorShape(output),
100 tflite::micro::GetTensorData<float>(output));
101 break;
102 case kTfLiteInt16:
103 reference_ops::Dequantize(data->quantization_params,
104 tflite::micro::GetTensorShape(input),
105 tflite::micro::GetTensorData<int16_t>(input),
106 tflite::micro::GetTensorShape(output),
107 tflite::micro::GetTensorData<float>(output));
108 break;
109 default:
110 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
111 TfLiteTypeGetName(input->type),
112 TfLiteTypeGetName(output->type));
113 return kTfLiteError;
114 }
115 } else if (output->type == kTfLiteInt32) {
116 int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
117 tflite::micro::GetTensorShape(output));
118 switch (input->type) {
119 case kTfLiteInt8: {
120 reference_ops::Requantize(
121 tflite::micro::GetTensorData<int8_t>(input), flat_size,
122 data->output_multiplier, data->output_shift,
123 data->quantization_params.zero_point, data->output_zero_point,
124 tflite::micro::GetTensorData<int32_t>(output));
125 break;
126 }
127 default:
128 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
129 TfLiteTypeGetName(input->type),
130 TfLiteTypeGetName(output->type));
131 return kTfLiteError;
132 }
133 } else {
134 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
135 TfLiteTypeGetName(input->type),
136 TfLiteTypeGetName(output->type));
137 return kTfLiteError;
138 }
139
140 return kTfLiteOk;
141 }
142
143 } // namespace dequantize
144
Register_DEQUANTIZE()145 TfLiteRegistration Register_DEQUANTIZE() {
146 return {/*init=*/dequantize::Init,
147 /*free=*/nullptr,
148 /*prepare=*/dequantize::Prepare,
149 /*invoke=*/dequantize::Eval,
150 /*profiling_string=*/nullptr,
151 /*builtin_code=*/0,
152 /*custom_name=*/nullptr,
153 /*version=*/0};
154 }
155
156 } // namespace micro
157 } // namespace ops
158 } // namespace tflite
159