• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
16 
17 #include <cstdint>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace concatenation {
31 
32 constexpr int kMaxInputNum = 10;  // Maximum number of input tensors
33 constexpr int kOutputTensor = 0;
34 
35 struct OpData {
36   ConcatenationParams params;
37 };
38 
39 // Handles negative axis index, coerces to positive index value.
CalculatePositiveAxis(int axis,const TfLiteTensor * output_tensor)40 inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
41   if (axis >= 0) {
42     return axis;
43   } else {
44     return NumDimensions(output_tensor) + axis;
45   }
46 }
47 
48 // The following functions are helpers to get tensor data in the format that the
49 // reference op implementation expects. They provide the same functionality as
50 // class VectorOfTensors and class VectorOfQuantizedTensors in TFLite.
51 
52 // Gets shapes from a list of tensors.
GetAllInputTensorShapes(const TfLiteContext * context,const TfLiteNode * node,RuntimeShape all_shapes[kMaxInputNum])53 inline void GetAllInputTensorShapes(const TfLiteContext* context,
54                                     const TfLiteNode* node,
55                                     RuntimeShape all_shapes[kMaxInputNum]) {
56   TFLITE_DCHECK(context != nullptr);
57   TFLITE_DCHECK(node != nullptr);
58   for (int i = 0; i < node->inputs->size; ++i) {
59     const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
60     RuntimeShape shape = tflite::micro::GetTensorShape(t);
61     all_shapes[i].ReplaceWith(shape.DimensionsCount(), shape.DimsData());
62   }
63 }
64 
65 // Get shape pointers from a list of shapes.
GetShapesPointers(const RuntimeShape * shapes,size_t num,const RuntimeShape * pointers[])66 inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
67                               const RuntimeShape* pointers[]) {
68   for (size_t i = 0; i < num; ++i) {
69     pointers[i] = &shapes[i];
70   }
71 }
72 
73 // Gets data pointers from a list of tensors.
74 template <typename T>
GetAllInputTensorData(const TfLiteContext * context,const TfLiteNode * node,T * all_data[kMaxInputNum])75 inline void GetAllInputTensorData(const TfLiteContext* context,
76                                   const TfLiteNode* node,
77                                   T* all_data[kMaxInputNum]) {
78   TFLITE_DCHECK(context != nullptr);
79   TFLITE_DCHECK(node != nullptr);
80   for (int i = 0; i < node->inputs->size; ++i) {
81     const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
82     all_data[i] = tflite::micro::GetTensorData<T>(t);
83   }
84 }
85 
86 template <typename data_type>
EvalUnquantized(TfLiteContext * context,TfLiteNode * node)87 void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
88   // Collect the shapes and data pointer of input tensors
89   RuntimeShape inputs_shape[kMaxInputNum];
90   const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
91   const data_type* inputs_data[kMaxInputNum];
92   GetAllInputTensorShapes(context, node, inputs_shape);
93   GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
94   GetAllInputTensorData(context, node, inputs_data);
95 
96   TfLiteEvalTensor* output =
97       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
98 
99   TFLITE_DCHECK(node->user_data != nullptr);
100   const OpData* data = static_cast<const OpData*>(node->user_data);
101 
102   reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
103                                tflite::micro::GetTensorShape(output),
104                                tflite::micro::GetTensorData<data_type>(output));
105 }
106 
EvalQuantizedUInt8(TfLiteContext * context,TfLiteNode * node)107 void EvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node) {
108   // Collect the shapes and data pointer of input tensors
109   RuntimeShape inputs_shape[kMaxInputNum];
110   const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
111   const uint8_t* inputs_data[kMaxInputNum];
112   GetAllInputTensorShapes(context, node, inputs_shape);
113   GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
114   GetAllInputTensorData(context, node, inputs_data);
115 
116   TfLiteEvalTensor* output =
117       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
118 
119   TFLITE_DCHECK(node->user_data != nullptr);
120   const OpData* data = static_cast<const OpData*>(node->user_data);
121 
122   reference_ops::ConcatenationWithScaling(
123       data->params, inputs_shape_ptr, inputs_data,
124       tflite::micro::GetTensorShape(output),
125       tflite::micro::GetTensorData<uint8_t>(output));
126 }
127 
Init(TfLiteContext * context,const char * buffer,size_t length)128 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
129   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
130   return context->AllocatePersistentBuffer(context, sizeof(OpData));
131 }
132 
Prepare(TfLiteContext * context,TfLiteNode * node)133 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
134   // This function only checks the types. Additional shape validations are
135   // performed in the reference implementation called during Eval().
136   const TfLiteConcatenationParams* params =
137       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
138 
139   const TfLiteTensor* input_tensor = GetInput(context, node, 0);
140   TF_LITE_ENSURE(context, input_tensor != nullptr);
141   TfLiteType input_type = input_tensor->type;
142   const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
143   TF_LITE_ENSURE(context, output_tensor != nullptr);
144   TfLiteType output_type = output_tensor->type;
145 
146   // Check activation and input type
147   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
148   TF_LITE_ENSURE(context,
149                  input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
150                      input_type == kTfLiteInt8 || input_type == kTfLiteInt32 ||
151                      input_type == kTfLiteInt64);
152 
153   // Output type must match input type
154   TF_LITE_ENSURE_EQ(context, output_type, input_type);
155 
156   // This implementation does not support large number of input tensors
157   const int num_inputs = NumInputs(node);
158   TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);
159 
160   // Shapes with dimensions >4 are not yet supported with static allocation.
161   for (int i = 0; i < num_inputs; ++i) {
162     const TfLiteTensor* input = GetInput(context, node, i);
163     TF_LITE_ENSURE(context, input != nullptr);
164     int num_dimensions = NumDimensions(input);
165 
166     if (num_dimensions > 4) {
167       TF_LITE_KERNEL_LOG(
168           context,
169           "Op Concatenation does not currently support num dimensions >4 "
170           "Tensor has %d dimensions.",
171           num_dimensions);
172       return kTfLiteError;
173     }
174   }
175 
176   // Calculate OpData.
177   TFLITE_DCHECK(node->user_data != nullptr);
178   OpData* data = static_cast<OpData*>(node->user_data);
179 
180   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
181   TF_LITE_ENSURE(context, output != nullptr);
182 
183   switch (output_type) {  // Already know in/outtypes are same.
184     case kTfLiteFloat32:
185     case kTfLiteInt32:
186     case kTfLiteInt64: {
187       data->params.axis = CalculatePositiveAxis(params->axis, output);
188       data->params.inputs_count = node->inputs->size;
189       break;
190     }
191     case kTfLiteUInt8:
192     case kTfLiteInt8: {
193       data->params.axis = CalculatePositiveAxis(params->axis, output);
194       data->params.inputs_count = node->inputs->size;
195 
196       float* input_scales =
197           reinterpret_cast<float*>(context->AllocatePersistentBuffer(
198               context, node->inputs->size * sizeof(float)));
199 
200       int32_t* input_zero_points =
201           reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
202               context, node->inputs->size * sizeof(int32_t)));
203 
204       // Allocate persistent scale and zeropoint buffers.
205       // Store input scale and zero point values in OpParams:
206       for (int i = 0; i < node->inputs->size; ++i) {
207         const TfLiteTensor* t = GetInput(context, node, i);
208         TF_LITE_ENSURE(context, t != nullptr);
209         input_scales[i] = t->params.scale;
210         input_zero_points[i] = t->params.zero_point;
211       }
212 
213       data->params.input_scale = input_scales;
214       data->params.input_zeropoint = input_zero_points;
215       data->params.output_zeropoint = output->params.zero_point;
216       data->params.output_scale = output->params.scale;
217       break;
218     }
219     default:
220       TF_LITE_KERNEL_LOG(
221           context, "Op Concatenation does not currently support Type '%s'.",
222           TfLiteTypeGetName(output_type));
223       return kTfLiteError;
224   }
225 
226   return kTfLiteOk;
227 }
228 
Eval(TfLiteContext * context,TfLiteNode * node)229 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
230   const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
231   TF_LITE_ENSURE(context, output_tensor != nullptr);
232   TfLiteType output_type = output_tensor->type;
233 
234   switch (output_type) {  // Already know in/outtypes are same.
235     case kTfLiteFloat32:
236       EvalUnquantized<float>(context, node);
237       break;
238     case kTfLiteInt32:
239       EvalUnquantized<int32_t>(context, node);
240       break;
241     case kTfLiteUInt8:
242       EvalQuantizedUInt8(context, node);
243       break;
244     case kTfLiteInt8:
245       EvalUnquantized<int8_t>(context, node);
246       break;
247     case kTfLiteInt64:
248       EvalUnquantized<int64_t>(context, node);
249       break;
250 
251     default:
252       TF_LITE_KERNEL_LOG(
253           context, "Op Concatenation does not currently support Type '%s'.",
254           TfLiteTypeGetName(output_type));
255       return kTfLiteError;
256   }
257 
258   return kTfLiteOk;
259 }
260 
261 }  // namespace concatenation
262 
Register_CONCATENATION()263 TfLiteRegistration Register_CONCATENATION() {
264   return {/*init=*/concatenation::Init,
265           /*free=*/nullptr,
266           /*prepare=*/concatenation::Prepare,
267           /*invoke=*/concatenation::Eval,
268           /*profiling_string=*/nullptr,
269           /*builtin_code=*/0,
270           /*custom_name=*/nullptr,
271           /*version=*/0};
272 }
273 
274 }  // namespace micro
275 }  // namespace ops
276 }  // namespace tflite
277