1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
16
17 #include <cstdint>
18
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace concatenation {
31
32 constexpr int kMaxInputNum = 10; // Maximum number of input tensors
33 constexpr int kOutputTensor = 0;
34
35 struct OpData {
36 ConcatenationParams params;
37 };
38
39 // Handles negative axis index, coerces to positive index value.
CalculatePositiveAxis(int axis,const TfLiteTensor * output_tensor)40 inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
41 if (axis >= 0) {
42 return axis;
43 } else {
44 return NumDimensions(output_tensor) + axis;
45 }
46 }
47
48 // The following functions are helpers to get tensor data in the format that the
49 // reference op implementation expects. They provide the same functionality as
50 // class VectorOfTensors and class VectorOfQuantizedTensors in TFLite.
51
52 // Gets shapes from a list of tensors.
GetAllInputTensorShapes(const TfLiteContext * context,const TfLiteNode * node,RuntimeShape all_shapes[kMaxInputNum])53 inline void GetAllInputTensorShapes(const TfLiteContext* context,
54 const TfLiteNode* node,
55 RuntimeShape all_shapes[kMaxInputNum]) {
56 TFLITE_DCHECK(context != nullptr);
57 TFLITE_DCHECK(node != nullptr);
58 for (int i = 0; i < node->inputs->size; ++i) {
59 const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
60 RuntimeShape shape = tflite::micro::GetTensorShape(t);
61 all_shapes[i].ReplaceWith(shape.DimensionsCount(), shape.DimsData());
62 }
63 }
64
65 // Get shape pointers from a list of shapes.
GetShapesPointers(const RuntimeShape * shapes,size_t num,const RuntimeShape * pointers[])66 inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
67 const RuntimeShape* pointers[]) {
68 for (size_t i = 0; i < num; ++i) {
69 pointers[i] = &shapes[i];
70 }
71 }
72
73 // Gets data pointers from a list of tensors.
74 template <typename T>
GetAllInputTensorData(const TfLiteContext * context,const TfLiteNode * node,T * all_data[kMaxInputNum])75 inline void GetAllInputTensorData(const TfLiteContext* context,
76 const TfLiteNode* node,
77 T* all_data[kMaxInputNum]) {
78 TFLITE_DCHECK(context != nullptr);
79 TFLITE_DCHECK(node != nullptr);
80 for (int i = 0; i < node->inputs->size; ++i) {
81 const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
82 all_data[i] = tflite::micro::GetTensorData<T>(t);
83 }
84 }
85
86 template <typename data_type>
EvalUnquantized(TfLiteContext * context,TfLiteNode * node)87 void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
88 // Collect the shapes and data pointer of input tensors
89 RuntimeShape inputs_shape[kMaxInputNum];
90 const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
91 const data_type* inputs_data[kMaxInputNum];
92 GetAllInputTensorShapes(context, node, inputs_shape);
93 GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
94 GetAllInputTensorData(context, node, inputs_data);
95
96 TfLiteEvalTensor* output =
97 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
98
99 TFLITE_DCHECK(node->user_data != nullptr);
100 const OpData* data = static_cast<const OpData*>(node->user_data);
101
102 reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
103 tflite::micro::GetTensorShape(output),
104 tflite::micro::GetTensorData<data_type>(output));
105 }
106
EvalQuantizedUInt8(TfLiteContext * context,TfLiteNode * node)107 void EvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node) {
108 // Collect the shapes and data pointer of input tensors
109 RuntimeShape inputs_shape[kMaxInputNum];
110 const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
111 const uint8_t* inputs_data[kMaxInputNum];
112 GetAllInputTensorShapes(context, node, inputs_shape);
113 GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
114 GetAllInputTensorData(context, node, inputs_data);
115
116 TfLiteEvalTensor* output =
117 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
118
119 TFLITE_DCHECK(node->user_data != nullptr);
120 const OpData* data = static_cast<const OpData*>(node->user_data);
121
122 reference_ops::ConcatenationWithScaling(
123 data->params, inputs_shape_ptr, inputs_data,
124 tflite::micro::GetTensorShape(output),
125 tflite::micro::GetTensorData<uint8_t>(output));
126 }
127
Init(TfLiteContext * context,const char * buffer,size_t length)128 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
129 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
130 return context->AllocatePersistentBuffer(context, sizeof(OpData));
131 }
132
Prepare(TfLiteContext * context,TfLiteNode * node)133 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
134 // This function only checks the types. Additional shape validations are
135 // performed in the reference implementation called during Eval().
136 const TfLiteConcatenationParams* params =
137 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
138
139 const TfLiteTensor* input_tensor = GetInput(context, node, 0);
140 TF_LITE_ENSURE(context, input_tensor != nullptr);
141 TfLiteType input_type = input_tensor->type;
142 const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
143 TF_LITE_ENSURE(context, output_tensor != nullptr);
144 TfLiteType output_type = output_tensor->type;
145
146 // Check activation and input type
147 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
148 TF_LITE_ENSURE(context,
149 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
150 input_type == kTfLiteInt8 || input_type == kTfLiteInt32 ||
151 input_type == kTfLiteInt64);
152
153 // Output type must match input type
154 TF_LITE_ENSURE_EQ(context, output_type, input_type);
155
156 // This implementation does not support large number of input tensors
157 const int num_inputs = NumInputs(node);
158 TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);
159
160 // Shapes with dimensions >4 are not yet supported with static allocation.
161 for (int i = 0; i < num_inputs; ++i) {
162 const TfLiteTensor* input = GetInput(context, node, i);
163 TF_LITE_ENSURE(context, input != nullptr);
164 int num_dimensions = NumDimensions(input);
165
166 if (num_dimensions > 4) {
167 TF_LITE_KERNEL_LOG(
168 context,
169 "Op Concatenation does not currently support num dimensions >4 "
170 "Tensor has %d dimensions.",
171 num_dimensions);
172 return kTfLiteError;
173 }
174 }
175
176 // Calculate OpData.
177 TFLITE_DCHECK(node->user_data != nullptr);
178 OpData* data = static_cast<OpData*>(node->user_data);
179
180 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
181 TF_LITE_ENSURE(context, output != nullptr);
182
183 switch (output_type) { // Already know in/outtypes are same.
184 case kTfLiteFloat32:
185 case kTfLiteInt32:
186 case kTfLiteInt64: {
187 data->params.axis = CalculatePositiveAxis(params->axis, output);
188 data->params.inputs_count = node->inputs->size;
189 break;
190 }
191 case kTfLiteUInt8:
192 case kTfLiteInt8: {
193 data->params.axis = CalculatePositiveAxis(params->axis, output);
194 data->params.inputs_count = node->inputs->size;
195
196 float* input_scales =
197 reinterpret_cast<float*>(context->AllocatePersistentBuffer(
198 context, node->inputs->size * sizeof(float)));
199
200 int32_t* input_zero_points =
201 reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
202 context, node->inputs->size * sizeof(int32_t)));
203
204 // Allocate persistent scale and zeropoint buffers.
205 // Store input scale and zero point values in OpParams:
206 for (int i = 0; i < node->inputs->size; ++i) {
207 const TfLiteTensor* t = GetInput(context, node, i);
208 TF_LITE_ENSURE(context, t != nullptr);
209 input_scales[i] = t->params.scale;
210 input_zero_points[i] = t->params.zero_point;
211 }
212
213 data->params.input_scale = input_scales;
214 data->params.input_zeropoint = input_zero_points;
215 data->params.output_zeropoint = output->params.zero_point;
216 data->params.output_scale = output->params.scale;
217 break;
218 }
219 default:
220 TF_LITE_KERNEL_LOG(
221 context, "Op Concatenation does not currently support Type '%s'.",
222 TfLiteTypeGetName(output_type));
223 return kTfLiteError;
224 }
225
226 return kTfLiteOk;
227 }
228
Eval(TfLiteContext * context,TfLiteNode * node)229 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
230 const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
231 TF_LITE_ENSURE(context, output_tensor != nullptr);
232 TfLiteType output_type = output_tensor->type;
233
234 switch (output_type) { // Already know in/outtypes are same.
235 case kTfLiteFloat32:
236 EvalUnquantized<float>(context, node);
237 break;
238 case kTfLiteInt32:
239 EvalUnquantized<int32_t>(context, node);
240 break;
241 case kTfLiteUInt8:
242 EvalQuantizedUInt8(context, node);
243 break;
244 case kTfLiteInt8:
245 EvalUnquantized<int8_t>(context, node);
246 break;
247 case kTfLiteInt64:
248 EvalUnquantized<int64_t>(context, node);
249 break;
250
251 default:
252 TF_LITE_KERNEL_LOG(
253 context, "Op Concatenation does not currently support Type '%s'.",
254 TfLiteTypeGetName(output_type));
255 return kTfLiteError;
256 }
257
258 return kTfLiteOk;
259 }
260
261 } // namespace concatenation
262
Register_CONCATENATION()263 TfLiteRegistration Register_CONCATENATION() {
264 return {/*init=*/concatenation::Init,
265 /*free=*/nullptr,
266 /*prepare=*/concatenation::Prepare,
267 /*invoke=*/concatenation::Eval,
268 /*profiling_string=*/nullptr,
269 /*builtin_code=*/0,
270 /*custom_name=*/nullptr,
271 /*version=*/0};
272 }
273
274 } // namespace micro
275 } // namespace ops
276 } // namespace tflite
277