1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_IR_DTYPE_H_
18 #define MINDSPORE_CORE_IR_DTYPE_H_
19
20 #include <cstddef>
21 #include <iostream>
22 #include <initializer_list>
23 #include <memory>
24 #include <utility>
25 #include <sstream>
26 #include <string>
27 #include <vector>
28 #include <type_traits>
29 #include <unordered_map>
30 #include <algorithm>
31 #include "base/base.h"
32 #include "ir/named.h"
33
34 #include "ir/dtype/type.h"
35 #include "ir/dtype/number.h"
36 #include "ir/dtype/container.h"
37 #include "ir/dtype/empty.h"
38 #include "ir/dtype/tensor_type.h"
39 #include "ir/dtype/ref.h"
40 #include "ir/dtype/monad_type.h"
41
42 /* namespace to support intermediate representation definition */
43 namespace mindspore {
44 // Only few type supported now.
45 MS_CORE_API TypePtr TypeIdToType(TypeId id);
46
47 class MS_CORE_API String : public Object {
48 public:
String()49 String() : Object(kObjectTypeString, false) {}
50 ~String() override = default;
MS_DECLARE_PARENT(String,Object)51 MS_DECLARE_PARENT(String, Object)
52
53 TypeId generic_type_id() const override { return kObjectTypeString; }
54
DeepCopy()55 TypePtr DeepCopy() const override { return std::make_shared<String>(); }
ToString()56 std::string ToString() const override { return std::string("String"); }
ToReprString()57 std::string ToReprString() const override { return "string_"; }
DumpText()58 std::string DumpText() const override { return "String"; }
59 };
60 using StringPtr = std::shared_ptr<String>;
61
62 class MS_CORE_API Keyword : public Object {
63 public:
Keyword()64 Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
Keyword(const std::string & key,const TypePtr & value)65 Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
66
67 ~Keyword() override = default;
MS_DECLARE_PARENT(Keyword,Object)68 MS_DECLARE_PARENT(Keyword, Object)
69
70 TypeId generic_type_id() const override { return kObjectTypeKeyword; }
71 TypePtr DeepCopy() const override;
72
73 std::string ToString() const override;
74 std::string DumpText() const override;
75 bool operator==(const Type &other) const override;
76
GetKey()77 std::string GetKey() const { return key_; }
GetValue()78 TypePtr GetValue() const { return value_; }
79
80 private:
81 std::string key_;
82 TypePtr value_;
83 };
84 using KeywordPtr = std::shared_ptr<Keyword>;
85
86 class MS_CORE_API Slice : public Object {
87 public:
Slice()88 Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
Slice(const TypePtr & start,const TypePtr & stop,const TypePtr & step)89 Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
90 : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
91
92 ~Slice() override = default;
MS_DECLARE_PARENT(Slice,Object)93 MS_DECLARE_PARENT(Slice, Object)
94
95 TypeId generic_type_id() const override { return kObjectTypeSlice; }
96 TypePtr DeepCopy() const override;
97
98 std::string ToString() const override;
99 std::string DumpText() const override;
100 bool operator==(const Type &other) const override;
101
get_start()102 TypePtr get_start() const { return start_; }
get_stop()103 TypePtr get_stop() const { return stop_; }
get_step()104 TypePtr get_step() const { return step_; }
105
106 private:
107 TypePtr start_;
108 TypePtr stop_;
109 TypePtr step_;
110 };
111 using SlicePtr = std::shared_ptr<Slice>;
112
113 class MS_CORE_API Function : public Object {
114 public:
115 Function();
116 Function(const std::vector<TypePtr> &args, const TypePtr retval);
117 ~Function() override = default;
MS_DECLARE_PARENT(Function,Object)118 MS_DECLARE_PARENT(Function, Object)
119
120 TypeId generic_type_id() const override { return kObjectTypeFunction; }
121
122 // Add temporarily for return abstraction to avoid type checking.
IsTransparent()123 bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
args()124 const std::vector<TypePtr> &args() const { return args_; }
retval()125 const TypePtr &retval() const { return retval_; }
126
127 TypePtr DeepCopy() const override;
128 bool operator==(const Type &other) const override;
129 std::string ToString() const override;
ToReprString()130 std::string ToReprString() const override { return "function"; }
131
132 private:
133 std::vector<TypePtr> args_;
134 TypePtr retval_;
135 };
136 using FunctionPtr = std::shared_ptr<Function>;
137
138 class MS_CORE_API JTagged : public Object {
139 public:
JTagged()140 JTagged() : Object(kObjectTypeJTagged) {}
JTagged(const TypePtr & subtype)141 explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
142 ~JTagged() override = default;
MS_DECLARE_PARENT(JTagged,Object)143 MS_DECLARE_PARENT(JTagged, Object)
144
145 TypeId generic_type_id() const override { return kObjectTypeJTagged; }
146
147 TypePtr DeepCopy() const override;
148 std::string ToString() const override;
149 std::string DumpText() const override;
150
151 private:
152 TypePtr subtype_;
153 };
154 using JTaggedPtr = std::shared_ptr<JTagged>;
155
156 class MS_CORE_API SymbolicKeyType : public Object {
157 public:
SymbolicKeyType()158 SymbolicKeyType() : Object(kObjectTypeSymbolicKeyType) {}
159 ~SymbolicKeyType() override = default;
MS_DECLARE_PARENT(SymbolicKeyType,Object)160 MS_DECLARE_PARENT(SymbolicKeyType, Object)
161
162 TypeId generic_type_id() const override { return kObjectTypeSymbolicKeyType; }
DeepCopy()163 TypePtr DeepCopy() const override { return std::make_shared<SymbolicKeyType>(); }
ToReprString()164 std::string ToReprString() const override { return "symbolic_key"; }
DumpText()165 std::string DumpText() const override { return "SymType"; }
166 };
167
168 class MS_CORE_API EnvType : public Object {
169 public:
EnvType()170 EnvType() : Object(kObjectTypeEnvType) {}
171 ~EnvType() override = default;
MS_DECLARE_PARENT(EnvType,Object)172 MS_DECLARE_PARENT(EnvType, Object)
173
174 TypePtr DeepCopy() const override { return std::make_shared<EnvType>(); }
ToReprString()175 std::string ToReprString() const override { return "env_type"; }
DumpText()176 std::string DumpText() const override { return "EnvType"; }
177 };
178 using EnvTypePtr = std::shared_ptr<EnvType>;
179
180 class MS_CORE_API TypeType : public Type {
181 public:
TypeType()182 TypeType() : Type(kMetaTypeTypeType) {}
183 ~TypeType() override = default;
MS_DECLARE_PARENT(TypeType,Type)184 MS_DECLARE_PARENT(TypeType, Type)
185
186 TypeId generic_type_id() const override { return kMetaTypeTypeType; }
DeepCopy()187 TypePtr DeepCopy() const override { return std::make_shared<TypeType>(); }
ToReprString()188 std::string ToReprString() const override { return "type_type"; }
DumpText()189 std::string DumpText() const override { return "TypeType"; }
190 };
191 using TypeTypePtr = std::shared_ptr<TypeType>;
192
193 class MS_CORE_API Problem : public Type {
194 public:
Problem()195 Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
Problem(const Named & kind)196 explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
197 ~Problem() override = default;
MS_DECLARE_PARENT(Problem,Type)198 MS_DECLARE_PARENT(Problem, Type)
199
200 TypeId generic_type_id() const override { return kMetaTypeProblem; }
DeepCopy()201 TypePtr DeepCopy() const override { return std::make_shared<Problem>(); }
ToString()202 std::string ToString() const override { return kind_.name(); }
DumpText()203 std::string DumpText() const override { return "ProblemType"; }
204
205 friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
206
207 private:
208 Named kind_;
209 };
210 using ProblemPtr = std::shared_ptr<Problem>;
211
212 class MS_CORE_API External : public Type {
213 public:
External()214 External() : Type(kMetaTypeExternal) {}
215 ~External() override = default;
MS_DECLARE_PARENT(External,Type)216 MS_DECLARE_PARENT(External, Type)
217
218 TypeId generic_type_id() const override { return kMetaTypeExternal; }
DeepCopy()219 TypePtr DeepCopy() const override { return std::make_shared<External>(); }
DumpText()220 std::string DumpText() const override { return "ExternalType"; }
221
222 private:
223 TypePtr kind;
224 };
225 using ExternalPtr = std::shared_ptr<External>;
226
227 // helper template
228 template <class T>
Clone(const T & t)229 TypePtr Clone(const T &t) {
230 return t.Clone();
231 }
232
233 MS_CORE_API TypePtr StringToType(const std::string &type_name);
234
235 // Judge whether x is predicate or is a subclass of predicate.
236 MS_CORE_API bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
237
238 // Whether t1 is identity or a subclass of t2.
239 MS_CORE_API bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
240
241 struct MS_CORE_API TypeHasher {
242 std::size_t operator()(TypePtr const &type) const;
243 };
244 struct MS_CORE_API TypeListHasher {
245 std::size_t operator()(const TypePtrList &type_list) const;
246 };
247 struct MS_CORE_API TypeEqual {
248 bool operator()(TypePtr const &t1, TypePtr const &t2) const;
249 };
250 struct MS_CORE_API TypeListEqual {
251 bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
252 };
253
254 inline const TypePtr kTypeExternal = std::make_shared<External>();
255 inline const TypePtr kTypeEnv = std::make_shared<EnvType>();
256 inline const TypePtr kTypeType = std::make_shared<TypeType>();
257 inline const TypePtr kString = std::make_shared<String>();
258 inline const TypePtr kList = std::make_shared<List>();
259 inline const TypePtr kTuple = std::make_shared<Tuple>();
260 inline const TypePtr kDict = std::make_shared<Dictionary>();
261 inline const TypePtr kSlice = std::make_shared<Slice>();
262 inline const TypePtr kKeyword = std::make_shared<Keyword>();
263 inline const TypePtr kTensorType = std::make_shared<TensorType>();
264 inline const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16));
265 inline const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32));
266 } // namespace mindspore
267
268 #endif // MINDSPORE_CORE_IR_DTYPE_H_
269