• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct PackFixture : public ParserFlatbuffersFixture
16 {
PackFixturePackFixture17     explicit PackFixture(const std::string & inputShape,
18                          const unsigned int numInputs,
19                          const std::string & outputShape,
20                          const std::string & axis)
21     {
22         m_JsonString = R"(
23             {
24                 "version": 3,
25                 "operator_codes": [ { "builtin_code": "PACK" } ],
26                 "subgraphs": [ {
27                     "tensors": [)";
28 
29         for (unsigned int i = 0; i < numInputs; ++i)
30         {
31             m_JsonString += R"(
32                         {
33                             "shape": )" + inputShape + R"(,
34                             "type": "FLOAT32",
35                             "buffer": )" + std::to_string(i) + R"(,
36                             "name": "inputTensor)" + std::to_string(i + 1) + R"(",
37                             "quantization": {
38                                 "min": [ 0.0 ],
39                                 "max": [ 255.0 ],
40                                 "scale": [ 1.0 ],
41                                 "zero_point": [ 0 ],
42                             }
43                         },)";
44         }
45 
46         std::string inputIndexes;
47         for (unsigned int i = 0; i < numInputs-1; ++i)
48         {
49             inputIndexes += std::to_string(i) + R"(, )";
50         }
51         inputIndexes += std::to_string(numInputs-1);
52 
53         m_JsonString += R"(
54                         {
55                             "shape": )" + outputShape + R"( ,
56                             "type": "FLOAT32",
57                             "buffer": )" + std::to_string(numInputs) + R"(,
58                             "name": "outputTensor",
59                             "quantization": {
60                                 "min": [ 0.0 ],
61                                 "max": [ 255.0 ],
62                                 "scale": [ 1.0 ],
63                                 "zero_point": [ 0 ],
64                             }
65                         }
66                     ],
67                     "inputs": [ )" + inputIndexes + R"( ],
68                     "outputs": [ 2 ],
69                     "operators": [
70                         {
71                             "opcode_index": 0,
72                             "inputs": [ )" + inputIndexes + R"( ],
73                             "outputs": [ 2 ],
74                             "builtin_options_type": "PackOptions",
75                             "builtin_options": {
76                                 "axis": )" + axis + R"(,
77                                 "values_count": )" + std::to_string(numInputs) + R"(
78                             },
79                             "custom_options_format": "FLEXBUFFERS"
80                         }
81                     ],
82                 } ],
83                 "buffers" : [)";
84 
85             for (unsigned int i = 0; i < numInputs-1; ++i)
86             {
87                 m_JsonString += R"(
88                     { },)";
89             }
90             m_JsonString += R"(
91                     { }
92                 ]
93             })";
94         Setup();
95     }
96 };
97 
98 struct SimplePackFixture : PackFixture
99 {
SimplePackFixtureSimplePackFixture100     SimplePackFixture() : PackFixture("[ 3, 2, 3 ]",
101                                       2,
102                                       "[ 3, 2, 3, 2 ]",
103                                       "3") {}
104 };
105 
BOOST_FIXTURE_TEST_CASE(ParsePack,SimplePackFixture)106 BOOST_FIXTURE_TEST_CASE(ParsePack, SimplePackFixture)
107 {
108     RunTest<4, armnn::DataType::Float32>(
109     0,
110     { {"inputTensor1", { 1, 2, 3,
111                          4, 5, 6,
112 
113                          7, 8, 9,
114                          10, 11, 12,
115 
116                          13, 14, 15,
117                          16, 17, 18 } },
118     {"inputTensor2", { 19, 20, 21,
119                        22, 23, 24,
120 
121                        25, 26, 27,
122                        28, 29, 30,
123 
124                        31, 32, 33,
125                        34, 35, 36 } } },
126     { {"outputTensor", { 1, 19,
127                          2, 20,
128                          3, 21,
129 
130                          4, 22,
131                          5, 23,
132                          6, 24,
133 
134 
135                          7, 25,
136                          8, 26,
137                          9, 27,
138 
139                          10, 28,
140                          11, 29,
141                          12, 30,
142 
143 
144                          13, 31,
145                          14, 32,
146                          15, 33,
147 
148                          16, 34,
149                          17, 35,
150                          18, 36 } } });
151 }
152 
153 BOOST_AUTO_TEST_SUITE_END()
154