1 /* 2 * Copyright (c) 2016-2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TENSORSHAPE_H 25 #define ARM_COMPUTE_TENSORSHAPE_H 26 27 #include "arm_compute/core/Dimensions.h" 28 #include "arm_compute/core/Error.h" 29 #include "arm_compute/core/utils/misc/Utility.h" 30 31 #include <algorithm> 32 #include <array> 33 #include <functional> 34 #include <numeric> 35 36 namespace arm_compute 37 { 38 /** Shape of a tensor */ 39 class TensorShape : public Dimensions<size_t> 40 { 41 public: 42 /** Constructor to initialize the tensor shape. 43 * 44 * @param[in] dims Values to initialize the dimensions. 45 */ 46 template <typename... Ts> TensorShape(Ts...dims)47 TensorShape(Ts... dims) 48 : Dimensions{ dims... } 49 { 50 // Initialize unspecified dimensions to 1 51 if(_num_dimensions > 0) 52 { 53 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 54 } 55 56 // Correct number dimensions to ignore trailing dimensions of size 1 57 apply_dimension_correction(); 58 } 59 /** Allow instances of this class to be copy constructed */ 60 TensorShape(const TensorShape &) = default; 61 /** Allow instances of this class to be copied */ 62 TensorShape &operator=(const TensorShape &) = default; 63 /** Allow instances of this class to be move constructed */ 64 TensorShape(TensorShape &&) = default; 65 /** Allow instances of this class to be moved */ 66 TensorShape &operator=(TensorShape &&) = default; 67 /** Default destructor */ 68 ~TensorShape() = default; 69 70 /** Accessor to set the value of one of the dimensions. 71 * 72 * @param[in] dimension Dimension for which the value is set. 73 * @param[in] value Value to be set for the dimension. 74 * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1. 75 * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions of the shape. 76 * 77 * @return *this. 78 */ 79 TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true) 80 { 81 // Clear entire shape if one dimension is zero 82 if(value == 0) 83 { 84 _num_dimensions = 0; 85 std::fill(_id.begin(), _id.end(), 0); 86 } 87 else 88 { 89 // Make sure all empty dimensions are filled with 1 90 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 91 92 // Set the specified dimension and increase the number of dimensions if 93 // necessary 94 Dimensions::set(dimension, value, increase_dim_unit); 95 96 // Correct number dimensions to ignore trailing dimensions of size 1 97 if(apply_dim_correction) 98 { 99 apply_dimension_correction(); 100 } 101 } 102 return *this; 103 } 104 105 /** Accessor to remove the dimension n from the tensor shape. 106 * 107 * @note The upper dimensions of the tensor shape will be shifted down by 1 108 * 109 * @param[in] n Dimension to remove 110 */ remove_dimension(size_t n)111 void remove_dimension(size_t n) 112 { 113 ARM_COMPUTE_ERROR_ON(_num_dimensions < 1); 114 ARM_COMPUTE_ERROR_ON(n >= _num_dimensions); 115 116 std::copy(_id.begin() + n + 1, _id.end(), _id.begin() + n); 117 118 // Reduce number of dimensions 119 _num_dimensions--; 120 121 // Make sure all empty dimensions are filled with 1 122 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 123 124 // Correct number dimensions to ignore trailing dimensions of size 1 125 apply_dimension_correction(); 126 } 127 128 /** Collapse the first n dimensions. 129 * 130 * @param[in] n Number of dimensions to collapse into @p first 131 * @param[in] first Dimensions into which the following @p n are collapsed. 132 */ 133 void collapse(size_t n, size_t first = 0) 134 { 135 Dimensions::collapse(n, first); 136 137 // Make sure all empty dimensions are filled with 1 138 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 139 } 140 /** Shifts right the tensor shape increasing its dimensions 141 * 142 * @param[in] step Rotation step 143 */ shift_right(size_t step)144 void shift_right(size_t step) 145 { 146 ARM_COMPUTE_ERROR_ON(step > TensorShape::num_max_dimensions - num_dimensions()); 147 148 std::rotate(begin(), begin() + TensorShape::num_max_dimensions - step, end()); 149 _num_dimensions += step; 150 151 // Correct number dimensions to ignore trailing dimensions of size 1 152 apply_dimension_correction(); 153 } 154 155 /** Return a copy with collapsed dimensions starting from a given point. 156 * 157 * @param[in] start Starting point of collapsing dimensions. 158 * 159 * @return A copy with collapse dimensions starting from start. 160 */ collapsed_from(size_t start)161 TensorShape collapsed_from(size_t start) const 162 { 163 TensorShape copy(*this); 164 copy.collapse(num_dimensions() - start, start); 165 return copy; 166 } 167 168 /** Collapses all dimensions to a single linear total size. 169 * 170 * @return The total tensor size in terms of elements. 171 */ total_size()172 size_t total_size() const 173 { 174 return std::accumulate(_id.begin(), _id.end(), 1, std::multiplies<size_t>()); 175 } 176 /** Collapses given dimension and above. 177 * 178 * @param[in] dimension Size of the wanted dimension 179 * 180 * @return The linear size of the collapsed dimensions 181 */ total_size_upper(size_t dimension)182 size_t total_size_upper(size_t dimension) const 183 { 184 ARM_COMPUTE_ERROR_ON(dimension >= TensorShape::num_max_dimensions); 185 return std::accumulate(_id.begin() + dimension, _id.end(), 1, std::multiplies<size_t>()); 186 } 187 188 /** Compute size of dimensions lower than the given one. 189 * 190 * @param[in] dimension Upper boundary. 191 * 192 * @return The linear size of the collapsed dimensions. 193 */ total_size_lower(size_t dimension)194 size_t total_size_lower(size_t dimension) const 195 { 196 ARM_COMPUTE_ERROR_ON(dimension > TensorShape::num_max_dimensions); 197 return std::accumulate(_id.begin(), _id.begin() + dimension, 1, std::multiplies<size_t>()); 198 } 199 200 /** If shapes are broadcast compatible, return the broadcasted shape. 201 * 202 * Two tensor shapes are broadcast compatible if for each dimension, they're equal or one of them is 1. 203 * 204 * If two shapes are compatible, each dimension in the broadcasted shape is the max of the original dimensions. 205 * 206 * @param[in] shapes Tensor shapes. 207 * 208 * @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible. 209 */ 210 template <typename... Shapes> broadcast_shape(const Shapes &...shapes)211 static TensorShape broadcast_shape(const Shapes &... shapes) 212 { 213 TensorShape bc_shape; 214 215 auto broadcast = [&bc_shape](const TensorShape & other) 216 { 217 if(bc_shape.num_dimensions() == 0) 218 { 219 bc_shape = other; 220 } 221 else if(other.num_dimensions() != 0) 222 { 223 for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d) 224 { 225 const size_t dim_min = std::min(bc_shape[d], other[d]); 226 const size_t dim_max = std::max(bc_shape[d], other[d]); 227 228 if((dim_min != 1) && (dim_min != dim_max)) 229 { 230 bc_shape = TensorShape{ 0U }; 231 break; 232 } 233 234 bc_shape.set(d, dim_max); 235 } 236 } 237 }; 238 239 utility::for_each(broadcast, shapes...); 240 241 return bc_shape; 242 } 243 244 private: 245 /** Remove trailing dimensions of size 1 from the reported number of dimensions. */ apply_dimension_correction()246 void apply_dimension_correction() 247 { 248 for(int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i) 249 { 250 if(_id[i] == 1) 251 { 252 --_num_dimensions; 253 } 254 else 255 { 256 break; 257 } 258 } 259 } 260 }; 261 } 262 #endif /*ARM_COMPUTE_TENSORSHAPE_H*/ 263