• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/quantization_util.h"
21 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26 #include "tensorflow/lite/micro/micro_utils.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace micro {
31 namespace reduce {
32 
33 constexpr int kMaxNumberOfAxis = 4;
34 constexpr int kMaxNumberOfReducedAxis = 2;
35 
36 struct OpData {
37   int32_t multiplier;
38   int shift;
39   int temp_buffer_idx;
40   int resolved_axis_idx;
41   int input_zp;
42   float input_scale;
43   int output_zp;
44   float output_scale;
45   int num_output_elements;
46 };
47 
InitReduce(TfLiteContext * context,const char * buffer,size_t length)48 void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
49   return context->AllocatePersistentBuffer(context, sizeof(OpData));
50 }
51 
PrepareSimple(TfLiteContext * context,TfLiteNode * node)52 TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
53   // Inputs Tensor (dtype depends on quantization):
54   // [0] = Input
55   // [1] = Axis
56   const TfLiteTensor* input = GetInput(context, node, 0);
57 
58   // Outputs Tensor (dtype depends on quantization):
59   // [0] = Output
60 
61   // Validate number of inputs and outputs
62   TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
63   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
64 
65   // Validate axis type
66   const TfLiteTensor* axis = GetInput(context, node, 1);
67   TF_LITE_ENSURE(context, axis != nullptr);
68   TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
69 
70   if (input->type == kTfLiteInt8) {
71     OpData* data = static_cast<OpData*>(node->user_data);
72     const TfLiteTensor* output = GetOutput(context, node, 0);
73     const double real_multiplier = static_cast<double>(input->params.scale) /
74                                    static_cast<double>(output->params.scale);
75     QuantizeMultiplier(real_multiplier, &data->multiplier, &data->shift);
76   }
77 
78   return kTfLiteOk;
79 }
80 
PrepareMax(TfLiteContext * context,TfLiteNode * node)81 TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
82   TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
83 
84   OpData* op_data = static_cast<OpData*>(node->user_data);
85   const TfLiteTensor* input = GetInput(context, node, 0);
86   const TfLiteTensor* output = GetOutput(context, node, 0);
87   const TfLiteTensor* axis = GetInput(context, node, 1);
88 
89   op_data->input_scale = input->params.scale;
90   op_data->output_scale = output->params.scale;
91   op_data->num_output_elements = NumElements(output);
92 
93   context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
94                                        &op_data->temp_buffer_idx);
95   context->RequestScratchBufferInArena(
96       context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
97       &op_data->resolved_axis_idx);
98 
99   return kTfLiteOk;
100 }
101 
PrepareMeanOrSum(TfLiteContext * context,TfLiteNode * node)102 TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
103   const TfLiteTensor* input = GetInput(context, node, 0);
104   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
105   const TfLiteTensor* output = GetOutput(context, node, 0);
106   if (input->type == kTfLiteInt8) {
107     const double real_multiplier = static_cast<double>(input->params.scale) /
108                                    static_cast<double>(output->params.scale);
109     QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
110   }
111 
112   int output_size = NumElements(output);
113   if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
114     context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
115                                          &op_data->temp_buffer_idx);
116     op_data->input_zp = input->params.zero_point;
117     op_data->input_scale = input->params.scale;
118     op_data->output_zp = output->params.zero_point;
119     op_data->output_scale = output->params.scale;
120   }
121 
122   TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
123   // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
124   return kTfLiteOk;
125 }
126 
ResolveAxis(const int * axis_data,int axis_count,tflite::MeanParams * op_params)127 void ResolveAxis(const int* axis_data, int axis_count,
128                  tflite::MeanParams* op_params) {
129   int i = 0;
130   for (; i < axis_count; ++i) {
131     op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
132   }
133   for (; i < 4; ++i) {
134     op_params->axis[i] = 1;
135   }
136   op_params->axis_count = axis_count;
137 }
138 
EvalMean(TfLiteContext * context,TfLiteNode * node)139 TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
140   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
141   const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
142   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
143   TfLiteReducerParams* params =
144       reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
145   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
146 
147   int num_axis = static_cast<int>(ElementCount(*axis->dims));
148   int temp_index[kMaxNumberOfAxis];
149   int resolved_axis[kMaxNumberOfReducedAxis];
150 
151   tflite::MeanParams op_params;
152   ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
153 
154   // Special case mean implementation exists for 4D mean across axes 1 and 2.
155   bool special_case_4d_axes_1_and_2 =
156       input->dims->size == 4 && op_params.axis_count == 2 &&
157       ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
158        (op_params.axis[0] == 2 && op_params.axis[1] == 1));
159 
160   switch (input->type) {
161     case kTfLiteFloat32: {
162       // Defer to specialized implementation for 4D Mean across axes 1 & 2.
163       if (params->keep_dims && special_case_4d_axes_1_and_2) {
164         reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
165                             tflite::micro::GetTensorData<float>(input),
166                             tflite::micro::GetTensorShape(output),
167                             tflite::micro::GetTensorData<float>(output));
168       } else {
169         TF_LITE_ENSURE(
170             context,
171             reference_ops::Mean(
172                 tflite::micro::GetTensorData<float>(input), input->dims->data,
173                 input->dims->size, tflite::micro::GetTensorData<float>(output),
174                 output->dims->data, output->dims->size,
175                 tflite::micro::GetTensorData<int>(axis), num_axis,
176                 params->keep_dims, temp_index, resolved_axis,
177                 tflite::micro::GetTensorData<float>(output)));
178       }
179     } break;
180     case kTfLiteInt8: {
181       // Defer to specialized implementation for 4D Mean across axes 1 & 2.
182       if (params->keep_dims && special_case_4d_axes_1_and_2) {
183         reference_integer_ops::Mean(
184             op_params, op_data->multiplier, op_data->shift,
185             tflite::micro::GetTensorShape(input),
186             tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
187             tflite::micro::GetTensorShape(output),
188             tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
189       } else if (op_data->input_zp == op_data->output_zp &&
190                  op_data->input_scale == op_data->output_scale) {
191         int32_t* temp_buffer = static_cast<int32_t*>(
192             context->GetScratchBuffer(context, op_data->temp_buffer_idx));
193         TF_LITE_ENSURE(
194             context,
195             reference_ops::Mean(
196                 tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
197                 input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
198                 output->dims->data, output->dims->size,
199                 tflite::micro::GetTensorData<int>(axis), num_axis,
200                 params->keep_dims, temp_index, resolved_axis, temp_buffer));
201       } else {
202         int32_t* temp_buffer = static_cast<int32_t*>(
203             context->GetScratchBuffer(context, op_data->temp_buffer_idx));
204         TF_LITE_ENSURE(
205             context,
206             reference_ops::QuantizedMeanOrSum(
207                 tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
208                 op_data->input_scale, input->dims->data, input->dims->size,
209                 tflite::micro::GetTensorData<int8_t>(output),
210                 op_data->output_zp, op_data->output_scale, output->dims->data,
211                 output->dims->size, tflite::micro::GetTensorData<int>(axis),
212                 num_axis, params->keep_dims, temp_index, resolved_axis,
213                 temp_buffer, false));
214       }
215     } break;
216     case kTfLiteUInt8: {
217       // Defer to specialized implementation for 4D Mean across axes 1 & 2.
218       if (params->keep_dims && special_case_4d_axes_1_and_2) {
219         reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
220                             tflite::micro::GetTensorData<uint8_t>(input),
221                             op_data->input_zp, op_data->input_scale,
222                             tflite::micro::GetTensorShape(output),
223                             tflite::micro::GetTensorData<uint8_t>(output),
224                             op_data->output_zp, op_data->output_scale);
225       } else if (op_data->input_zp == op_data->output_zp &&
226                  op_data->input_scale == op_data->output_scale) {
227         uint32_t* temp_buffer = static_cast<uint32_t*>(
228             context->GetScratchBuffer(context, op_data->temp_buffer_idx));
229         TF_LITE_ENSURE(
230             context,
231             reference_ops::Mean(tflite::micro::GetTensorData<uint8_t>(input),
232                                 input->dims->data, input->dims->size,
233                                 tflite::micro::GetTensorData<uint8_t>(output),
234                                 output->dims->data, output->dims->size,
235                                 tflite::micro::GetTensorData<int>(axis),
236                                 num_axis, params->keep_dims, temp_index,
237                                 resolved_axis, temp_buffer));
238       } else {
239         uint32_t* temp_buffer = static_cast<uint32_t*>(
240             context->GetScratchBuffer(context, op_data->temp_buffer_idx));
241         TF_LITE_ENSURE(
242             context,
243             reference_ops::QuantizedMeanOrSum(
244                 tflite::micro::GetTensorData<uint8_t>(input), op_data->input_zp,
245                 op_data->input_scale, input->dims->data, input->dims->size,
246                 tflite::micro::GetTensorData<uint8_t>(output),
247                 op_data->output_zp, op_data->output_scale, output->dims->data,
248                 output->dims->size, tflite::micro::GetTensorData<int>(axis),
249                 num_axis, params->keep_dims, temp_index, resolved_axis,
250                 temp_buffer, false));
251       }
252     } break;
253     default:
254       TF_LITE_ENSURE_MSG(context, false,
255                          "Currently, only float32, int8 or uint8 input type "
256                          "is supported.");
257   }
258   return kTfLiteOk;
259 }
260 
EvalMax(TfLiteContext * context,TfLiteNode * node)261 TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
262   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
263   const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
264   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
265   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
266   TfLiteReducerParams* params =
267       static_cast<TfLiteReducerParams*>(node->builtin_data);
268   OpData* op_data = static_cast<OpData*>(node->user_data);
269 
270   // Interpret an axis tensor with null dimensions as a scalar
271   int num_axis = static_cast<int>(ElementCount(*axis->dims));
272   int* temp_buffer = static_cast<int*>(
273       context->GetScratchBuffer(context, op_data->temp_buffer_idx));
274   int* resolved_axis = static_cast<int*>(
275       context->GetScratchBuffer(context, op_data->resolved_axis_idx));
276   switch (input->type) {
277     case kTfLiteFloat32:
278       TF_LITE_ENSURE(
279           context,
280           reference_ops::ReduceGeneric<float>(
281               tflite::micro::GetTensorData<float>(input), input->dims->data,
282               input->dims->size, tflite::micro::GetTensorData<float>(output),
283               output->dims->data, output->dims->size,
284               tflite::micro::GetTensorData<int>(axis), num_axis,
285               params->keep_dims, temp_buffer, resolved_axis,
286               std::numeric_limits<float>::lowest(),
287               [](const float current, const float in) -> float {
288                 return (in > current) ? in : current;
289               }));
290       break;
291     case kTfLiteInt8:
292       TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
293                         static_cast<double>(op_data->output_scale));
294       TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
295       TF_LITE_ENSURE(
296           context,
297           reference_ops::ReduceGeneric<int8_t>(
298               tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
299               input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
300               output->dims->data, output->dims->size,
301               tflite::micro::GetTensorData<int>(axis), num_axis,
302               params->keep_dims, temp_buffer, resolved_axis,
303               std::numeric_limits<int8_t>::lowest(),
304               [](const int8_t current, const int8_t in) -> int8_t {
305                 return (in > current) ? in : current;
306               }));
307       break;
308     default:
309       TF_LITE_KERNEL_LOG(context,
310                          "Only float32 and int8 types are supported.\n");
311       return kTfLiteError;
312   }
313   return kTfLiteOk;
314 }
315 
316 }  // namespace reduce
317 
Register_MEAN()318 TfLiteRegistration Register_MEAN() {
319   return {/*init=*/reduce::InitReduce,
320           /*free=*/nullptr,
321           /*prepare=*/reduce::PrepareMeanOrSum,
322           /*invoke=*/reduce::EvalMean,
323           /*profiling_string=*/nullptr,
324           /*builtin_code=*/0,
325           /*custom_name=*/nullptr,
326           /*version=*/0};
327 }
328 
Register_REDUCE_MAX()329 TfLiteRegistration Register_REDUCE_MAX() {
330   return {/*init=*/reduce::InitReduce,
331           /*free=*/nullptr,
332           /*prepare=*/reduce::PrepareMax,
333           /*invoke=*/reduce::EvalMax,
334           /*profiling_string=*/nullptr,
335           /*builtin_code=*/0,
336           /*custom_name=*/nullptr,
337           /*version=*/0};
338 }
339 
340 }  // namespace micro
341 }  // namespace ops
342 }  // namespace tflite
343