• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_DTYPE_H_
18 #define MINDSPORE_CORE_IR_DTYPE_H_
19 
20 #include <cstddef>
21 #include <iostream>
22 #include <initializer_list>
23 #include <memory>
24 #include <utility>
25 #include <sstream>
26 #include <string>
27 #include <vector>
28 #include <type_traits>
29 #include <unordered_map>
30 #include <algorithm>
31 #include "base/base.h"
32 #include "ir/named.h"
33 
34 #include "ir/dtype/type.h"
35 #include "ir/dtype/number.h"
36 #include "ir/dtype/container.h"
37 #include "ir/dtype/empty.h"
38 #include "ir/dtype/tensor_type.h"
39 #include "ir/dtype/ref.h"
40 #include "ir/dtype/monad_type.h"
41 
42 /* namespace to support intermediate representation definition */
43 namespace mindspore {
44 // Only few type supported now.
45 MS_CORE_API TypePtr TypeIdToType(TypeId id);
46 
47 class MS_CORE_API String : public Object {
48  public:
String()49   String() : Object(kObjectTypeString, false) {}
50   ~String() override = default;
MS_DECLARE_PARENT(String,Object)51   MS_DECLARE_PARENT(String, Object)
52 
53   TypeId generic_type_id() const override { return kObjectTypeString; }
54 
DeepCopy()55   TypePtr DeepCopy() const override { return std::make_shared<String>(); }
ToString()56   std::string ToString() const override { return std::string("String"); }
ToReprString()57   std::string ToReprString() const override { return "string_"; }
DumpText()58   std::string DumpText() const override { return "String"; }
59 };
60 using StringPtr = std::shared_ptr<String>;
61 
62 class MS_CORE_API Keyword : public Object {
63  public:
Keyword()64   Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
Keyword(const std::string & key,const TypePtr & value)65   Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
66 
67   ~Keyword() override = default;
MS_DECLARE_PARENT(Keyword,Object)68   MS_DECLARE_PARENT(Keyword, Object)
69 
70   TypeId generic_type_id() const override { return kObjectTypeKeyword; }
71   TypePtr DeepCopy() const override;
72 
73   std::string ToString() const override;
74   std::string DumpText() const override;
75   bool operator==(const Type &other) const override;
76 
GetKey()77   std::string GetKey() const { return key_; }
GetValue()78   TypePtr GetValue() const { return value_; }
79 
80  private:
81   std::string key_;
82   TypePtr value_;
83 };
84 using KeywordPtr = std::shared_ptr<Keyword>;
85 
86 class MS_CORE_API Slice : public Object {
87  public:
Slice()88   Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
Slice(const TypePtr & start,const TypePtr & stop,const TypePtr & step)89   Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
90       : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
91 
92   ~Slice() override = default;
MS_DECLARE_PARENT(Slice,Object)93   MS_DECLARE_PARENT(Slice, Object)
94 
95   TypeId generic_type_id() const override { return kObjectTypeSlice; }
96   TypePtr DeepCopy() const override;
97 
98   std::string ToString() const override;
99   std::string DumpText() const override;
100   bool operator==(const Type &other) const override;
101 
get_start()102   TypePtr get_start() const { return start_; }
get_stop()103   TypePtr get_stop() const { return stop_; }
get_step()104   TypePtr get_step() const { return step_; }
105 
106  private:
107   TypePtr start_;
108   TypePtr stop_;
109   TypePtr step_;
110 };
111 using SlicePtr = std::shared_ptr<Slice>;
112 
113 class MS_CORE_API Function : public Object {
114  public:
115   Function();
116   Function(const std::vector<TypePtr> &args, const TypePtr retval);
117   ~Function() override = default;
MS_DECLARE_PARENT(Function,Object)118   MS_DECLARE_PARENT(Function, Object)
119 
120   TypeId generic_type_id() const override { return kObjectTypeFunction; }
121 
122   // Add temporarily for return abstraction to avoid type checking.
IsTransparent()123   bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
args()124   const std::vector<TypePtr> &args() const { return args_; }
retval()125   const TypePtr &retval() const { return retval_; }
126 
127   TypePtr DeepCopy() const override;
128   bool operator==(const Type &other) const override;
129   std::string ToString() const override;
ToReprString()130   std::string ToReprString() const override { return "function"; }
131 
132  private:
133   std::vector<TypePtr> args_;
134   TypePtr retval_;
135 };
136 using FunctionPtr = std::shared_ptr<Function>;
137 
138 class MS_CORE_API JTagged : public Object {
139  public:
JTagged()140   JTagged() : Object(kObjectTypeJTagged) {}
JTagged(const TypePtr & subtype)141   explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
142   ~JTagged() override = default;
MS_DECLARE_PARENT(JTagged,Object)143   MS_DECLARE_PARENT(JTagged, Object)
144 
145   TypeId generic_type_id() const override { return kObjectTypeJTagged; }
146 
147   TypePtr DeepCopy() const override;
148   std::string ToString() const override;
149   std::string DumpText() const override;
150 
151  private:
152   TypePtr subtype_;
153 };
154 using JTaggedPtr = std::shared_ptr<JTagged>;
155 
156 class MS_CORE_API SymbolicKeyType : public Object {
157  public:
SymbolicKeyType()158   SymbolicKeyType() : Object(kObjectTypeSymbolicKeyType) {}
159   ~SymbolicKeyType() override = default;
MS_DECLARE_PARENT(SymbolicKeyType,Object)160   MS_DECLARE_PARENT(SymbolicKeyType, Object)
161 
162   TypeId generic_type_id() const override { return kObjectTypeSymbolicKeyType; }
DeepCopy()163   TypePtr DeepCopy() const override { return std::make_shared<SymbolicKeyType>(); }
ToReprString()164   std::string ToReprString() const override { return "symbolic_key"; }
DumpText()165   std::string DumpText() const override { return "SymType"; }
166 };
167 
168 class MS_CORE_API EnvType : public Object {
169  public:
EnvType()170   EnvType() : Object(kObjectTypeEnvType) {}
171   ~EnvType() override = default;
MS_DECLARE_PARENT(EnvType,Object)172   MS_DECLARE_PARENT(EnvType, Object)
173 
174   TypePtr DeepCopy() const override { return std::make_shared<EnvType>(); }
ToReprString()175   std::string ToReprString() const override { return "env_type"; }
DumpText()176   std::string DumpText() const override { return "EnvType"; }
177 };
178 using EnvTypePtr = std::shared_ptr<EnvType>;
179 
180 class MS_CORE_API TypeType : public Type {
181  public:
TypeType()182   TypeType() : Type(kMetaTypeTypeType) {}
183   ~TypeType() override = default;
MS_DECLARE_PARENT(TypeType,Type)184   MS_DECLARE_PARENT(TypeType, Type)
185 
186   TypeId generic_type_id() const override { return kMetaTypeTypeType; }
DeepCopy()187   TypePtr DeepCopy() const override { return std::make_shared<TypeType>(); }
ToReprString()188   std::string ToReprString() const override { return "type_type"; }
DumpText()189   std::string DumpText() const override { return "TypeType"; }
190 };
191 using TypeTypePtr = std::shared_ptr<TypeType>;
192 
193 class MS_CORE_API Problem : public Type {
194  public:
Problem()195   Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
Problem(const Named & kind)196   explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
197   ~Problem() override = default;
MS_DECLARE_PARENT(Problem,Type)198   MS_DECLARE_PARENT(Problem, Type)
199 
200   TypeId generic_type_id() const override { return kMetaTypeProblem; }
DeepCopy()201   TypePtr DeepCopy() const override { return std::make_shared<Problem>(); }
ToString()202   std::string ToString() const override { return kind_.name(); }
DumpText()203   std::string DumpText() const override { return "ProblemType"; }
204 
205   friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
206 
207  private:
208   Named kind_;
209 };
210 using ProblemPtr = std::shared_ptr<Problem>;
211 
212 class MS_CORE_API External : public Type {
213  public:
External()214   External() : Type(kMetaTypeExternal) {}
215   ~External() override = default;
MS_DECLARE_PARENT(External,Type)216   MS_DECLARE_PARENT(External, Type)
217 
218   TypeId generic_type_id() const override { return kMetaTypeExternal; }
DeepCopy()219   TypePtr DeepCopy() const override { return std::make_shared<External>(); }
DumpText()220   std::string DumpText() const override { return "ExternalType"; }
221 
222  private:
223   TypePtr kind;
224 };
225 using ExternalPtr = std::shared_ptr<External>;
226 
227 // helper template
228 template <class T>
Clone(const T & t)229 TypePtr Clone(const T &t) {
230   return t.Clone();
231 }
232 
233 MS_CORE_API TypePtr StringToType(const std::string &type_name);
234 
235 // Judge whether x is predicate or is a subclass of predicate.
236 MS_CORE_API bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
237 
238 // Whether t1 is identity or a subclass of t2.
239 MS_CORE_API bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
240 
241 struct MS_CORE_API TypeHasher {
242   std::size_t operator()(TypePtr const &type) const;
243 };
244 struct MS_CORE_API TypeListHasher {
245   std::size_t operator()(const TypePtrList &type_list) const;
246 };
247 struct MS_CORE_API TypeEqual {
248   bool operator()(TypePtr const &t1, TypePtr const &t2) const;
249 };
250 struct MS_CORE_API TypeListEqual {
251   bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
252 };
253 
254 inline const TypePtr kTypeExternal = std::make_shared<External>();
255 inline const TypePtr kTypeEnv = std::make_shared<EnvType>();
256 inline const TypePtr kTypeType = std::make_shared<TypeType>();
257 inline const TypePtr kString = std::make_shared<String>();
258 inline const TypePtr kList = std::make_shared<List>();
259 inline const TypePtr kTuple = std::make_shared<Tuple>();
260 inline const TypePtr kDict = std::make_shared<Dictionary>();
261 inline const TypePtr kSlice = std::make_shared<Slice>();
262 inline const TypePtr kKeyword = std::make_shared<Keyword>();
263 inline const TypePtr kTensorType = std::make_shared<TensorType>();
264 inline const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16));
265 inline const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32));
266 }  // namespace mindspore
267 
268 #endif  // MINDSPORE_CORE_IR_DTYPE_H_
269