• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8 
9 using armnnTfLiteParser::TfLiteParser;
10 using ModelPtr = TfLiteParser::ModelPtr;
11 using TensorRawPtr = TfLiteParser::TensorRawPtr;
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
16 {
GetSubgraphInputsOutputsMainFixtureGetSubgraphInputsOutputsMainFixture17     explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
18     {
19         m_JsonString = R"(
20         {
21             "version": 3,
22             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
23             "subgraphs": [
24             {
25                 "tensors": [
26                 {
27                     "shape": [ 1, 1, 1, 1 ] ,
28                     "type": "UINT8",
29                             "buffer": 0,
30                             "name": "OutputTensor",
31                             "quantization": {
32                                 "min": [ 0.0 ],
33                                 "max": [ 255.0 ],
34                                 "scale": [ 1.0 ],
35                                 "zero_point": [ 0 ]
36                             }
37                 },
38                 {
39                     "shape": [ 1, 2, 2, 1 ] ,
40                     "type": "UINT8",
41                             "buffer": 1,
42                             "name": "InputTensor",
43                             "quantization": {
44                                 "min": [ -1.2 ],
45                                 "max": [ 25.5 ],
46                                 "scale": [ 0.25 ],
47                                 "zero_point": [ 10 ]
48                             }
49                 }
50                 ],
51                 "inputs": )"
52                             + inputs
53                             + R"(,
54                 "outputs": )"
55                             + outputs
56                             + R"(,
57                 "operators": [ {
58                         "opcode_index": 0,
59                         "inputs": [ 1 ],
60                         "outputs": [ 0 ],
61                         "builtin_options_type": "Pool2DOptions",
62                         "builtin_options":
63                         {
64                             "padding": "VALID",
65                             "stride_w": 2,
66                             "stride_h": 2,
67                             "filter_width": 2,
68                             "filter_height": 2,
69                             "fused_activation_function": "NONE"
70                         },
71                         "custom_options_format": "FLEXBUFFERS"
72                     } ]
73                 },
74                 {
75                     "tensors": [
76                         {
77                             "shape": [ 1, 3, 3, 1 ],
78                             "type": "UINT8",
79                             "buffer": 0,
80                             "name": "ConvInputTensor",
81                             "quantization": {
82                                 "scale": [ 1.0 ],
83                                 "zero_point": [ 0 ],
84                             }
85                         },
86                         {
87                             "shape": [ 1, 1, 1, 1 ],
88                             "type": "UINT8",
89                             "buffer": 1,
90                             "name": "ConvOutputTensor",
91                             "quantization": {
92                                 "min": [ 0.0 ],
93                                 "max": [ 511.0 ],
94                                 "scale": [ 2.0 ],
95                                 "zero_point": [ 0 ],
96                             }
97                         },
98                         {
99                             "shape": [ 1, 3, 3, 1 ],
100                             "type": "UINT8",
101                             "buffer": 2,
102                             "name": "filterTensor",
103                             "quantization": {
104                                 "min": [ 0.0 ],
105                                 "max": [ 255.0 ],
106                                 "scale": [ 1.0 ],
107                                 "zero_point": [ 0 ],
108                             }
109                         }
110                     ],
111                     "inputs": [ 0 ],
112                     "outputs": [ 1 ],
113                     "operators": [
114                         {
115                             "opcode_index": 0,
116                             "inputs": [ 0, 2 ],
117                             "outputs": [ 1 ],
118                             "builtin_options_type": "Conv2DOptions",
119                             "builtin_options": {
120                                 "padding": "VALID",
121                                 "stride_w": 1,
122                                 "stride_h": 1,
123                                 "fused_activation_function": "NONE"
124                             },
125                             "custom_options_format": "FLEXBUFFERS"
126                         }
127                     ],
128                 }
129             ],
130             "description": "Test Subgraph Inputs Outputs",
131             "buffers" : [
132                     { },
133                     { },
134                     { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
135                     { },
136                 ]
137         })";
138 
139         ReadStringToBinary();
140     }
141 
142 };
143 
144 struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
145 {
GetEmptySubgraphInputsOutputsFixtureGetEmptySubgraphInputsOutputsFixture146     GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
147 };
148 
149 struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
150 {
GetSubgraphInputsOutputsFixtureGetSubgraphInputsOutputsFixture151     GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
152 };
153 
BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphInputs,GetEmptySubgraphInputsOutputsFixture)154 BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphInputs, GetEmptySubgraphInputsOutputsFixture)
155 {
156     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
157     TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 0);
158     BOOST_CHECK_EQUAL(0, subgraphTensors.size());
159 }
160 
BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphOutputs,GetEmptySubgraphInputsOutputsFixture)161 BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphOutputs, GetEmptySubgraphInputsOutputsFixture)
162 {
163     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
164     TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 0);
165     BOOST_CHECK_EQUAL(0, subgraphTensors.size());
166 }
167 
BOOST_FIXTURE_TEST_CASE(GetSubgraphInputs,GetSubgraphInputsOutputsFixture)168 BOOST_FIXTURE_TEST_CASE(GetSubgraphInputs, GetSubgraphInputsOutputsFixture)
169 {
170     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
171     TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 0);
172     BOOST_CHECK_EQUAL(1, subgraphTensors.size());
173     BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
174     CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
175                       "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
176 }
177 
BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsSimpleQuantized,GetSubgraphInputsOutputsFixture)178 BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsSimpleQuantized, GetSubgraphInputsOutputsFixture)
179 {
180     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
181     TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 0);
182     BOOST_CHECK_EQUAL(1, subgraphTensors.size());
183     BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
184     CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
185                       "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
186 }
187 
BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsEmptyMinMax,GetSubgraphInputsOutputsFixture)188 BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsEmptyMinMax, GetSubgraphInputsOutputsFixture)
189 {
190     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
191     TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 1);
192     BOOST_CHECK_EQUAL(1, subgraphTensors.size());
193     BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
194     CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
195                       "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
196 }
197 
BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputs,GetSubgraphInputsOutputsFixture)198 BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputs, GetSubgraphInputsOutputsFixture)
199 {
200     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
201     TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 1);
202     BOOST_CHECK_EQUAL(1, subgraphTensors.size());
203     BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
204     CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
205                       "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
206 }
207 
BOOST_AUTO_TEST_CASE(GetSubgraphInputsNullModel)208 BOOST_AUTO_TEST_CASE(GetSubgraphInputsNullModel)
209 {
210     BOOST_CHECK_THROW(TfLiteParser::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
211 }
212 
BOOST_AUTO_TEST_CASE(GetSubgraphOutputsNullModel)213 BOOST_AUTO_TEST_CASE(GetSubgraphOutputsNullModel)
214 {
215     BOOST_CHECK_THROW(TfLiteParser::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
216 }
217 
BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsInvalidSubgraph,GetSubgraphInputsOutputsFixture)218 BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
219 {
220     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
221     BOOST_CHECK_THROW(TfLiteParser::GetSubgraphInputs(model, 2), armnn::ParseException);
222 }
223 
BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsInvalidSubgraph,GetSubgraphInputsOutputsFixture)224 BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
225 {
226     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
227     BOOST_CHECK_THROW(TfLiteParser::GetSubgraphOutputs(model, 2), armnn::ParseException);
228 }
229 
230 BOOST_AUTO_TEST_SUITE_END()