• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
18 #define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
19 
20 #include <cstddef>
21 #include <iostream>
22 #include <initializer_list>
23 #include <map>
24 #include <memory>
25 #include <utility>
26 #include <sstream>
27 #include <string>
28 #include <vector>
29 #include <type_traits>
30 #include <algorithm>
31 #include "utils/hash_map.h"
32 #include "base/base.h"
33 #include "ir/named.h"
34 #include "ir/dtype/type.h"
35 
36 namespace mindspore {
37 /// \brief UndeterminedType defines interface for tensor undetermined data type.
38 class MS_CORE_API UndeterminedType final : public Object {
39  public:
40   /// \brief Default constructor for UndeterminedType.
UndeterminedType()41   UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
42 
43   /// \brief Constructor for UndeterminedType.
44   ///
45   /// \param[in] ele The element of UndeterminedType.
UndeterminedType(const TypePtr & ele)46   explicit UndeterminedType(const TypePtr &ele)
47       : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
48 
49   /// \brief Destructor of UndeterminedType.
50   ~UndeterminedType() override = default;
MS_DECLARE_PARENT(UndeterminedType,Object)51   MS_DECLARE_PARENT(UndeterminedType, Object)
52 
53   TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
54 
55   /// \brief Get the element of UndeterminedType object.
56   ///
57   /// \return The element of UndeterminedType object.
element()58   const TypePtr element() const { return element_type_; }
59 
60   /// \brief Set the element of UndeterminedType object.
61   ///
62   /// \param[in] element_type Define the element type to be set.
set_element(const TypePtr & element_type)63   void set_element(const TypePtr &element_type) { element_type_ = element_type; }
64 
65   TypePtr DeepCopy() const override;
66   std::string ToString() const override;
67   std::string ToReprString() const override;
68   std::string DumpText() const override;
69 
70   bool operator==(const Type &other) const override;
71   std::size_t hash() const override;
72 
73  protected:
74   TypePtr element_type_;
75 };
76 using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
77 
78 /// \brief TensorType defines interface for tensor data type.
79 class MS_CORE_API TensorType : public Object {
80  public:
81   /// \brief Default constructor for TensorType.
TensorType()82   TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
83 
84   /// \brief Constructor for TensorType.
85   ///
86   /// \param[in] ele The element of TensorType.
TensorType(const TypePtr & ele)87   explicit TensorType(const TypePtr &ele)
88       : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
89 
90   /// \brief Destructor of TensorType.
91   ~TensorType() override = default;
MS_DECLARE_PARENT(TensorType,Object)92   MS_DECLARE_PARENT(TensorType, Object)
93 
94   TypeId generic_type_id() const override { return kObjectTypeTensorType; }
95 
96   /// \brief Get the element of TensorType object.
97   ///
98   /// \return The element of TensorType object.
element()99   const TypePtr element() const { return element_type_; }
100 
101   /// \brief Set the element of TensorType object.
102   ///
103   /// \param[in] element_type Define the element type to be set.
set_element(const TypePtr & element_type)104   void set_element(const TypePtr &element_type) { element_type_ = element_type; }
105 
106   TypePtr DeepCopy() const override;
107   std::string ToString() const override;
108   std::string ToReprString() const override;
109   std::string DumpText() const override;
110   bool operator==(const Type &other) const override;
111 
112   /// \brief Overwrite the operator '==' to compare other tensor type.
113   ///
114   /// \param[in] other The other tensor type value to be compared.
115   ///
116   /// \return  A boolean, which indicates whether the type is same.
117   bool operator==(const TensorType &other) const;
118   std::size_t hash() const override;
119 
120  private:
121   TypePtr element_type_;
122 };
123 using TensorTypePtr = std::shared_ptr<TensorType>;
124 
125 /// \brief AnyType defines interface for any data type.
126 class MS_CORE_API AnyType : public TensorType {
127  public:
128   /// \brief Default constructor for AnyType.
129   AnyType() = default;
130 
131   /// \brief Constructor for AnyType.
132   ///
133   /// \param[in] element_type The element type of AnyType.
AnyType(const TypePtr & element_type)134   explicit AnyType(const TypePtr &element_type) : TensorType(element_type) {}
135 
136   /// \brief Destructor of AnyType.
137   ~AnyType() override = default;
138   MS_DECLARE_PARENT(AnyType, TensorType)
139 
140   std::string ToString() const override;
141   std::string DumpText() const override;
142   bool operator==(const Type &other) const override;
143 
144   /// \brief Overwrite the operator '==' to compare other anytype.
145   ///
146   /// \param[in] other The other anytype value to be compared.
147   ///
148   /// \return  A boolean, which indicates whether the type is same.
149   bool operator==(const AnyType &other) const;
150 };
151 using AnyTypePtr = std::shared_ptr<AnyType>;
152 
153 /// \brief NegligibleType defines interface for negligible data type.
154 class MS_CORE_API NegligibleType final : public AnyType {
155  public:
156   /// \brief Default constructor for NegligibleType.
157   NegligibleType() = default;
158 
159   /// \brief Constructor for NegligibleType.
160   ///
161   /// \param[in] element_type The element type of NegligibleType.
NegligibleType(const TypePtr & element_type)162   explicit NegligibleType(const TypePtr &element_type) : AnyType(element_type) {}
163 
164   /// \brief Destructor of NegligibleType.
165   ~NegligibleType() override = default;
166   MS_DECLARE_PARENT(NegligibleType, AnyType)
167 
168   std::string ToString() const override;
169   std::string DumpText() const override;
170 };
171 using NegligibleTypePtr = std::shared_ptr<NegligibleType>;
172 
173 /// \brief SparseTensorType is the base type for all sparse tensors.
174 class MS_CORE_API SparseTensorType : public Object {
175  public:
SparseTensorType()176   SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
177 
SparseTensorType(const TypeId object_type)178   explicit SparseTensorType(const TypeId object_type) : Object(object_type, kObjectTypeUndeterminedType) {}
179 
SparseTensorType(const TypePtrList & objs)180   explicit SparseTensorType(const TypePtrList &objs)
181       : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {}
182 
SparseTensorType(const TypeId object_type,const TypePtrList & objs)183   SparseTensorType(const TypeId object_type, const TypePtrList &objs)
184       : Object(object_type, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {}
185 
186   /// \brief Destructor of SparseTensorType.
187   ~SparseTensorType() override = default;
188   MS_DECLARE_PARENT(SparseTensorType, Object)
189 
190   enum StringType : int { kToString = 0, kDumpText, kReprString };
191 
GetSparseTensorTypeName()192   virtual std::string GetSparseTensorTypeName() const { return "SparseTensorType"; }
GetElementIndex()193   virtual size_t GetElementIndex() { return 0; }
element_type()194   virtual TypePtr element_type() {
195     if (elements_.empty()) {
196       return nullptr;
197     }
198     return elements_[GetElementIndex()];
199   }
200   std::string ElementsDtypeStr(const StringType str_type) const;
generic_type_id()201   TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
202 
203   const TypePtr operator[](std::size_t dim) const;
204   bool operator==(const Type &other) const override;
205   std::size_t hash() const override;
elements()206   TypePtrList elements() const { return elements_; }
207 
size()208   std::size_t size() const { return elements_.size(); }
209   std::string ToString() const override;
210   std::string ToReprString() const override;
211   std::string DumpText() const override;
212   const TypePtrList ElementsClone() const;
213   TypePtr DeepCopy() const override;
214 
215  private:
216   TypePtrList elements_;
217 };
218 using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
219 
220 /// \brief RowTensorType defines interface for row tensor data type.
221 class MS_CORE_API RowTensorType final : public Object {
222  public:
223   /// \brief Default constructor for RowTensorType.
RowTensorType()224   RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
225 
226   /// \brief Constructor for RowTensorType.
227   ///
228   /// \param[in] ele The element of RowTensorType.
RowTensorType(const TypePtr & ele)229   explicit RowTensorType(const TypePtr &ele)
230       : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
231 
232   /// \brief Destructor of RowTensorType.
233   ~RowTensorType() override = default;
MS_DECLARE_PARENT(RowTensorType,Object)234   MS_DECLARE_PARENT(RowTensorType, Object)
235 
236   TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
237 
238   /// \brief Get the element of RowTensorType object.
239   ///
240   /// \return The element of RowTensorType object.
element()241   const TypePtr element() const { return element_type_; }
242 
243   /// \brief Set the element of RowTensorType object.
244   ///
245   /// \param[in] element_type Define the element type to be set.
set_element(const TypePtr & element_type)246   void set_element(const TypePtr &element_type) { element_type_ = element_type; }
247 
248   TypePtr DeepCopy() const override;
249   std::string ToString() const override;
250   std::string ToReprString() const override;
251   std::string DumpText() const override;
252   bool operator==(const Type &other) const override;
253   std::size_t hash() const override;
254 
255  private:
256   TypePtr element_type_;
257 };
258 using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
259 
260 /// \brief COOTensorType defines interface for coo tensor data type.
261 class MS_CORE_API COOTensorType final : public SparseTensorType {
262  public:
263   /// \brief Default constructor for COOTensorType.
COOTensorType()264   COOTensorType() : SparseTensorType(kObjectTypeCOOTensorType) {}
265 
266   /// \brief Constructor for COOTensorType.
267   ///
268   /// \param[in] obj The list of COOTensorType.
COOTensorType(const TypePtrList & obj)269   explicit COOTensorType(const TypePtrList &obj) : SparseTensorType(kObjectTypeCOOTensorType, obj) {}
270 
271   /// \brief Destructor of COOTensorType.
272   ~COOTensorType() override = default;
MS_DECLARE_PARENT(COOTensorType,SparseTensorType)273   MS_DECLARE_PARENT(COOTensorType, SparseTensorType)
274 
275   std::string GetSparseTensorTypeName() const override { return "COOTensor"; }
GetElementIndex()276   size_t GetElementIndex() override { return 1; }
277 
generic_type_id()278   TypeId generic_type_id() const override { return kObjectTypeCOOTensorType; }
279   TypePtr DeepCopy() const override;
280 };
281 using COOTensorTypePtr = std::shared_ptr<COOTensorType>;
282 
283 /// \brief CSRTensorType defines interface for csr tensor data type.
284 class MS_CORE_API CSRTensorType : public SparseTensorType {
285  public:
286   /// \brief Default constructor for CSRTensorType.
CSRTensorType()287   CSRTensorType() : SparseTensorType(kObjectTypeCSRTensorType) {}
288 
289   /// \brief Constructor for CSRTensorType.
290   ///
291   /// \param[in] obj The list of CSRTensorType.
CSRTensorType(const TypePtrList & obj)292   explicit CSRTensorType(const TypePtrList &obj) : SparseTensorType(kObjectTypeCSRTensorType, obj) {}
293 
294   /// \brief Destructor of CSRTensorType.
295   ~CSRTensorType() override = default;
MS_DECLARE_PARENT(CSRTensorType,SparseTensorType)296   MS_DECLARE_PARENT(CSRTensorType, SparseTensorType)
297 
298   std::string GetSparseTensorTypeName() const override { return "CSRTensor"; }
GetElementIndex()299   size_t GetElementIndex() override { return 2; }
generic_type_id()300   TypeId generic_type_id() const override { return kObjectTypeCSRTensorType; }
301   TypePtr DeepCopy() const override;
302 };
303 using CSRTensorTypePtr = std::shared_ptr<CSRTensorType>;
304 
305 /// \brief MapTensorType defines interface for map tensor data type.
306 class MS_CORE_API MapTensorType final : public Object {
307  public:
308   /// \brief Construct a generic MapTensorType.
MapTensorType()309   MapTensorType() : Object(kObjectTypeMapTensorType, true) {}
310 
311   /// \brief Construct a MapTensorType.
312   ///
313   /// \param[in] key The key data type.
314   /// \param[in] value The value data type.
MapTensorType(const TypePtr & key,const TypePtr & value)315   explicit MapTensorType(const TypePtr &key, const TypePtr &value)
316       : Object(kObjectTypeMapTensorType, false), key_dtype_(key), value_dtype_(value) {}
317 
318   /// \brief Destructor of MapTensorType.
319   ~MapTensorType() override = default;
MS_DECLARE_PARENT(MapTensorType,Object)320   MS_DECLARE_PARENT(MapTensorType, Object)
321 
322   TypeId generic_type_id() const override { return kObjectTypeMapTensorType; }
323 
324   /// \brief Get the key data type of this MapTensorType.
325   ///
326   /// \return The key data type.
key_dtype()327   const TypePtr &key_dtype() const { return key_dtype_; }
328 
329   /// \brief Get the value data type of this MapTensorType.
330   ///
331   /// \return The key data type.
value_dtype()332   const TypePtr &value_dtype() const { return value_dtype_; }
333 
334   TypePtr DeepCopy() const override;
335   std::string ToString() const override;
336   std::string ToReprString() const override;
337   std::string DumpText() const override;
338   bool operator==(const Type &other) const override;
339   std::size_t hash() const override;
340 
341  private:
342   TypePtr key_dtype_;
343   TypePtr value_dtype_;
344 };
345 using MapTensorTypePtr = std::shared_ptr<MapTensorType>;
346 }  // namespace mindspore
347 
348 #endif  // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
349