• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
16 
17 #include <stdint.h>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
23 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24 #include "tensorflow/lite/kernels/internal/tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace concatenation {
33 
34 // This file has two implementation of Concatenation.
35 enum KernelType {
36   kReference,
37   kGenericOptimized,
38 };
39 
Prepare(TfLiteContext * context,TfLiteNode * node)40 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
41   auto* params =
42       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
43   int axis = params->axis;
44   int num_inputs = node->inputs->size;
45 
46   // The number of dimensions of the input tensors must match, and all
47   // dimensions except 'axis' must be equal.
48   const TfLiteTensor* t0;
49   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
50   TfLiteType input_type = t0->type;
51   if (axis < 0) axis += t0->dims->size;
52   TF_LITE_ENSURE(context, axis >= 0);
53   TF_LITE_ENSURE(context, axis < t0->dims->size);
54 
55   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
56   TF_LITE_ENSURE(context,
57                  input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
58                      input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
59                      input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
60                      input_type == kTfLiteBool);
61 
62   // Output dimensions will match input dimensions, except 'axis', which
63   // will be the sum of inputs
64   int sum_axis = t0->dims->data[axis];
65   for (int i = 1; i < num_inputs; ++i) {
66     const TfLiteTensor* t;
67     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
68     TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
69     TF_LITE_ENSURE_EQ(context, t->type, input_type);
70     for (int d = 0; d < t0->dims->size; ++d) {
71       if (d == axis) {
72         sum_axis += t->dims->data[axis];
73       } else {
74         TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
75       }
76     }
77   }
78 
79   TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
80   for (int d = 0; d < t0->dims->size; ++d) {
81     output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
82   }
83 
84   TfLiteTensor* output;
85   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
86   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
87 
88   if (input_type == kTfLiteInt8) {
89     // Make sure there is no re-scaling needed for Int8 quantized kernel. This
90     // is a restriction we introduced to Int8 kernels.
91     VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
92     for (int i = 0; i < node->inputs->size; ++i) {
93       const TfLiteTensor* t;
94       TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
95       TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
96       TF_LITE_ENSURE_EQ(context, t->params.zero_point,
97                         output->params.zero_point);
98     }
99   }
100 
101   if (input_type == kTfLiteInt16) {
102     // Make sure that all Int16 inputs have a null zero-point.
103     for (int i = 0; i < node->inputs->size; ++i) {
104       const TfLiteTensor* t = GetInput(context, node, i);
105       TF_LITE_ENSURE_EQ(context, t->params.zero_point, 0);
106     }
107     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
108   }
109 
110   return context->ResizeTensor(context, output, output_size);
111 }
112 
113 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)114 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
115   auto* params =
116       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
117   int axis = params->axis;
118   TfLiteTensor* output;
119   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
120   if (axis < 0) axis += output->dims->size;
121 
122 // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
123 // allocate and populate these during Prepare().
124 // TODO(ycling): Activation function parameter is ignored. For now we don't have
125 // a model with a Concatenation with fused activation function.
126 #define TF_LITE_CONCATENATION(scalar)                                         \
127   {                                                                           \
128     VectorOfTensors<scalar> all_inputs(*context, *node->inputs);              \
129     tflite::ConcatenationParams op_params;                                    \
130     op_params.axis = axis;                                                    \
131     op_params.inputs_count = node->inputs->size;                              \
132     if (kernel_type == kReference) {                                          \
133       reference_ops::Concatenation(op_params, all_inputs.shapes(),            \
134                                    all_inputs.data(), GetTensorShape(output), \
135                                    GetTensorData<scalar>(output));            \
136     } else {                                                                  \
137       optimized_ops::Concatenation(op_params, all_inputs.shapes(),            \
138                                    all_inputs.data(), GetTensorShape(output), \
139                                    GetTensorData<scalar>(output));            \
140     }                                                                         \
141   }
142 
143 #define TF_LITE_CONCATENATION_QUANTIZED()                         \
144   {                                                               \
145     VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
146     tflite::ConcatenationParams op_params;                        \
147     op_params.axis = axis;                                        \
148     op_params.input_zeropoint = all_inputs.zero_point();          \
149     op_params.input_scale = all_inputs.scale();                   \
150     op_params.inputs_count = node->inputs->size;                  \
151     op_params.output_zeropoint = output->params.zero_point;       \
152     op_params.output_scale = output->params.scale;                \
153     if (kernel_type == kReference) {                              \
154       reference_ops::ConcatenationWithScaling(                    \
155           op_params, all_inputs.shapes(), all_inputs.data(),      \
156           GetTensorShape(output), GetTensorData<uint8>(output));  \
157     } else {                                                      \
158       optimized_ops::ConcatenationWithScaling(                    \
159           op_params, all_inputs.shapes(), all_inputs.data(),      \
160           GetTensorShape(output), GetTensorData<uint8>(output));  \
161     }                                                             \
162   }
163 
164   switch (output->type) {  // Already know in/outtypes are same.
165     case kTfLiteFloat32:
166       TF_LITE_CONCATENATION(float);
167       break;
168     case kTfLiteInt32:
169       TF_LITE_CONCATENATION(int32);
170       break;
171     case kTfLiteUInt8:
172       TF_LITE_CONCATENATION_QUANTIZED();
173       break;
174     case kTfLiteInt8:
175       TF_LITE_CONCATENATION(int8_t);
176       break;
177     case kTfLiteInt64:
178       TF_LITE_CONCATENATION(int64_t);
179       break;
180     case kTfLiteInt16:
181       TF_LITE_CONCATENATION(int16_t);
182       break;
183     case kTfLiteBool:
184       TF_LITE_CONCATENATION(bool);
185       break;
186     default:
187       context->ReportError(context, "Type '%s' is not supported currently.",
188                            TfLiteTypeGetName(output->type));
189       return kTfLiteError;
190   }
191 
192 #undef TF_LITE_CONCATENATION_QUANTIZED
193 #undef TF_LITE_CONCATENATION
194 
195   return kTfLiteOk;
196 }
197 
198 #undef TF_LITE_MACRO_DISPATCH
199 
200 }  // namespace concatenation
201 
Register_CONCATENATION_REF()202 TfLiteRegistration* Register_CONCATENATION_REF() {
203   static TfLiteRegistration r = {
204       nullptr, nullptr, concatenation::Prepare,
205       concatenation::Eval<concatenation::kReference>};
206   return &r;
207 }
208 
Register_CONCATENATION_GENERIC_OPT()209 TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
210   static TfLiteRegistration r = {
211       nullptr, nullptr, concatenation::Prepare,
212       concatenation::Eval<concatenation::kGenericOptimized>};
213   return &r;
214 }
215 
Register_CONCATENATION()216 TfLiteRegistration* Register_CONCATENATION() {
217   // TODO(ahentz): It turns out the two versions of Concatenation are almost
218   // identical, so we should consider removing one.
219   return Register_CONCATENATION_GENERIC_OPT();
220 }
221 
222 }  // namespace builtin
223 }  // namespace ops
224 }  // namespace tflite
225