1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace split_v {
28
29 struct OpContext {
OpContexttflite::ops::builtin::split_v::OpContext30 OpContext(TfLiteContext* context, TfLiteNode* node) {
31 params = reinterpret_cast<TfLiteSplitVParams*>(node->builtin_data);
32 input = GetInput(context, node, 0);
33 size_splits = GetInput(context, node, 1);
34 axis = GetInput(context, node, 2);
35 }
36 TfLiteSplitVParams* params;
37 const TfLiteTensor* input;
38 const TfLiteTensor* size_splits;
39 const TfLiteTensor* axis;
40 };
41
UseDynamicOutputTensors(TfLiteContext * context,TfLiteNode * node)42 TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
43 for (int i = 0; i < NumOutputs(node); ++i) {
44 SetTensorToDynamic(GetOutput(context, node, i));
45 }
46 return kTfLiteOk;
47 }
48
49 template <typename T>
GetSizeSplitsVector(const TfLiteTensor * size_splits,std::vector<int64_t> * size_splits_vector)50 void GetSizeSplitsVector(const TfLiteTensor* size_splits,
51 std::vector<int64_t>* size_splits_vector) {
52 const auto num_elements = NumElements(size_splits);
53 for (int i = 0; i < num_elements; ++i) {
54 size_splits_vector->push_back(GetTensorData<T>(size_splits)[i]);
55 }
56 }
57
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * input,const TfLiteTensor * size_splits,const TfLiteTensor * axis)58 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
59 const TfLiteTensor* input,
60 const TfLiteTensor* size_splits,
61 const TfLiteTensor* axis) {
62 int axis_value = GetTensorData<int>(axis)[0];
63 if (axis_value < 0) {
64 axis_value += NumDimensions(input);
65 }
66
67 std::vector<int64_t> size_splits_vector;
68 if (size_splits->type == kTfLiteInt32) {
69 GetSizeSplitsVector<int32_t>(size_splits, &size_splits_vector);
70 } else if (size_splits->type == kTfLiteInt64) {
71 GetSizeSplitsVector<int64_t>(size_splits, &size_splits_vector);
72 } else {
73 context->ReportError(context, "size_splits only support type int32|int64.");
74 return kTfLiteError;
75 }
76
77 int minus_one_index = -1;
78 int64_t size_splits_sum = 0;
79
80 for (int i = 0; i < size_splits_vector.size(); ++i) {
81 if (size_splits_vector.at(i) == -1) {
82 if (minus_one_index == -1) {
83 minus_one_index = i;
84 } else {
85 context->ReportError(context,
86 "The size_splits contains more than one -1.");
87 }
88 } else {
89 size_splits_sum += size_splits_vector.at(i);
90 }
91 }
92
93 const int input_size = SizeOfDimension(input, axis_value);
94
95 if (minus_one_index != -1) {
96 if (size_splits_sum > input_size) {
97 context->ReportError(
98 context,
99 "The sum of size_splits must be less than the dimension of value.");
100 } else {
101 size_splits_vector[minus_one_index] = input_size - size_splits_sum;
102 }
103 } else if (size_splits_sum != input_size) {
104 context->ReportError(
105 context,
106 "The size_splits must sum to the dimension of value along axis.");
107 }
108
109 for (int i = 0; i < NumOutputs(node); ++i) {
110 TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
111 output_dims->data[axis_value] = size_splits_vector.at(i);
112 TfLiteTensor* output = GetOutput(context, node, i);
113 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
114 }
115
116 return kTfLiteOk;
117 }
118
Prepare(TfLiteContext * context,TfLiteNode * node)119 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
120 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
121
122 OpContext op_context(context, node);
123
124 TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
125
126 auto input_type = op_context.input->type;
127 TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
128 input_type == kTfLiteUInt8 ||
129 input_type == kTfLiteInt16);
130 for (int i = 0; i < NumOutputs(node); ++i) {
131 GetOutput(context, node, i)->type = input_type;
132 }
133
134 auto size_splits = op_context.size_splits;
135 TF_LITE_ENSURE_EQ(context, NumDimensions(size_splits), 1);
136 TF_LITE_ENSURE_EQ(context, NumOutputs(node), NumElements(size_splits));
137
138 // If we know the contents of the 'size_splits' tensor and the 'axis' tensor,
139 // resize all outputs. Otherwise, wait until Eval().
140 if (IsConstantTensor(op_context.size_splits) &&
141 IsConstantTensor(op_context.axis)) {
142 return ResizeOutputTensors(context, node, op_context.input,
143 op_context.size_splits, op_context.axis);
144 } else {
145 return UseDynamicOutputTensors(context, node);
146 }
147 }
148
Eval(TfLiteContext * context,TfLiteNode * node)149 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
150 OpContext op_context(context, node);
151
152 // When the 'size_splits' and the 'axis' tensor is non-const we can't resize
153 // output tensors in Prepare(), and we have to do it now.
154 if (!IsConstantTensor(op_context.axis) ||
155 !IsConstantTensor(op_context.size_splits)) {
156 TF_LITE_ENSURE_OK(
157 context, ResizeOutputTensors(context, node, op_context.input,
158 op_context.size_splits, op_context.axis));
159 }
160
161 int axis_value = GetTensorData<int>(op_context.axis)[0];
162
163 // Use split function to build the outputs since they share the same logic.
164 #define TF_LITE_SPLIT_V(scalar) \
165 VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
166 tflite::SplitParams op_params; \
167 op_params.num_split = NumOutputs(node); \
168 op_params.axis = axis_value; \
169 reference_ops::Split(op_params, GetTensorShape(op_context.input), \
170 GetTensorData<scalar>(op_context.input), \
171 all_outputs.shapes(), all_outputs.data());
172 switch (op_context.input->type) {
173 case kTfLiteFloat32: {
174 TF_LITE_SPLIT_V(float);
175 break;
176 }
177 case kTfLiteUInt8: {
178 TF_LITE_SPLIT_V(uint8_t);
179 break;
180 }
181 case kTfLiteInt16: {
182 TF_LITE_SPLIT_V(int16_t);
183 break;
184 }
185 default:
186 context->ReportError(context, "Type %s currently not supported.",
187 TfLiteTypeGetName(op_context.input->type));
188 return kTfLiteError;
189 }
190 #undef TF_LITE_SPLIT_V
191
192 return kTfLiteOk;
193 }
194
195 } // namespace split_v
196
Register_SPLIT_V()197 TfLiteRegistration* Register_SPLIT_V() {
198 static TfLiteRegistration r = {nullptr, nullptr, split_v::Prepare,
199 split_v::Eval};
200 return &r;
201 }
202
203 } // namespace builtin
204 } // namespace ops
205 } // namespace tflite
206