• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 
12 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13 
14 struct StridedSliceFixture : public ParserFlatbuffersFixture
15 {
StridedSliceFixtureStridedSliceFixture16     explicit StridedSliceFixture(const std::string & inputShape,
17                                  const std::string & outputShape,
18                                  const std::string & beginData,
19                                  const std::string & endData,
20                                  const std::string & stridesData,
21                                  int beginMask = 0,
22                                  int endMask = 0)
23     {
24         m_JsonString = R"(
25             {
26                 "version": 3,
27                 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
28                 "subgraphs": [ {
29                     "tensors": [
30                         {
31                             "shape": )" + inputShape + R"(,
32                             "type": "FLOAT32",
33                             "buffer": 0,
34                             "name": "inputTensor",
35                             "quantization": {
36                                 "min": [ 0.0 ],
37                                 "max": [ 255.0 ],
38                                 "scale": [ 1.0 ],
39                                 "zero_point": [ 0 ],
40                             }
41                         },
42                         {
43                             "shape": [ 4 ],
44                             "type": "INT32",
45                             "buffer": 1,
46                             "name": "beginTensor",
47                             "quantization": {
48                             }
49                         },
50                         {
51                            "shape": [ 4 ],
52                             "type": "INT32",
53                             "buffer": 2,
54                             "name": "endTensor",
55                             "quantization": {
56                             }
57                         },
58                         {
59                            "shape": [ 4 ],
60                             "type": "INT32",
61                             "buffer": 3,
62                             "name": "stridesTensor",
63                             "quantization": {
64                             }
65                         },
66                         {
67                             "shape": )" + outputShape + R"( ,
68                             "type": "FLOAT32",
69                             "buffer": 4,
70                             "name": "outputTensor",
71                             "quantization": {
72                                 "min": [ 0.0 ],
73                                 "max": [ 255.0 ],
74                                 "scale": [ 1.0 ],
75                                 "zero_point": [ 0 ],
76                             }
77                         }
78                     ],
79                     "inputs": [ 0, 1, 2, 3 ],
80                     "outputs": [ 4 ],
81                     "operators": [
82                         {
83                             "opcode_index": 0,
84                             "inputs": [ 0, 1, 2, 3 ],
85                             "outputs": [ 4 ],
86                             "builtin_options_type": "StridedSliceOptions",
87                             "builtin_options": {
88                                "begin_mask": )"       + std::to_string(beginMask)      + R"(,
89                                "end_mask": )"         + std::to_string(endMask)        + R"(
90                             },
91                             "custom_options_format": "FLEXBUFFERS"
92                         }
93                     ],
94                 } ],
95                 "buffers" : [
96                     { },
97                     { "data": )" + beginData + R"(, },
98                     { "data": )" + endData + R"(, },
99                     { "data": )" + stridesData + R"(, },
100                     { }
101                 ]
102             }
103         )";
104         Setup();
105     }
106 };
107 
108 struct StridedSlice4DFixture : StridedSliceFixture
109 {
StridedSlice4DFixtureStridedSlice4DFixture110     StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
111                                                   "[ 1, 2, 3, 1 ]",  // outputShape
112                                                   "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",  // beginData
113                                                   "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",  // endData
114                                                   "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]"   // stridesData
115                                                  ) {}
116 };
117 
BOOST_FIXTURE_TEST_CASE(StridedSlice4D,StridedSlice4DFixture)118 BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
119 {
120   RunTest<4, armnn::DataType::Float32>(
121       0,
122       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
123 
124                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
125 
126                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
127 
128       {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
129 }
130 
131 struct StridedSlice4DReverseFixture : StridedSliceFixture
132 {
StridedSlice4DReverseFixtureStridedSlice4DReverseFixture133     StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",    // inputShape
134                                                          "[ 1, 2, 3, 1 ]",    // outputShape
135                                                          "[ 1,0,0,0, "
136                                                          "255,255,255,255, "
137                                                          "0,0,0,0, "
138                                                          "0,0,0,0 ]",  // beginData    [ 1 -1 0 0 ]
139                                                          "[ 2,0,0,0, "
140                                                          "253,255,255,255, "
141                                                          "3,0,0,0, "
142                                                          "1,0,0,0 ]",  // endData      [ 2 -3 3 1 ]
143                                                          "[ 1,0,0,0, "
144                                                          "255,255,255,255, "
145                                                          "1,0,0,0, "
146                                                          "1,0,0,0 ]"   // stridesData  [ 1 -1 1 1 ]
147                                                         ) {}
148 };
149 
BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse,StridedSlice4DReverseFixture)150 BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
151 {
152   RunTest<4, armnn::DataType::Float32>(
153       0,
154       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
155 
156                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
157 
158                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
159 
160       {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
161 }
162 
163 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
164 {
StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture165     StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
166                                                             "[ 2, 1, 2, 1 ]",  // outputShape
167                                                             "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",  // beginData
168                                                             "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",  // endData
169                                                             "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]"   // stridesData
170                                                  ) {}
171 };
172 
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride,StridedSliceSimpleStrideFixture)173 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
174 {
175   RunTest<4, armnn::DataType::Float32>(
176       0,
177       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
178 
179                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
180 
181                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
182 
183       {{"outputTensor", { 1.0f, 1.0f,
184 
185                           5.0f, 5.0f }}});
186 }
187 
188 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
189 {
StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture190     StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
191                                                                "[ 3, 2, 3, 1 ]",  // outputShape
192                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // beginData
193                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // endData
194                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // stridesData
195                                                                (1 << 4) - 1,  // beginMask
196                                                                (1 << 4) - 1   // endMask
197                                                  ) {}
198 };
199 
BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask,StridedSliceSimpleRangeMaskFixture)200 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
201 {
202   RunTest<4, armnn::DataType::Float32>(
203       0,
204       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
205 
206                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
207 
208                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
209 
210       {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
211 
212                           3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
213 
214                           5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
215 }
216 
217 BOOST_AUTO_TEST_SUITE_END()
218