• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnTfParser/ITfParser.hpp"
7 
8 #include "ParserPrototxtFixture.hpp"
9 #include <PrototxtConversions.hpp>
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 namespace {
16 // helper for setting the dimensions in prototxt
shapeHelper(const armnn::TensorShape & shape,std::string & text)17 void shapeHelper(const armnn::TensorShape& shape, std::string& text){
18     for(unsigned int i = 0; i < shape.GetNumDimensions(); ++i) {
19         text.append(R"(dim {
20       size: )");
21         text.append(std::to_string(shape[i]));
22         text.append(R"(
23     })");
24     }
25 }
26 
27 // helper for converting from integer to octal representation
octalHelper(const std::vector<int> & content,std::string & text)28 void octalHelper(const std::vector<int>& content, std::string& text){
29     for (unsigned int i = 0; i < content.size(); ++i)
30     {
31         text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(content[i])));
32     }
33 }
34 } // namespace
35 
36 struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
37 {
StridedSliceFixtureStridedSliceFixture38     StridedSliceFixture(const armnn::TensorShape& inputShape,
39                         const std::vector<int>& beginData,
40                         const std::vector<int>& endData,
41                         const std::vector<int>& stridesData,
42                         int beginMask = 0,
43                         int endMask = 0,
44                         int ellipsisMask = 0,
45                         int newAxisMask = 0,
46                         int shrinkAxisMask = 0)
47     {
48         m_Prototext = R"(
49                          node {
50                            name: "input"
51                            op: "Placeholder"
52                            attr {
53                              key: "dtype"
54                              value {
55                                type: DT_FLOAT
56                              }
57                            }
58                            attr {
59                              key: "shape"
60                              value {
61                                shape {)";
62                                  shapeHelper(inputShape, m_Prototext);
63                                  m_Prototext.append(R"(
64                                }
65                              }
66                            }
67                          }
68                          node {
69                            name: "begin"
70                            op: "Const"
71                            attr {
72                              key: "dtype"
73                              value {
74                                type: DT_INT32
75                              }
76                            }
77                            attr {
78                              key: "value"
79                              value {
80                               tensor {
81                                dtype: DT_INT32
82                                  tensor_shape {
83                                    dim {
84                                     size: )");
85                                       m_Prototext += std::to_string(beginData.size());
86                                       m_Prototext.append(R"(
87                                     }
88                                  }
89                                  tensor_content: ")");
90                                    octalHelper(beginData, m_Prototext);
91                                    m_Prototext.append(R"("
92                                }
93                              }
94                            }
95                          }
96                          node {
97                            name: "end"
98                            op: "Const"
99                            attr {
100                              key: "dtype"
101                              value {
102                                type: DT_INT32
103                              }
104                            }
105                            attr {
106                              key: "value"
107                              value {
108                               tensor {
109                                dtype: DT_INT32
110                                  tensor_shape {
111                                    dim {
112                                     size: )");
113                                       m_Prototext += std::to_string(endData.size());
114                                       m_Prototext.append(R"(
115                                     }
116                                  }
117                                  tensor_content: ")");
118                                    octalHelper(endData, m_Prototext);
119                                    m_Prototext.append(R"("
120                                }
121                              }
122                            }
123                          }
124                          node {
125                            name: "strides"
126                            op: "Const"
127                            attr {
128                              key: "dtype"
129                              value {
130                                type: DT_INT32
131                              }
132                            }
133                            attr {
134                              key: "value"
135                              value {
136                               tensor {
137                                dtype: DT_INT32
138                                  tensor_shape {
139                                    dim {
140                                     size: )");
141                                       m_Prototext += std::to_string(stridesData.size());
142                                       m_Prototext.append(R"(
143                                     }
144                                  }
145                                  tensor_content: ")");
146                                    octalHelper(stridesData, m_Prototext);
147                                    m_Prototext.append(R"("
148                                }
149                              }
150                            }
151                          }
152                          node {
153                            name: "output"
154                            op: "StridedSlice"
155                            input: "input"
156                            input: "begin"
157                            input: "end"
158                            input: "strides"
159                            attr {
160                              key: "begin_mask"
161                              value {
162                                i: )");
163                                m_Prototext += std::to_string(beginMask);
164                                m_Prototext.append(R"(
165                              }
166                            }
167                            attr {
168                              key: "end_mask"
169                              value {
170                                i: )");
171                                  m_Prototext += std::to_string(endMask);
172                                  m_Prototext.append(R"(
173                              }
174                            }
175                            attr {
176                              key: "ellipsis_mask"
177                              value {
178                                i: )");
179                                  m_Prototext += std::to_string(ellipsisMask);
180                                  m_Prototext.append(R"(
181                              }
182                            }
183                            attr {
184                              key: "new_axis_mask"
185                              value {
186                                i: )");
187                                  m_Prototext += std::to_string(newAxisMask);
188                                  m_Prototext.append(R"(
189                              }
190                            }
191                            attr {
192                              key: "shrink_axis_mask"
193                              value {
194                                i: )");
195                                  m_Prototext += std::to_string(shrinkAxisMask);
196                                  m_Prototext.append(R"(
197                              }
198                            }
199                          })");
200 
201         Setup({ { "input", inputShape } }, { "output" });
202     }
203 };
204 
205 struct StridedSlice4DFixture : StridedSliceFixture
206 {
StridedSlice4DFixtureStridedSlice4DFixture207     StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 },  // inputShape
208                                                   { 1, 0, 0, 0 },  // beginData
209                                                   { 2, 2, 3, 1 },  // endData
210                                                   { 1, 1, 1, 1 }   // stridesData
211     ) {}
212 };
213 
BOOST_FIXTURE_TEST_CASE(StridedSlice4D,StridedSlice4DFixture)214 BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
215 {
216     RunTest<4>(
217             {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
218                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
219                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
220             {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
221 }
222 
223 struct StridedSlice4DReverseFixture : StridedSliceFixture
224 {
225 
StridedSlice4DReverseFixtureStridedSlice4DReverseFixture226     StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 },   // inputShape
227                                                          { 1, -1, 0, 0 },  // beginData
228                                                          { 2, -3, 3, 1 },  // endData
229                                                          { 1, -1, 1, 1 }   // stridesData
230     ) {}
231 };
232 
BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse,StridedSlice4DReverseFixture)233 BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
234 {
235     RunTest<4>(
236             {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
237                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
238                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
239             {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
240 }
241 
242 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
243 {
StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture244     StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
245                                                             { 0, 0, 0, 0 }, // beginData
246                                                             { 3, 2, 3, 1 }, // endData
247                                                             { 2, 2, 2, 1 }  // stridesData
248     ) {}
249 };
250 
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride,StridedSliceSimpleStrideFixture)251 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
252 {
253     RunTest<4>(
254             {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
255                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
256                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
257             {{"output", { 1.0f, 1.0f,
258                           5.0f, 5.0f }}});
259 }
260 
261 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
262 {
StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture263     StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
264                                                                { 1, 1, 1, 1 }, // beginData
265                                                                { 1, 1, 1, 1 }, // endData
266                                                                { 1, 1, 1, 1 }, // stridesData
267                                                                (1 << 4) - 1,   // beginMask
268                                                                (1 << 4) - 1    // endMask
269     ) {}
270 };
271 
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask,StridedSliceSimpleRangeMaskFixture)272 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
273 {
274     RunTest<4>(
275             {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
276                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
277                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
278             {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
279                           3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
280                           5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
281 }
282 
283 BOOST_AUTO_TEST_SUITE_END()
284