1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_META_TENSOR_H_ 18 #define MINDSPORE_CORE_IR_META_TENSOR_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 #include <string> 24 #include "base/base.h" 25 #include "ir/param_info.h" 26 #include "ir/dtype.h" 27 #include "utils/convert_utils_base.h" 28 #include "utils/hashing.h" 29 #include "utils/shape_utils.h" 30 31 // brief mindspore namespace. 32 // 33 // mindspore namespace is the top level namespace of MindSpore project. 34 // Other namespace should be a sub namespace of mindspore namespace in the ME project. 35 namespace mindspore { 36 37 // brief mindspore::tensor namespace 38 // 39 // A sub namespace in ME to support tensor related definition. 40 namespace tensor { 41 // brief Metadata of Tensor 42 // 43 // Includes the metadata information of a tensor, such as data type, shape 44 // and so on. But it does not contain values of a tensor. 45 class MS_CORE_API MetaTensor : public Value { 46 public: 47 /// \brief Construction 48 MetaTensor(); 49 50 /// \brief Constructs a meta tensor of a tensor having data_type data and shape. 51 /// The constructed MetaTensor is not a Tensor, but it has the data type and shape 52 /// information of a Tensor. 53 /// 54 /// \param[in] data_type The data type of the tensor. 55 /// \param[in] shape The shape of the tensor. 56 MetaTensor(TypeId data_type, const ShapeVector &shape); 57 58 MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape); 59 /// \brief Copy constructor. 60 /// The constructed MetaTensor object will have the same data type and shape as the 61 /// meta_tensor. 62 /// 63 /// \param[in] meta_tensor An existing MetaTensor object. 64 MetaTensor(const MetaTensor &meta_tensor); 65 66 /// \brief Destrustor of MetaTensor. 67 ~MetaTensor() override = default; 68 MS_DECLARE_PARENT(MetaTensor, Value) 69 70 /// \brief Overloads operator = for MetaTensor. 71 /// The constructed MetaTensor object has the same type and shape with meta_tensor. 72 /// 73 /// \param[in] meta_tensor An existing MetaTensor object. 74 /// \return A MetaTensor object. 75 MetaTensor &operator=(const MetaTensor &meta_tensor); 76 77 /// \brief Compares two MetaTensor objects. 78 /// The constructed MetaTensor object has the same type and shape with meta_tensor. 79 /// 80 /// \param[in] meta_tensor The MetaTensor object to be compared. 81 /// \return Return true if having same type and shape, otherwise return false. 82 virtual bool operator==(const MetaTensor &meta_tensor) const; 83 84 /// \brief Get the data type of the tensor in its MetaTensor. 85 /// All the types are defined in "ir/dtype.h". 86 /// 87 /// \return The data type of the tensor in its MetaTensor. 88 TypePtr Dtype() const; 89 90 abstract::AbstractBasePtr ToAbstract() override; 91 92 /// \brief Get the data type of a tensor in its MetaTensor. 93 /// 94 /// \return The data type. data_type()95 TypeId data_type() const { return data_type_; } 96 97 std::string ToString() const override; 98 99 /// \brief Set the data type of a tensor in its MetaTensor. 100 /// 101 /// \param[in] data_type The data type of the tensor to be set. set_data_type(TypeId data_type)102 virtual TypeId set_data_type(TypeId data_type) { 103 data_type_ = data_type; 104 return data_type_; 105 } 106 107 /// \brief Set the dtype of a tensor in its MetaTensor. 108 /// 109 /// \param[in] type_ptr The dtype of the tensor to be set. 110 virtual TypePtr SetDtype(const TypePtr type_ptr); 111 112 /// \brief Get tensor's shape. 113 /// The shape of a tensor is stored in a vector<int>. Each 114 /// element of the vector represents the size of a dimension of the tensor. 115 /// The order of each element in the vector is the same as the the dimension's 116 /// order it represents. 117 /// 118 /// \return A const vector<int> which represents the shape of the tensor. shape()119 const ShapeVector &shape() const { return shape_; } 120 121 /// \brief Sets the shape of a tensor. 122 /// The shape of a tensor is stored in a vector<int>. Each 123 /// element of the vector represents the size of a dimension of the tensor. 124 /// The order of each element in the vector is the same as the the dimension's 125 /// order it represents. 126 /// 127 /// \param[in] shape The shape of the tensor. 128 /// \return The shape's size. set_shape(const ShapeVector & shape)129 virtual size_t set_shape(const ShapeVector &shape) { 130 this->shape_ = shape; 131 return shape_.size(); 132 } 133 134 /// \brief Get the size of a given dimension by its index number. 135 /// 136 /// \return The size of a given dimension by its index number. 137 int64_t DimensionSize(size_t index) const; 138 139 /// \brief Get total number of elements in a tensor. 140 /// 141 /// \return The total number of elements in a tensor. 142 int64_t ElementsNum() const; 143 hash()144 std::size_t hash() const override { 145 std::size_t hash_value = std::hash<int>{}(static_cast<int>(data_type_)); 146 hash_value = hash_combine(hash_value, std::hash<size_t>{}(shape_.size())); 147 // hash all elements may costly, so only take at most 4 elements into account based on 148 // some experiments. 149 for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) { 150 hash_value = hash_combine(hash_value, (std::hash<int>{}(LongToInt(shape_[i])))); 151 } 152 return hash_value; 153 } 154 bool operator==(const Value &other) const override { 155 if (other.isa<MetaTensor>()) { 156 auto &other_ = static_cast<const MetaTensor &>(other); 157 return *this == other_; 158 } else { 159 return false; 160 } 161 } 162 /// \brief Get tensor's param_info info. 163 /// 164 /// \return The tensor's param_info info. param_info()165 ParamInfoPtr param_info() const { return param_info_; } 166 167 /// \brief Check whether this Tensor is a parameter. 168 /// 169 /// \return Whether this Tensor is a parameter. is_parameter()170 bool is_parameter() const { return is_parameter_; } 171 172 /// \brief Set tensor's param_info info. 173 /// 174 /// \param[in] param_info The input param_info. set_param_info(const ParamInfoPtr & param_info)175 void set_param_info(const ParamInfoPtr ¶m_info) { 176 is_parameter_ = true; 177 param_info_ = param_info; 178 } 179 180 protected: 181 // brief Data type of the tensor. 182 // 183 // All support data type is in Number Types of [TypeId], 184 // including [kNumberTypeBool], [kNumberTypeInt], 185 // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64]. 186 TypeId data_type_; 187 188 // brief Shape of the tensor. 189 // 190 // A ShapeVector container is used to store the shape of a tensor. 191 // Each element of the vector represents the size of a dimension of the tensor. 192 // The order of each element in the vector is as same as the the dimension's 193 // order it represents. If the dimension size is not set, its value will be -1. 194 ShapeVector shape_; 195 196 bool is_parameter_{false}; 197 ParamInfoPtr param_info_{nullptr}; 198 }; 199 200 using MetaTensorPtr = std::shared_ptr<MetaTensor>; 201 202 // brief Metadata of SparseTensor 203 // 204 // Includes the metadata information of a SparseTensor, such as data type, shape 205 // and so on. But it does not contain values of a SparseTensor. 206 class MS_CORE_API MetaSparseTensor : public Value { 207 public: 208 /// \brief Construction 209 MetaSparseTensor(); 210 211 /// \brief Constructs a meta SparseTensor having data_type data and shape. 212 /// The constructed MetaSparseTensor contains the data type and shape information of 213 /// a SparseTensor. 214 /// 215 /// \param[in] data_type The data type of the SparseTensor. 216 /// \param[in] shape The shape of the SparseTensor. 217 MetaSparseTensor(TypeId data_type, const ShapeVector &shape); 218 219 /// \brief Copy constructor. 220 /// The constructed MetaSparseTensor object will have the same data type and shape as the 221 /// meta_sparse_tensor. 222 /// 223 /// \param[in] meta_tensor An existing MetaSparseTensor object. 224 MetaSparseTensor(const MetaSparseTensor &meta_sparse_tensor); 225 226 /// \brief Copy assignment operator. 227 /// 228 /// \param[in] meta_sparse_tensor An existing MetaSparseTensor object. 229 /// \return A MetaSparseTensor object set with the same data type and shape as the meta_sparse_tensor. 230 MetaSparseTensor &operator=(const MetaSparseTensor &meta_sparse_tensor); 231 232 /// \brief Destrustor of MetaSparseTensor. 233 ~MetaSparseTensor() override = default; 234 MS_DECLARE_PARENT(MetaSparseTensor, Value) 235 236 /// \brief Compares two MetaSparseTensor objects. 237 /// The constructed MetaSparseTensor object has the same type and shape with meta_sparse_tensor. 238 /// 239 /// \param[in] meta_sparse_tensor The MetaSparseTensor object to be compared. 240 /// \return Return true if having same type and shape, otherwise return false. 241 virtual bool operator==(const MetaSparseTensor &meta_sparse_tensor) const { 242 return data_type_ == meta_sparse_tensor.data_type() && shape_ == meta_sparse_tensor.shape(); 243 } 244 245 /// \brief Get the data type of the sparse tensor. 246 /// All the types are defined in "ir/dtype.h". 247 /// 248 /// \return The data type of the sparse tensor. 249 TypePtr Dtype() const; 250 251 /// \brief Get the data type of a sparse tensor. 252 /// 253 /// \return The data type. data_type()254 TypeId data_type() const { return data_type_; } 255 256 /// \brief Set the data type of a sparse tensor. 257 /// 258 /// \param[in] data_type The data type of the tensor to be set. set_data_type(TypeId data_type)259 void set_data_type(TypeId data_type) { data_type_ = data_type; } 260 261 /// \brief Get sparsetensor's shape. 262 /// 263 /// \return A const vector<int> which represents the shape of the tensor. shape()264 const ShapeVector &shape() const { return shape_; } 265 266 /// \brief Sets the shape of a sparsetensor. 267 /// 268 /// \param[in] shape The shape of the tensor. set_shape(const ShapeVector & shape)269 void set_shape(const ShapeVector &shape) { this->shape_ = shape; } 270 271 /// \brief Get display information of this Tensor. 272 /// 273 /// \return The display information of this Tensor. 274 virtual std::string ToString() const = 0; 275 276 protected: 277 // Data type of the sparsetensor. 278 TypeId data_type_; 279 280 // Shape of the sparsetensor. 281 ShapeVector shape_; 282 }; 283 284 using MetaSparseTensorPtr = std::shared_ptr<MetaSparseTensor>; 285 } // namespace tensor 286 } // namespace mindspore 287 288 #endif // MINDSPORE_CORE_IR_META_TENSOR_H_ 289