1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12 struct TransposeConvFixture : public ParserFlatbuffersFixture
13 {
TransposeConvFixtureTransposeConvFixture14 explicit TransposeConvFixture(const std::string& inputShape,
15 const std::string& outputShape,
16 const std::string& filterShape,
17 const std::string& filterData,
18 const std::string& strideX,
19 const std::string& strideY,
20 const std::string& dataType)
21 {
22 m_JsonString = R"(
23 {
24 "version": 3,
25 "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ],
26 "subgraphs": [ {
27 "tensors": [
28 {
29 "shape": [ 4 ],
30 "type": "UINT8",
31 "buffer": 0,
32 "name": "outputShapeTensor",
33 "quantization": {
34 "min": [ 0.0 ],
35 "max": [ 255.0 ],
36 "scale": [ 1.0 ],
37 "zero_point": [ 0 ],
38 }
39 },
40 {
41 "shape": )" + filterShape + R"(,
42 "type": ")" + dataType + R"(",
43 "buffer": 1,
44 "name": "filterTensor",
45 "quantization": {
46 "min": [ 0.0 ],
47 "max": [ 255.0 ],
48 "scale": [ 1.0 ],
49 "zero_point": [ 0 ],
50 }
51 },
52 {
53 "shape": )" + inputShape + R"(,
54 "type": ")" + dataType + R"(",
55 "buffer": 2,
56 "name": "inputTensor",
57 "quantization": {
58 "min": [ 0.0 ],
59 "max": [ 255.0 ],
60 "scale": [ 1.0 ],
61 "zero_point": [ 0 ],
62 }
63 },
64 {
65 "shape": )" + outputShape + R"(,
66 "type": ")" + dataType + R"(",
67 "buffer": 3,
68 "name": "outputTensor",
69 "quantization": {
70 "min": [ 0.0 ],
71 "max": [ 255.0 ],
72 "scale": [ 1.0 ],
73 "zero_point": [ 0 ],
74 }
75 }
76 ],
77 "inputs": [ 2 ],
78 "outputs": [ 3 ],
79 "operators": [
80 {
81 "opcode_index": 0,
82 "inputs": [ 0, 1, 2 ],
83 "outputs": [ 3 ],
84 "builtin_options_type": "TransposeConvOptions",
85 "builtin_options": {
86 "padding": "VALID",
87 "stride_w": )" + strideX + R"(,
88 "stride_h": )" + strideY + R"(
89 },
90 "custom_options_format": "FLEXBUFFERS"
91 }
92 ],
93 } ],
94 "buffers" : [
95 { "data": )" + outputShape + R"( },
96 { "data": )" + filterData + R"( },
97 { },
98 { }
99 ]
100 }
101 )";
102 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
103 }
104 };
105
106 struct SimpleTransposeConvFixture : TransposeConvFixture
107 {
SimpleTransposeConvFixtureSimpleTransposeConvFixture108 SimpleTransposeConvFixture()
109 : TransposeConvFixture("[ 1, 2, 2, 1 ]", // inputShape
110 "[ 1, 3, 3, 1 ]", // outputShape
111 "[ 1, 2, 2, 1 ]", // filterShape
112 "[ 0, 1, 2, 4 ]", // filterData
113 "1", // strideX
114 "1", // strideY
115 "UINT8") // dataType
116 {}
117 };
118
BOOST_FIXTURE_TEST_CASE(ParseSimpleTransposeConv,SimpleTransposeConvFixture)119 BOOST_FIXTURE_TEST_CASE( ParseSimpleTransposeConv, SimpleTransposeConvFixture )
120 {
121 RunTest<4, armnn::DataType::QAsymmU8>(
122 0,
123 {
124 1, 2,
125 3, 4
126 },
127 {
128 0, 1, 2,
129 2, 11, 12,
130 6, 20, 16
131 });
132 }
133
134 BOOST_AUTO_TEST_SUITE_END()
135