1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 #include <string>
11
12 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13
14 struct StridedSliceFixture : public ParserFlatbuffersFixture
15 {
StridedSliceFixtureStridedSliceFixture16 explicit StridedSliceFixture(const std::string & inputShape,
17 const std::string & outputShape,
18 const std::string & beginData,
19 const std::string & endData,
20 const std::string & stridesData,
21 int beginMask = 0,
22 int endMask = 0)
23 {
24 m_JsonString = R"(
25 {
26 "version": 3,
27 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
28 "subgraphs": [ {
29 "tensors": [
30 {
31 "shape": )" + inputShape + R"(,
32 "type": "FLOAT32",
33 "buffer": 0,
34 "name": "inputTensor",
35 "quantization": {
36 "min": [ 0.0 ],
37 "max": [ 255.0 ],
38 "scale": [ 1.0 ],
39 "zero_point": [ 0 ],
40 }
41 },
42 {
43 "shape": [ 4 ],
44 "type": "INT32",
45 "buffer": 1,
46 "name": "beginTensor",
47 "quantization": {
48 }
49 },
50 {
51 "shape": [ 4 ],
52 "type": "INT32",
53 "buffer": 2,
54 "name": "endTensor",
55 "quantization": {
56 }
57 },
58 {
59 "shape": [ 4 ],
60 "type": "INT32",
61 "buffer": 3,
62 "name": "stridesTensor",
63 "quantization": {
64 }
65 },
66 {
67 "shape": )" + outputShape + R"( ,
68 "type": "FLOAT32",
69 "buffer": 4,
70 "name": "outputTensor",
71 "quantization": {
72 "min": [ 0.0 ],
73 "max": [ 255.0 ],
74 "scale": [ 1.0 ],
75 "zero_point": [ 0 ],
76 }
77 }
78 ],
79 "inputs": [ 0, 1, 2, 3 ],
80 "outputs": [ 4 ],
81 "operators": [
82 {
83 "opcode_index": 0,
84 "inputs": [ 0, 1, 2, 3 ],
85 "outputs": [ 4 ],
86 "builtin_options_type": "StridedSliceOptions",
87 "builtin_options": {
88 "begin_mask": )" + std::to_string(beginMask) + R"(,
89 "end_mask": )" + std::to_string(endMask) + R"(
90 },
91 "custom_options_format": "FLEXBUFFERS"
92 }
93 ],
94 } ],
95 "buffers" : [
96 { },
97 { "data": )" + beginData + R"(, },
98 { "data": )" + endData + R"(, },
99 { "data": )" + stridesData + R"(, },
100 { }
101 ]
102 }
103 )";
104 Setup();
105 }
106 };
107
108 struct StridedSlice4DFixture : StridedSliceFixture
109 {
StridedSlice4DFixtureStridedSlice4DFixture110 StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
111 "[ 1, 2, 3, 1 ]", // outputShape
112 "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
113 "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
114 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData
115 ) {}
116 };
117
BOOST_FIXTURE_TEST_CASE(StridedSlice4D,StridedSlice4DFixture)118 BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
119 {
120 RunTest<4, armnn::DataType::Float32>(
121 0,
122 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
123
124 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
125
126 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
127
128 {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
129 }
130
131 struct StridedSlice4DReverseFixture : StridedSliceFixture
132 {
StridedSlice4DReverseFixtureStridedSlice4DReverseFixture133 StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
134 "[ 1, 2, 3, 1 ]", // outputShape
135 "[ 1,0,0,0, "
136 "255,255,255,255, "
137 "0,0,0,0, "
138 "0,0,0,0 ]", // beginData [ 1 -1 0 0 ]
139 "[ 2,0,0,0, "
140 "253,255,255,255, "
141 "3,0,0,0, "
142 "1,0,0,0 ]", // endData [ 2 -3 3 1 ]
143 "[ 1,0,0,0, "
144 "255,255,255,255, "
145 "1,0,0,0, "
146 "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ]
147 ) {}
148 };
149
BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse,StridedSlice4DReverseFixture)150 BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
151 {
152 RunTest<4, armnn::DataType::Float32>(
153 0,
154 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
155
156 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
157
158 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
159
160 {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
161 }
162
163 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
164 {
StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture165 StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
166 "[ 2, 1, 2, 1 ]", // outputShape
167 "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
168 "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
169 "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData
170 ) {}
171 };
172
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride,StridedSliceSimpleStrideFixture)173 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
174 {
175 RunTest<4, armnn::DataType::Float32>(
176 0,
177 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
178
179 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
180
181 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
182
183 {{"outputTensor", { 1.0f, 1.0f,
184
185 5.0f, 5.0f }}});
186 }
187
188 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
189 {
StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture190 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
191 "[ 3, 2, 3, 1 ]", // outputShape
192 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData
193 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData
194 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData
195 (1 << 4) - 1, // beginMask
196 (1 << 4) - 1 // endMask
197 ) {}
198 };
199
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask,StridedSliceSimpleRangeMaskFixture)200 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
201 {
202 RunTest<4, armnn::DataType::Float32>(
203 0,
204 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
205
206 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
207
208 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
209
210 {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
211
212 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
213
214 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
215 }
216
217 BOOST_AUTO_TEST_SUITE_END()
218