• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct SplitVFixture : public ParserFlatbuffersFixture
16 {
SplitVFixtureSplitVFixture17     explicit SplitVFixture(const std::string& inputShape,
18                            const std::string& splitValues,
19                            const std::string& sizeSplitsShape,
20                            const std::string& axisShape,
21                            const std::string& numSplits,
22                            const std::string& outputShape1,
23                            const std::string& outputShape2,
24                            const std::string& axisData,
25                            const std::string& dataType)
26     {
27         m_JsonString = R"(
28             {
29                 "version": 3,
30                 "operator_codes": [ { "builtin_code": "SPLIT_V" } ],
31                 "subgraphs": [ {
32                     "tensors": [
33                         {
34                             "shape": )" + inputShape + R"(,
35                             "type": )" + dataType + R"(,
36                             "buffer": 0,
37                             "name": "inputTensor",
38                             "quantization": {
39                                 "min": [ 0.0 ],
40                                 "max": [ 255.0 ],
41                                 "scale": [ 1.0 ],
42                                 "zero_point": [ 0 ],
43                             }
44                         },
45                         {
46                             "shape": )" + sizeSplitsShape + R"(,
47                             "type": "INT32",
48                             "buffer": 1,
49                             "name": "sizeSplits",
50                             "quantization": {
51                                 "min": [ 0.0 ],
52                                 "max": [ 255.0 ],
53                                 "scale": [ 1.0 ],
54                                 "zero_point": [ 0 ],
55                             }
56                         },
57                         {
58                             "shape": )" + axisShape + R"(,
59                             "type": "INT32",
60                             "buffer": 2,
61                             "name": "axis",
62                             "quantization": {
63                                 "min": [ 0.0 ],
64                                 "max": [ 255.0 ],
65                                 "scale": [ 1.0 ],
66                                 "zero_point": [ 0 ],
67                             }
68                         },
69                         {
70                             "shape": )" + outputShape1 + R"( ,
71                             "type":)" + dataType + R"(,
72                             "buffer": 3,
73                             "name": "outputTensor1",
74                             "quantization": {
75                                 "min": [ 0.0 ],
76                                 "max": [ 255.0 ],
77                                 "scale": [ 1.0 ],
78                                 "zero_point": [ 0 ],
79                             }
80                         },
81                         {
82                             "shape": )" + outputShape2 + R"( ,
83                             "type":)" + dataType + R"(,
84                             "buffer": 4,
85                             "name": "outputTensor2",
86                             "quantization": {
87                                 "min": [ 0.0 ],
88                                 "max": [ 255.0 ],
89                                 "scale": [ 1.0 ],
90                                 "zero_point": [ 0 ],
91                             }
92                         }
93                     ],
94                     "inputs": [ 0, 1, 2 ],
95                     "outputs": [ 3, 4 ],
96                     "operators": [
97                         {
98                             "opcode_index": 0,
99                             "inputs": [ 0, 1, 2 ],
100                             "outputs": [ 3, 4 ],
101                             "builtin_options_type": "SplitVOptions",
102                             "builtin_options": {
103                                 "num_splits": )" + numSplits + R"(
104                             },
105                             "custom_options_format": "FLEXBUFFERS"
106                         }
107                     ],
108                 } ],
109                 "buffers" : [ {}, { "data": )" + splitValues + R"( }, { "data": )" + axisData + R"( }, {}, {}]
110             }
111         )";
112 
113         Setup();
114     }
115 };
116 
117 /*
118  *  Tested inferred splitSizes with splitValues [-1, 1] locally.
119  */
120 
121 struct SimpleSplitVAxisOneFixture : SplitVFixture
122 {
SimpleSplitVAxisOneFixtureSimpleSplitVAxisOneFixture123     SimpleSplitVAxisOneFixture()
124         : SplitVFixture( "[ 4, 2, 2, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
125                          "[ 1, 2, 2, 2 ]", "[ 3, 2, 2, 2 ]", "[ 0, 0, 0, 0 ]", "FLOAT32")
126     {}
127 };
128 
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitVTwo,SimpleSplitVAxisOneFixture)129 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitVTwo, SimpleSplitVAxisOneFixture)
130 {
131     RunTest<4, armnn::DataType::Float32>(
132         0,
133         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
134                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
135                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
136                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
137         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
138           {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
139                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
140                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
141 }
142 
143 struct SimpleSplitVAxisTwoFixture : SplitVFixture
144 {
SimpleSplitVAxisTwoFixtureSimpleSplitVAxisTwoFixture145     SimpleSplitVAxisTwoFixture()
146         : SplitVFixture( "[ 2, 4, 2, 2 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
147                          "[ 2, 3, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
148     {}
149 };
150 
BOOST_FIXTURE_TEST_CASE(ParseAxisTwoSplitVTwo,SimpleSplitVAxisTwoFixture)151 BOOST_FIXTURE_TEST_CASE(ParseAxisTwoSplitVTwo, SimpleSplitVAxisTwoFixture)
152 {
153     RunTest<4, armnn::DataType::Float32>(
154         0,
155         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
156                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
157                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
158                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
159         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
160                               9.0f, 10.0f, 11.0f, 12.0f, 17.0f, 18.0f, 19.0f, 20.0f,
161                               21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f } },
162           {"outputTensor2", { 13.0f, 14.0f, 15.0f, 16.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
163 }
164 
165 struct SimpleSplitVAxisThreeFixture : SplitVFixture
166 {
SimpleSplitVAxisThreeFixtureSimpleSplitVAxisThreeFixture167     SimpleSplitVAxisThreeFixture()
168         : SplitVFixture( "[ 2, 2, 4, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
169                          "[ 2, 2, 1, 2 ]", "[ 2, 2, 3, 2 ]", "[ 2, 0, 0, 0 ]", "FLOAT32")
170     {}
171 };
172 
BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitVTwo,SimpleSplitVAxisThreeFixture)173 BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitVTwo, SimpleSplitVAxisThreeFixture)
174 {
175     RunTest<4, armnn::DataType::Float32>(
176         0,
177         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
178                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
179                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
180                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
181         { {"outputTensor1", { 1.0f, 2.0f, 9.0f, 10.0f, 17.0f, 18.0f, 25.0f, 26.0f } },
182           {"outputTensor2", { 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 11.0f, 12.0f,
183                               13.0f, 14.0f, 15.0f, 16.0f, 19.0f, 20.0f, 21.0f, 22.0f,
184                               23.0f, 24.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
185 }
186 
187 struct SimpleSplitVAxisFourFixture : SplitVFixture
188 {
SimpleSplitVAxisFourFixtureSimpleSplitVAxisFourFixture189     SimpleSplitVAxisFourFixture()
190         : SplitVFixture( "[ 2, 2, 2, 4 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
191                          "[ 2, 2, 2, 3 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32")
192     {}
193 };
194 
BOOST_FIXTURE_TEST_CASE(ParseAxisFourSplitVTwo,SimpleSplitVAxisFourFixture)195 BOOST_FIXTURE_TEST_CASE(ParseAxisFourSplitVTwo, SimpleSplitVAxisFourFixture)
196 {
197     RunTest<4, armnn::DataType::Float32>(
198         0,
199         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
200                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
201                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
202                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
203         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f, 9.0f, 10.0f,
204                               11.0f, 13.0f, 14.0f, 15.0f, 17.0f, 18.0f, 19.0f, 21.0f,
205                               22.0f, 23.0f, 25.0f, 26.0f, 27.0f, 29.0f, 30.0f, 31.0f} },
206           {"outputTensor2", { 4.0f, 8.0f, 12.0f, 16.0f, 20.0f, 24.0f, 28.0f, 32.0f } } } );
207 }
208 
209 BOOST_AUTO_TEST_SUITE_END()
210