• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/sub.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/kernels/op_macros.h"
27 #include "tensorflow/lite/micro/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace micro {
32 namespace sub {
33 
34 constexpr int kInputTensor1 = 0;
35 constexpr int kInputTensor2 = 1;
36 constexpr int kOutputTensor = 0;
37 
38 struct OpData {
39   bool requires_broadcast;
40 
41   // These fields are used in both the general 8-bit -> 8bit quantized path,
42   // and the special 16-bit -> 16bit quantized path
43   int input1_shift;
44   int input2_shift;
45   int32_t output_activation_min;
46   int32_t output_activation_max;
47 
48   // These fields are used only in the general 8-bit -> 8bit quantized path
49   int32_t input1_multiplier;
50   int32_t input2_multiplier;
51   int32_t output_multiplier;
52   int output_shift;
53   int left_shift;
54   int32_t input1_offset;
55   int32_t input2_offset;
56   int32_t output_offset;
57 };
58 
CalculateOpData(TfLiteContext * context,TfLiteSubParams * params,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,OpData * data)59 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
60                              const TfLiteTensor* input1,
61                              const TfLiteTensor* input2, TfLiteTensor* output,
62                              OpData* data) {
63   data->requires_broadcast = !HaveSameShapes(input1, input2);
64 
65   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
66     // 8bit -> 8bit general quantized path, with general rescalings
67     data->input1_offset = -input1->params.zero_point;
68     data->input2_offset = -input2->params.zero_point;
69     data->output_offset = output->params.zero_point;
70     data->left_shift = 20;
71     const float twice_max_input_scale =
72         2 * std::max(input1->params.scale, input2->params.scale);
73     const double real_input1_multiplier =
74         static_cast<double>(input1->params.scale / twice_max_input_scale);
75     const double real_input2_multiplier =
76         static_cast<double>(input2->params.scale / twice_max_input_scale);
77     const double real_output_multiplier =
78         static_cast<double>(twice_max_input_scale /
79                             ((1 << data->left_shift) * output->params.scale));
80 
81     QuantizeMultiplierSmallerThanOneExp(
82         real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
83 
84     QuantizeMultiplierSmallerThanOneExp(
85         real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
86 
87     QuantizeMultiplierSmallerThanOneExp(
88         real_output_multiplier, &data->output_multiplier, &data->output_shift);
89 
90     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
91         context, params->activation, output, &data->output_activation_min,
92         &data->output_activation_max));
93   }
94 
95   return kTfLiteOk;
96 }
97 
Init(TfLiteContext * context,const char * buffer,size_t length)98 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
99   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
100   return context->AllocatePersistentBuffer(context, sizeof(OpData));
101 }
102 
Prepare(TfLiteContext * context,TfLiteNode * node)103 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
104   TFLITE_DCHECK(node->user_data != nullptr);
105   TFLITE_DCHECK(node->builtin_data != nullptr);
106 
107   OpData* data = static_cast<OpData*>(node->user_data);
108   auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
109 
110   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
111   TF_LITE_ENSURE(context, input1 != nullptr);
112   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
113   TF_LITE_ENSURE(context, input2 != nullptr);
114   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
115   TF_LITE_ENSURE(context, output != nullptr);
116 
117   TF_LITE_ENSURE_STATUS(
118       CalculateOpData(context, params, input1, input2, output, data));
119   return kTfLiteOk;
120 }
121 
EvalSub(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)122 void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
123              const OpData* data, const TfLiteEvalTensor* input1,
124              const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
125   float output_activation_min, output_activation_max;
126   CalculateActivationRange(params->activation, &output_activation_min,
127                            &output_activation_max);
128   tflite::ArithmeticParams op_params;
129   SetActivationParams(output_activation_min, output_activation_max, &op_params);
130   if (data->requires_broadcast) {
131     tflite::reference_ops::BroadcastSubSlow(
132         op_params, tflite::micro::GetTensorShape(input1),
133         tflite::micro::GetTensorData<float>(input1),
134         tflite::micro::GetTensorShape(input2),
135         tflite::micro::GetTensorData<float>(input2),
136         tflite::micro::GetTensorShape(output),
137         tflite::micro::GetTensorData<float>(output));
138   } else {
139     tflite::reference_ops::SubWithActivation(
140         op_params, tflite::micro::GetTensorShape(input1),
141         tflite::micro::GetTensorData<float>(input1),
142         tflite::micro::GetTensorShape(input2),
143         tflite::micro::GetTensorData<float>(input2),
144         tflite::micro::GetTensorShape(output),
145         tflite::micro::GetTensorData<float>(output));
146   }
147 }
148 
EvalSubQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)149 TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
150                               TfLiteSubParams* params, const OpData* data,
151                               const TfLiteEvalTensor* input1,
152                               const TfLiteEvalTensor* input2,
153                               TfLiteEvalTensor* output) {
154   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
155     tflite::ArithmeticParams op_params;
156     op_params.left_shift = data->left_shift;
157     op_params.input1_offset = data->input1_offset;
158     op_params.input1_multiplier = data->input1_multiplier;
159     op_params.input1_shift = data->input1_shift;
160     op_params.input2_offset = data->input2_offset;
161     op_params.input2_multiplier = data->input2_multiplier;
162     op_params.input2_shift = data->input2_shift;
163     op_params.output_offset = data->output_offset;
164     op_params.output_multiplier = data->output_multiplier;
165     op_params.output_shift = data->output_shift;
166     SetActivationParams(data->output_activation_min,
167                         data->output_activation_max, &op_params);
168     bool need_broadcast = reference_ops::ProcessBroadcastShapes(
169         tflite::micro::GetTensorShape(input1),
170         tflite::micro::GetTensorShape(input2), &op_params);
171 
172     if (output->type == kTfLiteInt8) {
173       if (need_broadcast) {
174         tflite::reference_ops::BroadcastSubSlow(
175             op_params, tflite::micro::GetTensorShape(input1),
176             tflite::micro::GetTensorData<int8_t>(input1),
177             tflite::micro::GetTensorShape(input2),
178             tflite::micro::GetTensorData<int8_t>(input2),
179             tflite::micro::GetTensorShape(output),
180             tflite::micro::GetTensorData<int8_t>(output));
181       } else {
182         tflite::reference_ops::Sub(
183             op_params, tflite::micro::GetTensorShape(input1),
184             tflite::micro::GetTensorData<int8_t>(input1),
185             tflite::micro::GetTensorShape(input2),
186             tflite::micro::GetTensorData<int8_t>(input2),
187             tflite::micro::GetTensorShape(output),
188             tflite::micro::GetTensorData<int8_t>(output));
189       }
190     } else {
191       if (need_broadcast) {
192         tflite::reference_ops::BroadcastSubSlow(
193             op_params, tflite::micro::GetTensorShape(input1),
194             tflite::micro::GetTensorData<uint8_t>(input1),
195             tflite::micro::GetTensorShape(input2),
196             tflite::micro::GetTensorData<uint8_t>(input2),
197             tflite::micro::GetTensorShape(output),
198             tflite::micro::GetTensorData<uint8_t>(output));
199       } else {
200         tflite::reference_ops::Sub(
201             op_params, tflite::micro::GetTensorShape(input1),
202             tflite::micro::GetTensorData<uint8_t>(input1),
203             tflite::micro::GetTensorShape(input2),
204             tflite::micro::GetTensorData<uint8_t>(input2),
205             tflite::micro::GetTensorShape(output),
206             tflite::micro::GetTensorData<uint8_t>(output));
207       }
208     }
209   }
210 
211   return kTfLiteOk;
212 }
213 
Eval(TfLiteContext * context,TfLiteNode * node)214 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
215   auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
216 
217   const TfLiteEvalTensor* input1 =
218       tflite::micro::GetEvalInput(context, node, kInputTensor1);
219   const TfLiteEvalTensor* input2 =
220       tflite::micro::GetEvalInput(context, node, kInputTensor2);
221   TfLiteEvalTensor* output =
222       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
223 
224   TFLITE_DCHECK(node->user_data != nullptr);
225   const OpData& data = *(static_cast<const OpData*>(node->user_data));
226 
227   if (output->type == kTfLiteFloat32) {
228     EvalSub(context, node, params, &data, input1, input2, output);
229   } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
230     TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
231                                                 input1, input2, output));
232   } else {
233     TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
234                        TfLiteTypeGetName(output->type), output->type);
235     return kTfLiteError;
236   }
237 
238   return kTfLiteOk;
239 }
240 
241 }  // namespace sub
242 
Register_SUB()243 TfLiteRegistration Register_SUB() {
244   return {/*init=*/sub::Init,
245           /*free=*/nullptr,
246           /*prepare=*/sub::Prepare,
247           /*invoke=*/sub::Eval,
248           /*profiling_string=*/nullptr,
249           /*builtin_code=*/0,
250           /*custom_name=*/nullptr,
251           /*version=*/0};
252 }
253 
254 }  // namespace micro
255 }  // namespace ops
256 }  // namespace tflite
257