1 /* 2 * Copyright (c) 2018-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TEST_SPLIT_DATASET 25 #define ARM_COMPUTE_TEST_SPLIT_DATASET 26 27 #include "utils/TypePrinter.h" 28 29 #include "arm_compute/core/Types.h" 30 31 namespace arm_compute 32 { 33 namespace test 34 { 35 namespace datasets 36 { 37 class SplitDataset 38 { 39 public: 40 using type = std::tuple<TensorShape, unsigned int, unsigned int>; 41 42 struct iterator 43 { iteratoriterator44 iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, 45 std::vector<unsigned int>::const_iterator axis_values_it, 46 std::vector<unsigned int>::const_iterator splits_values_it) 47 : _tensor_shapes_it{ std::move(tensor_shapes_it) }, 48 _axis_values_it{ std::move(axis_values_it) }, 49 _splits_values_it{ std::move(splits_values_it) } 50 { 51 } 52 descriptioniterator53 std::string description() const 54 { 55 std::stringstream description; 56 description << "Shape=" << *_tensor_shapes_it << ":"; 57 description << "Axis=" << *_axis_values_it << ":"; 58 description << "Splits=" << *_splits_values_it << ":"; 59 return description.str(); 60 } 61 62 SplitDataset::type operator*() const 63 { 64 return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_splits_values_it); 65 } 66 67 iterator &operator++() 68 { 69 ++_tensor_shapes_it; 70 ++_axis_values_it; 71 ++_splits_values_it; 72 return *this; 73 } 74 75 private: 76 std::vector<TensorShape>::const_iterator _tensor_shapes_it; 77 std::vector<unsigned int>::const_iterator _axis_values_it; 78 std::vector<unsigned int>::const_iterator _splits_values_it; 79 }; 80 begin()81 iterator begin() const 82 { 83 return iterator(_tensor_shapes.begin(), _axis_values.begin(), _splits_values.begin()); 84 } 85 size()86 int size() const 87 { 88 return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _splits_values.size())); 89 } 90 add_config(TensorShape shape,unsigned int axis,unsigned int splits)91 void add_config(TensorShape shape, unsigned int axis, unsigned int splits) 92 { 93 _tensor_shapes.emplace_back(std::move(shape)); 94 _axis_values.emplace_back(axis); 95 _splits_values.emplace_back(splits); 96 } 97 98 protected: 99 SplitDataset() = default; 100 SplitDataset(SplitDataset &&) = default; 101 102 private: 103 std::vector<TensorShape> _tensor_shapes{}; 104 std::vector<unsigned int> _axis_values{}; 105 std::vector<unsigned int> _splits_values{}; 106 }; 107 108 class SmallSplitDataset final : public SplitDataset 109 { 110 public: SmallSplitDataset()111 SmallSplitDataset() 112 { 113 add_config(TensorShape(128U), 0U, 4U); 114 add_config(TensorShape(6U, 3U, 4U), 2U, 2U); 115 add_config(TensorShape(27U, 14U, 2U), 1U, 2U); 116 add_config(TensorShape(64U, 32U, 4U, 6U), 3U, 3U); 117 } 118 }; 119 120 class LargeSplitDataset final : public SplitDataset 121 { 122 public: LargeSplitDataset()123 LargeSplitDataset() 124 { 125 add_config(TensorShape(512U), 0U, 8U); 126 add_config(TensorShape(128U, 64U, 8U), 2U, 2U); 127 add_config(TensorShape(128U, 64U, 8U, 2U), 1U, 2U); 128 add_config(TensorShape(128U, 64U, 32U, 4U), 3U, 4U); 129 } 130 }; 131 132 class SplitShapesDataset 133 { 134 public: 135 using type = std::tuple<TensorShape, unsigned int, std::vector<TensorShape>>; 136 137 struct iterator 138 { iteratoriterator139 iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, 140 std::vector<unsigned int>::const_iterator axis_values_it, 141 std::vector<std::vector<TensorShape>>::const_iterator split_shapes_values_it) 142 : _tensor_shapes_it{ std::move(tensor_shapes_it) }, 143 _axis_values_it{ std::move(axis_values_it) }, 144 _split_shapes_values_it{ std::move(split_shapes_values_it) } 145 { 146 } 147 descriptioniterator148 std::string description() const 149 { 150 std::stringstream description; 151 description << "Shape=" << *_tensor_shapes_it << ":"; 152 description << "Axis=" << *_axis_values_it << ":"; 153 description << "Split shapes=" << *_split_shapes_values_it << ":"; 154 return description.str(); 155 } 156 157 SplitShapesDataset::type operator*() const 158 { 159 return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_split_shapes_values_it); 160 } 161 162 iterator &operator++() 163 { 164 ++_tensor_shapes_it; 165 ++_axis_values_it; 166 ++_split_shapes_values_it; 167 return *this; 168 } 169 170 private: 171 std::vector<TensorShape>::const_iterator _tensor_shapes_it; 172 std::vector<unsigned int>::const_iterator _axis_values_it; 173 std::vector<std::vector<TensorShape>>::const_iterator _split_shapes_values_it; 174 }; 175 begin()176 iterator begin() const 177 { 178 return iterator(_tensor_shapes.begin(), _axis_values.begin(), _split_shapes_values.begin()); 179 } 180 size()181 int size() const 182 { 183 return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _split_shapes_values.size())); 184 } 185 add_config(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes)186 void add_config(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes) 187 { 188 _tensor_shapes.emplace_back(std::move(shape)); 189 _axis_values.emplace_back(axis); 190 _split_shapes_values.emplace_back(split_shapes); 191 } 192 193 protected: 194 SplitShapesDataset() = default; 195 SplitShapesDataset(SplitShapesDataset &&) = default; 196 197 private: 198 std::vector<TensorShape> _tensor_shapes{}; 199 std::vector<unsigned int> _axis_values{}; 200 std::vector<std::vector<TensorShape>> _split_shapes_values{}; 201 }; 202 203 class SmallSplitShapesDataset final : public SplitShapesDataset 204 { 205 public: SmallSplitShapesDataset()206 SmallSplitShapesDataset() 207 { 208 add_config(TensorShape(27U, 3U, 16U, 2U), 2U, std::vector<TensorShape> { TensorShape(27U, 3U, 4U, 2U), 209 TensorShape(27U, 3U, 4U, 2U), 210 TensorShape(27U, 3U, 8U, 2U) 211 }); 212 } 213 }; 214 215 } // namespace datasets 216 } // namespace test 217 } // namespace arm_compute 218 #endif /* ARM_COMPUTE_TEST_SPLIT_DATASET */ 219