1 /* 2 * Copyright (c) 2018-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TEST_SPLIT_FIXTURE 25 #define ARM_COMPUTE_TEST_SPLIT_FIXTURE 26 27 #include "arm_compute/core/TensorShape.h" 28 #include "arm_compute/core/Types.h" 29 30 #include "tests/AssetsLibrary.h" 31 #include "tests/Globals.h" 32 #include "tests/IAccessor.h" 33 #include "tests/RawLutAccessor.h" 34 #include "tests/framework/Asserts.h" 35 #include "tests/framework/Fixture.h" 36 #include "tests/validation/Helpers.h" 37 #include "tests/validation/reference/SliceOperations.h" 38 39 #include <algorithm> 40 41 namespace arm_compute 42 { 43 namespace test 44 { 45 namespace validation 46 { 47 template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T> 48 class SplitFixture : public framework::Fixture 49 { 50 public: 51 template <typename...> setup(TensorShape shape,unsigned int axis,unsigned int splits,DataType data_type)52 void setup(TensorShape shape, unsigned int axis, unsigned int splits, DataType data_type) 53 { 54 _target = compute_target(shape, axis, splits, data_type); 55 _reference = compute_reference(shape, axis, splits, data_type); 56 } 57 58 protected: 59 template <typename U> fill(U && tensor,int i)60 void fill(U &&tensor, int i) 61 { 62 library->fill_tensor_uniform(tensor, i); 63 } 64 compute_target(const TensorShape & shape,unsigned int axis,unsigned int splits,DataType data_type)65 std::vector<TensorType> compute_target(const TensorShape &shape, unsigned int axis, unsigned int splits, DataType data_type) 66 { 67 // Create tensors 68 TensorType src = create_tensor<TensorType>(shape, data_type); 69 std::vector<TensorType> dsts(splits); 70 std::vector<ITensorType *> dsts_ptr; 71 for(auto &dst : dsts) 72 { 73 dsts_ptr.emplace_back(&dst); 74 } 75 76 // Create and configure function 77 FunctionType split; 78 split.configure(&src, dsts_ptr, axis); 79 80 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); 81 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t) 82 { 83 return t.info()->is_resizable(); 84 }), 85 framework::LogLevel::ERRORS); 86 87 // Allocate tensors 88 src.allocator()->allocate(); 89 for(unsigned int i = 0; i < splits; ++i) 90 { 91 dsts[i].allocator()->allocate(); 92 } 93 94 ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); 95 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t) 96 { 97 return !t.info()->is_resizable(); 98 }), 99 framework::LogLevel::ERRORS); 100 101 // Fill tensors 102 fill(AccessorType(src), 0); 103 104 // Compute function 105 split.run(); 106 107 return dsts; 108 } 109 compute_reference(const TensorShape & shape,unsigned int axis,unsigned int splits,DataType data_type)110 std::vector<SimpleTensor<T>> compute_reference(const TensorShape &shape, unsigned int axis, unsigned int splits, DataType data_type) 111 { 112 // Create reference 113 SimpleTensor<T> src{ shape, data_type }; 114 std::vector<SimpleTensor<T>> dsts; 115 116 // Fill reference 117 fill(src, 0); 118 119 // Calculate splice for each split 120 const size_t axis_split_step = shape[axis] / splits; 121 unsigned int axis_offset = 0; 122 123 // Start/End coordinates 124 Coordinates start_coords; 125 Coordinates end_coords; 126 for(unsigned int d = 0; d < shape.num_dimensions(); ++d) 127 { 128 end_coords.set(d, -1); 129 } 130 131 for(unsigned int i = 0; i < splits; ++i) 132 { 133 // Update coordinate on axis 134 start_coords.set(axis, axis_offset); 135 end_coords.set(axis, axis_offset + axis_split_step); 136 137 dsts.emplace_back(std::move(reference::slice(src, start_coords, end_coords))); 138 139 axis_offset += axis_split_step; 140 } 141 142 return dsts; 143 } 144 145 std::vector<TensorType> _target{}; 146 std::vector<SimpleTensor<T>> _reference{}; 147 }; 148 149 template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T> 150 class SplitShapesFixture : public framework::Fixture 151 { 152 public: 153 template <typename...> setup(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes,DataType data_type)154 void setup(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type) 155 { 156 _target = compute_target(shape, axis, split_shapes, data_type); 157 _reference = compute_reference(shape, axis, split_shapes, data_type); 158 } 159 160 protected: 161 template <typename U> fill(U && tensor,int i)162 void fill(U &&tensor, int i) 163 { 164 library->fill_tensor_uniform(tensor, i); 165 } 166 compute_target(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes,DataType data_type)167 std::vector<TensorType> compute_target(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type) 168 { 169 // Create tensors 170 TensorType src = create_tensor<TensorType>(shape, data_type); 171 std::vector<TensorType> dsts{}; 172 std::vector<ITensorType *> dsts_ptr; 173 174 for(const auto &split_shape : split_shapes) 175 { 176 TensorType dst = create_tensor<TensorType>(split_shape, data_type); 177 dsts.push_back(std::move(dst)); 178 } 179 180 for(auto &dst : dsts) 181 { 182 dsts_ptr.emplace_back(&dst); 183 } 184 185 // Create and configure function 186 FunctionType split; 187 split.configure(&src, dsts_ptr, axis); 188 189 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); 190 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t) 191 { 192 return t.info()->is_resizable(); 193 }), 194 framework::LogLevel::ERRORS); 195 196 // Allocate tensors 197 src.allocator()->allocate(); 198 for(unsigned int i = 0; i < dsts.size(); ++i) 199 { 200 dsts[i].allocator()->allocate(); 201 } 202 203 ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); 204 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t) 205 { 206 return !t.info()->is_resizable(); 207 }), 208 framework::LogLevel::ERRORS); 209 210 // Fill tensors 211 fill(AccessorType(src), 0); 212 213 // Compute function 214 split.run(); 215 216 return dsts; 217 } 218 compute_reference(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes,DataType data_type)219 std::vector<SimpleTensor<T>> compute_reference(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type) 220 { 221 // Create reference 222 SimpleTensor<T> src{ shape, data_type }; 223 std::vector<SimpleTensor<T>> dsts; 224 225 // Fill reference 226 fill(src, 0); 227 228 unsigned int axis_offset{ 0 }; 229 for(const auto &split_shape : split_shapes) 230 { 231 // Calculate splice for each split 232 const size_t axis_split_step = split_shape[axis]; 233 234 // Start/End coordinates 235 Coordinates start_coords; 236 Coordinates end_coords; 237 for(unsigned int d = 0; d < shape.num_dimensions(); ++d) 238 { 239 end_coords.set(d, -1); 240 } 241 242 // Update coordinate on axis 243 start_coords.set(axis, axis_offset); 244 end_coords.set(axis, axis_offset + axis_split_step); 245 246 dsts.emplace_back(std::move(reference::slice(src, start_coords, end_coords))); 247 248 axis_offset += axis_split_step; 249 } 250 251 return dsts; 252 } 253 254 std::vector<TensorType> _target{}; 255 std::vector<SimpleTensor<T>> _reference{}; 256 }; 257 } // namespace validation 258 } // namespace test 259 } // namespace arm_compute 260 #endif /* ARM_COMPUTE_TEST_SPLIT_FIXTURE */ 261