• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TENSORINFO_H
25 #define ARM_COMPUTE_TENSORINFO_H
26 
27 #include "arm_compute/core/ITensorInfo.h"
28 
29 #include "ITensorInfo.h"
30 #include "arm_compute/core/Coordinates.h"
31 #include "arm_compute/core/Helpers.h"
32 #include "arm_compute/core/Strides.h"
33 #include "arm_compute/core/TensorShape.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
36 
37 #include <cstddef>
38 #include <memory>
39 
40 namespace arm_compute
41 {
42 /** Store the tensor's metadata */
43 class TensorInfo final : public ITensorInfo
44 {
45 public:
46     /** Default constructor */
47     TensorInfo();
48     /** Default destructor */
49     ~TensorInfo() = default;
50     /** Allow instances of this class to be copy constructed */
51     TensorInfo(const ITensorInfo &info);
52     /** Allow instances of this class to be copy constructed */
53     TensorInfo(const TensorInfo &);
54     /** Allow instances of this class to be copied */
55     TensorInfo &operator=(const TensorInfo &) = default;
56     /** Allow instances of this class to be move constructed */
57     TensorInfo(TensorInfo &&) = default;
58     /** Allow instances of this class to be moved */
59     TensorInfo &operator=(TensorInfo &&) = default;
60 
61     /** Construct a tensor info with a format.
62      *
63      * Can be used for automatic derivation of the shape by the function.
64      *
65      * @param[in] format Format of the tensor.
66      */
67     TensorInfo(Format format);
68 
69     /** 2D tensor constructor
70      *
71      * @param[in] width  Width of the 2D tensor
72      * @param[in] height Height of the 2D tensor
73      * @param[in] format Single plane format of the tensor.
74      */
75     TensorInfo(unsigned int width, unsigned int height, Format format);
76     /** Constructor
77      *
78      * @param[in] tensor_shape It specifies the size for each dimension of the tensor in number of elements.
79      * @param[in] format       Single plane format of the tensor.
80      */
81     TensorInfo(const TensorShape &tensor_shape, Format format);
82 
83     /** Construct a tensor info with a data type and number of channels.
84      *
85      * Can be used for automatic derivation of the shape by the function.
86      *
87      * @param[in] num_channels It indicates the number of channels for each tensor element
88      * @param[in] data_type    Data type to use for each tensor element
89      */
90     TensorInfo(size_t num_channels, DataType data_type);
91 
92     /** Constructor
93      *
94      * @param[in] tensor_shape It specifies the size for each dimension of the tensor in number of elements.
95      * @param[in] num_channels It indicates the number of channels for each tensor element
96      * @param[in] data_type    Data type to use for each tensor element
97      */
98     TensorInfo(const TensorShape &tensor_shape, size_t num_channels, DataType data_type);
99 
100     /** Constructor
101      *
102      * @param[in] tensor_shape It specifies the size for each dimension of the tensor in number of elements.
103      * @param[in] num_channels It indicates the number of channels for each tensor element
104      * @param[in] data_type    Data type to use for each tensor element
105      * @param[in] data_layout  The data layout setting for the tensor data.
106      */
107     TensorInfo(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, DataLayout data_layout);
108 
109     /** Constructor
110      *
111      * @param[in] tensor_shape      It specifies the size for each dimension of the tensor in number of elements.
112      * @param[in] num_channels      It indicates the number of channels for each tensor element
113      * @param[in] data_type         Data type to use for each tensor element
114      * @param[in] quantization_info The quantization settings for the tensor data.
115      */
116     TensorInfo(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, QuantizationInfo quantization_info);
117 
118     /** Initialize the tensor info with just a format.
119      *
120      * Can be used for automatic derivation of the shape by the function.
121      *
122      * @param[in] format Single plane format of the tensor.
123      */
124     void init(Format format);
125 
126     /** Initialize the metadata structure with the given parameters
127      *
128      * @param[in] tensor_shape Size for each dimension of the tensor in number of elements.
129      * @param[in] format       Single plane format of the tensor.
130      */
131     void init(const TensorShape &tensor_shape, Format format);
132     /** Initialize the metadata structure with the given parameters
133      *
134      * @param[in] tensor_shape                  Size for each dimension of the tensor in number of elements.
135      * @param[in] format                        Single plane format of the tensor.
136      * @param[in] strides_in_bytes              Stride in bytes for accessing each dimension of the tensor.
137      * @param[in] offset_first_element_in_bytes Offset in bytes from the beginning of memory allocation to access the first element.
138      * @param[in] total_size_in_bytes           Size in bytes of the memory allocation (including the offset to the first element).
139      */
140     void init(const TensorShape &tensor_shape, Format format, const Strides &strides_in_bytes, size_t offset_first_element_in_bytes, size_t total_size_in_bytes);
141 
142     /** Initialize the tensor info with just a format.
143      *
144      * Can be used for automatic derivation of the shape by the function.
145      *
146      * @param[in] num_channels Desired number of channels for each tensor element.
147      * @param[in] data_type    Data type to use for each tensor element.
148      */
149     void init(size_t num_channels, DataType data_type);
150 
151     /** Initialize the metadata structure with the given parameters
152      *
153      * @param[in] tensor_shape Size for each dimension of the tensor in number of elements.
154      * @param[in] num_channels Desired number of channels for each tensor element.
155      * @param[in] data_type    Data type to use for each tensor element.
156      */
157     void init(const TensorShape &tensor_shape, size_t num_channels, DataType data_type);
158 
159     /** Initialize the metadata structure with the given parameters
160      *
161      * @param[in] tensor_shape                  Size for each dimension of the tensor in number of elements.
162      * @param[in] num_channels                  Desired number of channels for each tensor element.
163      * @param[in] data_type                     Data type to use for each tensor element.
164      * @param[in] strides_in_bytes              Stride in bytes for accessing each dimension of the tensor.
165      * @param[in] offset_first_element_in_bytes Offset in bytes from the beginning of memory allocation to access the first element.
166      * @param[in] total_size_in_bytes           Size in bytes of the memory allocation (including the offset to the first element).
167      */
168     void init(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, const Strides &strides_in_bytes, size_t offset_first_element_in_bytes,
169               size_t total_size_in_bytes);
170     /** Initialize the metadata structure for the given tensor shape and single-plane format, (Padding is automatically calculated)
171      *
172      * @note The padding used by this method is really conservative so that the tensor can be used for most functions.
173      *
174      * @param[in] tensor_shape It specifies the size for each dimension of the tensor in number of elements
175      * @param[in] format       Single plane format of the image.
176      *
177      * @return Total allocation size including padding in bytes.
178      */
179     size_t init_auto_padding(const TensorShape &tensor_shape, Format format);
180     /** Initialize the metadata structure for the given tensor shape, number of channels and
181      *  data type. (Padding is automatically calculated)
182      *
183      * @note The padding used by this method is really conservative so that the tensor can be used for most functions.
184      *
185      * @param[in] tensor_shape It specifies the size for each dimension of the tensor in number of elements
186      * @param[in] num_channels It indicates the number of channels for each tensor element
187      * @param[in] data_type    Data type to use for each tensor element
188      *
189      * @return Total allocation size including padding in bytes.
190      */
191     size_t init_auto_padding(const TensorShape &tensor_shape, size_t num_channels, DataType data_type);
192 
193     // Inherited methods overridden:
194     std::unique_ptr<ITensorInfo> clone() const override;
195     ITensorInfo &set_data_type(DataType data_type) override;
196     ITensorInfo &set_num_channels(int num_channels) override;
197     ITensorInfo &set_format(Format format) override;
198     ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
199     ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override;
200     ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override;
201     ITensorInfo &set_data_layout(const DataLayout &data_layout) override;
202     ITensorInfo &reset_padding() override;
203     bool         auto_padding() override;
204     ITensorInfo &set_lock_paddings(bool flag) override;
205     bool lock_paddings() const override;
206     bool extend_padding(const PaddingSize &padding) override;
dimension(size_t index)207     size_t dimension(size_t index) const override
208     {
209         return _tensor_shape[index];
210     }
dimension(DataLayoutDimension dimension)211     size_t dimension(DataLayoutDimension dimension) const override
212     {
213         return get_data_layout_dimension_index(_data_layout, dimension);
214     }
strides_in_bytes()215     const Strides &strides_in_bytes() const override
216     {
217         return _strides_in_bytes;
218     }
offset_first_element_in_bytes()219     size_t offset_first_element_in_bytes() const override
220     {
221         return _offset_first_element_in_bytes;
222     }
223     int32_t offset_element_in_bytes(const Coordinates &pos) const override;
element_size()224     size_t element_size() const override
225     {
226         return data_size_from_type(_data_type) * _num_channels;
227     }
num_dimensions()228     size_t num_dimensions() const override
229     {
230         return _tensor_shape.num_dimensions();
231     }
num_channels()232     size_t num_channels() const override
233     {
234         return _num_channels;
235     }
tensor_shape()236     const TensorShape &tensor_shape() const override
237     {
238         return _tensor_shape;
239     }
tensor_dims_state()240     const TensorDimsState &tensor_dims_state() const override
241     {
242         return _dims_state;
243     }
data_type()244     DataType data_type() const override
245     {
246         return _data_type;
247     }
format()248     Format format() const override
249     {
250         return _format;
251     }
total_size()252     size_t total_size() const override
253     {
254         return _total_size;
255     }
padding()256     PaddingSize padding() const override
257     {
258         return _padding;
259     }
has_padding()260     bool has_padding() const override
261     {
262         return !_padding.empty();
263     }
is_resizable()264     bool is_resizable() const override
265     {
266         return _is_resizable;
267     }
is_dynamic()268     bool is_dynamic() const override
269     {
270         return std::find(std::cbegin(_dims_state), std::cend(_dims_state), get_dynamic_state_value()) != std::cend(_dims_state);
271     }
are_values_constant()272     bool are_values_constant() const override
273     {
274         return _are_values_constant;
275     }
set_is_resizable(bool is_resizable)276     ITensorInfo &set_is_resizable(bool is_resizable) override
277     {
278         _is_resizable = is_resizable;
279         return *this;
280     }
valid_region()281     ValidRegion valid_region() const override
282     {
283         return _valid_region;
284     }
set_valid_region(const ValidRegion & valid_region)285     void set_valid_region(const ValidRegion &valid_region) override
286     {
287         _valid_region = valid_region;
288     }
quantization_info()289     QuantizationInfo quantization_info() const override
290     {
291         return _quantization_info;
292     }
data_layout()293     DataLayout data_layout() const override
294     {
295         return _data_layout;
296     }
set_are_values_constant(bool are_values_constant)297     ITensorInfo &set_are_values_constant(bool are_values_constant) override
298     {
299         _are_values_constant = are_values_constant;
300         return *this;
301     }
id()302     ITensorInfo::Id id() const override
303     {
304         return _id;
305     }
set_id(ITensorInfo::Id id)306     ITensorInfo &set_id(ITensorInfo::Id id) override
307     {
308         _id = id;
309         return *this;
310     }
311     inline friend bool operator==(const TensorInfo &lhs, const TensorInfo &rhs);
312 
313 private:
314     /** Calculates strides, offset and total size resulting from the specified padding around the XY plane.
315      *
316      * @param[in] padding Padding around the XY plane in elements.
317      */
318     std::tuple<Strides, size_t, size_t> calculate_padding_requirements(const PaddingSize &padding);
319 
320     size_t           _total_size;
321     size_t           _offset_first_element_in_bytes;
322     Strides          _strides_in_bytes;
323     size_t           _num_channels;
324     TensorShape      _tensor_shape;
325     TensorDimsState  _dims_state;
326     DataType         _data_type;
327     Format           _format;
328     bool             _is_resizable;
329     ValidRegion      _valid_region;
330     PaddingSize      _padding;
331     QuantizationInfo _quantization_info;
332     DataLayout       _data_layout;
333     bool             _are_values_constant;
334     ITensorInfo::Id  _id;
335     bool             _lock_paddings;
336 };
337 
338 /** Check whether two tensor info are equal.
339  *
340  * @param[in] lhs LHS tensor info.
341  * @param[in] rhs RHS tensor info.
342  *
343  * @return True if the given tensor infos are the same.
344  */
345 inline bool operator==(const TensorInfo &lhs, const TensorInfo &rhs)
346 {
347     return (lhs._total_size == rhs._total_size) && (lhs._offset_first_element_in_bytes == rhs._offset_first_element_in_bytes) && (lhs._strides_in_bytes == rhs._strides_in_bytes)
348            && (lhs._num_channels == rhs._num_channels) && (lhs._tensor_shape == rhs._tensor_shape) && (lhs._dims_state == rhs._dims_state) && (lhs._data_type == rhs._data_type) && (lhs._format == rhs._format)
349            && (lhs._is_resizable == rhs._is_resizable) && (lhs._valid_region == rhs._valid_region) && (lhs._padding == rhs._padding) && (lhs._quantization_info == rhs._quantization_info)
350            && (lhs._data_layout == rhs._data_layout) && (lhs._are_values_constant == rhs._are_values_constant)
351            && (lhs._id == rhs._id);
352 }
353 } // namespace arm_compute
354 #endif /*ARM_COMPUTE_TENSORINFO_H */
355