1 /* 2 * Copyright (c) 2016-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TENSORSHAPE_H 25 #define ARM_COMPUTE_TENSORSHAPE_H 26 27 #include "arm_compute/core/Dimensions.h" 28 #include "arm_compute/core/Error.h" 29 #include "arm_compute/core/utils/misc/Utility.h" 30 31 #include <algorithm> 32 #include <array> 33 #include <functional> 34 #include <numeric> 35 36 namespace arm_compute 37 { 38 /** Shape of a tensor */ 39 class TensorShape : public Dimensions<size_t> 40 { 41 public: 42 /** Constructor to initialize the tensor shape. 43 * 44 * @param[in] dims Values to initialize the dimensions. 45 */ 46 template <typename... Ts> TensorShape(Ts...dims)47 TensorShape(Ts... dims) 48 : Dimensions{ dims... } 49 { 50 // Initialize unspecified dimensions to 1 51 if(_num_dimensions > 0) 52 { 53 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 54 } 55 56 // Correct number dimensions to ignore trailing dimensions of size 1 57 apply_dimension_correction(); 58 } 59 /** Allow instances of this class to be copy constructed */ 60 TensorShape(const TensorShape &) = default; 61 /** Allow instances of this class to be copied */ 62 TensorShape &operator=(const TensorShape &) = default; 63 /** Allow instances of this class to be move constructed */ 64 TensorShape(TensorShape &&) = default; 65 /** Allow instances of this class to be moved */ 66 TensorShape &operator=(TensorShape &&) = default; 67 /** Default destructor */ 68 ~TensorShape() = default; 69 70 /** Accessor to set the value of one of the dimensions. 71 * 72 * @param[in] dimension Dimension for which the value is set. 73 * @param[in] value Value to be set for the dimension. 74 * @param[in] apply_dim_correction Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1. 75 * 76 * @return *this. 77 */ 78 TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true) 79 { 80 // Clear entire shape if one dimension is zero 81 if(value == 0) 82 { 83 _num_dimensions = 0; 84 std::fill(_id.begin(), _id.end(), 0); 85 } 86 else 87 { 88 // Make sure all empty dimensions are filled with 1 89 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 90 91 // Set the specified dimension and increase the number of dimensions if 92 // necessary 93 Dimensions::set(dimension, value); 94 95 // Correct number dimensions to ignore trailing dimensions of size 1 96 if(apply_dim_correction) 97 { 98 apply_dimension_correction(); 99 } 100 } 101 return *this; 102 } 103 104 /** Accessor to remove the dimension n from the tensor shape. 105 * 106 * @note The upper dimensions of the tensor shape will be shifted down by 1 107 * 108 * @param[in] n Dimension to remove 109 */ remove_dimension(size_t n)110 void remove_dimension(size_t n) 111 { 112 ARM_COMPUTE_ERROR_ON(_num_dimensions < 1); 113 ARM_COMPUTE_ERROR_ON(n >= _num_dimensions); 114 115 std::copy(_id.begin() + n + 1, _id.end(), _id.begin() + n); 116 117 // Reduce number of dimensions 118 _num_dimensions--; 119 120 // Make sure all empty dimensions are filled with 1 121 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 122 123 // Correct number dimensions to ignore trailing dimensions of size 1 124 apply_dimension_correction(); 125 } 126 127 /** Collapse the first n dimensions. 128 * 129 * @param[in] n Number of dimensions to collapse into @p first 130 * @param[in] first Dimensions into which the following @p n are collapsed. 131 */ 132 void collapse(size_t n, size_t first = 0) 133 { 134 Dimensions::collapse(n, first); 135 136 // Make sure all empty dimensions are filled with 1 137 std::fill(_id.begin() + _num_dimensions, _id.end(), 1); 138 } 139 /** Shifts right the tensor shape increasing its dimensions 140 * 141 * @param[in] step Rotation step 142 */ shift_right(size_t step)143 void shift_right(size_t step) 144 { 145 ARM_COMPUTE_ERROR_ON(step > TensorShape::num_max_dimensions - num_dimensions()); 146 147 std::rotate(begin(), begin() + TensorShape::num_max_dimensions - step, end()); 148 _num_dimensions += step; 149 150 // Correct number dimensions to ignore trailing dimensions of size 1 151 apply_dimension_correction(); 152 } 153 154 /** Return a copy with collapsed dimensions starting from a given point. 155 * 156 * @param[in] start Starting point of collapsing dimensions. 157 * 158 * @return A copy with collapse dimensions starting from start. 159 */ collapsed_from(size_t start)160 TensorShape collapsed_from(size_t start) const 161 { 162 TensorShape copy(*this); 163 copy.collapse(num_dimensions() - start, start); 164 return copy; 165 } 166 167 /** Collapses all dimensions to a single linear total size. 168 * 169 * @return The total tensor size in terms of elements. 170 */ total_size()171 size_t total_size() const 172 { 173 return std::accumulate(_id.begin(), _id.end(), 1, std::multiplies<size_t>()); 174 } 175 /** Collapses given dimension and above. 176 * 177 * @param[in] dimension Size of the wanted dimension 178 * 179 * @return The linear size of the collapsed dimensions 180 */ total_size_upper(size_t dimension)181 size_t total_size_upper(size_t dimension) const 182 { 183 ARM_COMPUTE_ERROR_ON(dimension >= TensorShape::num_max_dimensions); 184 return std::accumulate(_id.begin() + dimension, _id.end(), 1, std::multiplies<size_t>()); 185 } 186 187 /** Compute size of dimensions lower than the given one. 188 * 189 * @param[in] dimension Upper boundary. 190 * 191 * @return The linear size of the collapsed dimensions. 192 */ total_size_lower(size_t dimension)193 size_t total_size_lower(size_t dimension) const 194 { 195 ARM_COMPUTE_ERROR_ON(dimension > TensorShape::num_max_dimensions); 196 return std::accumulate(_id.begin(), _id.begin() + dimension, 1, std::multiplies<size_t>()); 197 } 198 199 /** If shapes are broadcast compatible, return the broadcasted shape. 200 * 201 * Two tensor shapes are broadcast compatible if for each dimension, they're equal or one of them is 1. 202 * 203 * If two shapes are compatible, each dimension in the broadcasted shape is the max of the original dimensions. 204 * 205 * @param[in] shapes Tensor shapes. 206 * 207 * @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible. 208 */ 209 template <typename... Shapes> broadcast_shape(const Shapes &...shapes)210 static TensorShape broadcast_shape(const Shapes &... shapes) 211 { 212 TensorShape bc_shape; 213 214 auto broadcast = [&bc_shape](const TensorShape & other) 215 { 216 if(bc_shape.num_dimensions() == 0) 217 { 218 bc_shape = other; 219 } 220 else if(other.num_dimensions() != 0) 221 { 222 for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d) 223 { 224 const size_t dim_min = std::min(bc_shape[d], other[d]); 225 const size_t dim_max = std::max(bc_shape[d], other[d]); 226 227 if((dim_min != 1) && (dim_min != dim_max)) 228 { 229 bc_shape = TensorShape{ 0U }; 230 break; 231 } 232 233 bc_shape.set(d, dim_max); 234 } 235 } 236 }; 237 238 utility::for_each(broadcast, shapes...); 239 240 return bc_shape; 241 } 242 243 private: 244 /** Remove trailing dimensions of size 1 from the reported number of dimensions. */ apply_dimension_correction()245 void apply_dimension_correction() 246 { 247 for(int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i) 248 { 249 if(_id[i] == 1) 250 { 251 --_num_dimensions; 252 } 253 else 254 { 255 break; 256 } 257 } 258 } 259 }; 260 } 261 #endif /*ARM_COMPUTE_TENSORSHAPE_H*/ 262