• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/kernels/test_util.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace {
28 
29 using ::testing::ElementsAreArray;
30 
31 constexpr int kAxisIsATensor = -1000;
32 
33 class SplitVOpModel : public SingleOpModel {
34  public:
SplitVOpModel(const TensorData & input,const TensorData & size_splits,int num_splits,int axis)35   SplitVOpModel(const TensorData& input, const TensorData& size_splits,
36                 int num_splits, int axis) {
37     input_ = AddInput(input);
38     size_splits_ = AddInput(size_splits);
39     if (axis == kAxisIsATensor) {
40       axis_ = AddInput({TensorType_INT32, {1}});
41     } else {
42       axis_ = AddConstInput(TensorType_INT32, {axis}, {1});
43     }
44     for (int i = 0; i < num_splits; ++i) {
45       outputs_.push_back(AddOutput(input.type));
46     }
47     SetBuiltinOp(BuiltinOperator_SPLIT_V, BuiltinOptions_SplitVOptions,
48                  CreateSplitVOptions(builder_, num_splits).Union());
49     if (axis == kAxisIsATensor) {
50       BuildInterpreter(
51           {GetShape(input_), GetShape(size_splits_), GetShape(axis_)});
52     } else {
53       BuildInterpreter({GetShape(input_), GetShape(size_splits_), {}});
54     }
55   }
56 
57   template <typename T>
SetInput(std::initializer_list<T> data)58   void SetInput(std::initializer_list<T> data) {
59     PopulateTensor<T>(input_, data);
60   }
SetSizeSplits(std::initializer_list<int> data)61   void SetSizeSplits(std::initializer_list<int> data) {
62     PopulateTensor(size_splits_, data);
63   }
SetAxis(int axis)64   void SetAxis(int axis) { PopulateTensor(axis_, {axis}); }
65 
66   template <typename T>
GetOutput(int i)67   std::vector<T> GetOutput(int i) {
68     return ExtractVector<T>(outputs_[i]);
69   }
GetOutputShape(int i)70   std::vector<int> GetOutputShape(int i) { return GetTensorShape(outputs_[i]); }
71 
72  private:
73   int input_;
74   int size_splits_;
75   int axis_;
76   std::vector<int> outputs_;
77 };
78 
79 template <typename T>
Check(int axis,std::initializer_list<int> input_shape,std::initializer_list<int> size_splits_shape,std::vector<std::initializer_list<int>> output_shapes,const std::initializer_list<T> & input_data,const std::initializer_list<int> & size_splits_data,const std::vector<std::initializer_list<T>> & output_data)80 void Check(int axis, std::initializer_list<int> input_shape,
81            std::initializer_list<int> size_splits_shape,
82            std::vector<std::initializer_list<int>> output_shapes,
83            const std::initializer_list<T>& input_data,
84            const std::initializer_list<int>& size_splits_data,
85            const std::vector<std::initializer_list<T>>& output_data) {
86   int num_splits = size_splits_data.size();
87   SplitVOpModel m({GetTensorType<T>(), input_shape},
88                   {TensorType_INT32, size_splits_shape}, num_splits,
89                   kAxisIsATensor);
90   m.SetInput<T>(input_data);
91   m.SetSizeSplits(size_splits_data);
92   m.SetAxis(axis);
93   m.Invoke();
94   for (int i = 0; i < num_splits; ++i) {
95     EXPECT_THAT(m.GetOutput<T>(i), ElementsAreArray(output_data[i]));
96     EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shapes[i]));
97   }
98 
99   SplitVOpModel const_m({GetTensorType<T>(), input_shape},
100                         {TensorType_INT32, size_splits_shape}, num_splits,
101                         axis);
102   const_m.SetInput<T>(input_data);
103   const_m.SetSizeSplits(size_splits_data);
104   const_m.Invoke();
105   for (int i = 0; i < num_splits; ++i) {
106     EXPECT_THAT(const_m.GetOutput<T>(i), ElementsAreArray(output_data[i]));
107     EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shapes[i]));
108   }
109 }
110 
111 template <typename T>
112 class SplitVOpTypedTest : public ::testing::Test {};
113 
114 using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>;
115 TYPED_TEST_SUITE(SplitVOpTypedTest, DataTypes);
116 
TYPED_TEST(SplitVOpTypedTest,TwoDimensional)117 TYPED_TEST(SplitVOpTypedTest, TwoDimensional) {
118   // Input shape: {4, 3}
119   // size_splits: {1, 1, 2}
120   // axis: 0
121   // We should have 3 outpus with shapes respectively:
122   //  output 1 : {1, 3}
123   //  output 2 : {1, 3}
124   //  output 3 : {2, 3}
125   Check<TypeParam>(
126       /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}},
127       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
128       {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}});
129 }
130 
TYPED_TEST(SplitVOpTypedTest,FourDimensional)131 TYPED_TEST(SplitVOpTypedTest, FourDimensional) {
132   Check<TypeParam>(
133       /*axis=*/0, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}},
134       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
135       {
136           {1, 2, 3, 4, 5, 6, 7, 8},
137           {9, 10, 11, 12, 13, 14, 15, 16},
138       });
139   Check<TypeParam>(
140       /*axis=*/1, {2, 2, 2, 2}, {2}, {{2, 1, 2, 2}, {2, 1, 2, 2}},
141       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, -1},
142       {
143           {1, 2, 3, 4, 9, 10, 11, 12},
144           {5, 6, 7, 8, 13, 14, 15, 16},
145       });
146   Check<TypeParam>(
147       /*axis=*/2, {2, 2, 2, 2}, {2}, {{2, 2, 1, 2}, {2, 2, 1, 2}},
148       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
149       {
150           {1, 2, 5, 6, 9, 10, 13, 14},
151           {3, 4, 7, 8, 11, 12, 15, 16},
152       });
153   Check<TypeParam>(
154       /*axis=*/3, {2, 2, 2, 2}, {2}, {{2, 2, 2, 1}, {2, 2, 2, 1}},
155       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
156       {
157           {1, 3, 5, 7, 9, 11, 13, 15},
158           {2, 4, 6, 8, 10, 12, 14, 16},
159       });
160 }
161 
TYPED_TEST(SplitVOpTypedTest,OneDimensional)162 TYPED_TEST(SplitVOpTypedTest, OneDimensional) {
163   Check<TypeParam>(
164       /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}},
165       {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 1, 1},
166       {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
167 }
168 
TYPED_TEST(SplitVOpTypedTest,OneDimensional2)169 TYPED_TEST(SplitVOpTypedTest, OneDimensional2) {
170   Check<TypeParam>(
171       /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {2}, {0}},
172       {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 2, -1},
173       {{1}, {2}, {3}, {4}, {5}, {6}, {7, 8}, {}});
174 }
175 
TYPED_TEST(SplitVOpTypedTest,NegativeAxis)176 TYPED_TEST(SplitVOpTypedTest, NegativeAxis) {
177   Check<TypeParam>(
178       /*axis=*/-4, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}},
179       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1},
180       {
181           {1, 2, 3, 4, 5, 6, 7, 8},
182           {9, 10, 11, 12, 13, 14, 15, 16},
183       });
184 }
185 
186 }  // namespace
187 }  // namespace tflite
188