1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cassert>
16 #include <cmath>
17 #include <cstdio>
18 #include <cstdlib>
19 #include <iostream>
20 #include <limits>
21
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/c_api_internal.h"
24 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/kernels/op_macros.h"
29
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace concatenation {
34
35 // This file has two implementation of Concatenation.
36 enum KernelType {
37 kReference,
38 kGenericOptimized,
39 };
40
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 auto* params =
43 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
44 int axis = params->axis;
45 int num_inputs = node->inputs->size;
46
47 // The number of dimensions of the input tensors must match, and all
48 // dimensions except 'axis' must be equal.
49 TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]];
50 TfLiteType input_type = t0->type;
51 if (axis < 0) axis += t0->dims->size;
52 TF_LITE_ENSURE(context, axis >= 0);
53 TF_LITE_ENSURE(context, axis < t0->dims->size);
54
55 // TODO(ahentz): These are limitations of our implementation that could be
56 // removed with a bit of effort.
57 TF_LITE_ENSURE(context, t0->dims->size <= 4);
58 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
59 TF_LITE_ENSURE(context,
60 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
61 input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
62 input_type == kTfLiteInt32 || input_type == kTfLiteInt64);
63
64 // Output dimensions will match input dimensions, except 'axis', which
65 // will be the sum of inputs
66 int sum_axis = t0->dims->data[axis];
67 for (int i = 1; i < num_inputs; ++i) {
68 TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
69 TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
70 TF_LITE_ENSURE_EQ(context, t->type, input_type);
71 for (int d = 0; d < t0->dims->size; ++d) {
72 if (d == axis) {
73 sum_axis += t->dims->data[axis];
74 } else {
75 TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
76 }
77 }
78 }
79
80 TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
81 for (int d = 0; d < t0->dims->size; ++d) {
82 output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
83 }
84
85 TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
86 TF_LITE_ENSURE_EQ(context, output->type, input_type);
87
88 if (input_type == kTfLiteInt8) {
89 // Make sure there is no re-scaling needed for Int8 quantized kernel. This
90 // is a restriction we introduced to Int8 kernels.
91 VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
92 for (int i = 0; i < node->inputs->size; ++i) {
93 TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
94 TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
95 TF_LITE_ENSURE_EQ(context, t->params.zero_point,
96 output->params.zero_point);
97 }
98 }
99
100 return context->ResizeTensor(context, output, output_size);
101 }
102
103 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)104 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
105 auto* params =
106 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
107 int axis = params->axis;
108 TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
109 if (axis < 0) axis += output->dims->size;
110
111 // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
112 // allocate and populate these during Prepare().
113 // TODO(ycling): Activation function parameter is ignored. For now we dont have
114 // a model with a Concatenation with fused activation function.
115 #define TF_LITE_CONCATENATION(type, scalar) \
116 { \
117 VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
118 tflite::ConcatenationParams op_params; \
119 op_params.axis = axis; \
120 op_params.inputs_count = node->inputs->size; \
121 type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
122 GetTensorShape(output), \
123 GetTensorData<scalar>(output)); \
124 }
125
126 #define TF_LITE_CONCATENATION_QUANTIZED(type) \
127 { \
128 VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
129 tflite::ConcatenationParams op_params; \
130 op_params.axis = axis; \
131 op_params.input_zeropoint = all_inputs.zero_point(); \
132 op_params.input_scale = all_inputs.scale(); \
133 op_params.inputs_count = node->inputs->size; \
134 op_params.output_zeropoint = output->params.zero_point; \
135 op_params.output_scale = output->params.scale; \
136 type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
137 all_inputs.data(), GetTensorShape(output), \
138 GetTensorData<uint8>(output)); \
139 }
140
141 switch (output->type) { // Already know in/outtypes are same.
142 case kTfLiteFloat32:
143 if (kernel_type == kReference) {
144 TF_LITE_CONCATENATION(reference_ops, float);
145 } else {
146 TF_LITE_CONCATENATION(optimized_ops, float);
147 }
148 break;
149 case kTfLiteInt32:
150 if (kernel_type == kReference) {
151 TF_LITE_CONCATENATION(reference_ops, int32);
152 } else {
153 TF_LITE_CONCATENATION(optimized_ops, int32);
154 }
155 break;
156 case kTfLiteUInt8:
157 if (kernel_type == kReference) {
158 TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
159 } else {
160 TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
161 }
162 break;
163 case kTfLiteInt8: {
164 if (kernel_type == kReference) {
165 TF_LITE_CONCATENATION(reference_ops, int8_t);
166 } else {
167 TF_LITE_CONCATENATION(optimized_ops, int8_t);
168 }
169 } break;
170 case kTfLiteInt64:
171 if (kernel_type == kReference) {
172 TF_LITE_CONCATENATION(reference_ops, int64_t);
173 } else {
174 TF_LITE_CONCATENATION(optimized_ops, int64_t);
175 }
176 break;
177
178 default:
179 context->ReportError(context,
180 "Only float32 and uint8 are currently supported.");
181 return kTfLiteError;
182 }
183
184 #undef TF_LITE_CONCATENATION_QUANTIZED
185 #undef TF_LITE_CONCATENATION
186
187 return kTfLiteOk;
188 }
189
190 #undef TF_LITE_MACRO_DISPATCH
191
192 } // namespace concatenation
193
Register_CONCATENATION_REF()194 TfLiteRegistration* Register_CONCATENATION_REF() {
195 static TfLiteRegistration r = {
196 nullptr, nullptr, concatenation::Prepare,
197 concatenation::Eval<concatenation::kReference>};
198 return &r;
199 }
200
Register_CONCATENATION_GENERIC_OPT()201 TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
202 static TfLiteRegistration r = {
203 nullptr, nullptr, concatenation::Prepare,
204 concatenation::Eval<concatenation::kGenericOptimized>};
205 return &r;
206 }
207
Register_CONCATENATION()208 TfLiteRegistration* Register_CONCATENATION() {
209 // TODO(ahentz): It turns out the two versions of Concatenation are almost
210 // identical, so we should consider removing one.
211 return Register_CONCATENATION_GENERIC_OPT();
212 }
213
214 } // namespace builtin
215 } // namespace ops
216 } // namespace tflite
217