• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8 
9 #include <armnn/LayerVisitorBase.hpp>
10 #include <armnn/utility/Assert.hpp>
11 #include <armnn/utility/NumericCast.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 #include <layers/StandInLayer.hpp>
15 
16 #include <boost/test/unit_test.hpp>
17 
18 #include <sstream>
19 #include <string>
20 #include <vector>
21 
22 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
23 
24 using namespace armnn;
25 
26 class StandInLayerVerifier : public LayerVisitorBase<VisitorThrowingPolicy>
27 {
28 public:
StandInLayerVerifier(const std::vector<TensorInfo> & inputInfos,const std::vector<TensorInfo> & outputInfos)29     StandInLayerVerifier(const std::vector<TensorInfo>& inputInfos,
30                          const std::vector<TensorInfo>& outputInfos)
31         : LayerVisitorBase<VisitorThrowingPolicy>()
32         , m_InputInfos(inputInfos)
33         , m_OutputInfos(outputInfos) {}
34 
VisitInputLayer(const IConnectableLayer *,LayerBindingId,const char *)35     void VisitInputLayer(const IConnectableLayer*, LayerBindingId, const char*) override {}
36 
VisitOutputLayer(const IConnectableLayer *,LayerBindingId,const char *)37     void VisitOutputLayer(const IConnectableLayer*, LayerBindingId, const char*) override {}
38 
VisitStandInLayer(const IConnectableLayer * layer,const StandInDescriptor & descriptor,const char *)39     void VisitStandInLayer(const IConnectableLayer* layer,
40                            const StandInDescriptor& descriptor,
41                            const char*) override
42     {
43         unsigned int numInputs = armnn::numeric_cast<unsigned int>(m_InputInfos.size());
44         BOOST_CHECK(descriptor.m_NumInputs    == numInputs);
45         BOOST_CHECK(layer->GetNumInputSlots() == numInputs);
46 
47         unsigned int numOutputs = armnn::numeric_cast<unsigned int>(m_OutputInfos.size());
48         BOOST_CHECK(descriptor.m_NumOutputs    == numOutputs);
49         BOOST_CHECK(layer->GetNumOutputSlots() == numOutputs);
50 
51         const StandInLayer* standInLayer = PolymorphicDowncast<const StandInLayer*>(layer);
52         for (unsigned int i = 0u; i < numInputs; ++i)
53         {
54             const OutputSlot* connectedSlot = standInLayer->GetInputSlot(i).GetConnectedOutputSlot();
55             BOOST_CHECK(connectedSlot != nullptr);
56 
57             const TensorInfo& inputInfo = connectedSlot->GetTensorInfo();
58             BOOST_CHECK(inputInfo == m_InputInfos[i]);
59         }
60 
61         for (unsigned int i = 0u; i < numOutputs; ++i)
62         {
63             const TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
64             BOOST_CHECK(outputInfo == m_OutputInfos[i]);
65         }
66     }
67 
68 private:
69     std::vector<TensorInfo> m_InputInfos;
70     std::vector<TensorInfo> m_OutputInfos;
71 };
72 
73 class DummyCustomFixture : public ParserFlatbuffersFixture
74 {
75 public:
DummyCustomFixture(const std::vector<TensorInfo> & inputInfos,const std::vector<TensorInfo> & outputInfos)76     explicit DummyCustomFixture(const std::vector<TensorInfo>& inputInfos,
77                                 const std::vector<TensorInfo>& outputInfos)
78         : ParserFlatbuffersFixture()
79         , m_StandInLayerVerifier(inputInfos, outputInfos)
80     {
81         const unsigned int numInputs = armnn::numeric_cast<unsigned int>(inputInfos.size());
82         ARMNN_ASSERT(numInputs > 0);
83 
84         const unsigned int numOutputs = armnn::numeric_cast<unsigned int>(outputInfos.size());
85         ARMNN_ASSERT(numOutputs > 0);
86 
87         m_JsonString = R"(
88             {
89                 "version": 3,
90                 "operator_codes": [{
91                     "builtin_code": "CUSTOM",
92                     "custom_code": "DummyCustomOperator"
93                 }],
94                 "subgraphs": [ {
95                     "tensors": [)";
96 
97         // Add input tensors
98         for (unsigned int i = 0u; i < numInputs; ++i)
99         {
100             const TensorInfo& inputInfo = inputInfos[i];
101             m_JsonString += R"(
102                     {
103                         "shape": )" + GetTensorShapeAsString(inputInfo.GetShape()) + R"(,
104                         "type": )" + GetDataTypeAsString(inputInfo.GetDataType()) + R"(,
105                         "buffer": 0,
106                         "name": "inputTensor)" + std::to_string(i) + R"(",
107                         "quantization": {
108                             "min": [ 0.0 ],
109                             "max": [ 255.0 ],
110                             "scale": [ )" + std::to_string(inputInfo.GetQuantizationScale()) + R"( ],
111                             "zero_point": [ )" + std::to_string(inputInfo.GetQuantizationOffset()) + R"( ],
112                         }
113                     },)";
114         }
115 
116         // Add output tensors
117         for (unsigned int i = 0u; i < numOutputs; ++i)
118         {
119             const TensorInfo& outputInfo = outputInfos[i];
120             m_JsonString += R"(
121                     {
122                         "shape": )" + GetTensorShapeAsString(outputInfo.GetShape()) + R"(,
123                         "type": )" + GetDataTypeAsString(outputInfo.GetDataType()) + R"(,
124                         "buffer": 0,
125                         "name": "outputTensor)" + std::to_string(i) + R"(",
126                         "quantization": {
127                             "min": [ 0.0 ],
128                             "max": [ 255.0 ],
129                             "scale": [ )" + std::to_string(outputInfo.GetQuantizationScale()) + R"( ],
130                             "zero_point": [ )" + std::to_string(outputInfo.GetQuantizationOffset()) + R"( ],
131                         }
132                     })";
133 
134             if (i + 1 < numOutputs)
135             {
136                 m_JsonString += ",";
137             }
138         }
139 
140         const std::string inputIndices  = GetIndicesAsString(0u, numInputs - 1u);
141         const std::string outputIndices = GetIndicesAsString(numInputs, numInputs + numOutputs - 1u);
142 
143         // Add dummy custom operator
144         m_JsonString +=  R"(],
145                     "inputs": )" + inputIndices + R"(,
146                     "outputs": )" + outputIndices + R"(,
147                     "operators": [
148                         {
149                             "opcode_index": 0,
150                             "inputs": )" + inputIndices + R"(,
151                             "outputs": )" + outputIndices + R"(,
152                             "builtin_options_type": 0,
153                             "custom_options": [ ],
154                             "custom_options_format": "FLEXBUFFERS"
155                         }
156                     ],
157                 } ],
158                 "buffers" : [
159                     { },
160                     { }
161                 ]
162             }
163         )";
164 
165         ReadStringToBinary();
166     }
167 
RunTest()168     void RunTest()
169     {
170         INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
171         network->Accept(m_StandInLayerVerifier);
172     }
173 
174 private:
GetTensorShapeAsString(const TensorShape & tensorShape)175     static std::string GetTensorShapeAsString(const TensorShape& tensorShape)
176     {
177         std::stringstream stream;
178         stream << "[ ";
179         for (unsigned int i = 0u; i < tensorShape.GetNumDimensions(); ++i)
180         {
181             stream << tensorShape[i];
182             if (i + 1 < tensorShape.GetNumDimensions())
183             {
184                 stream << ",";
185             }
186             stream << " ";
187         }
188         stream << "]";
189 
190         return stream.str();
191     }
192 
GetDataTypeAsString(DataType dataType)193     static std::string GetDataTypeAsString(DataType dataType)
194     {
195         switch (dataType)
196         {
197             case DataType::Float32:         return "FLOAT32";
198             case DataType::QAsymmU8: return "UINT8";
199             default:                        return "UNKNOWN";
200         }
201     }
202 
GetIndicesAsString(unsigned int first,unsigned int last)203     static std::string GetIndicesAsString(unsigned int first, unsigned int last)
204     {
205         std::stringstream stream;
206         stream << "[ ";
207         for (unsigned int i = first; i <= last ; ++i)
208         {
209             stream << i;
210             if (i + 1 <= last)
211             {
212                 stream << ",";
213             }
214             stream << " ";
215         }
216         stream << "]";
217 
218         return stream.str();
219     }
220 
221     StandInLayerVerifier m_StandInLayerVerifier;
222 };
223 
224 class DummyCustom1Input1OutputFixture : public DummyCustomFixture
225 {
226 public:
DummyCustom1Input1OutputFixture()227     DummyCustom1Input1OutputFixture()
228         : DummyCustomFixture({ TensorInfo({ 1, 1 }, DataType::Float32) },
229                              { TensorInfo({ 2, 2 }, DataType::Float32) }) {}
230 };
231 
232 class DummyCustom2Inputs1OutputFixture : public DummyCustomFixture
233 {
234 public:
DummyCustom2Inputs1OutputFixture()235     DummyCustom2Inputs1OutputFixture()
236         : DummyCustomFixture({ TensorInfo({ 1, 1 }, DataType::Float32), TensorInfo({ 2, 2 }, DataType::Float32) },
237                              { TensorInfo({ 3, 3 }, DataType::Float32) }) {}
238 };
239 
BOOST_FIXTURE_TEST_CASE(UnsupportedCustomOperator1Input1Output,DummyCustom1Input1OutputFixture)240 BOOST_FIXTURE_TEST_CASE(UnsupportedCustomOperator1Input1Output, DummyCustom1Input1OutputFixture)
241 {
242     RunTest();
243 }
244 
BOOST_FIXTURE_TEST_CASE(UnsupportedCustomOperator2Inputs1Output,DummyCustom2Inputs1OutputFixture)245 BOOST_FIXTURE_TEST_CASE(UnsupportedCustomOperator2Inputs1Output, DummyCustom2Inputs1OutputFixture)
246 {
247     RunTest();
248 }
249 
250 BOOST_AUTO_TEST_SUITE_END()
251