1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "armnnTfParser/ITfParser.hpp"
7
8 #include "ParserPrototxtFixture.hpp"
9 #include <PrototxtConversions.hpp>
10
11 #include <boost/test/unit_test.hpp>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15 namespace {
16 // helper for setting the dimensions in prototxt
shapeHelper(const armnn::TensorShape & shape,std::string & text)17 void shapeHelper(const armnn::TensorShape& shape, std::string& text){
18 for(unsigned int i = 0; i < shape.GetNumDimensions(); ++i) {
19 text.append(R"(dim {
20 size: )");
21 text.append(std::to_string(shape[i]));
22 text.append(R"(
23 })");
24 }
25 }
26
27 // helper for converting from integer to octal representation
octalHelper(const std::vector<int> & content,std::string & text)28 void octalHelper(const std::vector<int>& content, std::string& text){
29 for (unsigned int i = 0; i < content.size(); ++i)
30 {
31 text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(content[i])));
32 }
33 }
34 } // namespace
35
36 struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
37 {
StridedSliceFixtureStridedSliceFixture38 StridedSliceFixture(const armnn::TensorShape& inputShape,
39 const std::vector<int>& beginData,
40 const std::vector<int>& endData,
41 const std::vector<int>& stridesData,
42 int beginMask = 0,
43 int endMask = 0,
44 int ellipsisMask = 0,
45 int newAxisMask = 0,
46 int shrinkAxisMask = 0)
47 {
48 m_Prototext = R"(
49 node {
50 name: "input"
51 op: "Placeholder"
52 attr {
53 key: "dtype"
54 value {
55 type: DT_FLOAT
56 }
57 }
58 attr {
59 key: "shape"
60 value {
61 shape {)";
62 shapeHelper(inputShape, m_Prototext);
63 m_Prototext.append(R"(
64 }
65 }
66 }
67 }
68 node {
69 name: "begin"
70 op: "Const"
71 attr {
72 key: "dtype"
73 value {
74 type: DT_INT32
75 }
76 }
77 attr {
78 key: "value"
79 value {
80 tensor {
81 dtype: DT_INT32
82 tensor_shape {
83 dim {
84 size: )");
85 m_Prototext += std::to_string(beginData.size());
86 m_Prototext.append(R"(
87 }
88 }
89 tensor_content: ")");
90 octalHelper(beginData, m_Prototext);
91 m_Prototext.append(R"("
92 }
93 }
94 }
95 }
96 node {
97 name: "end"
98 op: "Const"
99 attr {
100 key: "dtype"
101 value {
102 type: DT_INT32
103 }
104 }
105 attr {
106 key: "value"
107 value {
108 tensor {
109 dtype: DT_INT32
110 tensor_shape {
111 dim {
112 size: )");
113 m_Prototext += std::to_string(endData.size());
114 m_Prototext.append(R"(
115 }
116 }
117 tensor_content: ")");
118 octalHelper(endData, m_Prototext);
119 m_Prototext.append(R"("
120 }
121 }
122 }
123 }
124 node {
125 name: "strides"
126 op: "Const"
127 attr {
128 key: "dtype"
129 value {
130 type: DT_INT32
131 }
132 }
133 attr {
134 key: "value"
135 value {
136 tensor {
137 dtype: DT_INT32
138 tensor_shape {
139 dim {
140 size: )");
141 m_Prototext += std::to_string(stridesData.size());
142 m_Prototext.append(R"(
143 }
144 }
145 tensor_content: ")");
146 octalHelper(stridesData, m_Prototext);
147 m_Prototext.append(R"("
148 }
149 }
150 }
151 }
152 node {
153 name: "output"
154 op: "StridedSlice"
155 input: "input"
156 input: "begin"
157 input: "end"
158 input: "strides"
159 attr {
160 key: "begin_mask"
161 value {
162 i: )");
163 m_Prototext += std::to_string(beginMask);
164 m_Prototext.append(R"(
165 }
166 }
167 attr {
168 key: "end_mask"
169 value {
170 i: )");
171 m_Prototext += std::to_string(endMask);
172 m_Prototext.append(R"(
173 }
174 }
175 attr {
176 key: "ellipsis_mask"
177 value {
178 i: )");
179 m_Prototext += std::to_string(ellipsisMask);
180 m_Prototext.append(R"(
181 }
182 }
183 attr {
184 key: "new_axis_mask"
185 value {
186 i: )");
187 m_Prototext += std::to_string(newAxisMask);
188 m_Prototext.append(R"(
189 }
190 }
191 attr {
192 key: "shrink_axis_mask"
193 value {
194 i: )");
195 m_Prototext += std::to_string(shrinkAxisMask);
196 m_Prototext.append(R"(
197 }
198 }
199 })");
200
201 Setup({ { "input", inputShape } }, { "output" });
202 }
203 };
204
205 struct StridedSlice4DFixture : StridedSliceFixture
206 {
StridedSlice4DFixtureStridedSlice4DFixture207 StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
208 { 1, 0, 0, 0 }, // beginData
209 { 2, 2, 3, 1 }, // endData
210 { 1, 1, 1, 1 } // stridesData
211 ) {}
212 };
213
BOOST_FIXTURE_TEST_CASE(StridedSlice4D,StridedSlice4DFixture)214 BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
215 {
216 RunTest<4>(
217 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
218 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
219 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
220 {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
221 }
222
223 struct StridedSlice4DReverseFixture : StridedSliceFixture
224 {
225
StridedSlice4DReverseFixtureStridedSlice4DReverseFixture226 StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
227 { 1, -1, 0, 0 }, // beginData
228 { 2, -3, 3, 1 }, // endData
229 { 1, -1, 1, 1 } // stridesData
230 ) {}
231 };
232
BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse,StridedSlice4DReverseFixture)233 BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
234 {
235 RunTest<4>(
236 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
237 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
238 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
239 {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
240 }
241
242 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
243 {
StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture244 StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
245 { 0, 0, 0, 0 }, // beginData
246 { 3, 2, 3, 1 }, // endData
247 { 2, 2, 2, 1 } // stridesData
248 ) {}
249 };
250
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride,StridedSliceSimpleStrideFixture)251 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
252 {
253 RunTest<4>(
254 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
255 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
256 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
257 {{"output", { 1.0f, 1.0f,
258 5.0f, 5.0f }}});
259 }
260
261 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
262 {
StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture263 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
264 { 1, 1, 1, 1 }, // beginData
265 { 1, 1, 1, 1 }, // endData
266 { 1, 1, 1, 1 }, // stridesData
267 (1 << 4) - 1, // beginMask
268 (1 << 4) - 1 // endMask
269 ) {}
270 };
271
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask,StridedSliceSimpleRangeMaskFixture)272 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
273 {
274 RunTest<4>(
275 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
276 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
277 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
278 {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
279 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
280 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
281 }
282
283 BOOST_AUTO_TEST_SUITE_END()
284