• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_SPLIT_FIXTURE
25 #define ARM_COMPUTE_TEST_SPLIT_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 
30 #include "tests/AssetsLibrary.h"
31 #include "tests/Globals.h"
32 #include "tests/IAccessor.h"
33 #include "tests/RawLutAccessor.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Fixture.h"
36 #include "tests/validation/Helpers.h"
37 #include "tests/validation/reference/SliceOperations.h"
38 
39 #include <algorithm>
40 
41 namespace arm_compute
42 {
43 namespace test
44 {
45 namespace validation
46 {
47 template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T>
48 class SplitFixture : public framework::Fixture
49 {
50 public:
51     template <typename...>
setup(TensorShape shape,unsigned int axis,unsigned int splits,DataType data_type)52     void setup(TensorShape shape, unsigned int axis, unsigned int splits, DataType data_type)
53     {
54         _target    = compute_target(shape, axis, splits, data_type);
55         _reference = compute_reference(shape, axis, splits, data_type);
56     }
57 
58 protected:
59     template <typename U>
fill(U && tensor,int i)60     void fill(U &&tensor, int i)
61     {
62         library->fill_tensor_uniform(tensor, i);
63     }
64 
compute_target(const TensorShape & shape,unsigned int axis,unsigned int splits,DataType data_type)65     std::vector<TensorType> compute_target(const TensorShape &shape, unsigned int axis, unsigned int splits, DataType data_type)
66     {
67         // Create tensors
68         TensorType                 src = create_tensor<TensorType>(shape, data_type);
69         std::vector<TensorType>    dsts(splits);
70         std::vector<ITensorType *> dsts_ptr;
71         for(auto &dst : dsts)
72         {
73             dsts_ptr.emplace_back(&dst);
74         }
75 
76         // Create and configure function
77         FunctionType split;
78         split.configure(&src, dsts_ptr, axis);
79 
80         ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
81         ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
82         {
83             return t.info()->is_resizable();
84         }),
85         framework::LogLevel::ERRORS);
86 
87         // Allocate tensors
88         src.allocator()->allocate();
89         for(unsigned int i = 0; i < splits; ++i)
90         {
91             dsts[i].allocator()->allocate();
92         }
93 
94         ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
95         ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
96         {
97             return !t.info()->is_resizable();
98         }),
99         framework::LogLevel::ERRORS);
100 
101         // Fill tensors
102         fill(AccessorType(src), 0);
103 
104         // Compute function
105         split.run();
106 
107         return dsts;
108     }
109 
compute_reference(const TensorShape & shape,unsigned int axis,unsigned int splits,DataType data_type)110     std::vector<SimpleTensor<T>> compute_reference(const TensorShape &shape, unsigned int axis, unsigned int splits, DataType data_type)
111     {
112         // Create reference
113         SimpleTensor<T>              src{ shape, data_type };
114         std::vector<SimpleTensor<T>> dsts;
115 
116         // Fill reference
117         fill(src, 0);
118 
119         // Calculate splice for each split
120         const size_t axis_split_step = shape[axis] / splits;
121         unsigned int axis_offset     = 0;
122 
123         // Start/End coordinates
124         Coordinates start_coords;
125         Coordinates end_coords;
126         for(unsigned int d = 0; d < shape.num_dimensions(); ++d)
127         {
128             end_coords.set(d, -1);
129         }
130 
131         for(unsigned int i = 0; i < splits; ++i)
132         {
133             // Update coordinate on axis
134             start_coords.set(axis, axis_offset);
135             end_coords.set(axis, axis_offset + axis_split_step);
136 
137             dsts.emplace_back(std::move(reference::slice(src, start_coords, end_coords)));
138 
139             axis_offset += axis_split_step;
140         }
141 
142         return dsts;
143     }
144 
145     std::vector<TensorType>      _target{};
146     std::vector<SimpleTensor<T>> _reference{};
147 };
148 
149 template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T>
150 class SplitShapesFixture : public framework::Fixture
151 {
152 public:
153     template <typename...>
setup(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes,DataType data_type)154     void setup(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type)
155     {
156         _target    = compute_target(shape, axis, split_shapes, data_type);
157         _reference = compute_reference(shape, axis, split_shapes, data_type);
158     }
159 
160 protected:
161     template <typename U>
fill(U && tensor,int i)162     void fill(U &&tensor, int i)
163     {
164         library->fill_tensor_uniform(tensor, i);
165     }
166 
compute_target(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes,DataType data_type)167     std::vector<TensorType> compute_target(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type)
168     {
169         // Create tensors
170         TensorType                 src = create_tensor<TensorType>(shape, data_type);
171         std::vector<TensorType>    dsts{};
172         std::vector<ITensorType *> dsts_ptr;
173 
174         for(const auto &split_shape : split_shapes)
175         {
176             TensorType dst = create_tensor<TensorType>(split_shape, data_type);
177             dsts.push_back(std::move(dst));
178         }
179 
180         for(auto &dst : dsts)
181         {
182             dsts_ptr.emplace_back(&dst);
183         }
184 
185         // Create and configure function
186         FunctionType split;
187         split.configure(&src, dsts_ptr, axis);
188 
189         ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
190         ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
191         {
192             return t.info()->is_resizable();
193         }),
194         framework::LogLevel::ERRORS);
195 
196         // Allocate tensors
197         src.allocator()->allocate();
198         for(unsigned int i = 0; i < dsts.size(); ++i)
199         {
200             dsts[i].allocator()->allocate();
201         }
202 
203         ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
204         ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
205         {
206             return !t.info()->is_resizable();
207         }),
208         framework::LogLevel::ERRORS);
209 
210         // Fill tensors
211         fill(AccessorType(src), 0);
212 
213         // Compute function
214         split.run();
215 
216         return dsts;
217     }
218 
compute_reference(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes,DataType data_type)219     std::vector<SimpleTensor<T>> compute_reference(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type)
220     {
221         // Create reference
222         SimpleTensor<T>              src{ shape, data_type };
223         std::vector<SimpleTensor<T>> dsts;
224 
225         // Fill reference
226         fill(src, 0);
227 
228         unsigned int axis_offset{ 0 };
229         for(const auto &split_shape : split_shapes)
230         {
231             // Calculate splice for each split
232             const size_t axis_split_step = split_shape[axis];
233 
234             // Start/End coordinates
235             Coordinates start_coords;
236             Coordinates end_coords;
237             for(unsigned int d = 0; d < shape.num_dimensions(); ++d)
238             {
239                 end_coords.set(d, -1);
240             }
241 
242             // Update coordinate on axis
243             start_coords.set(axis, axis_offset);
244             end_coords.set(axis, axis_offset + axis_split_step);
245 
246             dsts.emplace_back(std::move(reference::slice(src, start_coords, end_coords)));
247 
248             axis_offset += axis_split_step;
249         }
250 
251         return dsts;
252     }
253 
254     std::vector<TensorType>      _target{};
255     std::vector<SimpleTensor<T>> _reference{};
256 };
257 } // namespace validation
258 } // namespace test
259 } // namespace arm_compute
260 #endif /* ARM_COMPUTE_TEST_SPLIT_FIXTURE */
261