1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
16
17 #include <stdint.h>
18
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
23 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24 #include "tensorflow/lite/kernels/internal/tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace concatenation {
33
34 // This file has two implementation of Concatenation.
35 enum KernelType {
36 kReference,
37 kGenericOptimized,
38 };
39
Prepare(TfLiteContext * context,TfLiteNode * node)40 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
41 auto* params =
42 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
43 int axis = params->axis;
44 int num_inputs = node->inputs->size;
45
46 // The number of dimensions of the input tensors must match, and all
47 // dimensions except 'axis' must be equal.
48 const TfLiteTensor* t0;
49 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
50 TfLiteType input_type = t0->type;
51 if (axis < 0) axis += t0->dims->size;
52 TF_LITE_ENSURE(context, axis >= 0);
53 TF_LITE_ENSURE(context, axis < t0->dims->size);
54
55 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
56 TF_LITE_ENSURE(context,
57 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
58 input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
59 input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
60 input_type == kTfLiteBool);
61
62 // Output dimensions will match input dimensions, except 'axis', which
63 // will be the sum of inputs
64 int sum_axis = t0->dims->data[axis];
65 for (int i = 1; i < num_inputs; ++i) {
66 const TfLiteTensor* t;
67 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
68 TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
69 TF_LITE_ENSURE_EQ(context, t->type, input_type);
70 for (int d = 0; d < t0->dims->size; ++d) {
71 if (d == axis) {
72 sum_axis += t->dims->data[axis];
73 } else {
74 TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
75 }
76 }
77 }
78
79 TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
80 for (int d = 0; d < t0->dims->size; ++d) {
81 output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
82 }
83
84 TfLiteTensor* output;
85 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
86 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
87
88 if (input_type == kTfLiteInt8) {
89 // Make sure there is no re-scaling needed for Int8 quantized kernel. This
90 // is a restriction we introduced to Int8 kernels.
91 VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
92 for (int i = 0; i < node->inputs->size; ++i) {
93 const TfLiteTensor* t;
94 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
95 TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
96 TF_LITE_ENSURE_EQ(context, t->params.zero_point,
97 output->params.zero_point);
98 }
99 }
100
101 if (input_type == kTfLiteInt16) {
102 // Make sure that all Int16 inputs have a null zero-point.
103 for (int i = 0; i < node->inputs->size; ++i) {
104 const TfLiteTensor* t = GetInput(context, node, i);
105 TF_LITE_ENSURE_EQ(context, t->params.zero_point, 0);
106 }
107 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
108 }
109
110 return context->ResizeTensor(context, output, output_size);
111 }
112
113 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)114 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
115 auto* params =
116 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
117 int axis = params->axis;
118 TfLiteTensor* output;
119 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
120 if (axis < 0) axis += output->dims->size;
121
122 // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
123 // allocate and populate these during Prepare().
124 // TODO(ycling): Activation function parameter is ignored. For now we don't have
125 // a model with a Concatenation with fused activation function.
126 #define TF_LITE_CONCATENATION(scalar) \
127 { \
128 VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
129 tflite::ConcatenationParams op_params; \
130 op_params.axis = axis; \
131 op_params.inputs_count = node->inputs->size; \
132 if (kernel_type == kReference) { \
133 reference_ops::Concatenation(op_params, all_inputs.shapes(), \
134 all_inputs.data(), GetTensorShape(output), \
135 GetTensorData<scalar>(output)); \
136 } else { \
137 optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
138 all_inputs.data(), GetTensorShape(output), \
139 GetTensorData<scalar>(output)); \
140 } \
141 }
142
143 #define TF_LITE_CONCATENATION_QUANTIZED() \
144 { \
145 VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
146 tflite::ConcatenationParams op_params; \
147 op_params.axis = axis; \
148 op_params.input_zeropoint = all_inputs.zero_point(); \
149 op_params.input_scale = all_inputs.scale(); \
150 op_params.inputs_count = node->inputs->size; \
151 op_params.output_zeropoint = output->params.zero_point; \
152 op_params.output_scale = output->params.scale; \
153 if (kernel_type == kReference) { \
154 reference_ops::ConcatenationWithScaling( \
155 op_params, all_inputs.shapes(), all_inputs.data(), \
156 GetTensorShape(output), GetTensorData<uint8>(output)); \
157 } else { \
158 optimized_ops::ConcatenationWithScaling( \
159 op_params, all_inputs.shapes(), all_inputs.data(), \
160 GetTensorShape(output), GetTensorData<uint8>(output)); \
161 } \
162 }
163
164 switch (output->type) { // Already know in/outtypes are same.
165 case kTfLiteFloat32:
166 TF_LITE_CONCATENATION(float);
167 break;
168 case kTfLiteInt32:
169 TF_LITE_CONCATENATION(int32);
170 break;
171 case kTfLiteUInt8:
172 TF_LITE_CONCATENATION_QUANTIZED();
173 break;
174 case kTfLiteInt8:
175 TF_LITE_CONCATENATION(int8_t);
176 break;
177 case kTfLiteInt64:
178 TF_LITE_CONCATENATION(int64_t);
179 break;
180 case kTfLiteInt16:
181 TF_LITE_CONCATENATION(int16_t);
182 break;
183 case kTfLiteBool:
184 TF_LITE_CONCATENATION(bool);
185 break;
186 default:
187 context->ReportError(context, "Type '%s' is not supported currently.",
188 TfLiteTypeGetName(output->type));
189 return kTfLiteError;
190 }
191
192 #undef TF_LITE_CONCATENATION_QUANTIZED
193 #undef TF_LITE_CONCATENATION
194
195 return kTfLiteOk;
196 }
197
198 #undef TF_LITE_MACRO_DISPATCH
199
200 } // namespace concatenation
201
Register_CONCATENATION_REF()202 TfLiteRegistration* Register_CONCATENATION_REF() {
203 static TfLiteRegistration r = {
204 nullptr, nullptr, concatenation::Prepare,
205 concatenation::Eval<concatenation::kReference>};
206 return &r;
207 }
208
Register_CONCATENATION_GENERIC_OPT()209 TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
210 static TfLiteRegistration r = {
211 nullptr, nullptr, concatenation::Prepare,
212 concatenation::Eval<concatenation::kGenericOptimized>};
213 return &r;
214 }
215
Register_CONCATENATION()216 TfLiteRegistration* Register_CONCATENATION() {
217 // TODO(ahentz): It turns out the two versions of Concatenation are almost
218 // identical, so we should consider removing one.
219 return Register_CONCATENATION_GENERIC_OPT();
220 }
221
222 } // namespace builtin
223 } // namespace ops
224 } // namespace tflite
225