1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ 18 #define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ 19 20 #include <cstddef> 21 #include <iostream> 22 #include <initializer_list> 23 #include <map> 24 #include <memory> 25 #include <utility> 26 #include <sstream> 27 #include <string> 28 #include <vector> 29 #include <type_traits> 30 #include <algorithm> 31 #include "utils/hash_map.h" 32 #include "base/base.h" 33 #include "ir/named.h" 34 #include "ir/dtype/type.h" 35 36 namespace mindspore { 37 /// \brief UndeterminedType defines interface for tensor undetermined data type. 38 class MS_CORE_API UndeterminedType final : public Object { 39 public: 40 /// \brief Default constructor for UndeterminedType. UndeterminedType()41 UndeterminedType() : Object(kObjectTypeUndeterminedType) {} 42 43 /// \brief Constructor for UndeterminedType. 44 /// 45 /// \param[in] ele The element of UndeterminedType. UndeterminedType(const TypePtr & ele)46 explicit UndeterminedType(const TypePtr &ele) 47 : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} 48 49 /// \brief Destructor of UndeterminedType. 50 ~UndeterminedType() override = default; MS_DECLARE_PARENT(UndeterminedType,Object)51 MS_DECLARE_PARENT(UndeterminedType, Object) 52 53 TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } 54 55 /// \brief Get the element of UndeterminedType object. 56 /// 57 /// \return The element of UndeterminedType object. element()58 const TypePtr element() const { return element_type_; } 59 60 /// \brief Set the element of UndeterminedType object. 61 /// 62 /// \param[in] element_type Define the element type to be set. set_element(const TypePtr & element_type)63 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 64 65 TypePtr DeepCopy() const override; 66 std::string ToString() const override; 67 std::string ToReprString() const override; 68 std::string DumpText() const override; 69 70 bool operator==(const Type &other) const override; 71 std::size_t hash() const override; 72 73 protected: 74 TypePtr element_type_; 75 }; 76 using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; 77 78 /// \brief TensorType defines interface for tensor data type. 79 class MS_CORE_API TensorType : public Object { 80 public: 81 /// \brief Default constructor for TensorType. TensorType()82 TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} 83 84 /// \brief Constructor for TensorType. 85 /// 86 /// \param[in] ele The element of TensorType. TensorType(const TypePtr & ele)87 explicit TensorType(const TypePtr &ele) 88 : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} 89 90 /// \brief Destructor of TensorType. 91 ~TensorType() override = default; MS_DECLARE_PARENT(TensorType,Object)92 MS_DECLARE_PARENT(TensorType, Object) 93 94 TypeId generic_type_id() const override { return kObjectTypeTensorType; } 95 96 /// \brief Get the element of TensorType object. 97 /// 98 /// \return The element of TensorType object. element()99 const TypePtr element() const { return element_type_; } 100 101 /// \brief Set the element of TensorType object. 102 /// 103 /// \param[in] element_type Define the element type to be set. set_element(const TypePtr & element_type)104 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 105 106 TypePtr DeepCopy() const override; 107 std::string ToString() const override; 108 std::string ToReprString() const override; 109 std::string DumpText() const override; 110 bool operator==(const Type &other) const override; 111 112 /// \brief Overwrite the operator '==' to compare other tensor type. 113 /// 114 /// \param[in] other The other tensor type value to be compared. 115 /// 116 /// \return A boolean, which indicates whether the type is same. 117 bool operator==(const TensorType &other) const; 118 std::size_t hash() const override; 119 120 private: 121 TypePtr element_type_; 122 }; 123 using TensorTypePtr = std::shared_ptr<TensorType>; 124 125 /// \brief AnyType defines interface for any data type. 126 class MS_CORE_API AnyType : public TensorType { 127 public: 128 /// \brief Default constructor for AnyType. 129 AnyType() = default; 130 131 /// \brief Constructor for AnyType. 132 /// 133 /// \param[in] element_type The element type of AnyType. AnyType(const TypePtr & element_type)134 explicit AnyType(const TypePtr &element_type) : TensorType(element_type) {} 135 136 /// \brief Destructor of AnyType. 137 ~AnyType() override = default; 138 MS_DECLARE_PARENT(AnyType, TensorType) 139 140 std::string ToString() const override; 141 std::string DumpText() const override; 142 bool operator==(const Type &other) const override; 143 144 /// \brief Overwrite the operator '==' to compare other anytype. 145 /// 146 /// \param[in] other The other anytype value to be compared. 147 /// 148 /// \return A boolean, which indicates whether the type is same. 149 bool operator==(const AnyType &other) const; 150 }; 151 using AnyTypePtr = std::shared_ptr<AnyType>; 152 153 /// \brief NegligibleType defines interface for negligible data type. 154 class MS_CORE_API NegligibleType final : public AnyType { 155 public: 156 /// \brief Default constructor for NegligibleType. 157 NegligibleType() = default; 158 159 /// \brief Constructor for NegligibleType. 160 /// 161 /// \param[in] element_type The element type of NegligibleType. NegligibleType(const TypePtr & element_type)162 explicit NegligibleType(const TypePtr &element_type) : AnyType(element_type) {} 163 164 /// \brief Destructor of NegligibleType. 165 ~NegligibleType() override = default; 166 MS_DECLARE_PARENT(NegligibleType, AnyType) 167 168 std::string ToString() const override; 169 std::string DumpText() const override; 170 }; 171 using NegligibleTypePtr = std::shared_ptr<NegligibleType>; 172 173 /// \brief SparseTensorType is the base type for all sparse tensors. 174 class MS_CORE_API SparseTensorType : public Object { 175 public: SparseTensorType()176 SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} 177 SparseTensorType(const TypeId object_type)178 explicit SparseTensorType(const TypeId object_type) : Object(object_type, kObjectTypeUndeterminedType) {} 179 SparseTensorType(const TypePtrList & objs)180 explicit SparseTensorType(const TypePtrList &objs) 181 : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {} 182 SparseTensorType(const TypeId object_type,const TypePtrList & objs)183 SparseTensorType(const TypeId object_type, const TypePtrList &objs) 184 : Object(object_type, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {} 185 186 /// \brief Destructor of SparseTensorType. 187 ~SparseTensorType() override = default; 188 MS_DECLARE_PARENT(SparseTensorType, Object) 189 190 enum StringType : int { kToString = 0, kDumpText, kReprString }; 191 GetSparseTensorTypeName()192 virtual std::string GetSparseTensorTypeName() const { return "SparseTensorType"; } GetElementIndex()193 virtual size_t GetElementIndex() { return 0; } element_type()194 virtual TypePtr element_type() { 195 if (elements_.empty()) { 196 return nullptr; 197 } 198 return elements_[GetElementIndex()]; 199 } 200 std::string ElementsDtypeStr(const StringType str_type) const; generic_type_id()201 TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } 202 203 const TypePtr operator[](std::size_t dim) const; 204 bool operator==(const Type &other) const override; 205 std::size_t hash() const override; elements()206 TypePtrList elements() const { return elements_; } 207 size()208 std::size_t size() const { return elements_.size(); } 209 std::string ToString() const override; 210 std::string ToReprString() const override; 211 std::string DumpText() const override; 212 const TypePtrList ElementsClone() const; 213 TypePtr DeepCopy() const override; 214 215 private: 216 TypePtrList elements_; 217 }; 218 using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>; 219 220 /// \brief RowTensorType defines interface for row tensor data type. 221 class MS_CORE_API RowTensorType final : public Object { 222 public: 223 /// \brief Default constructor for RowTensorType. RowTensorType()224 RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} 225 226 /// \brief Constructor for RowTensorType. 227 /// 228 /// \param[in] ele The element of RowTensorType. RowTensorType(const TypePtr & ele)229 explicit RowTensorType(const TypePtr &ele) 230 : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} 231 232 /// \brief Destructor of RowTensorType. 233 ~RowTensorType() override = default; MS_DECLARE_PARENT(RowTensorType,Object)234 MS_DECLARE_PARENT(RowTensorType, Object) 235 236 TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } 237 238 /// \brief Get the element of RowTensorType object. 239 /// 240 /// \return The element of RowTensorType object. element()241 const TypePtr element() const { return element_type_; } 242 243 /// \brief Set the element of RowTensorType object. 244 /// 245 /// \param[in] element_type Define the element type to be set. set_element(const TypePtr & element_type)246 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 247 248 TypePtr DeepCopy() const override; 249 std::string ToString() const override; 250 std::string ToReprString() const override; 251 std::string DumpText() const override; 252 bool operator==(const Type &other) const override; 253 std::size_t hash() const override; 254 255 private: 256 TypePtr element_type_; 257 }; 258 using RowTensorTypePtr = std::shared_ptr<RowTensorType>; 259 260 /// \brief COOTensorType defines interface for coo tensor data type. 261 class MS_CORE_API COOTensorType final : public SparseTensorType { 262 public: 263 /// \brief Default constructor for COOTensorType. COOTensorType()264 COOTensorType() : SparseTensorType(kObjectTypeCOOTensorType) {} 265 266 /// \brief Constructor for COOTensorType. 267 /// 268 /// \param[in] obj The list of COOTensorType. COOTensorType(const TypePtrList & obj)269 explicit COOTensorType(const TypePtrList &obj) : SparseTensorType(kObjectTypeCOOTensorType, obj) {} 270 271 /// \brief Destructor of COOTensorType. 272 ~COOTensorType() override = default; MS_DECLARE_PARENT(COOTensorType,SparseTensorType)273 MS_DECLARE_PARENT(COOTensorType, SparseTensorType) 274 275 std::string GetSparseTensorTypeName() const override { return "COOTensor"; } GetElementIndex()276 size_t GetElementIndex() override { return 1; } 277 generic_type_id()278 TypeId generic_type_id() const override { return kObjectTypeCOOTensorType; } 279 TypePtr DeepCopy() const override; 280 }; 281 using COOTensorTypePtr = std::shared_ptr<COOTensorType>; 282 283 /// \brief CSRTensorType defines interface for csr tensor data type. 284 class MS_CORE_API CSRTensorType : public SparseTensorType { 285 public: 286 /// \brief Default constructor for CSRTensorType. CSRTensorType()287 CSRTensorType() : SparseTensorType(kObjectTypeCSRTensorType) {} 288 289 /// \brief Constructor for CSRTensorType. 290 /// 291 /// \param[in] obj The list of CSRTensorType. CSRTensorType(const TypePtrList & obj)292 explicit CSRTensorType(const TypePtrList &obj) : SparseTensorType(kObjectTypeCSRTensorType, obj) {} 293 294 /// \brief Destructor of CSRTensorType. 295 ~CSRTensorType() override = default; MS_DECLARE_PARENT(CSRTensorType,SparseTensorType)296 MS_DECLARE_PARENT(CSRTensorType, SparseTensorType) 297 298 std::string GetSparseTensorTypeName() const override { return "CSRTensor"; } GetElementIndex()299 size_t GetElementIndex() override { return 2; } generic_type_id()300 TypeId generic_type_id() const override { return kObjectTypeCSRTensorType; } 301 TypePtr DeepCopy() const override; 302 }; 303 using CSRTensorTypePtr = std::shared_ptr<CSRTensorType>; 304 305 /// \brief MapTensorType defines interface for map tensor data type. 306 class MS_CORE_API MapTensorType final : public Object { 307 public: 308 /// \brief Construct a generic MapTensorType. MapTensorType()309 MapTensorType() : Object(kObjectTypeMapTensorType, true) {} 310 311 /// \brief Construct a MapTensorType. 312 /// 313 /// \param[in] key The key data type. 314 /// \param[in] value The value data type. MapTensorType(const TypePtr & key,const TypePtr & value)315 explicit MapTensorType(const TypePtr &key, const TypePtr &value) 316 : Object(kObjectTypeMapTensorType, false), key_dtype_(key), value_dtype_(value) {} 317 318 /// \brief Destructor of MapTensorType. 319 ~MapTensorType() override = default; MS_DECLARE_PARENT(MapTensorType,Object)320 MS_DECLARE_PARENT(MapTensorType, Object) 321 322 TypeId generic_type_id() const override { return kObjectTypeMapTensorType; } 323 324 /// \brief Get the key data type of this MapTensorType. 325 /// 326 /// \return The key data type. key_dtype()327 const TypePtr &key_dtype() const { return key_dtype_; } 328 329 /// \brief Get the value data type of this MapTensorType. 330 /// 331 /// \return The key data type. value_dtype()332 const TypePtr &value_dtype() const { return value_dtype_; } 333 334 TypePtr DeepCopy() const override; 335 std::string ToString() const override; 336 std::string ToReprString() const override; 337 std::string DumpText() const override; 338 bool operator==(const Type &other) const override; 339 std::size_t hash() const override; 340 341 private: 342 TypePtr key_dtype_; 343 TypePtr value_dtype_; 344 }; 345 using MapTensorTypePtr = std::shared_ptr<MapTensorType>; 346 } // namespace mindspore 347 348 #endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ 349