1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_IR_VALUE_H_
18 #define MINDSPORE_CORE_IR_VALUE_H_
19
20 #include <type_traits>
21 #include <algorithm>
22 #include <vector>
23 #include <string>
24 #include <memory>
25 #include <sstream>
26 #include <utility>
27
28 #include "base/base.h"
29 #include "ir/anf.h"
30 #include "ir/dtype.h"
31 #include "ir/scalar.h"
32 #include "ir/dtype/ref.h"
33 #include "utils/hashing.h"
34 #include "utils/ms_utils.h"
35
36 namespace mindspore {
37 class MS_CORE_API ValueSequeue : public Value {
38 public:
ValueSequeue(const ValuePtrList & elements)39 explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) {
40 TypePtrList t_list;
41 (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) {
42 MS_EXCEPTION_IF_NULL(ele);
43 return ele->type();
44 });
45 TypePtr t = std::make_shared<Tuple>(t_list);
46 type_ = t;
47 }
ValueSequeue(const std::initializer_list<ValuePtr> & elements)48 ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) {
49 TypePtrList t_list;
50 (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
51 [](const ValuePtr &ele) { return ele->type(); });
52 TypePtr t = std::make_shared<Tuple>(t_list);
53 type_ = t;
54 }
55 ~ValueSequeue() override = default;
MS_DECLARE_PARENT(ValueSequeue,Value)56 MS_DECLARE_PARENT(ValueSequeue, Value)
57 std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
size()58 std::size_t size() const { return elements_.size(); }
59 bool erase(size_t idx);
60 const ValuePtr operator[](const std::size_t &dim) const;
value()61 const ValuePtrList &value() const { return elements_; }
62 bool operator==(const Value &other) const override;
63 bool operator==(const ValueSequeue &other) const;
64 std::string ToString() const override;
65 std::string DumpText() const override;
66
67 protected:
68 ValuePtrList elements_;
69 };
70 using ValueSequeuePtr = std::shared_ptr<ValueSequeue>;
71
72 class MS_CORE_API ValueTuple : public ValueSequeue {
73 public:
ValueTuple(const std::vector<ValuePtr> & elements)74 explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueTuple(const std::initializer_list<ValuePtr> & elements)75 ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
76 ~ValueTuple() override = default;
77 MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
78 abstract::AbstractBasePtr ToAbstract() override;
79
DumpText()80 std::string DumpText() const override { return "(" + ValueSequeue::DumpText() + ")"; }
ToString()81 std::string ToString() const override { return "(" + ValueSequeue::ToString() + ")"; }
82 };
83 using ValueTuplePtr = std::shared_ptr<ValueTuple>;
84
85 class MS_CORE_API ValueList : public ValueSequeue {
86 public:
ValueList(const std::vector<ValuePtr> & elements)87 explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueList(const std::initializer_list<ValuePtr> & elements)88 ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
89 ~ValueList() override = default;
90 MS_DECLARE_PARENT(ValueList, ValueSequeue)
91 abstract::AbstractBasePtr ToAbstract() override;
92
DumpText()93 std::string DumpText() const override { return "[" + ValueSequeue::DumpText() + "]"; }
ToString()94 std::string ToString() const override { return "[" + ValueSequeue::ToString() + "]"; }
95 };
96 using ValueListPtr = std::shared_ptr<ValueList>;
97
MakeValue(const std::vector<ValuePtr> & v)98 inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); }
MakeValue(std::initializer_list<ValuePtr> v)99 inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
100
101 template <typename T>
102 struct is_vector : public std::false_type {};
103 template <typename T, typename A>
104 struct is_vector<std::vector<T, A>> : public std::true_type {};
105
106 template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
107 ValuePtr MakeValue(const T &vec) {
108 std::vector<ValuePtr> list;
109 (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
110 return std::make_shared<ValueTuple>(list);
111 }
112
113 class MS_CORE_API ValueSlice : public Value {
114 public:
115 ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step)
116 : start_(start), stop_(stop), step_(step) {}
117 ~ValueSlice() override = default;
118 MS_DECLARE_PARENT(ValueSlice, Value)
119 std::size_t hash() const override;
120 bool operator==(const Value &other) const override;
121 bool operator==(const ValueSlice &other) const;
122
123 std::string ToString() const override;
124
125 abstract::AbstractBasePtr ToAbstract() override;
126 std::string DumpText() const override { return ToString(); }
127 ValuePtr start() const { return start_; }
128 ValuePtr stop() const { return stop_; }
129 ValuePtr step() const { return step_; }
130
131 private:
132 ValuePtr start_;
133 ValuePtr stop_;
134 ValuePtr step_;
135 };
136 using ValueSlicePtr = std::shared_ptr<ValueSlice>;
137
138 class MS_CORE_API KeywordArg : public Value {
139 public:
140 KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {}
141 ~KeywordArg() override = default;
142 MS_DECLARE_PARENT(KeywordArg, Value)
143 std::size_t hash() const override;
144 ValuePtr get_value() const { return value_; }
145 bool operator==(const Value &other) const override;
146 bool operator==(const KeywordArg &other) const;
147
148 std::string ToString() const override;
149
150 abstract::AbstractBasePtr ToAbstract() override;
151 std::string DumpText() const override { return ToString(); }
152
153 private:
154 std::string key_;
155 ValuePtr value_;
156 };
157 using KeywordArgPtr = std::shared_ptr<KeywordArg>;
158
159 class MS_CORE_API ValueDictionary : public Value {
160 public:
161 explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {}
162 ~ValueDictionary() override = default;
163 MS_DECLARE_PARENT(ValueDictionary, Value)
164 std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
165 std::size_t size() const { return key_values_.size(); }
166 const ValuePtr operator[](const std::string &key) const;
167 const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; }
168 bool operator==(const Value &other) const override;
169 bool operator==(const ValueDictionary &other) const;
170
171 std::string ToString() const override {
172 std::ostringstream buffer;
173 std::vector<std::string> keys;
174 std::vector<ValuePtr> values;
175 for (const auto &kv : key_values_) {
176 keys.push_back(kv.first);
177 values.push_back(kv.second);
178 }
179 buffer << "dict: {keys: (";
180 for (size_t i = 0; i < keys.size(); i++) {
181 buffer << keys[i];
182 if (i != keys.size() - 1) {
183 buffer << ", ";
184 }
185 }
186 buffer << "), values: (";
187 for (size_t i = 0; i < values.size(); i++) {
188 const auto &value = values[i];
189 MS_EXCEPTION_IF_NULL(value);
190 buffer << value->ToString();
191 if (i != values.size() - 1) {
192 buffer << ", ";
193 }
194 }
195 buffer << ")}";
196 return buffer.str();
197 }
198 abstract::AbstractBasePtr ToAbstract() override;
199 std::string DumpText() const override { return ToString(); }
200
201 private:
202 std::vector<std::pair<std::string, ValuePtr>> key_values_;
203 };
204 using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>;
205
206 class MS_CORE_API StringImm : public Value {
207 public:
208 explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
209
210 ~StringImm() override = default;
211 MS_DECLARE_PARENT(StringImm, Value)
212 std::size_t hash() const override { return hash_; }
213 const std::string &value() const { return str_; }
214 bool operator==(const Value &other) const override;
215 bool operator==(const StringImm &other) const;
216 abstract::AbstractBasePtr ToAbstract() override;
217 std::string ToString() const override { return str_; }
218
219 std::string DumpText() const override {
220 std::ostringstream oss;
221 oss << "\"" << str_ << "\"";
222 return oss.str();
223 }
224
225 private:
226 std::string str_;
227 std::size_t hash_ = 0;
228 };
229 using StringImmPtr = std::shared_ptr<StringImm>;
230 IMM_TRAITS(StringImmPtr, std::string)
231 IMM_TRAITS(StringImmPtr, const char *)
232
233 class MS_CORE_API RefKey : public Named {
234 public:
235 explicit RefKey(const std::string &tag) : Named(tag) {}
236
237 ~RefKey() override = default;
238 MS_DECLARE_PARENT(RefKey, Named)
239 const std::string &tag() const { return name(); }
240 abstract::AbstractBasePtr ToAbstract() override;
241 std::string ToString() const override { return "RefKey[" + name() + "]"; }
242
243 std::string DumpText() const override {
244 std::ostringstream oss;
245 oss << "RefKey[\"" << name() << "\"]";
246 return oss.str();
247 }
248 };
249 using RefKeyPtr = std::shared_ptr<RefKey>;
250
251 class MS_CORE_API AnyValue : public Value {
252 public:
253 AnyValue() = default;
254 ~AnyValue() override = default;
255 MS_DECLARE_PARENT(AnyValue, Value)
256 std::size_t hash() const override { return tid(); }
257 bool operator==(const Value &other) const override;
258 abstract::AbstractBasePtr ToAbstract() override;
259 };
260
261 inline const ValuePtr kAnyValue = std::make_shared<AnyValue>();
262
263 class MS_CORE_API Monad : public Value {
264 public:
265 ~Monad() override = default;
266 MS_DECLARE_PARENT(Monad, Value)
267 abstract::AbstractBasePtr ToAbstract() override = 0;
268
269 protected:
270 explicit Monad(TypePtr type) : Value(type) {}
271 };
272
273 class MS_CORE_API UMonad : public Monad {
274 public:
275 UMonad() : Monad(kUMonadType) {}
276 ~UMonad() override = default;
277 MS_DECLARE_PARENT(UMonad, Monad)
278 std::size_t hash() const override { return tid(); }
279 bool operator==(const Value &other) const override;
280 abstract::AbstractBasePtr ToAbstract() override;
281 std::string ToString() const override { return "U"; }
282 };
283 using UMonadPtr = std::shared_ptr<UMonad>;
284 extern const ValuePtr kUMonad;
285
286 class MS_CORE_API IOMonad : public Monad {
287 public:
288 IOMonad() : Monad(kIOMonadType) {}
289 ~IOMonad() override = default;
290 MS_DECLARE_PARENT(IOMonad, Monad)
291 std::size_t hash() const override { return tid(); }
292 bool operator==(const Value &other) const override;
293 abstract::AbstractBasePtr ToAbstract() override;
294 std::string ToString() const override { return "IO"; }
295 };
296 using IOMonadPtr = std::shared_ptr<IOMonad>;
297 extern const ValuePtr kIOMonad;
298
299 template <>
300 inline const char *GetValue(const ValuePtr &value) {
301 if (value == nullptr) {
302 MS_LOG(EXCEPTION) << "Value is nullptr";
303 }
304 auto imm = value->cast<StringImmPtr>();
305 if (imm == nullptr) {
306 MS_LOG(EXCEPTION) << "GetValue:" << value->ToString() << ", Type:" << value->type_name();
307 }
308 return common::SafeCStr(imm->value());
309 }
310
311 template <typename T, typename S = typename std::decay<T>::type,
312 typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
313 std::vector<U> GetValue(const ValuePtr &value) {
314 if (value == nullptr) {
315 MS_LOG(EXCEPTION) << "Value is nullptr";
316 }
317
318 if (!value->isa<ValueSequeue>()) {
319 MS_LOG(EXCEPTION) << "Error GetValue for value: " << value->ToString() << ", type: vector<" << typeid(U).name()
320 << ">";
321 }
322 std::vector<U> rets;
323 const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value();
324 (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
325 [](const ValuePtr &v) { return GetValue<U>(v); });
326 return rets;
327 }
328
329 inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); }
330
331 template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
332 inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) {
333 return NewValueNode(MakeValue(x));
334 }
335
336 template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
337 inline ValueNodePtr NewValueNode(const T &x) {
338 return NewValueNode(MakeValue(x));
339 }
340 } // namespace mindspore
341
342 #endif // MINDSPORE_CORE_IR_VALUE_H_
343