1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/micro/kernels/kernel_util.h"
21
22 namespace tflite {
23 namespace ops {
24 namespace micro {
25 namespace split {
26
27 template <typename T>
SplitImpl(TfLiteContext * context,TfLiteNode * node,const TfLiteEvalTensor * input,int axis_value)28 TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
29 const TfLiteEvalTensor* input, int axis_value) {
30 const int output_count = NumOutputs(node);
31 const TfLiteIntArray* input_dims = input->dims;
32 const TfLiteEvalTensor* output0 =
33 tflite::micro::GetEvalOutput(context, node, 0);
34 const TfLiteIntArray* output_dims = output0->dims;
35
36 const int split_dimensions = input_dims->size;
37 int axis = axis_value < 0 ? axis_value + split_dimensions : axis_value;
38
39 TFLITE_DCHECK_LT(axis, split_dimensions);
40 TFLITE_DCHECK_EQ(output_dims->size, split_dimensions);
41
42 int64_t split_size = output_dims->data[axis] * output_count;
43
44 TFLITE_DCHECK_EQ(split_size, input_dims->data[axis]);
45 int64_t outer_size = 1;
46 for (int i = 0; i < axis; ++i) {
47 outer_size *= input_dims->data[i];
48 }
49
50 int64_t base_inner_size = 1;
51 for (int i = axis + 1; i < split_dimensions; ++i) {
52 base_inner_size *= input_dims->data[i];
53 }
54
55 const T* input_ptr = tflite::micro::GetTensorData<T>(input);
56 for (int k = 0; k < outer_size; ++k) {
57 for (int i = 0; i < output_count; ++i) {
58 TfLiteEvalTensor* t = tflite::micro::GetEvalOutput(context, node, i);
59 T* output_data = tflite::micro::GetTensorData<T>(t);
60 const int copy_size = output_dims->data[axis] * base_inner_size;
61 T* output_ptr = output_data + k * copy_size;
62 for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
63 input_ptr += copy_size;
64 }
65 }
66
67 return kTfLiteOk;
68 }
69
Prepare(TfLiteContext * context,TfLiteNode * node)70 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
71 const TfLiteTensor* axis = GetInput(context, node, 0);
72 TF_LITE_ENSURE(context, axis != nullptr);
73
74 // Dynamic output tensors are needed if axis tensor is not constant.
75 // But Micro doesn't support dynamic memory allocation, so we only support
76 // constant axis tensor for now.
77 TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
78 "Non constant axis tensor not supported");
79 return kTfLiteOk;
80 }
81
Eval(TfLiteContext * context,TfLiteNode * node)82 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
83 const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 0);
84 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 1);
85
86 int axis_value = tflite::micro::GetTensorData<int32_t>(axis)[0];
87 if (axis_value < 0) {
88 axis_value += input->dims->size;
89 }
90
91 TF_LITE_ENSURE(context, axis_value >= 0);
92 TF_LITE_ENSURE(context, axis_value < input->dims->size);
93
94 switch (input->type) {
95 case kTfLiteFloat32: {
96 return SplitImpl<float>(context, node, input, axis_value);
97 }
98 case kTfLiteUInt8: {
99 return SplitImpl<uint8_t>(context, node, input, axis_value);
100 }
101 case kTfLiteInt8: {
102 return SplitImpl<int8_t>(context, node, input, axis_value);
103 }
104 case kTfLiteInt16: {
105 return SplitImpl<int16_t>(context, node, input, axis_value);
106 }
107 case kTfLiteInt32: {
108 return SplitImpl<int32_t>(context, node, input, axis_value);
109 }
110 default:
111 TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
112 TfLiteTypeGetName(input->type));
113 return kTfLiteError;
114 }
115 #undef TF_LITE_SPLIT
116
117 return kTfLiteOk;
118 }
119
120 } // namespace split
121
Register_SPLIT()122 TfLiteRegistration Register_SPLIT() {
123 return {/*init=*/nullptr,
124 /*free=*/nullptr,
125 /*prepare=*/split::Prepare,
126 /*invoke=*/split::Eval,
127 /*profiling_string=*/nullptr,
128 /*builtin_code=*/0,
129 /*custom_name=*/nullptr,
130 /*version=*/0};
131 }
132
133 } // namespace micro
134 } // namespace ops
135 } // namespace tflite
136