• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11 
12 struct TransposeConvFixture : public ParserFlatbuffersFixture
13 {
TransposeConvFixtureTransposeConvFixture14     explicit TransposeConvFixture(const std::string& inputShape,
15                                   const std::string& outputShape,
16                                   const std::string& filterShape,
17                                   const std::string& filterData,
18                                   const std::string& strideX,
19                                   const std::string& strideY,
20                                   const std::string& dataType)
21     {
22         m_JsonString = R"(
23             {
24                 "version": 3,
25                 "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ],
26                 "subgraphs": [ {
27                     "tensors": [
28                         {
29                             "shape": [ 4 ],
30                             "type": "UINT8",
31                             "buffer": 0,
32                             "name": "outputShapeTensor",
33                             "quantization": {
34                                 "min": [ 0.0 ],
35                                 "max": [ 255.0 ],
36                                 "scale": [ 1.0 ],
37                                 "zero_point": [ 0 ],
38                             }
39                         },
40                         {
41                             "shape": )" + filterShape + R"(,
42                             "type": ")" + dataType + R"(",
43                             "buffer": 1,
44                             "name": "filterTensor",
45                             "quantization": {
46                                 "min": [ 0.0 ],
47                                 "max": [ 255.0 ],
48                                 "scale": [ 1.0 ],
49                                 "zero_point": [ 0 ],
50                             }
51                         },
52                         {
53                             "shape": )" + inputShape + R"(,
54                             "type": ")" + dataType + R"(",
55                             "buffer": 2,
56                             "name": "inputTensor",
57                             "quantization": {
58                                 "min": [ 0.0 ],
59                                 "max": [ 255.0 ],
60                                 "scale": [ 1.0 ],
61                                 "zero_point": [ 0 ],
62                             }
63                         },
64                         {
65                             "shape": )" + outputShape + R"(,
66                             "type": ")" + dataType + R"(",
67                             "buffer": 3,
68                             "name": "outputTensor",
69                             "quantization": {
70                                 "min": [ 0.0 ],
71                                 "max": [ 255.0 ],
72                                 "scale": [ 1.0 ],
73                                 "zero_point": [ 0 ],
74                             }
75                         }
76                     ],
77                     "inputs": [ 2 ],
78                     "outputs": [ 3 ],
79                     "operators": [
80                         {
81                             "opcode_index": 0,
82                             "inputs": [ 0, 1, 2 ],
83                             "outputs": [ 3 ],
84                             "builtin_options_type": "TransposeConvOptions",
85                             "builtin_options": {
86                                 "padding": "VALID",
87                                 "stride_w": )" + strideX + R"(,
88                                 "stride_h": )" + strideY + R"(
89                             },
90                             "custom_options_format": "FLEXBUFFERS"
91                         }
92                     ],
93                 } ],
94                 "buffers" : [
95                     { "data": )" + outputShape + R"( },
96                     { "data": )" + filterData + R"( },
97                     { },
98                     { }
99                 ]
100             }
101         )";
102         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
103     }
104 };
105 
106 struct SimpleTransposeConvFixture : TransposeConvFixture
107 {
SimpleTransposeConvFixtureSimpleTransposeConvFixture108     SimpleTransposeConvFixture()
109     : TransposeConvFixture("[ 1, 2, 2, 1 ]",  // inputShape
110                            "[ 1, 3, 3, 1 ]",  // outputShape
111                            "[ 1, 2, 2, 1 ]",  // filterShape
112                            "[ 0, 1, 2, 4 ]",  // filterData
113                            "1",               // strideX
114                            "1",               // strideY
115                            "UINT8")           // dataType
116     {}
117 };
118 
BOOST_FIXTURE_TEST_CASE(ParseSimpleTransposeConv,SimpleTransposeConvFixture)119 BOOST_FIXTURE_TEST_CASE( ParseSimpleTransposeConv, SimpleTransposeConvFixture )
120 {
121     RunTest<4, armnn::DataType::QAsymmU8>(
122         0,
123         {
124             1, 2,
125             3, 4
126         },
127         {
128             0, 1,  2,
129             2, 11, 12,
130             6, 20, 16
131         });
132 }
133 
134 BOOST_AUTO_TEST_SUITE_END()
135