• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TENSORSHAPE_H
25 #define ARM_COMPUTE_TENSORSHAPE_H
26 
27 #include "arm_compute/core/Dimensions.h"
28 #include "arm_compute/core/Error.h"
29 #include "arm_compute/core/utils/misc/Utility.h"
30 
31 #include <algorithm>
32 #include <array>
33 #include <functional>
34 #include <numeric>
35 
36 namespace arm_compute
37 {
38 /** Shape of a tensor */
39 class TensorShape : public Dimensions<size_t>
40 {
41 public:
42     /** Constructor to initialize the tensor shape.
43      *
44      * @param[in] dims Values to initialize the dimensions.
45      */
46     template <typename... Ts>
TensorShape(Ts...dims)47     TensorShape(Ts... dims)
48         : Dimensions{ dims... }
49     {
50         // Initialize unspecified dimensions to 1
51         if(_num_dimensions > 0)
52         {
53             std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
54         }
55 
56         // Correct number dimensions to ignore trailing dimensions of size 1
57         apply_dimension_correction();
58     }
59     /** Allow instances of this class to be copy constructed */
60     TensorShape(const TensorShape &) = default;
61     /** Allow instances of this class to be copied */
62     TensorShape &operator=(const TensorShape &) = default;
63     /** Allow instances of this class to be move constructed */
64     TensorShape(TensorShape &&) = default;
65     /** Allow instances of this class to be moved */
66     TensorShape &operator=(TensorShape &&) = default;
67     /** Default destructor */
68     ~TensorShape() = default;
69 
70     /** Accessor to set the value of one of the dimensions.
71      *
72      * @param[in] dimension            Dimension for which the value is set.
73      * @param[in] value                Value to be set for the dimension.
74      * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1.
75      * @param[in] increase_dim_unit    (Optional) Set to true if new unit dimensions increase the number of dimensions of the shape.
76      *
77      * @return *this.
78      */
79     TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true)
80     {
81         // Clear entire shape if one dimension is zero
82         if(value == 0)
83         {
84             _num_dimensions = 0;
85             std::fill(_id.begin(), _id.end(), 0);
86         }
87         else
88         {
89             // Make sure all empty dimensions are filled with 1
90             std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
91 
92             // Set the specified dimension and increase the number of dimensions if
93             // necessary
94             Dimensions::set(dimension, value, increase_dim_unit);
95 
96             // Correct number dimensions to ignore trailing dimensions of size 1
97             if(apply_dim_correction)
98             {
99                 apply_dimension_correction();
100             }
101         }
102         return *this;
103     }
104 
105     /** Accessor to remove the dimension n from the tensor shape.
106      *
107      * @note The upper dimensions of the tensor shape will be shifted down by 1
108      *
109      * @param[in] n Dimension to remove
110      */
remove_dimension(size_t n)111     void remove_dimension(size_t n)
112     {
113         ARM_COMPUTE_ERROR_ON(_num_dimensions < 1);
114         ARM_COMPUTE_ERROR_ON(n >= _num_dimensions);
115 
116         std::copy(_id.begin() + n + 1, _id.end(), _id.begin() + n);
117 
118         // Reduce number of dimensions
119         _num_dimensions--;
120 
121         // Make sure all empty dimensions are filled with 1
122         std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
123 
124         // Correct number dimensions to ignore trailing dimensions of size 1
125         apply_dimension_correction();
126     }
127 
128     /** Collapse the first n dimensions.
129      *
130      * @param[in] n     Number of dimensions to collapse into @p first
131      * @param[in] first Dimensions into which the following @p n are collapsed.
132      */
133     void collapse(size_t n, size_t first = 0)
134     {
135         Dimensions::collapse(n, first);
136 
137         // Make sure all empty dimensions are filled with 1
138         std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
139     }
140     /** Shifts right the tensor shape increasing its dimensions
141      *
142      * @param[in] step Rotation step
143      */
shift_right(size_t step)144     void shift_right(size_t step)
145     {
146         ARM_COMPUTE_ERROR_ON(step > TensorShape::num_max_dimensions - num_dimensions());
147 
148         std::rotate(begin(), begin() + TensorShape::num_max_dimensions - step, end());
149         _num_dimensions += step;
150 
151         // Correct number dimensions to ignore trailing dimensions of size 1
152         apply_dimension_correction();
153     }
154 
155     /** Return a copy with collapsed dimensions starting from a given point.
156      *
157      * @param[in] start Starting point of collapsing dimensions.
158      *
159      * @return A copy with collapse dimensions starting from start.
160      */
collapsed_from(size_t start)161     TensorShape collapsed_from(size_t start) const
162     {
163         TensorShape copy(*this);
164         copy.collapse(num_dimensions() - start, start);
165         return copy;
166     }
167 
168     /** Collapses all dimensions to a single linear total size.
169      *
170      * @return The total tensor size in terms of elements.
171      */
total_size()172     size_t total_size() const
173     {
174         return std::accumulate(_id.begin(), _id.end(), 1, std::multiplies<size_t>());
175     }
176     /** Collapses given dimension and above.
177      *
178      * @param[in] dimension Size of the wanted dimension
179      *
180      * @return The linear size of the collapsed dimensions
181      */
total_size_upper(size_t dimension)182     size_t total_size_upper(size_t dimension) const
183     {
184         ARM_COMPUTE_ERROR_ON(dimension >= TensorShape::num_max_dimensions);
185         return std::accumulate(_id.begin() + dimension, _id.end(), 1, std::multiplies<size_t>());
186     }
187 
188     /** Compute size of dimensions lower than the given one.
189      *
190      * @param[in] dimension Upper boundary.
191      *
192      * @return The linear size of the collapsed dimensions.
193      */
total_size_lower(size_t dimension)194     size_t total_size_lower(size_t dimension) const
195     {
196         ARM_COMPUTE_ERROR_ON(dimension > TensorShape::num_max_dimensions);
197         return std::accumulate(_id.begin(), _id.begin() + dimension, 1, std::multiplies<size_t>());
198     }
199 
200     /** If shapes are broadcast compatible, return the broadcasted shape.
201      *
202      * Two tensor shapes are broadcast compatible if for each dimension, they're equal or one of them is 1.
203      *
204      * If two shapes are compatible, each dimension in the broadcasted shape is the max of the original dimensions.
205      *
206      * @param[in] shapes Tensor shapes.
207      *
208      * @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible.
209      */
210     template <typename... Shapes>
broadcast_shape(const Shapes &...shapes)211     static TensorShape broadcast_shape(const Shapes &... shapes)
212     {
213         TensorShape bc_shape;
214 
215         auto broadcast = [&bc_shape](const TensorShape & other)
216         {
217             if(bc_shape.num_dimensions() == 0)
218             {
219                 bc_shape = other;
220             }
221             else if(other.num_dimensions() != 0)
222             {
223                 for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d)
224                 {
225                     const size_t dim_min = std::min(bc_shape[d], other[d]);
226                     const size_t dim_max = std::max(bc_shape[d], other[d]);
227 
228                     if((dim_min != 1) && (dim_min != dim_max))
229                     {
230                         bc_shape = TensorShape{ 0U };
231                         break;
232                     }
233 
234                     bc_shape.set(d, dim_max);
235                 }
236             }
237         };
238 
239         utility::for_each(broadcast, shapes...);
240 
241         return bc_shape;
242     }
243 
244 private:
245     /** Remove trailing dimensions of size 1 from the reported number of dimensions. */
apply_dimension_correction()246     void apply_dimension_correction()
247     {
248         for(int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i)
249         {
250             if(_id[i] == 1)
251             {
252                 --_num_dimensions;
253             }
254             else
255             {
256                 break;
257             }
258         }
259     }
260 };
261 }
262 #endif /*ARM_COMPUTE_TENSORSHAPE_H*/
263