• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8 
9 #include <Filesystem.hpp>
10 
11 using armnnTfLiteParser::TfLiteParser;
12 using ModelPtr = TfLiteParser::ModelPtr;
13 using SubgraphPtr = TfLiteParser::SubgraphPtr;
14 using OperatorPtr = TfLiteParser::OperatorPtr;
15 
16 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
17 
18 struct LoadModelFixture : public ParserFlatbuffersFixture
19 {
LoadModelFixtureLoadModelFixture20     explicit LoadModelFixture()
21     {
22         m_JsonString = R"(
23         {
24             "version": 3,
25             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
26             "subgraphs": [
27             {
28                 "tensors": [
29                 {
30                     "shape": [ 1, 1, 1, 1 ] ,
31                     "type": "UINT8",
32                             "buffer": 0,
33                             "name": "OutputTensor",
34                             "quantization": {
35                                 "min": [ 0.0 ],
36                                 "max": [ 255.0 ],
37                                 "scale": [ 1.0 ],
38                                 "zero_point": [ 0 ]
39                             }
40                 },
41                 {
42                     "shape": [ 1, 2, 2, 1 ] ,
43                     "type": "UINT8",
44                             "buffer": 1,
45                             "name": "InputTensor",
46                             "quantization": {
47                                 "min": [ 0.0 ],
48                                 "max": [ 255.0 ],
49                                 "scale": [ 1.0 ],
50                                 "zero_point": [ 0 ]
51                             }
52                 }
53                 ],
54                 "inputs": [ 1 ],
55                 "outputs": [ 0 ],
56                 "operators": [ {
57                         "opcode_index": 0,
58                         "inputs": [ 1 ],
59                         "outputs": [ 0 ],
60                         "builtin_options_type": "Pool2DOptions",
61                         "builtin_options":
62                         {
63                             "padding": "VALID",
64                             "stride_w": 2,
65                             "stride_h": 2,
66                             "filter_width": 2,
67                             "filter_height": 2,
68                             "fused_activation_function": "NONE"
69                         },
70                         "custom_options_format": "FLEXBUFFERS"
71                     } ]
72                 },
73                 {
74                     "tensors": [
75                         {
76                             "shape": [ 1, 3, 3, 1 ],
77                             "type": "UINT8",
78                             "buffer": 0,
79                             "name": "ConvInputTensor",
80                             "quantization": {
81                                 "scale": [ 1.0 ],
82                                 "zero_point": [ 0 ],
83                             }
84                         },
85                         {
86                             "shape": [ 1, 1, 1, 1 ],
87                             "type": "UINT8",
88                             "buffer": 1,
89                             "name": "ConvOutputTensor",
90                             "quantization": {
91                                 "min": [ 0.0 ],
92                                 "max": [ 511.0 ],
93                                 "scale": [ 2.0 ],
94                                 "zero_point": [ 0 ],
95                             }
96                         },
97                         {
98                             "shape": [ 1, 3, 3, 1 ],
99                             "type": "UINT8",
100                             "buffer": 2,
101                             "name": "filterTensor",
102                             "quantization": {
103                                 "min": [ 0.0 ],
104                                 "max": [ 255.0 ],
105                                 "scale": [ 1.0 ],
106                                 "zero_point": [ 0 ],
107                             }
108                         }
109                     ],
110                     "inputs": [ 0 ],
111                     "outputs": [ 1 ],
112                     "operators": [
113                         {
114                             "opcode_index": 1,
115                             "inputs": [ 0, 2 ],
116                             "outputs": [ 1 ],
117                             "builtin_options_type": "Conv2DOptions",
118                             "builtin_options": {
119                                 "padding": "VALID",
120                                 "stride_w": 1,
121                                 "stride_h": 1,
122                                 "fused_activation_function": "NONE"
123                             },
124                             "custom_options_format": "FLEXBUFFERS"
125                         }
126                     ],
127                 }
128             ],
129             "description": "Test loading a model",
130             "buffers" : [ {}, {} ]
131         })";
132 
133         ReadStringToBinary();
134     }
135 
CheckModelLoadModelFixture136     void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
137                     const std::vector<tflite::BuiltinOperator>& opcodes,
138                     size_t subgraphs, const std::string desc, size_t buffers)
139     {
140         BOOST_CHECK(model);
141         BOOST_CHECK_EQUAL(version, model->version);
142         BOOST_CHECK_EQUAL(opcodeSize, model->operator_codes.size());
143         CheckBuiltinOperators(opcodes, model->operator_codes);
144         BOOST_CHECK_EQUAL(subgraphs, model->subgraphs.size());
145         BOOST_CHECK_EQUAL(desc, model->description);
146         BOOST_CHECK_EQUAL(buffers, model->buffers.size());
147     }
148 
CheckBuiltinOperatorsLoadModelFixture149     void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
150                                const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
151     {
152         BOOST_CHECK_EQUAL(expectedOperators.size(), result.size());
153         for (size_t i = 0; i < expectedOperators.size(); i++)
154         {
155             BOOST_CHECK_EQUAL(expectedOperators[i], result[i]->builtin_code);
156         }
157     }
158 
CheckSubgraphLoadModelFixture159     void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
160                        const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
161     {
162         BOOST_CHECK(subgraph);
163         BOOST_CHECK_EQUAL(tensors, subgraph->tensors.size());
164         BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end());
165         BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
166                                       subgraph->outputs.begin(), subgraph->outputs.end());
167         BOOST_CHECK_EQUAL(operators, subgraph->operators.size());
168         BOOST_CHECK_EQUAL(name, subgraph->name);
169     }
170 
CheckOperatorLoadModelFixture171     void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode,  const std::vector<int32_t>& inputs,
172                        const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
173                        tflite::CustomOptionsFormat custom_options_format)
174     {
175         BOOST_CHECK(operatorPtr);
176         BOOST_CHECK_EQUAL(opcode, operatorPtr->opcode_index);
177         BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(),
178                                       operatorPtr->inputs.begin(), operatorPtr->inputs.end());
179         BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
180                                       operatorPtr->outputs.begin(), operatorPtr->outputs.end());
181         BOOST_CHECK_EQUAL(optionType, operatorPtr->builtin_options.type);
182         BOOST_CHECK_EQUAL(custom_options_format, operatorPtr->custom_options_format);
183     }
184 };
185 
BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary,LoadModelFixture)186 BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
187 {
188     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
189     CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
190                2, "Test loading a model", 2);
191     CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
192     CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
193     CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
194                   tflite::CustomOptionsFormat_FLEXBUFFERS);
195     CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
196                   tflite::CustomOptionsFormat_FLEXBUFFERS);
197 }
198 
BOOST_FIXTURE_TEST_CASE(LoadModelFromFile,LoadModelFixture)199 BOOST_FIXTURE_TEST_CASE(LoadModelFromFile, LoadModelFixture)
200 {
201     using namespace fs;
202     fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv");
203     bool saved = flatbuffers::SaveFile(fname.c_str(),
204                                        reinterpret_cast<char *>(m_GraphBinary.data()),
205                                        m_GraphBinary.size(), true);
206     BOOST_CHECK_MESSAGE(saved, "Cannot save test file");
207 
208     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromFile(fname.c_str());
209     CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
210                2, "Test loading a model", 2);
211     CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
212     CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
213     CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
214                   tflite::CustomOptionsFormat_FLEXBUFFERS);
215     CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
216                   tflite::CustomOptionsFormat_FLEXBUFFERS);
217     remove(fname);
218 }
219 
BOOST_AUTO_TEST_CASE(LoadNullBinary)220 BOOST_AUTO_TEST_CASE(LoadNullBinary)
221 {
222     BOOST_CHECK_THROW(TfLiteParser::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
223 }
224 
BOOST_AUTO_TEST_CASE(LoadInvalidBinary)225 BOOST_AUTO_TEST_CASE(LoadInvalidBinary)
226 {
227     std::string testData = "invalid data";
228     BOOST_CHECK_THROW(TfLiteParser::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
229                                                         testData.length()), armnn::ParseException);
230 }
231 
BOOST_AUTO_TEST_CASE(LoadFileNotFound)232 BOOST_AUTO_TEST_CASE(LoadFileNotFound)
233 {
234     BOOST_CHECK_THROW(TfLiteParser::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
235 }
236 
BOOST_AUTO_TEST_CASE(LoadNullPtrFile)237 BOOST_AUTO_TEST_CASE(LoadNullPtrFile)
238 {
239     BOOST_CHECK_THROW(TfLiteParser::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
240 }
241 
242 BOOST_AUTO_TEST_SUITE_END()
243