1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_META_TENSOR_H_ 18 #define MINDSPORE_CORE_IR_META_TENSOR_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 #include <string> 24 25 #include "base/base.h" 26 #include "ir/param_info.h" 27 #include "ir/dtype.h" 28 #include "utils/convert_utils_base.h" 29 #include "utils/hashing.h" 30 #include "utils/shape_utils.h" 31 32 // brief mindspore namespace. 33 // 34 // mindspore namespace is the top level namespace of MindSpore project. 35 // Other namespace should be a sub namespace of mindspore namespace in the ME project. 36 namespace mindspore { 37 38 // brief mindspore::tensor namespace 39 // 40 // A sub namespace in ME to support tensor related definition. 41 namespace tensor { 42 // brief Device info of Tensor 43 // 44 // Includes the format, data type and host format of a tensor. 45 struct DeviceInfo { 46 explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr, 47 std::string host_format = "DefaultFormat") format_DeviceInfo48 : format_(std::move(format)), data_type_(std::move(data_type)), host_format_(std::move(host_format)) {} 49 std::string format_ = "DefaultFormat"; 50 TypePtr data_type_ = nullptr; 51 std::string host_format_ = "DefaultFormat"; 52 }; 53 54 // brief Metadata of Tensor 55 // 56 // Includes the metadata information of a tensor, such as data type, shape 57 // and so on. But it does not contain values of a tensor. 58 class MS_CORE_API MetaTensor : public Value { 59 public: 60 // Construction 61 MetaTensor(); 62 63 // brief Constructs a meta tensor of a tensor having data_type data and shape. 64 // 65 // The constructed MetaTensor is not a Tensor, but it has the data type and shape 66 // information of a Tensor. The following codes will create a 2x3 float 67 // param data_type The data type of the tensor. 68 // param shape The shape of the tensor. 69 MetaTensor(const TypeId data_type, const ShapeVector &shape); 70 71 MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape); 72 // brief Constructs a MetaTensor object from an existing MetaTensor instance. 73 // 74 // The constructed MetaTensor object will have the same data type and shape as the 75 // meta_tensor. 76 // 77 // param meta_tensor An existing MetaTensor object. 78 MetaTensor(const MetaTensor &meta_tensor); 79 ~MetaTensor() override = default; 80 MS_DECLARE_PARENT(MetaTensor, Value) 81 82 // brief Overloads operator = for MetaTensor. 83 // 84 // The constructed MetaTensor object has the same type and shape with meta_tensor. 85 // 86 // param meta_tensor An existing MetaTensor object. 87 virtual MetaTensor &operator=(const MetaTensor &meta_tensor); 88 89 // brief Compares two MetaTensor objects. 90 // 91 // The constructed MetaTensor object has the same type and shape with meta_tensor. 92 // 93 // param meta_tensor The MetaTensor object to be compared. 94 // return true: If having same type and shape, return true, or return false. 95 virtual bool operator==(const MetaTensor &meta_tensor) const; 96 97 // brief Returns the data type of the tensor in its MetaTensor. 98 // 99 // All the types are defined in "ir/dtype.h". 100 TypePtr Dtype() const; 101 abstract::AbstractBasePtr ToAbstract() override; data_type()102 TypeId data_type() const { return data_type_; } 103 std::string ToString() const override; 104 std::string DumpText() const override; 105 // brief Sets the data type of a tensor in its MetaTensor. 106 // 107 // param data_type The data type of the tensor to be set. set_data_type(const TypeId data_type)108 virtual TypeId set_data_type(const TypeId data_type) { 109 data_type_ = data_type; 110 return data_type_; 111 } 112 virtual TypePtr SetDtype(const TypePtr type_ptr); 113 // brief Get tensor's shape. 114 // 115 // The shape of a tensor is stored in a vector<int>. Each 116 // element of the vector represents the size of a dimension of the tensor. 117 // The order of each element in the vector is as same as the the dimension's 118 // order it represents. 119 // 120 // return A const vector<int> which represents the shape of the tensor. shape()121 const ShapeVector &shape() const { return shape_; } 122 123 // brief Sets the shape of a tensor. 124 // 125 // The shape of a tensor is stored in a vector<int>. Each 126 // element of the vector represents the size of a dimension of the tensor. 127 // The order of each element in the vector is as same as the the dimension's 128 // order it represents. 129 // 130 // param shape The shape of the tensor. 131 // return The shape's size. set_shape(const ShapeVector & shape)132 size_t set_shape(const ShapeVector &shape) { 133 this->shape_ = shape; 134 return shape_.size(); 135 } 136 137 // Get tensor's device info. device_info()138 DeviceInfo device_info() const { return device_info_; } 139 140 // Set tensor's device info. set_device_info(const DeviceInfo & device_info)141 void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } 142 143 void SetDeviceInfo(const std::string &format, const TypePtr &data_type, 144 const std::string &host_format = "DefaultFormat"); 145 146 // Get the size of a given dimension by its index number. 147 int64_t DimensionSize(size_t index) const; 148 149 // Get total number of elements in a tensor. 150 int ElementsNum() const; 151 hash()152 std::size_t hash() const override { 153 std::size_t hash_value = std::hash<int>{}(SizeToInt(data_type_)); 154 hash_value = hash_combine(hash_value, std::hash<size_t>{}(shape_.size())); 155 // hash all elements may costly, so only take at most 4 elements into account based on 156 // some experiments. 157 for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) { 158 hash_value = hash_combine(hash_value, (std::hash<int>{}(shape_[i]))); 159 } 160 return hash_value; 161 } 162 bool operator==(const Value &other) const override { 163 if (other.isa<MetaTensor>()) { 164 auto other_ = static_cast<const MetaTensor &>(other); 165 return *this == other_; 166 } else { 167 return false; 168 } 169 } 170 // Get tensor's param_info info. param_info()171 ParamInfoPtr param_info() const { return param_info_; } is_parameter()172 bool is_parameter() const { return is_parameter_; } 173 // Set tensor's param_info info. set_param_info(const ParamInfoPtr & param_info)174 void set_param_info(const ParamInfoPtr ¶m_info) { 175 is_parameter_ = true; 176 param_info_ = param_info; 177 } 178 179 protected: 180 // brief Data type of the tensor. 181 // 182 // All support data type is in Number Types of [TypeId], 183 // including [kNumberTypeBool], [kNumberTypeInt], 184 // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64]. 185 TypeId data_type_; 186 187 // brief Shape of the tensor. 188 // 189 // A ShapeVector container is used to store the shape of a tensor. 190 // Each element of the vector represents the size of a dimension of the tensor. 191 // The order of each element in the vector is as same as the the dimension's 192 // order it represents. If the dimension size is not set, its value will be -1. 193 ShapeVector shape_; 194 195 // brief Device info of Tensor 196 // 197 // Includes the format and data type of a tensor on device. 198 DeviceInfo device_info_; 199 200 bool is_parameter_{false}; 201 ParamInfoPtr param_info_{nullptr}; 202 }; 203 204 using MetaTensorPtr = std::shared_ptr<MetaTensor>; 205 206 } // namespace tensor 207 } // namespace mindspore 208 209 #endif // MINDSPORE_CORE_IR_META_TENSOR_H_ 210