• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET
25 #define ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET
26 
27 #include "utils/TypePrinter.h"
28 
29 #include "arm_compute/core/Types.h"
30 
31 namespace arm_compute
32 {
33 namespace test
34 {
35 namespace datasets
36 {
37 class SliceDataset
38 {
39 public:
40     using type = std::tuple<TensorShape, Coordinates, Coordinates>;
41 
42     struct iterator
43     {
iteratoriterator44         iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it,
45                  std::vector<Coordinates>::const_iterator starts_values_it,
46                  std::vector<Coordinates>::const_iterator ends_values_it)
47             : _tensor_shapes_it{ std::move(tensor_shapes_it) },
48               _starts_values_it{ std::move(starts_values_it) },
49               _ends_values_it{ std::move(ends_values_it) }
50         {
51         }
52 
descriptioniterator53         std::string description() const
54         {
55             std::stringstream description;
56             description << "Shape=" << *_tensor_shapes_it << ":";
57             description << "Starts=" << *_starts_values_it << ":";
58             description << "Ends=" << *_ends_values_it << ":";
59             return description.str();
60         }
61 
62         SliceDataset::type operator*() const
63         {
64             return std::make_tuple(*_tensor_shapes_it, *_starts_values_it, *_ends_values_it);
65         }
66 
67         iterator &operator++()
68         {
69             ++_tensor_shapes_it;
70             ++_starts_values_it;
71             ++_ends_values_it;
72             return *this;
73         }
74 
75     private:
76         std::vector<TensorShape>::const_iterator _tensor_shapes_it;
77         std::vector<Coordinates>::const_iterator _starts_values_it;
78         std::vector<Coordinates>::const_iterator _ends_values_it;
79     };
80 
begin()81     iterator begin() const
82     {
83         return iterator(_tensor_shapes.begin(), _starts_values.begin(), _ends_values.begin());
84     }
85 
size()86     int size() const
87     {
88         return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), _ends_values.size()));
89     }
90 
add_config(TensorShape shape,Coordinates starts,Coordinates ends)91     void add_config(TensorShape shape, Coordinates starts, Coordinates ends)
92     {
93         _tensor_shapes.emplace_back(std::move(shape));
94         _starts_values.emplace_back(std::move(starts));
95         _ends_values.emplace_back(std::move(ends));
96     }
97 
98 protected:
99     SliceDataset()                = default;
100     SliceDataset(SliceDataset &&) = default;
101 
102 private:
103     std::vector<TensorShape> _tensor_shapes{};
104     std::vector<Coordinates> _starts_values{};
105     std::vector<Coordinates> _ends_values{};
106 };
107 
108 class StridedSliceDataset
109 {
110 public:
111     using type = std::tuple<TensorShape, Coordinates, Coordinates, BiStrides, int32_t, int32_t, int32_t>;
112 
113     struct iterator
114     {
iteratoriterator115         iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it,
116                  std::vector<Coordinates>::const_iterator starts_values_it,
117                  std::vector<Coordinates>::const_iterator ends_values_it,
118                  std::vector<BiStrides>::const_iterator   strides_values_it,
119                  std::vector<int32_t>::const_iterator     begin_mask_values_it,
120                  std::vector<int32_t>::const_iterator     end_mask_values_it,
121                  std::vector<int32_t>::const_iterator     shrink_mask_values_it)
122             : _tensor_shapes_it{ std::move(tensor_shapes_it) },
123               _starts_values_it{ std::move(starts_values_it) },
124               _ends_values_it{ std::move(ends_values_it) },
125               _strides_values_it{ std::move(strides_values_it) },
126               _begin_mask_values_it{ std::move(begin_mask_values_it) },
127               _end_mask_values_it{ std::move(end_mask_values_it) },
128               _shrink_mask_values_it{ std::move(shrink_mask_values_it) }
129         {
130         }
131 
descriptioniterator132         std::string description() const
133         {
134             std::stringstream description;
135             description << "Shape=" << *_tensor_shapes_it << ":";
136             description << "Starts=" << *_starts_values_it << ":";
137             description << "Ends=" << *_ends_values_it << ":";
138             description << "Strides=" << *_strides_values_it << ":";
139             description << "BeginMask=" << *_begin_mask_values_it << ":";
140             description << "EndMask=" << *_end_mask_values_it << ":";
141             description << "ShrinkMask=" << *_shrink_mask_values_it << ":";
142             return description.str();
143         }
144 
145         StridedSliceDataset::type operator*() const
146         {
147             return std::make_tuple(*_tensor_shapes_it,
148                                    *_starts_values_it, *_ends_values_it, *_strides_values_it,
149                                    *_begin_mask_values_it, *_end_mask_values_it, *_shrink_mask_values_it);
150         }
151 
152         iterator &operator++()
153         {
154             ++_tensor_shapes_it;
155             ++_starts_values_it;
156             ++_ends_values_it;
157             ++_strides_values_it;
158             ++_begin_mask_values_it;
159             ++_end_mask_values_it;
160             ++_shrink_mask_values_it;
161 
162             return *this;
163         }
164 
165     private:
166         std::vector<TensorShape>::const_iterator _tensor_shapes_it;
167         std::vector<Coordinates>::const_iterator _starts_values_it;
168         std::vector<Coordinates>::const_iterator _ends_values_it;
169         std::vector<BiStrides>::const_iterator   _strides_values_it;
170         std::vector<int32_t>::const_iterator     _begin_mask_values_it;
171         std::vector<int32_t>::const_iterator     _end_mask_values_it;
172         std::vector<int32_t>::const_iterator     _shrink_mask_values_it;
173     };
174 
begin()175     iterator begin() const
176     {
177         return iterator(_tensor_shapes.begin(),
178                         _starts_values.begin(), _ends_values.begin(), _strides_values.begin(),
179                         _begin_mask_values.begin(), _end_mask_values.begin(), _shrink_mask_values.begin());
180     }
181 
size()182     int size() const
183     {
184         return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), std::min(_ends_values.size(), _strides_values.size())));
185     }
186 
187     void add_config(TensorShape shape,
188                     Coordinates starts, Coordinates ends, BiStrides strides,
189                     int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_mask = 0)
190     {
191         _tensor_shapes.emplace_back(std::move(shape));
192         _starts_values.emplace_back(std::move(starts));
193         _ends_values.emplace_back(std::move(ends));
194         _strides_values.emplace_back(std::move(strides));
195         _begin_mask_values.emplace_back(std::move(begin_mask));
196         _end_mask_values.emplace_back(std::move(end_mask));
197         _shrink_mask_values.emplace_back(std::move(shrink_mask));
198     }
199 
200 protected:
201     StridedSliceDataset()                       = default;
202     StridedSliceDataset(StridedSliceDataset &&) = default;
203 
204 private:
205     std::vector<TensorShape> _tensor_shapes{};
206     std::vector<Coordinates> _starts_values{};
207     std::vector<Coordinates> _ends_values{};
208     std::vector<BiStrides>   _strides_values{};
209     std::vector<int32_t>     _begin_mask_values{};
210     std::vector<int32_t>     _end_mask_values{};
211     std::vector<int32_t>     _shrink_mask_values{};
212 };
213 
214 class SmallSliceDataset final : public SliceDataset
215 {
216 public:
SmallSliceDataset()217     SmallSliceDataset()
218     {
219         // 1D
220         add_config(TensorShape(15U), Coordinates(4), Coordinates(9));
221         add_config(TensorShape(15U), Coordinates(0), Coordinates(-1));
222         // 2D
223         add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1));
224         add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1));
225         // 3D
226         add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4));
227         add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4));
228         // 4D
229         add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5));
230     }
231 };
232 
233 class LargeSliceDataset final : public SliceDataset
234 {
235 public:
LargeSliceDataset()236     LargeSliceDataset()
237     {
238         // 1D
239         add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100));
240         // 2D
241         add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -1));
242         // 3D
243         add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, 2), Coordinates(368, -1, 4));
244         // 4D
245         add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, 17, 5));
246     }
247 };
248 
249 class SmallStridedSliceDataset final : public StridedSliceDataset
250 {
251 public:
SmallStridedSliceDataset()252     SmallStridedSliceDataset()
253     {
254         // 1D
255         add_config(TensorShape(15U), Coordinates(0), Coordinates(5), BiStrides(2));
256         add_config(TensorShape(15U), Coordinates(-1), Coordinates(-8), BiStrides(-2));
257         // 2D
258         add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1), BiStrides(2, 1));
259         add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1), BiStrides(2, 1), 1);
260         // 3D
261         add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2));
262         add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2), 0, 1);
263         // 4D
264         add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5), BiStrides(2, 1, 2, 3));
265 
266         // Shrink axis
267         add_config(TensorShape(1U, 3U, 2U, 3U), Coordinates(0, 1, 0, 0), Coordinates(1, 1, 1, 1), BiStrides(1, 1, 1, 1), 0, 15, 6);
268         add_config(TensorShape(3U, 2U), Coordinates(0, 0), Coordinates(3U, 1U), BiStrides(1, 1), 0, 0, 2);
269         add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 0, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 6, 1);
270         add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 1, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 5, 3);
271     }
272 };
273 
274 class LargeStridedSliceDataset final : public StridedSliceDataset
275 {
276 public:
LargeStridedSliceDataset()277     LargeStridedSliceDataset()
278     {
279         // 1D
280         add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100), BiStrides(20));
281         // 2D
282         add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -30), BiStrides(10, 7));
283         // 3D
284         add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, -1), Coordinates(368, -30, -5), BiStrides(14, 7, -2));
285         // 4D
286         add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, -30, 5), BiStrides(20, 7, 2), 1, 1);
287     }
288 };
289 } // namespace datasets
290 } // namespace test
291 } // namespace arm_compute
292 #endif /* ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET */
293