1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ 18 #define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ 19 20 #include <cstddef> 21 #include <iostream> 22 #include <initializer_list> 23 #include <map> 24 #include <memory> 25 #include <utility> 26 #include <sstream> 27 #include <string> 28 #include <vector> 29 #include <type_traits> 30 #include <unordered_map> 31 #include <algorithm> 32 #include "base/base.h" 33 #include "ir/named.h" 34 #include "ir/dtype/type.h" 35 36 namespace mindspore { 37 class MS_CORE_API UndeterminedType : public Object { 38 public: UndeterminedType()39 UndeterminedType() : Object(kObjectTypeUndeterminedType) {} UndeterminedType(const TypePtr & ele)40 explicit UndeterminedType(const TypePtr &ele) 41 : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} 42 ~UndeterminedType() override = default; MS_DECLARE_PARENT(UndeterminedType,Object)43 MS_DECLARE_PARENT(UndeterminedType, Object) 44 45 TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } element()46 const TypePtr element() const { return element_type_; } set_element(const TypePtr & element_type)47 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 48 49 TypePtr DeepCopy() const override; 50 std::string ToString() const override; 51 std::string ToReprString() const override; 52 std::string DumpText() const override; 53 bool operator==(const Type &other) const override; 54 55 protected: 56 TypePtr element_type_; 57 }; 58 using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; 59 60 class MS_CORE_API TensorType : public Object { 61 public: TensorType()62 TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} TensorType(const TypePtr & ele)63 explicit TensorType(const TypePtr &ele) 64 : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} 65 ~TensorType() override = default; MS_DECLARE_PARENT(TensorType,Object)66 MS_DECLARE_PARENT(TensorType, Object) 67 68 TypeId generic_type_id() const override { return kObjectTypeTensorType; } element()69 const TypePtr element() const { return element_type_; } set_element(const TypePtr & element_type)70 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 71 72 TypePtr DeepCopy() const override; 73 std::string ToString() const override; 74 std::string ToReprString() const override; 75 std::string DumpText() const override; 76 bool operator==(const Type &other) const override; 77 78 private: 79 TypePtr element_type_; 80 }; 81 using TensorTypePtr = std::shared_ptr<TensorType>; 82 83 class MS_CORE_API RowTensorType : public Object { 84 public: RowTensorType()85 RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} RowTensorType(const TypePtr & ele)86 explicit RowTensorType(const TypePtr &ele) 87 : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} 88 ~RowTensorType() override = default; MS_DECLARE_PARENT(RowTensorType,Object)89 MS_DECLARE_PARENT(RowTensorType, Object) 90 91 TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } element()92 const TypePtr element() const { return element_type_; } set_element(const TypePtr & element_type)93 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 94 95 TypePtr DeepCopy() const override; 96 std::string ToString() const override; 97 std::string ToReprString() const override; 98 std::string DumpText() const override; 99 bool operator==(const Type &other) const override; 100 101 private: 102 TypePtr element_type_; 103 }; 104 using RowTensorTypePtr = std::shared_ptr<RowTensorType>; 105 106 class MS_CORE_API SparseTensorType : public Object { 107 public: SparseTensorType()108 SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} SparseTensorType(const TypePtr & ele)109 explicit SparseTensorType(const TypePtr &ele) 110 : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} 111 ~SparseTensorType() override = default; MS_DECLARE_PARENT(SparseTensorType,Object)112 MS_DECLARE_PARENT(SparseTensorType, Object) 113 114 TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } element()115 const TypePtr element() const { return element_type_; } set_element(const TypePtr & element_type)116 void set_element(const TypePtr &element_type) { element_type_ = element_type; } 117 118 TypePtr DeepCopy() const override; 119 std::string ToString() const override; 120 std::string ToReprString() const override; 121 std::string DumpText() const override; 122 bool operator==(const Type &other) const override; 123 124 private: 125 TypePtr element_type_; 126 }; 127 using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>; 128 } // namespace mindspore 129 130 #endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ 131