• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct FullyConnectedFixture : public ParserFlatbuffersFixture
16 {
FullyConnectedFixtureFullyConnectedFixture17     explicit FullyConnectedFixture(const std::string& inputShape,
18                                            const std::string& outputShape,
19                                            const std::string& filterShape,
20                                            const std::string& filterData,
21                                            const std::string biasShape = "",
22                                            const std::string biasData = "")
23     {
24         std::string inputTensors = "[ 0, 2 ]";
25         std::string biasTensor = "";
26         std::string biasBuffer = "";
27         if (biasShape.size() > 0 && biasData.size() > 0)
28         {
29             inputTensors = "[ 0, 2, 3 ]";
30             biasTensor = R"(
31                         {
32                             "shape": )" + biasShape + R"( ,
33                             "type": "INT32",
34                             "buffer": 3,
35                             "name": "biasTensor",
36                             "quantization": {
37                                 "min": [ 0.0 ],
38                                 "max": [ 255.0 ],
39                                 "scale": [ 1.0 ],
40                                 "zero_point": [ 0 ],
41                             }
42                         } )";
43             biasBuffer = R"(
44                     { "data": )" + biasData + R"(, }, )";
45         }
46         m_JsonString = R"(
47             {
48                 "version": 3,
49                 "operator_codes": [ { "builtin_code": "FULLY_CONNECTED" } ],
50                 "subgraphs": [ {
51                     "tensors": [
52                         {
53                             "shape": )" + inputShape + R"(,
54                             "type": "UINT8",
55                             "buffer": 0,
56                             "name": "inputTensor",
57                             "quantization": {
58                                 "min": [ 0.0 ],
59                                 "max": [ 255.0 ],
60                                 "scale": [ 1.0 ],
61                                 "zero_point": [ 0 ],
62                             }
63                         },
64                         {
65                             "shape": )" + outputShape + R"(,
66                             "type": "UINT8",
67                             "buffer": 1,
68                             "name": "outputTensor",
69                             "quantization": {
70                                 "min": [ 0.0 ],
71                                 "max": [ 511.0 ],
72                                 "scale": [ 2.0 ],
73                                 "zero_point": [ 0 ],
74                             }
75                         },
76                         {
77                             "shape": )" + filterShape + R"(,
78                             "type": "UINT8",
79                             "buffer": 2,
80                             "name": "filterTensor",
81                             "quantization": {
82                                 "min": [ 0.0 ],
83                                 "max": [ 255.0 ],
84                                 "scale": [ 1.0 ],
85                                 "zero_point": [ 0 ],
86                             }
87                         }, )" + biasTensor + R"(
88                     ],
89                     "inputs": [ 0 ],
90                     "outputs": [ 1 ],
91                     "operators": [
92                         {
93                             "opcode_index": 0,
94                             "inputs": )" + inputTensors + R"(,
95                             "outputs": [ 1 ],
96                             "builtin_options_type": "FullyConnectedOptions",
97                             "builtin_options": {
98                                 "fused_activation_function": "NONE"
99                             },
100                             "custom_options_format": "FLEXBUFFERS"
101                         }
102                     ],
103                 } ],
104                 "buffers" : [
105                     { },
106                     { },
107                     { "data": )" + filterData + R"(, }, )"
108                        + biasBuffer + R"(
109                 ]
110             }
111         )";
112         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
113     }
114 };
115 
116 struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
117 {
FullyConnectedWithNoBiasFixtureFullyConnectedWithNoBiasFixture118     FullyConnectedWithNoBiasFixture()
119         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
120                                 "[ 1, 1 ]",           // outputShape
121                                 "[ 1, 4 ]",           // filterShape
122                                 "[ 2, 3, 4, 5 ]")     // filterData
123     {}
124 };
125 
BOOST_FIXTURE_TEST_CASE(FullyConnectedWithNoBias,FullyConnectedWithNoBiasFixture)126 BOOST_FIXTURE_TEST_CASE(FullyConnectedWithNoBias, FullyConnectedWithNoBiasFixture)
127 {
128     RunTest<2, armnn::DataType::QAsymmU8>(
129         0,
130         { 10, 20, 30, 40 },
131         { 400/2 });
132 }
133 
134 struct FullyConnectedWithBiasFixture : FullyConnectedFixture
135 {
FullyConnectedWithBiasFixtureFullyConnectedWithBiasFixture136     FullyConnectedWithBiasFixture()
137         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
138                                 "[ 1, 1 ]",           // outputShape
139                                 "[ 1, 4 ]",           // filterShape
140                                 "[ 2, 3, 4, 5 ]",     // filterData
141                                 "[ 1 ]",              // biasShape
142                                 "[ 10, 0, 0, 0 ]" )   // biasData
143     {}
144 };
145 
BOOST_FIXTURE_TEST_CASE(ParseFullyConnectedWithBias,FullyConnectedWithBiasFixture)146 BOOST_FIXTURE_TEST_CASE(ParseFullyConnectedWithBias, FullyConnectedWithBiasFixture)
147 {
148     RunTest<2, armnn::DataType::QAsymmU8>(
149         0,
150         { 10, 20, 30, 40 },
151         { (400+10)/2 });
152 }
153 
154 struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
155 {
FullyConnectedWithBiasMultipleOutputsFixtureFullyConnectedWithBiasMultipleOutputsFixture156     FullyConnectedWithBiasMultipleOutputsFixture()
157             : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
158                                     "[ 2, 1 ]",           // outputShape
159                                     "[ 1, 4 ]",           // filterShape
160                                     "[ 2, 3, 4, 5 ]",     // filterData
161                                     "[ 1 ]",              // biasShape
162                                     "[ 10, 0, 0, 0 ]" )   // biasData
163     {}
164 };
165 
BOOST_FIXTURE_TEST_CASE(FullyConnectedWithBiasMultipleOutputs,FullyConnectedWithBiasMultipleOutputsFixture)166 BOOST_FIXTURE_TEST_CASE(FullyConnectedWithBiasMultipleOutputs, FullyConnectedWithBiasMultipleOutputsFixture)
167 {
168     RunTest<2, armnn::DataType::QAsymmU8>(
169             0,
170             { 1, 2, 3, 4, 10, 20, 30, 40 },
171             { (40+10)/2, (400+10)/2 });
172 }
173 
174 struct DynamicFullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
175 {
DynamicFullyConnectedWithBiasMultipleOutputsFixtureDynamicFullyConnectedWithBiasMultipleOutputsFixture176     DynamicFullyConnectedWithBiasMultipleOutputsFixture()
177         : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
178                                 "[ ]",               // outputShape
179                                 "[ 1, 4 ]",           // filterShape
180                                 "[ 2, 3, 4, 5 ]",     // filterData
181                                 "[ 1 ]",              // biasShape
182                                 "[ 10, 0, 0, 0 ]" )   // biasData
183     { }
184 };
185 
BOOST_FIXTURE_TEST_CASE(DynamicFullyConnectedWithBiasMultipleOutputs,DynamicFullyConnectedWithBiasMultipleOutputsFixture)186 BOOST_FIXTURE_TEST_CASE(
187     DynamicFullyConnectedWithBiasMultipleOutputs,
188     DynamicFullyConnectedWithBiasMultipleOutputsFixture)
189 {
190     RunTest<2,
191             armnn::DataType::QAsymmU8,
192             armnn::DataType::QAsymmU8>(0,
193                                       { { "inputTensor", { 1, 2, 3, 4, 10, 20, 30, 40} } },
194                                       { { "outputTensor", { (40+10)/2, (400+10)/2 } } },
195                                       true);
196 }
197 
198 BOOST_AUTO_TEST_SUITE_END()
199