• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_META_TENSOR_H_
18 #define MINDSPORE_CORE_IR_META_TENSOR_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 
25 #include "base/base.h"
26 #include "ir/param_info.h"
27 #include "ir/dtype.h"
28 #include "utils/convert_utils_base.h"
29 #include "utils/hashing.h"
30 #include "utils/shape_utils.h"
31 
32 // brief mindspore namespace.
33 //
34 // mindspore namespace is the top level namespace of MindSpore project.
35 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
36 namespace mindspore {
37 
38 // brief mindspore::tensor namespace
39 //
40 // A sub namespace in ME to support tensor related definition.
41 namespace tensor {
42 // brief Device info of Tensor
43 //
44 // Includes the format, data type and host format of a tensor.
45 struct DeviceInfo {
46   explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr,
47                       std::string host_format = "DefaultFormat")
format_DeviceInfo48       : format_(std::move(format)), data_type_(std::move(data_type)), host_format_(std::move(host_format)) {}
49   std::string format_ = "DefaultFormat";
50   TypePtr data_type_ = nullptr;
51   std::string host_format_ = "DefaultFormat";
52 };
53 
54 // brief Metadata of Tensor
55 //
56 // Includes the metadata information of a tensor, such as data type, shape
57 // and so on. But it does not contain values of a tensor.
58 class MS_CORE_API MetaTensor : public Value {
59  public:
60   // Construction
61   MetaTensor();
62 
63   // brief Constructs a meta tensor of a tensor having data_type data and shape.
64   //
65   // The constructed MetaTensor is not a Tensor, but it has the data type and shape
66   // information of a Tensor. The following codes will create a 2x3 float
67   // param data_type The data type of the tensor.
68   // param shape The shape of the tensor.
69   MetaTensor(const TypeId data_type, const ShapeVector &shape);
70 
71   MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape);
72   // brief Constructs a MetaTensor object from an existing MetaTensor instance.
73   //
74   // The constructed MetaTensor object will have the same data type and shape as the
75   // meta_tensor.
76   //
77   // param meta_tensor An existing MetaTensor object.
78   MetaTensor(const MetaTensor &meta_tensor);
79   ~MetaTensor() override = default;
80   MS_DECLARE_PARENT(MetaTensor, Value)
81 
82   // brief Overloads operator = for MetaTensor.
83   //
84   // The constructed MetaTensor object has the same type and shape with meta_tensor.
85   //
86   // param meta_tensor An existing MetaTensor object.
87   virtual MetaTensor &operator=(const MetaTensor &meta_tensor);
88 
89   // brief Compares two MetaTensor objects.
90   //
91   // The constructed MetaTensor object has the same type and shape with meta_tensor.
92   //
93   // param meta_tensor The MetaTensor object to be compared.
94   // return true: If having same type and shape, return true, or return false.
95   virtual bool operator==(const MetaTensor &meta_tensor) const;
96 
97   // brief Returns the data type of the tensor in its MetaTensor.
98   //
99   // All the types are defined in "ir/dtype.h".
100   TypePtr Dtype() const;
101   abstract::AbstractBasePtr ToAbstract() override;
data_type()102   TypeId data_type() const { return data_type_; }
103   std::string ToString() const override;
104   std::string DumpText() const override;
105   // brief Sets the data type of a tensor in its MetaTensor.
106   //
107   // param data_type The data type of the tensor to be set.
set_data_type(const TypeId data_type)108   virtual TypeId set_data_type(const TypeId data_type) {
109     data_type_ = data_type;
110     return data_type_;
111   }
112   virtual TypePtr SetDtype(const TypePtr type_ptr);
113   // brief Get tensor's shape.
114   //
115   // The shape of a tensor is stored in a vector<int>. Each
116   // element of the vector represents the size of a dimension of the tensor.
117   // The order of each element in the vector is as same as the the dimension's
118   // order it represents.
119   //
120   // return A const vector<int> which represents the shape of the tensor.
shape()121   const ShapeVector &shape() const { return shape_; }
122 
123   // brief Sets the shape of a tensor.
124   //
125   // The shape of a tensor is stored in a vector<int>. Each
126   // element of the vector represents the size of a dimension of the tensor.
127   // The order of each element in the vector is as same as the the dimension's
128   // order it represents.
129   //
130   // param shape The shape of the tensor.
131   // return The shape's size.
set_shape(const ShapeVector & shape)132   size_t set_shape(const ShapeVector &shape) {
133     this->shape_ = shape;
134     return shape_.size();
135   }
136 
137   // Get tensor's device info.
device_info()138   DeviceInfo device_info() const { return device_info_; }
139 
140   // Set tensor's device info.
set_device_info(const DeviceInfo & device_info)141   void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
142 
143   void SetDeviceInfo(const std::string &format, const TypePtr &data_type,
144                      const std::string &host_format = "DefaultFormat");
145 
146   // Get the size of a given dimension by its index number.
147   int64_t DimensionSize(size_t index) const;
148 
149   // Get total number of elements in a tensor.
150   int ElementsNum() const;
151 
hash()152   std::size_t hash() const override {
153     std::size_t hash_value = std::hash<int>{}(SizeToInt(data_type_));
154     hash_value = hash_combine(hash_value, std::hash<size_t>{}(shape_.size()));
155     // hash all elements may costly, so only take at most 4 elements into account based on
156     // some experiments.
157     for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) {
158       hash_value = hash_combine(hash_value, (std::hash<int>{}(shape_[i])));
159     }
160     return hash_value;
161   }
162   bool operator==(const Value &other) const override {
163     if (other.isa<MetaTensor>()) {
164       auto other_ = static_cast<const MetaTensor &>(other);
165       return *this == other_;
166     } else {
167       return false;
168     }
169   }
170   // Get tensor's param_info info.
param_info()171   ParamInfoPtr param_info() const { return param_info_; }
is_parameter()172   bool is_parameter() const { return is_parameter_; }
173   // Set tensor's param_info info.
set_param_info(const ParamInfoPtr & param_info)174   void set_param_info(const ParamInfoPtr &param_info) {
175     is_parameter_ = true;
176     param_info_ = param_info;
177   }
178 
179  protected:
180   // brief Data type of the tensor.
181   //
182   // All support data type is in Number Types of [TypeId],
183   // including [kNumberTypeBool], [kNumberTypeInt],
184   // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64].
185   TypeId data_type_;
186 
187   // brief Shape of the tensor.
188   //
189   // A ShapeVector container is used to store the shape of a tensor.
190   // Each element of the vector represents the size of a dimension of the tensor.
191   // The order of each element in the vector is as same as the the dimension's
192   // order it represents. If the dimension size is not set, its value will be -1.
193   ShapeVector shape_;
194 
195   // brief Device info of Tensor
196   //
197   // Includes the format and data type of a tensor on device.
198   DeviceInfo device_info_;
199 
200   bool is_parameter_{false};
201   ParamInfoPtr param_info_{nullptr};
202 };
203 
204 using MetaTensorPtr = std::shared_ptr<MetaTensor>;
205 
206 }  // namespace tensor
207 }  // namespace mindspore
208 
209 #endif  // MINDSPORE_CORE_IR_META_TENSOR_H_
210