1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/sub.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/kernels/op_macros.h"
27 #include "tensorflow/lite/micro/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace micro {
32 namespace sub {
33
34 constexpr int kInputTensor1 = 0;
35 constexpr int kInputTensor2 = 1;
36 constexpr int kOutputTensor = 0;
37
38 struct OpData {
39 bool requires_broadcast;
40
41 // These fields are used in both the general 8-bit -> 8bit quantized path,
42 // and the special 16-bit -> 16bit quantized path
43 int input1_shift;
44 int input2_shift;
45 int32_t output_activation_min;
46 int32_t output_activation_max;
47
48 // These fields are used only in the general 8-bit -> 8bit quantized path
49 int32_t input1_multiplier;
50 int32_t input2_multiplier;
51 int32_t output_multiplier;
52 int output_shift;
53 int left_shift;
54 int32_t input1_offset;
55 int32_t input2_offset;
56 int32_t output_offset;
57 };
58
CalculateOpData(TfLiteContext * context,TfLiteSubParams * params,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,OpData * data)59 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
60 const TfLiteTensor* input1,
61 const TfLiteTensor* input2, TfLiteTensor* output,
62 OpData* data) {
63 data->requires_broadcast = !HaveSameShapes(input1, input2);
64
65 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
66 // 8bit -> 8bit general quantized path, with general rescalings
67 data->input1_offset = -input1->params.zero_point;
68 data->input2_offset = -input2->params.zero_point;
69 data->output_offset = output->params.zero_point;
70 data->left_shift = 20;
71 const float twice_max_input_scale =
72 2 * std::max(input1->params.scale, input2->params.scale);
73 const double real_input1_multiplier =
74 static_cast<double>(input1->params.scale / twice_max_input_scale);
75 const double real_input2_multiplier =
76 static_cast<double>(input2->params.scale / twice_max_input_scale);
77 const double real_output_multiplier =
78 static_cast<double>(twice_max_input_scale /
79 ((1 << data->left_shift) * output->params.scale));
80
81 QuantizeMultiplierSmallerThanOneExp(
82 real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
83
84 QuantizeMultiplierSmallerThanOneExp(
85 real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
86
87 QuantizeMultiplierSmallerThanOneExp(
88 real_output_multiplier, &data->output_multiplier, &data->output_shift);
89
90 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
91 context, params->activation, output, &data->output_activation_min,
92 &data->output_activation_max));
93 }
94
95 return kTfLiteOk;
96 }
97
Init(TfLiteContext * context,const char * buffer,size_t length)98 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
99 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
100 return context->AllocatePersistentBuffer(context, sizeof(OpData));
101 }
102
Prepare(TfLiteContext * context,TfLiteNode * node)103 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
104 TFLITE_DCHECK(node->user_data != nullptr);
105 TFLITE_DCHECK(node->builtin_data != nullptr);
106
107 OpData* data = static_cast<OpData*>(node->user_data);
108 auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
109
110 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
111 TF_LITE_ENSURE(context, input1 != nullptr);
112 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
113 TF_LITE_ENSURE(context, input2 != nullptr);
114 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
115 TF_LITE_ENSURE(context, output != nullptr);
116
117 TF_LITE_ENSURE_STATUS(
118 CalculateOpData(context, params, input1, input2, output, data));
119 return kTfLiteOk;
120 }
121
EvalSub(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)122 void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
123 const OpData* data, const TfLiteEvalTensor* input1,
124 const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
125 float output_activation_min, output_activation_max;
126 CalculateActivationRange(params->activation, &output_activation_min,
127 &output_activation_max);
128 tflite::ArithmeticParams op_params;
129 SetActivationParams(output_activation_min, output_activation_max, &op_params);
130 if (data->requires_broadcast) {
131 tflite::reference_ops::BroadcastSubSlow(
132 op_params, tflite::micro::GetTensorShape(input1),
133 tflite::micro::GetTensorData<float>(input1),
134 tflite::micro::GetTensorShape(input2),
135 tflite::micro::GetTensorData<float>(input2),
136 tflite::micro::GetTensorShape(output),
137 tflite::micro::GetTensorData<float>(output));
138 } else {
139 tflite::reference_ops::SubWithActivation(
140 op_params, tflite::micro::GetTensorShape(input1),
141 tflite::micro::GetTensorData<float>(input1),
142 tflite::micro::GetTensorShape(input2),
143 tflite::micro::GetTensorData<float>(input2),
144 tflite::micro::GetTensorShape(output),
145 tflite::micro::GetTensorData<float>(output));
146 }
147 }
148
EvalSubQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)149 TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
150 TfLiteSubParams* params, const OpData* data,
151 const TfLiteEvalTensor* input1,
152 const TfLiteEvalTensor* input2,
153 TfLiteEvalTensor* output) {
154 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
155 tflite::ArithmeticParams op_params;
156 op_params.left_shift = data->left_shift;
157 op_params.input1_offset = data->input1_offset;
158 op_params.input1_multiplier = data->input1_multiplier;
159 op_params.input1_shift = data->input1_shift;
160 op_params.input2_offset = data->input2_offset;
161 op_params.input2_multiplier = data->input2_multiplier;
162 op_params.input2_shift = data->input2_shift;
163 op_params.output_offset = data->output_offset;
164 op_params.output_multiplier = data->output_multiplier;
165 op_params.output_shift = data->output_shift;
166 SetActivationParams(data->output_activation_min,
167 data->output_activation_max, &op_params);
168 bool need_broadcast = reference_ops::ProcessBroadcastShapes(
169 tflite::micro::GetTensorShape(input1),
170 tflite::micro::GetTensorShape(input2), &op_params);
171
172 if (output->type == kTfLiteInt8) {
173 if (need_broadcast) {
174 tflite::reference_ops::BroadcastSubSlow(
175 op_params, tflite::micro::GetTensorShape(input1),
176 tflite::micro::GetTensorData<int8_t>(input1),
177 tflite::micro::GetTensorShape(input2),
178 tflite::micro::GetTensorData<int8_t>(input2),
179 tflite::micro::GetTensorShape(output),
180 tflite::micro::GetTensorData<int8_t>(output));
181 } else {
182 tflite::reference_ops::Sub(
183 op_params, tflite::micro::GetTensorShape(input1),
184 tflite::micro::GetTensorData<int8_t>(input1),
185 tflite::micro::GetTensorShape(input2),
186 tflite::micro::GetTensorData<int8_t>(input2),
187 tflite::micro::GetTensorShape(output),
188 tflite::micro::GetTensorData<int8_t>(output));
189 }
190 } else {
191 if (need_broadcast) {
192 tflite::reference_ops::BroadcastSubSlow(
193 op_params, tflite::micro::GetTensorShape(input1),
194 tflite::micro::GetTensorData<uint8_t>(input1),
195 tflite::micro::GetTensorShape(input2),
196 tflite::micro::GetTensorData<uint8_t>(input2),
197 tflite::micro::GetTensorShape(output),
198 tflite::micro::GetTensorData<uint8_t>(output));
199 } else {
200 tflite::reference_ops::Sub(
201 op_params, tflite::micro::GetTensorShape(input1),
202 tflite::micro::GetTensorData<uint8_t>(input1),
203 tflite::micro::GetTensorShape(input2),
204 tflite::micro::GetTensorData<uint8_t>(input2),
205 tflite::micro::GetTensorShape(output),
206 tflite::micro::GetTensorData<uint8_t>(output));
207 }
208 }
209 }
210
211 return kTfLiteOk;
212 }
213
Eval(TfLiteContext * context,TfLiteNode * node)214 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
215 auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
216
217 const TfLiteEvalTensor* input1 =
218 tflite::micro::GetEvalInput(context, node, kInputTensor1);
219 const TfLiteEvalTensor* input2 =
220 tflite::micro::GetEvalInput(context, node, kInputTensor2);
221 TfLiteEvalTensor* output =
222 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
223
224 TFLITE_DCHECK(node->user_data != nullptr);
225 const OpData& data = *(static_cast<const OpData*>(node->user_data));
226
227 if (output->type == kTfLiteFloat32) {
228 EvalSub(context, node, params, &data, input1, input2, output);
229 } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
230 TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
231 input1, input2, output));
232 } else {
233 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
234 TfLiteTypeGetName(output->type), output->type);
235 return kTfLiteError;
236 }
237
238 return kTfLiteOk;
239 }
240
241 } // namespace sub
242
Register_SUB()243 TfLiteRegistration Register_SUB() {
244 return {/*init=*/sub::Init,
245 /*free=*/nullptr,
246 /*prepare=*/sub::Prepare,
247 /*invoke=*/sub::Eval,
248 /*profiling_string=*/nullptr,
249 /*builtin_code=*/0,
250 /*custom_name=*/nullptr,
251 /*version=*/0};
252 }
253
254 } // namespace micro
255 } // namespace ops
256 } // namespace tflite
257