1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/add.h"
17
18 #include "arm_nnfunctions.h"
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/kernels/internal/quantization_util.h"
21 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
22 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace add {
31
32 constexpr int kInputTensor1 = 0;
33 constexpr int kInputTensor2 = 1;
34 constexpr int kOutputTensor = 0;
35
36 struct OpData {
37 bool requires_broadcast;
38
39 // These fields are used in both the general 8-bit -> 8bit quantized path,
40 // and the special 16-bit -> 16bit quantized path
41 int input1_shift;
42 int input2_shift;
43 int32 output_activation_min;
44 int32 output_activation_max;
45
46 // These fields are used only in the general 8-bit -> 8bit quantized path
47 int32 input1_multiplier;
48 int32 input2_multiplier;
49 int32 output_multiplier;
50 int output_shift;
51 int left_shift;
52 int32 input1_offset;
53 int32 input2_offset;
54 int32 output_offset;
55 };
56
CalculateOpData(TfLiteContext * context,TfLiteAddParams * params,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,OpData * data)57 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params,
58 const TfLiteTensor* input1,
59 const TfLiteTensor* input2, TfLiteTensor* output,
60 OpData* data) {
61 data->requires_broadcast = !HaveSameShapes(input1, input2);
62
63 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
64 // 8bit -> 8bit general quantized path, with general rescalings
65 data->input1_offset = -input1->params.zero_point;
66 data->input2_offset = -input2->params.zero_point;
67 data->output_offset = output->params.zero_point;
68 data->left_shift = 20;
69 const double twice_max_input_scale =
70 2 * std::max(input1->params.scale, input2->params.scale);
71 const double real_input1_multiplier =
72 input1->params.scale / twice_max_input_scale;
73 const double real_input2_multiplier =
74 input2->params.scale / twice_max_input_scale;
75 const double real_output_multiplier =
76 twice_max_input_scale /
77 ((1 << data->left_shift) * output->params.scale);
78
79 QuantizeMultiplierSmallerThanOneExp(
80 real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
81
82 QuantizeMultiplierSmallerThanOneExp(
83 real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
84
85 QuantizeMultiplierSmallerThanOneExp(
86 real_output_multiplier, &data->output_multiplier, &data->output_shift);
87
88 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
89 context, params->activation, output, &data->output_activation_min,
90 &data->output_activation_max));
91 }
92
93 return kTfLiteOk;
94 }
95
EvalAdd(TfLiteContext * context,TfLiteNode * node,TfLiteAddParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)96 void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
97 const OpData* data, const TfLiteTensor* input1,
98 const TfLiteTensor* input2, TfLiteTensor* output) {
99 float output_activation_min, output_activation_max;
100 CalculateActivationRange(params->activation, &output_activation_min,
101 &output_activation_max);
102 tflite::ArithmeticParams op_params;
103 SetActivationParams(output_activation_min, output_activation_max, &op_params);
104 #define TF_LITE_ADD(opname) \
105 reference_ops::opname(op_params, GetTensorShape(input1), \
106 GetTensorData<float>(input1), GetTensorShape(input2), \
107 GetTensorData<float>(input2), GetTensorShape(output), \
108 GetTensorData<float>(output))
109 if (data->requires_broadcast) {
110 TF_LITE_ADD(BroadcastAdd4DSlow);
111 } else {
112 TF_LITE_ADD(Add);
113 }
114 #undef TF_LITE_ADD
115 }
116
EvalAddQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteAddParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)117 TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
118 TfLiteAddParams* params, const OpData* data,
119 const TfLiteTensor* input1,
120 const TfLiteTensor* input2,
121 TfLiteTensor* output) {
122 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
123 tflite::ArithmeticParams op_params;
124 op_params.left_shift = data->left_shift;
125 op_params.input1_offset = data->input1_offset;
126 op_params.input1_multiplier = data->input1_multiplier;
127 op_params.input1_shift = data->input1_shift;
128 op_params.input2_offset = data->input2_offset;
129 op_params.input2_multiplier = data->input2_multiplier;
130 op_params.input2_shift = data->input2_shift;
131 op_params.output_offset = data->output_offset;
132 op_params.output_multiplier = data->output_multiplier;
133 op_params.output_shift = data->output_shift;
134 SetActivationParams(data->output_activation_min,
135 data->output_activation_max, &op_params);
136 bool need_broadcast = reference_ops::ProcessBroadcastShapes(
137 GetTensorShape(input1), GetTensorShape(input2), &op_params);
138 #define TF_LITE_ADD(type, opname, dtype) \
139 type::opname(op_params, GetTensorShape(input1), \
140 GetTensorData<dtype>(input1), GetTensorShape(input2), \
141 GetTensorData<dtype>(input2), GetTensorShape(output), \
142 GetTensorData<dtype>(output));
143 if (output->type == kTfLiteInt8) {
144 if (need_broadcast) {
145 TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t);
146 } else {
147 arm_elementwise_add_s8(
148 GetTensorData<int8_t>(input1), GetTensorData<int8_t>(input2),
149 op_params.input1_offset, op_params.input1_multiplier,
150 op_params.input1_shift, op_params.input2_offset,
151 op_params.input2_multiplier, op_params.input2_shift,
152 op_params.left_shift, GetTensorData<int8_t>(output),
153 op_params.output_offset, op_params.output_multiplier,
154 op_params.output_shift, op_params.quantized_activation_min,
155 op_params.quantized_activation_max,
156 MatchingElementsSize(GetTensorShape(input1), GetTensorShape(input2),
157 GetTensorShape(output)));
158 }
159 } else {
160 if (need_broadcast) {
161 TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t);
162 } else {
163 TF_LITE_ADD(reference_ops, Add, uint8_t);
164 }
165 }
166 #undef TF_LITE_ADD
167 }
168
169 return kTfLiteOk;
170 }
171
Eval(TfLiteContext * context,TfLiteNode * node)172 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
173 auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
174
175 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
176 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
177 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
178
179 OpData data;
180 TF_LITE_ENSURE_STATUS(
181 CalculateOpData(context, params, input1, input2, output, &data));
182
183 if (output->type == kTfLiteFloat32) {
184 EvalAdd(context, node, params, &data, input1, input2, output);
185 } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
186 TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, &data,
187 input1, input2, output));
188 } else {
189 context->ReportError(context,
190 "Inputs and outputs not all float|uint8|int8 types.");
191 return kTfLiteError;
192 }
193
194 return kTfLiteOk;
195 }
196
197 } // namespace add
198
Register_ADD()199 TfLiteRegistration* Register_ADD() {
200 static TfLiteRegistration r = {nullptr /* Init */, nullptr /* Free */,
201 nullptr /* Prepare */, add::Eval};
202 return &r;
203 }
204
205 } // namespace micro
206 } // namespace ops
207 } // namespace tflite
208