1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/quantization_util.h"
21 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26 #include "tensorflow/lite/micro/micro_utils.h"
27
28 namespace tflite {
29 namespace ops {
30 namespace micro {
31 namespace reduce {
32
33 constexpr int kMaxNumberOfAxis = 4;
34 constexpr int kMaxNumberOfReducedAxis = 2;
35
36 struct OpData {
37 int32_t multiplier;
38 int shift;
39 int temp_buffer_idx;
40 int resolved_axis_idx;
41 int input_zp;
42 float input_scale;
43 int output_zp;
44 float output_scale;
45 int num_output_elements;
46 };
47
InitReduce(TfLiteContext * context,const char * buffer,size_t length)48 void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
49 return context->AllocatePersistentBuffer(context, sizeof(OpData));
50 }
51
PrepareSimple(TfLiteContext * context,TfLiteNode * node)52 TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
53 // Inputs Tensor (dtype depends on quantization):
54 // [0] = Input
55 // [1] = Axis
56 const TfLiteTensor* input = GetInput(context, node, 0);
57
58 // Outputs Tensor (dtype depends on quantization):
59 // [0] = Output
60
61 // Validate number of inputs and outputs
62 TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
63 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
64
65 // Validate axis type
66 const TfLiteTensor* axis = GetInput(context, node, 1);
67 TF_LITE_ENSURE(context, axis != nullptr);
68 TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
69
70 if (input->type == kTfLiteInt8) {
71 OpData* data = static_cast<OpData*>(node->user_data);
72 const TfLiteTensor* output = GetOutput(context, node, 0);
73 const double real_multiplier = static_cast<double>(input->params.scale) /
74 static_cast<double>(output->params.scale);
75 QuantizeMultiplier(real_multiplier, &data->multiplier, &data->shift);
76 }
77
78 return kTfLiteOk;
79 }
80
PrepareMax(TfLiteContext * context,TfLiteNode * node)81 TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
82 TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
83
84 OpData* op_data = static_cast<OpData*>(node->user_data);
85 const TfLiteTensor* input = GetInput(context, node, 0);
86 const TfLiteTensor* output = GetOutput(context, node, 0);
87 const TfLiteTensor* axis = GetInput(context, node, 1);
88
89 op_data->input_scale = input->params.scale;
90 op_data->output_scale = output->params.scale;
91 op_data->num_output_elements = NumElements(output);
92
93 context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
94 &op_data->temp_buffer_idx);
95 context->RequestScratchBufferInArena(
96 context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
97 &op_data->resolved_axis_idx);
98
99 return kTfLiteOk;
100 }
101
PrepareMeanOrSum(TfLiteContext * context,TfLiteNode * node)102 TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
103 const TfLiteTensor* input = GetInput(context, node, 0);
104 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
105 const TfLiteTensor* output = GetOutput(context, node, 0);
106 if (input->type == kTfLiteInt8) {
107 const double real_multiplier = static_cast<double>(input->params.scale) /
108 static_cast<double>(output->params.scale);
109 QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
110 }
111
112 int output_size = NumElements(output);
113 if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
114 context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
115 &op_data->temp_buffer_idx);
116 op_data->input_zp = input->params.zero_point;
117 op_data->input_scale = input->params.scale;
118 op_data->output_zp = output->params.zero_point;
119 op_data->output_scale = output->params.scale;
120 }
121
122 TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
123 // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
124 return kTfLiteOk;
125 }
126
ResolveAxis(const int * axis_data,int axis_count,tflite::MeanParams * op_params)127 void ResolveAxis(const int* axis_data, int axis_count,
128 tflite::MeanParams* op_params) {
129 int i = 0;
130 for (; i < axis_count; ++i) {
131 op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
132 }
133 for (; i < 4; ++i) {
134 op_params->axis[i] = 1;
135 }
136 op_params->axis_count = axis_count;
137 }
138
EvalMean(TfLiteContext * context,TfLiteNode * node)139 TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
140 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
141 const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
142 TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
143 TfLiteReducerParams* params =
144 reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
145 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
146
147 int num_axis = static_cast<int>(ElementCount(*axis->dims));
148 int temp_index[kMaxNumberOfAxis];
149 int resolved_axis[kMaxNumberOfReducedAxis];
150
151 tflite::MeanParams op_params;
152 ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
153
154 // Special case mean implementation exists for 4D mean across axes 1 and 2.
155 bool special_case_4d_axes_1_and_2 =
156 input->dims->size == 4 && op_params.axis_count == 2 &&
157 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
158 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
159
160 switch (input->type) {
161 case kTfLiteFloat32: {
162 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
163 if (params->keep_dims && special_case_4d_axes_1_and_2) {
164 reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
165 tflite::micro::GetTensorData<float>(input),
166 tflite::micro::GetTensorShape(output),
167 tflite::micro::GetTensorData<float>(output));
168 } else {
169 TF_LITE_ENSURE(
170 context,
171 reference_ops::Mean(
172 tflite::micro::GetTensorData<float>(input), input->dims->data,
173 input->dims->size, tflite::micro::GetTensorData<float>(output),
174 output->dims->data, output->dims->size,
175 tflite::micro::GetTensorData<int>(axis), num_axis,
176 params->keep_dims, temp_index, resolved_axis,
177 tflite::micro::GetTensorData<float>(output)));
178 }
179 } break;
180 case kTfLiteInt8: {
181 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
182 if (params->keep_dims && special_case_4d_axes_1_and_2) {
183 reference_integer_ops::Mean(
184 op_params, op_data->multiplier, op_data->shift,
185 tflite::micro::GetTensorShape(input),
186 tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
187 tflite::micro::GetTensorShape(output),
188 tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
189 } else if (op_data->input_zp == op_data->output_zp &&
190 op_data->input_scale == op_data->output_scale) {
191 int32_t* temp_buffer = static_cast<int32_t*>(
192 context->GetScratchBuffer(context, op_data->temp_buffer_idx));
193 TF_LITE_ENSURE(
194 context,
195 reference_ops::Mean(
196 tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
197 input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
198 output->dims->data, output->dims->size,
199 tflite::micro::GetTensorData<int>(axis), num_axis,
200 params->keep_dims, temp_index, resolved_axis, temp_buffer));
201 } else {
202 int32_t* temp_buffer = static_cast<int32_t*>(
203 context->GetScratchBuffer(context, op_data->temp_buffer_idx));
204 TF_LITE_ENSURE(
205 context,
206 reference_ops::QuantizedMeanOrSum(
207 tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
208 op_data->input_scale, input->dims->data, input->dims->size,
209 tflite::micro::GetTensorData<int8_t>(output),
210 op_data->output_zp, op_data->output_scale, output->dims->data,
211 output->dims->size, tflite::micro::GetTensorData<int>(axis),
212 num_axis, params->keep_dims, temp_index, resolved_axis,
213 temp_buffer, false));
214 }
215 } break;
216 case kTfLiteUInt8: {
217 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
218 if (params->keep_dims && special_case_4d_axes_1_and_2) {
219 reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
220 tflite::micro::GetTensorData<uint8_t>(input),
221 op_data->input_zp, op_data->input_scale,
222 tflite::micro::GetTensorShape(output),
223 tflite::micro::GetTensorData<uint8_t>(output),
224 op_data->output_zp, op_data->output_scale);
225 } else if (op_data->input_zp == op_data->output_zp &&
226 op_data->input_scale == op_data->output_scale) {
227 uint32_t* temp_buffer = static_cast<uint32_t*>(
228 context->GetScratchBuffer(context, op_data->temp_buffer_idx));
229 TF_LITE_ENSURE(
230 context,
231 reference_ops::Mean(tflite::micro::GetTensorData<uint8_t>(input),
232 input->dims->data, input->dims->size,
233 tflite::micro::GetTensorData<uint8_t>(output),
234 output->dims->data, output->dims->size,
235 tflite::micro::GetTensorData<int>(axis),
236 num_axis, params->keep_dims, temp_index,
237 resolved_axis, temp_buffer));
238 } else {
239 uint32_t* temp_buffer = static_cast<uint32_t*>(
240 context->GetScratchBuffer(context, op_data->temp_buffer_idx));
241 TF_LITE_ENSURE(
242 context,
243 reference_ops::QuantizedMeanOrSum(
244 tflite::micro::GetTensorData<uint8_t>(input), op_data->input_zp,
245 op_data->input_scale, input->dims->data, input->dims->size,
246 tflite::micro::GetTensorData<uint8_t>(output),
247 op_data->output_zp, op_data->output_scale, output->dims->data,
248 output->dims->size, tflite::micro::GetTensorData<int>(axis),
249 num_axis, params->keep_dims, temp_index, resolved_axis,
250 temp_buffer, false));
251 }
252 } break;
253 default:
254 TF_LITE_ENSURE_MSG(context, false,
255 "Currently, only float32, int8 or uint8 input type "
256 "is supported.");
257 }
258 return kTfLiteOk;
259 }
260
EvalMax(TfLiteContext * context,TfLiteNode * node)261 TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
262 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
263 const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
264 TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
265 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
266 TfLiteReducerParams* params =
267 static_cast<TfLiteReducerParams*>(node->builtin_data);
268 OpData* op_data = static_cast<OpData*>(node->user_data);
269
270 // Interpret an axis tensor with null dimensions as a scalar
271 int num_axis = static_cast<int>(ElementCount(*axis->dims));
272 int* temp_buffer = static_cast<int*>(
273 context->GetScratchBuffer(context, op_data->temp_buffer_idx));
274 int* resolved_axis = static_cast<int*>(
275 context->GetScratchBuffer(context, op_data->resolved_axis_idx));
276 switch (input->type) {
277 case kTfLiteFloat32:
278 TF_LITE_ENSURE(
279 context,
280 reference_ops::ReduceGeneric<float>(
281 tflite::micro::GetTensorData<float>(input), input->dims->data,
282 input->dims->size, tflite::micro::GetTensorData<float>(output),
283 output->dims->data, output->dims->size,
284 tflite::micro::GetTensorData<int>(axis), num_axis,
285 params->keep_dims, temp_buffer, resolved_axis,
286 std::numeric_limits<float>::lowest(),
287 [](const float current, const float in) -> float {
288 return (in > current) ? in : current;
289 }));
290 break;
291 case kTfLiteInt8:
292 TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
293 static_cast<double>(op_data->output_scale));
294 TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
295 TF_LITE_ENSURE(
296 context,
297 reference_ops::ReduceGeneric<int8_t>(
298 tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
299 input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
300 output->dims->data, output->dims->size,
301 tflite::micro::GetTensorData<int>(axis), num_axis,
302 params->keep_dims, temp_buffer, resolved_axis,
303 std::numeric_limits<int8_t>::lowest(),
304 [](const int8_t current, const int8_t in) -> int8_t {
305 return (in > current) ? in : current;
306 }));
307 break;
308 default:
309 TF_LITE_KERNEL_LOG(context,
310 "Only float32 and int8 types are supported.\n");
311 return kTfLiteError;
312 }
313 return kTfLiteOk;
314 }
315
316 } // namespace reduce
317
Register_MEAN()318 TfLiteRegistration Register_MEAN() {
319 return {/*init=*/reduce::InitReduce,
320 /*free=*/nullptr,
321 /*prepare=*/reduce::PrepareMeanOrSum,
322 /*invoke=*/reduce::EvalMean,
323 /*profiling_string=*/nullptr,
324 /*builtin_code=*/0,
325 /*custom_name=*/nullptr,
326 /*version=*/0};
327 }
328
Register_REDUCE_MAX()329 TfLiteRegistration Register_REDUCE_MAX() {
330 return {/*init=*/reduce::InitReduce,
331 /*free=*/nullptr,
332 /*prepare=*/reduce::PrepareMax,
333 /*invoke=*/reduce::EvalMax,
334 /*profiling_string=*/nullptr,
335 /*builtin_code=*/0,
336 /*custom_name=*/nullptr,
337 /*version=*/0};
338 }
339
340 } // namespace micro
341 } // namespace ops
342 } // namespace tflite
343