1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <vector>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25
26 namespace tflite {
27 namespace {
28
29 using ::testing::ElementsAreArray;
30
31 constexpr int kAxisIsATensor = -1000;
32
33 class SplitVOpModel : public SingleOpModel {
34 public:
SplitVOpModel(const TensorData & input,const TensorData & size_splits,int num_splits,int axis)35 SplitVOpModel(const TensorData& input, const TensorData& size_splits,
36 int num_splits, int axis) {
37 input_ = AddInput(input);
38 size_splits_ = AddInput(size_splits);
39 if (axis == kAxisIsATensor) {
40 axis_ = AddInput({TensorType_INT32, {1}});
41 } else {
42 axis_ = AddConstInput(TensorType_INT32, {axis}, {1});
43 }
44 for (int i = 0; i < num_splits; ++i) {
45 outputs_.push_back(AddOutput(input.type));
46 }
47 SetBuiltinOp(BuiltinOperator_SPLIT_V, BuiltinOptions_SplitVOptions,
48 CreateSplitVOptions(builder_, num_splits).Union());
49 if (axis == kAxisIsATensor) {
50 BuildInterpreter(
51 {GetShape(input_), GetShape(size_splits_), GetShape(axis_)});
52 } else {
53 BuildInterpreter({GetShape(input_), GetShape(size_splits_), {}});
54 }
55 }
56
57 template <typename T>
SetInput(std::initializer_list<T> data)58 void SetInput(std::initializer_list<T> data) {
59 PopulateTensor<T>(input_, data);
60 }
SetSizeSplits(std::initializer_list<int> data)61 void SetSizeSplits(std::initializer_list<int> data) {
62 PopulateTensor(size_splits_, data);
63 }
SetAxis(int axis)64 void SetAxis(int axis) { PopulateTensor(axis_, {axis}); }
65
66 template <typename T>
GetOutput(int i)67 std::vector<T> GetOutput(int i) {
68 return ExtractVector<T>(outputs_[i]);
69 }
GetOutputShape(int i)70 std::vector<int> GetOutputShape(int i) { return GetTensorShape(outputs_[i]); }
71
72 private:
73 int input_;
74 int size_splits_;
75 int axis_;
76 std::vector<int> outputs_;
77 };
78
79 template <typename T>
Check(int axis,std::initializer_list<int> input_shape,std::initializer_list<int> size_splits_shape,std::vector<std::initializer_list<int>> output_shapes,const std::initializer_list<T> & input_data,const std::initializer_list<int> & size_splits_data,const std::vector<std::initializer_list<T>> & output_data)80 void Check(int axis, std::initializer_list<int> input_shape,
81 std::initializer_list<int> size_splits_shape,
82 std::vector<std::initializer_list<int>> output_shapes,
83 const std::initializer_list<T>& input_data,
84 const std::initializer_list<int>& size_splits_data,
85 const std::vector<std::initializer_list<T>>& output_data) {
86 int num_splits = size_splits_data.size();
87 SplitVOpModel m({GetTensorType<T>(), input_shape},
88 {TensorType_INT32, size_splits_shape}, num_splits,
89 kAxisIsATensor);
90 m.SetInput<T>(input_data);
91 m.SetSizeSplits(size_splits_data);
92 m.SetAxis(axis);
93 m.Invoke();
94 for (int i = 0; i < num_splits; ++i) {
95 EXPECT_THAT(m.GetOutput<T>(i), ElementsAreArray(output_data[i]));
96 EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shapes[i]));
97 }
98
99 SplitVOpModel const_m({GetTensorType<T>(), input_shape},
100 {TensorType_INT32, size_splits_shape}, num_splits,
101 axis);
102 const_m.SetInput<T>(input_data);
103 const_m.SetSizeSplits(size_splits_data);
104 const_m.Invoke();
105 for (int i = 0; i < num_splits; ++i) {
106 EXPECT_THAT(const_m.GetOutput<T>(i), ElementsAreArray(output_data[i]));
107 EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shapes[i]));
108 }
109 }
110
111 template <typename T>
112 class SplitVOpTypedTest : public ::testing::Test {};
113
114 using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>;
115 TYPED_TEST_SUITE(SplitVOpTypedTest, DataTypes);
116
TYPED_TEST(SplitVOpTypedTest,TwoDimensional)117 TYPED_TEST(SplitVOpTypedTest, TwoDimensional) {
118 // Input shape: {4, 3}
119 // size_splits: {1, 1, 2}
120 // axis: 0
121 // We should have 3 outpus with shapes respectively:
122 // output 1 : {1, 3}
123 // output 2 : {1, 3}
124 // output 3 : {2, 3}
125 Check<TypeParam>(
126 /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}},
127 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
128 {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
129 }
130
TYPED_TEST(SplitVOpTypedTest,FourDimensional)131 TYPED_TEST(SplitVOpTypedTest, FourDimensional) {
132 Check<TypeParam>(
133 /*axis=*/0, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}},
134 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
135 {
136 {1, 2, 3, 4, 5, 6, 7, 8},
137 {9, 10, 11, 12, 13, 14, 15, 16},
138 });
139 Check<TypeParam>(
140 /*axis=*/1, {2, 2, 2, 2}, {2}, {{2, 1, 2, 2}, {2, 1, 2, 2}},
141 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, -1},
142 {
143 {1, 2, 3, 4, 9, 10, 11, 12},
144 {5, 6, 7, 8, 13, 14, 15, 16},
145 });
146 Check<TypeParam>(
147 /*axis=*/2, {2, 2, 2, 2}, {2}, {{2, 2, 1, 2}, {2, 2, 1, 2}},
148 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
149 {
150 {1, 2, 5, 6, 9, 10, 13, 14},
151 {3, 4, 7, 8, 11, 12, 15, 16},
152 });
153 Check<TypeParam>(
154 /*axis=*/3, {2, 2, 2, 2}, {2}, {{2, 2, 2, 1}, {2, 2, 2, 1}},
155 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
156 {
157 {1, 3, 5, 7, 9, 11, 13, 15},
158 {2, 4, 6, 8, 10, 12, 14, 16},
159 });
160 }
161
TYPED_TEST(SplitVOpTypedTest,OneDimensional)162 TYPED_TEST(SplitVOpTypedTest, OneDimensional) {
163 Check<TypeParam>(
164 /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}},
165 {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 1, 1},
166 {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
167 }
168
TYPED_TEST(SplitVOpTypedTest,OneDimensional2)169 TYPED_TEST(SplitVOpTypedTest, OneDimensional2) {
170 Check<TypeParam>(
171 /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {2}, {0}},
172 {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 2, -1},
173 {{1}, {2}, {3}, {4}, {5}, {6}, {7, 8}, {}});
174 }
175
TYPED_TEST(SplitVOpTypedTest,NegativeAxis)176 TYPED_TEST(SplitVOpTypedTest, NegativeAxis) {
177 Check<TypeParam>(
178 /*axis=*/-4, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}},
179 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
180 {
181 {1, 2, 3, 4, 5, 6, 7, 8},
182 {9, 10, 11, 12, 13, 14, 15, 16},
183 });
184 }
185
186 } // namespace
187 } // namespace tflite
188