• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/quantization_util.h"
18 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
19 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/micro/kernels/kernel_util.h"
23 #include "tensorflow/lite/micro/kernels/quantize.h"
24 #include "tensorflow/lite/micro/micro_utils.h"
25 
26 namespace tflite {
27 
EvalQuantizeReference(TfLiteContext * context,TfLiteNode * node)28 TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
29   TFLITE_DCHECK(node->user_data != nullptr);
30   auto* data = static_cast<OpDataQuantizeReference*>(node->user_data);
31 
32   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
33   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
34 
35   if (input->type == kTfLiteFloat32) {
36     switch (output->type) {
37       case kTfLiteInt8:
38         reference_ops::AffineQuantize(
39             data->quantization_params, tflite::micro::GetTensorShape(input),
40             tflite::micro::GetTensorData<float>(input),
41             tflite::micro::GetTensorShape(output),
42             tflite::micro::GetTensorData<int8_t>(output));
43         break;
44       case kTfLiteUInt8:
45         reference_ops::AffineQuantize(
46             data->quantization_params, tflite::micro::GetTensorShape(input),
47             tflite::micro::GetTensorData<float>(input),
48             tflite::micro::GetTensorShape(output),
49             tflite::micro::GetTensorData<uint8_t>(output));
50         break;
51       case kTfLiteInt16:
52         reference_ops::AffineQuantize(
53             data->quantization_params, tflite::micro::GetTensorShape(input),
54             tflite::micro::GetTensorData<float>(input),
55             tflite::micro::GetTensorShape(output),
56             tflite::micro::GetTensorData<int16_t>(output));
57         return kTfLiteOk;
58       default:
59         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
60                            TfLiteTypeGetName(input->type),
61                            TfLiteTypeGetName(output->type));
62         return kTfLiteError;
63     }
64   } else if (input->type == kTfLiteInt16) {
65     size_t size = ElementCount(*input->dims);
66     switch (output->type) {
67       case kTfLiteInt8:
68         reference_ops::Requantize(
69             tflite::micro::GetTensorData<int16_t>(input), size,
70             data->requantize_output_multiplier, data->requantize_output_shift,
71             data->input_zero_point, data->quantization_params.zero_point,
72             tflite::micro::GetTensorData<int8_t>(output));
73         break;
74       case kTfLiteInt16:
75         reference_ops::Requantize(
76             tflite::micro::GetTensorData<int16_t>(input), size,
77             data->requantize_output_multiplier, data->requantize_output_shift,
78             data->input_zero_point, data->quantization_params.zero_point,
79             tflite::micro::GetTensorData<int16_t>(output));
80         return kTfLiteOk;
81       case kTfLiteInt32:
82         reference_ops::Requantize(
83             tflite::micro::GetTensorData<int16_t>(input), size,
84             data->requantize_output_multiplier, data->requantize_output_shift,
85             data->input_zero_point, data->quantization_params.zero_point,
86             tflite::micro::GetTensorData<int32_t>(output));
87         return kTfLiteOk;
88       default:
89         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
90                            TfLiteTypeGetName(input->type),
91                            TfLiteTypeGetName(output->type));
92         return kTfLiteError;
93     }
94   } else if (input->type == kTfLiteInt8) {
95     // Int8 to Int8 requantization, required if the input and output tensors
96     // have different scales and/or zero points.
97     size_t size = ElementCount(*input->dims);
98     switch (output->type) {
99       case kTfLiteInt8:
100         reference_ops::Requantize(
101             tflite::micro::GetTensorData<int8_t>(input), size,
102             data->requantize_output_multiplier, data->requantize_output_shift,
103             data->input_zero_point, data->quantization_params.zero_point,
104             tflite::micro::GetTensorData<int8_t>(output));
105         break;
106       default:
107         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
108                            TfLiteTypeGetName(input->type),
109                            TfLiteTypeGetName(output->type));
110         return kTfLiteError;
111     }
112   } else {
113     TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
114                        TfLiteTypeGetName(input->type),
115                        TfLiteTypeGetName(output->type));
116     return kTfLiteError;
117   }
118 
119   return kTfLiteOk;
120 }
121 
122 }  // namespace tflite
123