1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/quantization_util.h"
18 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
19 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/micro/kernels/kernel_util.h"
23 #include "tensorflow/lite/micro/kernels/quantize.h"
24 #include "tensorflow/lite/micro/micro_utils.h"
25
26 namespace tflite {
27
EvalQuantizeReference(TfLiteContext * context,TfLiteNode * node)28 TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
29 TFLITE_DCHECK(node->user_data != nullptr);
30 auto* data = static_cast<OpDataQuantizeReference*>(node->user_data);
31
32 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
33 TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
34
35 if (input->type == kTfLiteFloat32) {
36 switch (output->type) {
37 case kTfLiteInt8:
38 reference_ops::AffineQuantize(
39 data->quantization_params, tflite::micro::GetTensorShape(input),
40 tflite::micro::GetTensorData<float>(input),
41 tflite::micro::GetTensorShape(output),
42 tflite::micro::GetTensorData<int8_t>(output));
43 break;
44 case kTfLiteUInt8:
45 reference_ops::AffineQuantize(
46 data->quantization_params, tflite::micro::GetTensorShape(input),
47 tflite::micro::GetTensorData<float>(input),
48 tflite::micro::GetTensorShape(output),
49 tflite::micro::GetTensorData<uint8_t>(output));
50 break;
51 case kTfLiteInt16:
52 reference_ops::AffineQuantize(
53 data->quantization_params, tflite::micro::GetTensorShape(input),
54 tflite::micro::GetTensorData<float>(input),
55 tflite::micro::GetTensorShape(output),
56 tflite::micro::GetTensorData<int16_t>(output));
57 return kTfLiteOk;
58 default:
59 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
60 TfLiteTypeGetName(input->type),
61 TfLiteTypeGetName(output->type));
62 return kTfLiteError;
63 }
64 } else if (input->type == kTfLiteInt16) {
65 size_t size = ElementCount(*input->dims);
66 switch (output->type) {
67 case kTfLiteInt8:
68 reference_ops::Requantize(
69 tflite::micro::GetTensorData<int16_t>(input), size,
70 data->requantize_output_multiplier, data->requantize_output_shift,
71 data->input_zero_point, data->quantization_params.zero_point,
72 tflite::micro::GetTensorData<int8_t>(output));
73 break;
74 case kTfLiteInt16:
75 reference_ops::Requantize(
76 tflite::micro::GetTensorData<int16_t>(input), size,
77 data->requantize_output_multiplier, data->requantize_output_shift,
78 data->input_zero_point, data->quantization_params.zero_point,
79 tflite::micro::GetTensorData<int16_t>(output));
80 return kTfLiteOk;
81 case kTfLiteInt32:
82 reference_ops::Requantize(
83 tflite::micro::GetTensorData<int16_t>(input), size,
84 data->requantize_output_multiplier, data->requantize_output_shift,
85 data->input_zero_point, data->quantization_params.zero_point,
86 tflite::micro::GetTensorData<int32_t>(output));
87 return kTfLiteOk;
88 default:
89 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
90 TfLiteTypeGetName(input->type),
91 TfLiteTypeGetName(output->type));
92 return kTfLiteError;
93 }
94 } else if (input->type == kTfLiteInt8) {
95 // Int8 to Int8 requantization, required if the input and output tensors
96 // have different scales and/or zero points.
97 size_t size = ElementCount(*input->dims);
98 switch (output->type) {
99 case kTfLiteInt8:
100 reference_ops::Requantize(
101 tflite::micro::GetTensorData<int8_t>(input), size,
102 data->requantize_output_multiplier, data->requantize_output_shift,
103 data->input_zero_point, data->quantization_params.zero_point,
104 tflite::micro::GetTensorData<int8_t>(output));
105 break;
106 default:
107 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
108 TfLiteTypeGetName(input->type),
109 TfLiteTypeGetName(output->type));
110 return kTfLiteError;
111 }
112 } else {
113 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
114 TfLiteTypeGetName(input->type),
115 TfLiteTypeGetName(output->type));
116 return kTfLiteError;
117 }
118
119 return kTfLiteOk;
120 }
121
122 } // namespace tflite
123