• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_META_TENSOR_H_
18 #define MINDSPORE_CORE_IR_META_TENSOR_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include "base/base.h"
25 #include "ir/param_info.h"
26 #include "ir/dtype.h"
27 #include "utils/convert_utils_base.h"
28 #include "utils/hashing.h"
29 #include "utils/shape_utils.h"
30 
31 // brief mindspore namespace.
32 //
33 // mindspore namespace is the top level namespace of MindSpore project.
34 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
35 namespace mindspore {
36 
37 // brief mindspore::tensor namespace
38 //
39 // A sub namespace in ME to support tensor related definition.
40 namespace tensor {
41 // brief Metadata of Tensor
42 //
43 // Includes the metadata information of a tensor, such as data type, shape
44 // and so on. But it does not contain values of a tensor.
45 class MS_CORE_API MetaTensor : public Value {
46  public:
47   /// \brief Construction
48   MetaTensor();
49 
50   /// \brief Constructs a meta tensor of a tensor having data_type data and shape.
51   /// The constructed MetaTensor is not a Tensor, but it has the data type and shape
52   /// information of a Tensor.
53   ///
54   /// \param[in] data_type The data type of the tensor.
55   /// \param[in] shape The shape of the tensor.
56   MetaTensor(TypeId data_type, const ShapeVector &shape);
57 
58   MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape);
59   /// \brief Copy constructor.
60   /// The constructed MetaTensor object will have the same data type and shape as the
61   /// meta_tensor.
62   ///
63   /// \param[in] meta_tensor An existing MetaTensor object.
64   MetaTensor(const MetaTensor &meta_tensor);
65 
66   /// \brief Destrustor of MetaTensor.
67   ~MetaTensor() override = default;
68   MS_DECLARE_PARENT(MetaTensor, Value)
69 
70   /// \brief Overloads operator = for MetaTensor.
71   /// The constructed MetaTensor object has the same type and shape with meta_tensor.
72   ///
73   /// \param[in] meta_tensor An existing MetaTensor object.
74   /// \return A MetaTensor object.
75   MetaTensor &operator=(const MetaTensor &meta_tensor);
76 
77   /// \brief Compares two MetaTensor objects.
78   /// The constructed MetaTensor object has the same type and shape with meta_tensor.
79   ///
80   /// \param[in] meta_tensor The MetaTensor object to be compared.
81   /// \return Return true if having same type and shape, otherwise return false.
82   virtual bool operator==(const MetaTensor &meta_tensor) const;
83 
84   /// \brief Get the data type of the tensor in its MetaTensor.
85   /// All the types are defined in "ir/dtype.h".
86   ///
87   /// \return The data type of the tensor in its MetaTensor.
88   TypePtr Dtype() const;
89 
90   abstract::AbstractBasePtr ToAbstract() override;
91 
92   /// \brief Get the data type of a tensor in its MetaTensor.
93   ///
94   /// \return The data type.
data_type()95   TypeId data_type() const { return data_type_; }
96 
97   std::string ToString() const override;
98 
99   /// \brief Set the data type of a tensor in its MetaTensor.
100   ///
101   /// \param[in] data_type The data type of the tensor to be set.
set_data_type(TypeId data_type)102   virtual TypeId set_data_type(TypeId data_type) {
103     data_type_ = data_type;
104     return data_type_;
105   }
106 
107   /// \brief Set the dtype of a tensor in its MetaTensor.
108   ///
109   /// \param[in] type_ptr The dtype of the tensor to be set.
110   virtual TypePtr SetDtype(const TypePtr type_ptr);
111 
112   /// \brief Get tensor's shape.
113   /// The shape of a tensor is stored in a vector<int>. Each
114   /// element of the vector represents the size of a dimension of the tensor.
115   /// The order of each element in the vector is the same as the the dimension's
116   /// order it represents.
117   ///
118   /// \return A const vector<int> which represents the shape of the tensor.
shape()119   const ShapeVector &shape() const { return shape_; }
120 
121   /// \brief Sets the shape of a tensor.
122   /// The shape of a tensor is stored in a vector<int>. Each
123   /// element of the vector represents the size of a dimension of the tensor.
124   /// The order of each element in the vector is the same as the the dimension's
125   /// order it represents.
126   ///
127   /// \param[in] shape The shape of the tensor.
128   /// \return The shape's size.
set_shape(const ShapeVector & shape)129   virtual size_t set_shape(const ShapeVector &shape) {
130     this->shape_ = shape;
131     return shape_.size();
132   }
133 
134   /// \brief Get the size of a given dimension by its index number.
135   ///
136   /// \return The size of a given dimension by its index number.
137   int64_t DimensionSize(size_t index) const;
138 
139   /// \brief Get total number of elements in a tensor.
140   ///
141   /// \return The total number of elements in a tensor.
142   int64_t ElementsNum() const;
143 
hash()144   std::size_t hash() const override {
145     std::size_t hash_value = std::hash<int>{}(static_cast<int>(data_type_));
146     hash_value = hash_combine(hash_value, std::hash<size_t>{}(shape_.size()));
147     // hash all elements may costly, so only take at most 4 elements into account based on
148     // some experiments.
149     for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) {
150       hash_value = hash_combine(hash_value, (std::hash<int>{}(LongToInt(shape_[i]))));
151     }
152     return hash_value;
153   }
154   bool operator==(const Value &other) const override {
155     if (other.isa<MetaTensor>()) {
156       auto &other_ = static_cast<const MetaTensor &>(other);
157       return *this == other_;
158     } else {
159       return false;
160     }
161   }
162   /// \brief Get tensor's param_info info.
163   ///
164   /// \return The tensor's param_info info.
param_info()165   ParamInfoPtr param_info() const { return param_info_; }
166 
167   /// \brief Check whether this Tensor is a parameter.
168   ///
169   /// \return Whether this Tensor is a parameter.
is_parameter()170   bool is_parameter() const { return is_parameter_; }
171 
172   /// \brief Set tensor's param_info info.
173   ///
174   /// \param[in] param_info The input param_info.
set_param_info(const ParamInfoPtr & param_info)175   void set_param_info(const ParamInfoPtr &param_info) {
176     is_parameter_ = true;
177     param_info_ = param_info;
178   }
179 
180  protected:
181   // brief Data type of the tensor.
182   //
183   // All support data type is in Number Types of [TypeId],
184   // including [kNumberTypeBool], [kNumberTypeInt],
185   // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64].
186   TypeId data_type_;
187 
188   // brief Shape of the tensor.
189   //
190   // A ShapeVector container is used to store the shape of a tensor.
191   // Each element of the vector represents the size of a dimension of the tensor.
192   // The order of each element in the vector is as same as the the dimension's
193   // order it represents. If the dimension size is not set, its value will be -1.
194   ShapeVector shape_;
195 
196   bool is_parameter_{false};
197   ParamInfoPtr param_info_{nullptr};
198 };
199 
200 using MetaTensorPtr = std::shared_ptr<MetaTensor>;
201 
202 // brief Metadata of SparseTensor
203 //
204 // Includes the metadata information of a SparseTensor, such as data type, shape
205 // and so on. But it does not contain values of a SparseTensor.
206 class MS_CORE_API MetaSparseTensor : public Value {
207  public:
208   /// \brief Construction
209   MetaSparseTensor();
210 
211   /// \brief Constructs a meta SparseTensor having data_type data and shape.
212   /// The constructed MetaSparseTensor contains the data type and shape information of
213   /// a SparseTensor.
214   ///
215   /// \param[in] data_type The data type of the SparseTensor.
216   /// \param[in] shape The shape of the SparseTensor.
217   MetaSparseTensor(TypeId data_type, const ShapeVector &shape);
218 
219   /// \brief Copy constructor.
220   /// The constructed MetaSparseTensor object will have the same data type and shape as the
221   /// meta_sparse_tensor.
222   ///
223   /// \param[in] meta_tensor An existing MetaSparseTensor object.
224   MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor);
225 
226   /// \brief Copy assignment operator.
227   ///
228   /// \param[in] meta_sparse_tensor An existing MetaSparseTensor object.
229   /// \return A MetaSparseTensor object set with the same data type and shape as the meta_sparse_tensor.
230   MetaSparseTensor &operator=(const MetaSparseTensor &meta_sparse_tensor);
231 
232   /// \brief Destrustor of MetaSparseTensor.
233   ~MetaSparseTensor() override = default;
234   MS_DECLARE_PARENT(MetaSparseTensor, Value)
235 
236   /// \brief Compares two MetaSparseTensor objects.
237   /// The constructed MetaSparseTensor object has the same type and shape with meta_sparse_tensor.
238   ///
239   /// \param[in] meta_sparse_tensor The MetaSparseTensor object to be compared.
240   /// \return Return true if having same type and shape, otherwise return false.
241   virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const {
242     return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape();
243   }
244 
245   /// \brief Get the data type of the sparse tensor.
246   /// All the types are defined in "ir/dtype.h".
247   ///
248   /// \return The data type of the sparse tensor.
249   TypePtr Dtype() const;
250 
251   /// \brief Get the data type of a sparse tensor.
252   ///
253   /// \return The data type.
data_type()254   TypeId data_type() const { return data_type_; }
255 
256   /// \brief Set the data type of a sparse tensor.
257   ///
258   /// \param[in] data_type The data type of the tensor to be set.
set_data_type(TypeId data_type)259   void set_data_type(TypeId data_type) { data_type_ = data_type; }
260 
261   /// \brief Get sparsetensor's shape.
262   ///
263   /// \return A const vector<int> which represents the shape of the tensor.
shape()264   const ShapeVector &shape() const { return shape_; }
265 
266   /// \brief Sets the shape of a sparsetensor.
267   ///
268   /// \param[in] shape The shape of the tensor.
set_shape(const ShapeVector & shape)269   void set_shape(const ShapeVector &shape) { this->shape_ = shape; }
270 
271   /// \brief Get display information of this Tensor.
272   ///
273   /// \return The display information of this Tensor.
274   virtual std::string ToString() const = 0;
275 
276  protected:
277   // Data type of the sparsetensor.
278   TypeId data_type_;
279 
280   // Shape of the sparsetensor.
281   ShapeVector shape_;
282 };
283 
284 using MetaSparseTensorPtr = std::shared_ptr<MetaSparseTensor>;
285 }  // namespace tensor
286 }  // namespace mindspore
287 
288 #endif  // MINDSPORE_CORE_IR_META_TENSOR_H_
289