• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../TfLiteParser.hpp"
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 #include "ParserHelper.hpp"
10 #include "test/GraphUtils.hpp"
11 
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 #include <QuantizeHelper.hpp>
14 
15 #include <boost/test/unit_test.hpp>
16 
17 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
18 
19 struct DetectionPostProcessFixture : ParserFlatbuffersFixture
20 {
DetectionPostProcessFixtureDetectionPostProcessFixture21     explicit DetectionPostProcessFixture(const std::string& custom_options)
22     {
23         /*
24             The following values were used for the custom_options:
25             use_regular_nms = true
26             max_classes_per_detection = 1
27             detections_per_class = 1
28             nms_score_threshold = 0.0
29             nms_iou_threshold = 0.5
30             max_detections = 3
31             max_detections = 3
32             num_classes = 2
33             h_scale = 5
34             w_scale = 5
35             x_scale = 10
36             y_scale = 10
37         */
38         m_JsonString = R"(
39             {
40                 "version": 3,
41                 "operator_codes": [{
42                     "builtin_code": "CUSTOM",
43                     "custom_code": "TFLite_Detection_PostProcess"
44                 }],
45                 "subgraphs": [{
46                     "tensors": [{
47                             "shape": [1, 6, 4],
48                             "type": "UINT8",
49                             "buffer": 0,
50                             "name": "box_encodings",
51                             "quantization": {
52                                 "min": [0.0],
53                                 "max": [255.0],
54                                 "scale": [1.0],
55                                 "zero_point": [ 1 ]
56                             }
57                         },
58                         {
59                             "shape": [1, 6, 3],
60                             "type": "UINT8",
61                             "buffer": 1,
62                             "name": "scores",
63                             "quantization": {
64                                 "min": [0.0],
65                                 "max": [255.0],
66                                 "scale": [0.01],
67                                 "zero_point": [0]
68                             }
69                         },
70                         {
71                             "shape": [6, 4],
72                             "type": "UINT8",
73                             "buffer": 2,
74                             "name": "anchors",
75                             "quantization": {
76                                 "min": [0.0],
77                                 "max": [255.0],
78                                 "scale": [0.5],
79                                 "zero_point": [0]
80                             }
81                         },
82                         {
83                             "type": "FLOAT32",
84                             "buffer": 3,
85                             "name": "detection_boxes",
86                             "quantization": {}
87                         },
88                         {
89                             "type": "FLOAT32",
90                             "buffer": 4,
91                             "name": "detection_classes",
92                             "quantization": {}
93                         },
94                         {
95                             "type": "FLOAT32",
96                             "buffer": 5,
97                             "name": "detection_scores",
98                             "quantization": {}
99                         },
100                         {
101                             "type": "FLOAT32",
102                             "buffer": 6,
103                             "name": "num_detections",
104                             "quantization": {}
105                         }
106                     ],
107                     "inputs": [0, 1, 2],
108                     "outputs": [3, 4, 5, 6],
109                     "operators": [{
110                         "opcode_index": 0,
111                         "inputs": [0, 1, 2],
112                         "outputs": [3, 4, 5, 6],
113                         "builtin_options_type": 0,
114                         "custom_options": [)" + custom_options + R"(],
115                         "custom_options_format": "FLEXBUFFERS"
116                     }]
117                 }],
118                 "buffers": [{},
119                     {},
120                     { "data": [ 1, 1,   2, 2,
121                                 1, 1,   2, 2,
122                                 1, 1,   2, 2,
123                                 1, 21,  2, 2,
124                                 1, 21,  2, 2,
125                                 1, 201, 2, 2]},
126                     {},
127                     {},
128                     {},
129                     {},
130                 ]
131             }
132         )";
133     }
134 };
135 
136 struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture
137 {
138 private:
GenerateDescriptorParseDetectionPostProcessCustomOptions139     static armnn::DetectionPostProcessDescriptor GenerateDescriptor()
140     {
141         static armnn::DetectionPostProcessDescriptor descriptor;
142         descriptor.m_UseRegularNms          = true;
143         descriptor.m_MaxDetections          = 3u;
144         descriptor.m_MaxClassesPerDetection = 1u;
145         descriptor.m_DetectionsPerClass     = 1u;
146         descriptor.m_NumClasses             = 2u;
147         descriptor.m_NmsScoreThreshold      = 0.0f;
148         descriptor.m_NmsIouThreshold        = 0.5f;
149         descriptor.m_ScaleH                 = 5.0f;
150         descriptor.m_ScaleW                 = 5.0f;
151         descriptor.m_ScaleX                 = 10.0f;
152         descriptor.m_ScaleY                 = 10.0f;
153 
154         return descriptor;
155     }
156 
157 public:
ParseDetectionPostProcessCustomOptionsParseDetectionPostProcessCustomOptions158     ParseDetectionPostProcessCustomOptions()
159         : DetectionPostProcessFixture(
160             GenerateDetectionPostProcessJsonString(GenerateDescriptor()))
161     {}
162 };
163 
BOOST_FIXTURE_TEST_CASE(ParseDetectionPostProcess,ParseDetectionPostProcessCustomOptions)164 BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, ParseDetectionPostProcessCustomOptions )
165 {
166     Setup();
167 
168     // Inputs
169     using UnquantizedContainer = std::vector<float>;
170     UnquantizedContainer boxEncodings =
171     {
172         0.0f,  0.0f, 0.0f, 0.0f,
173         0.0f,  1.0f, 0.0f, 0.0f,
174         0.0f, -1.0f, 0.0f, 0.0f,
175         0.0f,  0.0f, 0.0f, 0.0f,
176         0.0f,  1.0f, 0.0f, 0.0f,
177         0.0f,  0.0f, 0.0f, 0.0f
178     };
179 
180     UnquantizedContainer scores =
181     {
182         0.0f, 0.9f,  0.8f,
183         0.0f, 0.75f, 0.72f,
184         0.0f, 0.6f,  0.5f,
185         0.0f, 0.93f, 0.95f,
186         0.0f, 0.5f,  0.4f,
187         0.0f, 0.3f,  0.2f
188     };
189 
190     // Outputs
191     UnquantizedContainer detectionBoxes =
192     {
193         0.0f, 10.0f, 1.0f, 11.0f,
194         0.0f, 10.0f, 1.0f, 11.0f,
195         0.0f, 0.0f,  0.0f, 0.0f
196     };
197 
198     UnquantizedContainer detectionClasses = { 1.0f,  0.0f,  0.0f };
199     UnquantizedContainer detectionScores  = { 0.95f, 0.93f, 0.0f };
200 
201     UnquantizedContainer numDetections    = { 2.0f };
202 
203     // Quantize inputs and outputs
204     using QuantizedContainer = std::vector<uint8_t>;
205 
206     QuantizedContainer quantBoxEncodings = armnnUtils::QuantizedVector<uint8_t>(boxEncodings, 1.00f, 1);
207     QuantizedContainer quantScores       = armnnUtils::QuantizedVector<uint8_t>(scores,       0.01f, 0);
208 
209     std::map<std::string, QuantizedContainer> input =
210     {
211         { "box_encodings", quantBoxEncodings },
212         { "scores", quantScores }
213     };
214 
215     std::map<std::string, UnquantizedContainer> output =
216     {
217         { "detection_boxes", detectionBoxes},
218         { "detection_classes", detectionClasses},
219         { "detection_scores", detectionScores},
220         { "num_detections", numDetections}
221     };
222 
223     RunTest<armnn::DataType::QAsymmU8, armnn::DataType::Float32>(0, input, output);
224 }
225 
BOOST_FIXTURE_TEST_CASE(DetectionPostProcessGraphStructureTest,ParseDetectionPostProcessCustomOptions)226 BOOST_FIXTURE_TEST_CASE(DetectionPostProcessGraphStructureTest, ParseDetectionPostProcessCustomOptions)
227 {
228     /*
229        Inputs:            box_encodings  scores
230                                \          /
231                             DetectionPostProcess
232                           /        /     \       \
233                          /        /       \       \
234        Outputs:     detection detection detection num_detections
235                     boxes     classes   scores
236     */
237 
238     ReadStringToBinary();
239 
240     armnn::INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
241 
242     auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
243 
244     auto optimizedNetwork = armnn::PolymorphicDowncast<armnn::OptimizedNetwork*>(optimized.get());
245     auto graph = optimizedNetwork->GetGraph();
246 
247     // Check the number of layers in the graph
248     BOOST_TEST((graph.GetNumInputs() == 2));
249     BOOST_TEST((graph.GetNumOutputs() == 4));
250     BOOST_TEST((graph.GetNumLayers() == 7));
251 
252     // Input layers
253     armnn::Layer* boxEncodingLayer = GetFirstLayerWithName(graph, "box_encodings");
254     BOOST_TEST((boxEncodingLayer->GetType() == armnn::LayerType::Input));
255     BOOST_TEST(CheckNumberOfInputSlot(boxEncodingLayer, 0));
256     BOOST_TEST(CheckNumberOfOutputSlot(boxEncodingLayer, 1));
257 
258     armnn::Layer* scoresLayer = GetFirstLayerWithName(graph, "scores");
259     BOOST_TEST((scoresLayer->GetType() == armnn::LayerType::Input));
260     BOOST_TEST(CheckNumberOfInputSlot(scoresLayer, 0));
261     BOOST_TEST(CheckNumberOfOutputSlot(scoresLayer, 1));
262 
263     // DetectionPostProcess layer
264     armnn::Layer* detectionPostProcessLayer = GetFirstLayerWithName(graph, "DetectionPostProcess:0:0");
265     BOOST_TEST((detectionPostProcessLayer->GetType() == armnn::LayerType::DetectionPostProcess));
266     BOOST_TEST(CheckNumberOfInputSlot(detectionPostProcessLayer, 2));
267     BOOST_TEST(CheckNumberOfOutputSlot(detectionPostProcessLayer, 4));
268 
269     // Output layers
270     armnn::Layer* detectionBoxesLayer = GetFirstLayerWithName(graph, "detection_boxes");
271     BOOST_TEST((detectionBoxesLayer->GetType() == armnn::LayerType::Output));
272     BOOST_TEST(CheckNumberOfInputSlot(detectionBoxesLayer, 1));
273     BOOST_TEST(CheckNumberOfOutputSlot(detectionBoxesLayer, 0));
274 
275     armnn::Layer* detectionClassesLayer = GetFirstLayerWithName(graph, "detection_classes");
276     BOOST_TEST((detectionClassesLayer->GetType() == armnn::LayerType::Output));
277     BOOST_TEST(CheckNumberOfInputSlot(detectionClassesLayer, 1));
278     BOOST_TEST(CheckNumberOfOutputSlot(detectionClassesLayer, 0));
279 
280     armnn::Layer* detectionScoresLayer = GetFirstLayerWithName(graph, "detection_scores");
281     BOOST_TEST((detectionScoresLayer->GetType() == armnn::LayerType::Output));
282     BOOST_TEST(CheckNumberOfInputSlot(detectionScoresLayer, 1));
283     BOOST_TEST(CheckNumberOfOutputSlot(detectionScoresLayer, 0));
284 
285     armnn::Layer* numDetectionsLayer = GetFirstLayerWithName(graph, "num_detections");
286     BOOST_TEST((numDetectionsLayer->GetType() == armnn::LayerType::Output));
287     BOOST_TEST(CheckNumberOfInputSlot(numDetectionsLayer, 1));
288     BOOST_TEST(CheckNumberOfOutputSlot(numDetectionsLayer, 0));
289 
290     // Check the connections
291     armnn::TensorInfo boxEncodingTensor(armnn::TensorShape({ 1, 6, 4 }), armnn::DataType::QAsymmU8, 1, 1);
292     armnn::TensorInfo scoresTensor(armnn::TensorShape({ 1, 6, 3 }), armnn::DataType::QAsymmU8,
293                                                       0.00999999978f, 0);
294 
295     armnn::TensorInfo detectionBoxesTensor(armnn::TensorShape({ 1, 3, 4 }), armnn::DataType::Float32, 0, 0);
296     armnn::TensorInfo detectionClassesTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32, 0, 0);
297     armnn::TensorInfo detectionScoresTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32, 0, 0);
298     armnn::TensorInfo numDetectionsTensor(armnn::TensorShape({ 1} ), armnn::DataType::Float32, 0, 0);
299 
300     BOOST_TEST(IsConnected(boxEncodingLayer, detectionPostProcessLayer, 0, 0, boxEncodingTensor));
301     BOOST_TEST(IsConnected(scoresLayer, detectionPostProcessLayer, 0, 1, scoresTensor));
302     BOOST_TEST(IsConnected(detectionPostProcessLayer, detectionBoxesLayer, 0, 0, detectionBoxesTensor));
303     BOOST_TEST(IsConnected(detectionPostProcessLayer, detectionClassesLayer, 1, 0, detectionClassesTensor));
304     BOOST_TEST(IsConnected(detectionPostProcessLayer, detectionScoresLayer, 2, 0, detectionScoresTensor));
305     BOOST_TEST(IsConnected(detectionPostProcessLayer, numDetectionsLayer, 3, 0, numDetectionsTensor));
306 }
307 
308 BOOST_AUTO_TEST_SUITE_END()
309