1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8
9 using armnnTfLiteParser::TfLiteParser;
10 using ModelPtr = TfLiteParser::ModelPtr;
11 using TensorRawPtr = TfLiteParser::TensorRawPtr;
12
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15 struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
16 {
GetSubgraphInputsOutputsMainFixtureGetSubgraphInputsOutputsMainFixture17 explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
18 {
19 m_JsonString = R"(
20 {
21 "version": 3,
22 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
23 "subgraphs": [
24 {
25 "tensors": [
26 {
27 "shape": [ 1, 1, 1, 1 ] ,
28 "type": "UINT8",
29 "buffer": 0,
30 "name": "OutputTensor",
31 "quantization": {
32 "min": [ 0.0 ],
33 "max": [ 255.0 ],
34 "scale": [ 1.0 ],
35 "zero_point": [ 0 ]
36 }
37 },
38 {
39 "shape": [ 1, 2, 2, 1 ] ,
40 "type": "UINT8",
41 "buffer": 1,
42 "name": "InputTensor",
43 "quantization": {
44 "min": [ -1.2 ],
45 "max": [ 25.5 ],
46 "scale": [ 0.25 ],
47 "zero_point": [ 10 ]
48 }
49 }
50 ],
51 "inputs": )"
52 + inputs
53 + R"(,
54 "outputs": )"
55 + outputs
56 + R"(,
57 "operators": [ {
58 "opcode_index": 0,
59 "inputs": [ 1 ],
60 "outputs": [ 0 ],
61 "builtin_options_type": "Pool2DOptions",
62 "builtin_options":
63 {
64 "padding": "VALID",
65 "stride_w": 2,
66 "stride_h": 2,
67 "filter_width": 2,
68 "filter_height": 2,
69 "fused_activation_function": "NONE"
70 },
71 "custom_options_format": "FLEXBUFFERS"
72 } ]
73 },
74 {
75 "tensors": [
76 {
77 "shape": [ 1, 3, 3, 1 ],
78 "type": "UINT8",
79 "buffer": 0,
80 "name": "ConvInputTensor",
81 "quantization": {
82 "scale": [ 1.0 ],
83 "zero_point": [ 0 ],
84 }
85 },
86 {
87 "shape": [ 1, 1, 1, 1 ],
88 "type": "UINT8",
89 "buffer": 1,
90 "name": "ConvOutputTensor",
91 "quantization": {
92 "min": [ 0.0 ],
93 "max": [ 511.0 ],
94 "scale": [ 2.0 ],
95 "zero_point": [ 0 ],
96 }
97 },
98 {
99 "shape": [ 1, 3, 3, 1 ],
100 "type": "UINT8",
101 "buffer": 2,
102 "name": "filterTensor",
103 "quantization": {
104 "min": [ 0.0 ],
105 "max": [ 255.0 ],
106 "scale": [ 1.0 ],
107 "zero_point": [ 0 ],
108 }
109 }
110 ],
111 "inputs": [ 0 ],
112 "outputs": [ 1 ],
113 "operators": [
114 {
115 "opcode_index": 0,
116 "inputs": [ 0, 2 ],
117 "outputs": [ 1 ],
118 "builtin_options_type": "Conv2DOptions",
119 "builtin_options": {
120 "padding": "VALID",
121 "stride_w": 1,
122 "stride_h": 1,
123 "fused_activation_function": "NONE"
124 },
125 "custom_options_format": "FLEXBUFFERS"
126 }
127 ],
128 }
129 ],
130 "description": "Test Subgraph Inputs Outputs",
131 "buffers" : [
132 { },
133 { },
134 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
135 { },
136 ]
137 })";
138
139 ReadStringToBinary();
140 }
141
142 };
143
144 struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
145 {
GetEmptySubgraphInputsOutputsFixtureGetEmptySubgraphInputsOutputsFixture146 GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
147 };
148
149 struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
150 {
GetSubgraphInputsOutputsFixtureGetSubgraphInputsOutputsFixture151 GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
152 };
153
BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphInputs,GetEmptySubgraphInputsOutputsFixture)154 BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphInputs, GetEmptySubgraphInputsOutputsFixture)
155 {
156 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
157 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 0);
158 BOOST_CHECK_EQUAL(0, subgraphTensors.size());
159 }
160
BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphOutputs,GetEmptySubgraphInputsOutputsFixture)161 BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphOutputs, GetEmptySubgraphInputsOutputsFixture)
162 {
163 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
164 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 0);
165 BOOST_CHECK_EQUAL(0, subgraphTensors.size());
166 }
167
BOOST_FIXTURE_TEST_CASE(GetSubgraphInputs,GetSubgraphInputsOutputsFixture)168 BOOST_FIXTURE_TEST_CASE(GetSubgraphInputs, GetSubgraphInputsOutputsFixture)
169 {
170 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
171 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 0);
172 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
173 BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
174 CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
175 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
176 }
177
BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsSimpleQuantized,GetSubgraphInputsOutputsFixture)178 BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsSimpleQuantized, GetSubgraphInputsOutputsFixture)
179 {
180 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
181 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 0);
182 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
183 BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
184 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
185 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
186 }
187
BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsEmptyMinMax,GetSubgraphInputsOutputsFixture)188 BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsEmptyMinMax, GetSubgraphInputsOutputsFixture)
189 {
190 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
191 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 1);
192 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
193 BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
194 CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
195 "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
196 }
197
BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputs,GetSubgraphInputsOutputsFixture)198 BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputs, GetSubgraphInputsOutputsFixture)
199 {
200 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
201 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 1);
202 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
203 BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
204 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
205 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
206 }
207
BOOST_AUTO_TEST_CASE(GetSubgraphInputsNullModel)208 BOOST_AUTO_TEST_CASE(GetSubgraphInputsNullModel)
209 {
210 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
211 }
212
BOOST_AUTO_TEST_CASE(GetSubgraphOutputsNullModel)213 BOOST_AUTO_TEST_CASE(GetSubgraphOutputsNullModel)
214 {
215 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
216 }
217
BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsInvalidSubgraph,GetSubgraphInputsOutputsFixture)218 BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
219 {
220 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
221 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphInputs(model, 2), armnn::ParseException);
222 }
223
BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsInvalidSubgraph,GetSubgraphInputsOutputsFixture)224 BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
225 {
226 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
227 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphOutputs(model, 2), armnn::ParseException);
228 }
229
230 BOOST_AUTO_TEST_SUITE_END()