• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace split_v {
28 
29 struct OpContext {
OpContexttflite::ops::builtin::split_v::OpContext30   OpContext(TfLiteContext* context, TfLiteNode* node) {
31     params = reinterpret_cast<TfLiteSplitVParams*>(node->builtin_data);
32     input = GetInput(context, node, 0);
33     size_splits = GetInput(context, node, 1);
34     axis = GetInput(context, node, 2);
35   }
36   TfLiteSplitVParams* params;
37   const TfLiteTensor* input;
38   const TfLiteTensor* size_splits;
39   const TfLiteTensor* axis;
40 };
41 
UseDynamicOutputTensors(TfLiteContext * context,TfLiteNode * node)42 TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
43   for (int i = 0; i < NumOutputs(node); ++i) {
44     SetTensorToDynamic(GetOutput(context, node, i));
45   }
46   return kTfLiteOk;
47 }
48 
49 template <typename T>
GetSizeSplitsVector(const TfLiteTensor * size_splits,std::vector<int64_t> * size_splits_vector)50 void GetSizeSplitsVector(const TfLiteTensor* size_splits,
51                          std::vector<int64_t>* size_splits_vector) {
52   const auto num_elements = NumElements(size_splits);
53   for (int i = 0; i < num_elements; ++i) {
54     size_splits_vector->push_back(GetTensorData<T>(size_splits)[i]);
55   }
56 }
57 
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * input,const TfLiteTensor * size_splits,const TfLiteTensor * axis)58 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
59                                  const TfLiteTensor* input,
60                                  const TfLiteTensor* size_splits,
61                                  const TfLiteTensor* axis) {
62   int axis_value = GetTensorData<int>(axis)[0];
63   if (axis_value < 0) {
64     axis_value += NumDimensions(input);
65   }
66 
67   std::vector<int64_t> size_splits_vector;
68   if (size_splits->type == kTfLiteInt32) {
69     GetSizeSplitsVector<int32_t>(size_splits, &size_splits_vector);
70   } else if (size_splits->type == kTfLiteInt64) {
71     GetSizeSplitsVector<int64_t>(size_splits, &size_splits_vector);
72   } else {
73     context->ReportError(context, "size_splits only support type int32|int64.");
74     return kTfLiteError;
75   }
76 
77   int minus_one_index = -1;
78   int64_t size_splits_sum = 0;
79 
80   for (int i = 0; i < size_splits_vector.size(); ++i) {
81     if (size_splits_vector.at(i) == -1) {
82       if (minus_one_index == -1) {
83         minus_one_index = i;
84       } else {
85         context->ReportError(context,
86                              "The size_splits contains more than one -1.");
87       }
88     } else {
89       size_splits_sum += size_splits_vector.at(i);
90     }
91   }
92 
93   const int input_size = SizeOfDimension(input, axis_value);
94 
95   if (minus_one_index != -1) {
96     if (size_splits_sum > input_size) {
97       context->ReportError(
98           context,
99           "The sum of size_splits must be less than the dimension of value.");
100     } else {
101       size_splits_vector[minus_one_index] = input_size - size_splits_sum;
102     }
103   } else if (size_splits_sum != input_size) {
104     context->ReportError(
105         context,
106         "The size_splits must sum to the dimension of value along axis.");
107   }
108 
109   for (int i = 0; i < NumOutputs(node); ++i) {
110     TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
111     output_dims->data[axis_value] = size_splits_vector.at(i);
112     TfLiteTensor* output = GetOutput(context, node, i);
113     TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
114   }
115 
116   return kTfLiteOk;
117 }
118 
Prepare(TfLiteContext * context,TfLiteNode * node)119 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
120   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
121 
122   OpContext op_context(context, node);
123 
124   TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
125 
126   auto input_type = op_context.input->type;
127   TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
128                               input_type == kTfLiteUInt8 ||
129                               input_type == kTfLiteInt16);
130   for (int i = 0; i < NumOutputs(node); ++i) {
131     GetOutput(context, node, i)->type = input_type;
132   }
133 
134   auto size_splits = op_context.size_splits;
135   TF_LITE_ENSURE_EQ(context, NumDimensions(size_splits), 1);
136   TF_LITE_ENSURE_EQ(context, NumOutputs(node), NumElements(size_splits));
137 
138   // If we know the contents of the 'size_splits' tensor and the 'axis' tensor,
139   // resize all outputs. Otherwise, wait until Eval().
140   if (IsConstantTensor(op_context.size_splits) &&
141       IsConstantTensor(op_context.axis)) {
142     return ResizeOutputTensors(context, node, op_context.input,
143                                op_context.size_splits, op_context.axis);
144   } else {
145     return UseDynamicOutputTensors(context, node);
146   }
147 }
148 
Eval(TfLiteContext * context,TfLiteNode * node)149 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
150   OpContext op_context(context, node);
151 
152   // When the 'size_splits' and the 'axis' tensor is non-const we can't resize
153   // output tensors in Prepare(), and we have to do it now.
154   if (!IsConstantTensor(op_context.axis) ||
155       !IsConstantTensor(op_context.size_splits)) {
156     TF_LITE_ENSURE_OK(
157         context, ResizeOutputTensors(context, node, op_context.input,
158                                      op_context.size_splits, op_context.axis));
159   }
160 
161   int axis_value = GetTensorData<int>(op_context.axis)[0];
162 
163   // Use split function to build the outputs since they share the same logic.
164 #define TF_LITE_SPLIT_V(scalar)                                     \
165   VectorOfTensors<scalar> all_outputs(*context, *node->outputs);    \
166   tflite::SplitParams op_params;                                    \
167   op_params.num_split = NumOutputs(node);                           \
168   op_params.axis = axis_value;                                      \
169   reference_ops::Split(op_params, GetTensorShape(op_context.input), \
170                        GetTensorData<scalar>(op_context.input),     \
171                        all_outputs.shapes(), all_outputs.data());
172   switch (op_context.input->type) {
173     case kTfLiteFloat32: {
174       TF_LITE_SPLIT_V(float);
175       break;
176     }
177     case kTfLiteUInt8: {
178       TF_LITE_SPLIT_V(uint8_t);
179       break;
180     }
181     case kTfLiteInt16: {
182       TF_LITE_SPLIT_V(int16_t);
183       break;
184     }
185     default:
186       context->ReportError(context, "Type %s currently not supported.",
187                            TfLiteTypeGetName(op_context.input->type));
188       return kTfLiteError;
189   }
190 #undef TF_LITE_SPLIT_V
191 
192   return kTfLiteOk;
193 }
194 
195 }  // namespace split_v
196 
Register_SPLIT_V()197 TfLiteRegistration* Register_SPLIT_V() {
198   static TfLiteRegistration r = {nullptr, nullptr, split_v::Prepare,
199                                  split_v::Eval};
200   return &r;
201 }
202 
203 }  // namespace builtin
204 }  // namespace ops
205 }  // namespace tflite
206