• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_VALUE_H_
18 #define MINDSPORE_CORE_IR_VALUE_H_
19 
20 #include <type_traits>
21 #include <algorithm>
22 #include <vector>
23 #include <string>
24 #include <memory>
25 #include <sstream>
26 #include <utility>
27 
28 #include "base/base.h"
29 #include "ir/anf.h"
30 #include "ir/dtype.h"
31 #include "ir/scalar.h"
32 #include "ir/dtype/ref.h"
33 #include "utils/hashing.h"
34 #include "utils/ms_utils.h"
35 
36 namespace mindspore {
37 class MS_CORE_API ValueSequeue : public Value {
38  public:
ValueSequeue(const ValuePtrList & elements)39   explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) {
40     TypePtrList t_list;
41     (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) {
42       MS_EXCEPTION_IF_NULL(ele);
43       return ele->type();
44     });
45     TypePtr t = std::make_shared<Tuple>(t_list);
46     type_ = t;
47   }
ValueSequeue(const std::initializer_list<ValuePtr> & elements)48   ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) {
49     TypePtrList t_list;
50     (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
51                          [](const ValuePtr &ele) { return ele->type(); });
52     TypePtr t = std::make_shared<Tuple>(t_list);
53     type_ = t;
54   }
55   ~ValueSequeue() override = default;
MS_DECLARE_PARENT(ValueSequeue,Value)56   MS_DECLARE_PARENT(ValueSequeue, Value)
57   std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
size()58   std::size_t size() const { return elements_.size(); }
59   bool erase(size_t idx);
60   const ValuePtr operator[](const std::size_t &dim) const;
value()61   const ValuePtrList &value() const { return elements_; }
62   bool operator==(const Value &other) const override;
63   bool operator==(const ValueSequeue &other) const;
64   std::string ToString() const override;
65   std::string DumpText() const override;
66 
67  protected:
68   ValuePtrList elements_;
69 };
70 using ValueSequeuePtr = std::shared_ptr<ValueSequeue>;
71 
72 class MS_CORE_API ValueTuple : public ValueSequeue {
73  public:
ValueTuple(const std::vector<ValuePtr> & elements)74   explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueTuple(const std::initializer_list<ValuePtr> & elements)75   ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
76   ~ValueTuple() override = default;
77   MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
78   abstract::AbstractBasePtr ToAbstract() override;
79 
DumpText()80   std::string DumpText() const override { return "(" + ValueSequeue::DumpText() + ")"; }
ToString()81   std::string ToString() const override { return "(" + ValueSequeue::ToString() + ")"; }
82 };
83 using ValueTuplePtr = std::shared_ptr<ValueTuple>;
84 
85 class MS_CORE_API ValueList : public ValueSequeue {
86  public:
ValueList(const std::vector<ValuePtr> & elements)87   explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueList(const std::initializer_list<ValuePtr> & elements)88   ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
89   ~ValueList() override = default;
90   MS_DECLARE_PARENT(ValueList, ValueSequeue)
91   abstract::AbstractBasePtr ToAbstract() override;
92 
DumpText()93   std::string DumpText() const override { return "[" + ValueSequeue::DumpText() + "]"; }
ToString()94   std::string ToString() const override { return "[" + ValueSequeue::ToString() + "]"; }
95 };
96 using ValueListPtr = std::shared_ptr<ValueList>;
97 
MakeValue(const std::vector<ValuePtr> & v)98 inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); }
MakeValue(std::initializer_list<ValuePtr> v)99 inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
100 
101 template <typename T>
102 struct is_vector : public std::false_type {};
103 template <typename T, typename A>
104 struct is_vector<std::vector<T, A>> : public std::true_type {};
105 
106 template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
107 ValuePtr MakeValue(const T &vec) {
108   std::vector<ValuePtr> list;
109   (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
110   return std::make_shared<ValueTuple>(list);
111 }
112 
113 class MS_CORE_API ValueSlice : public Value {
114  public:
115   ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step)
116       : start_(start), stop_(stop), step_(step) {}
117   ~ValueSlice() override = default;
118   MS_DECLARE_PARENT(ValueSlice, Value)
119   std::size_t hash() const override;
120   bool operator==(const Value &other) const override;
121   bool operator==(const ValueSlice &other) const;
122 
123   std::string ToString() const override;
124 
125   abstract::AbstractBasePtr ToAbstract() override;
126   std::string DumpText() const override { return ToString(); }
127   ValuePtr start() const { return start_; }
128   ValuePtr stop() const { return stop_; }
129   ValuePtr step() const { return step_; }
130 
131  private:
132   ValuePtr start_;
133   ValuePtr stop_;
134   ValuePtr step_;
135 };
136 using ValueSlicePtr = std::shared_ptr<ValueSlice>;
137 
138 class MS_CORE_API KeywordArg : public Value {
139  public:
140   KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {}
141   ~KeywordArg() override = default;
142   MS_DECLARE_PARENT(KeywordArg, Value)
143   std::size_t hash() const override;
144   ValuePtr get_value() const { return value_; }
145   bool operator==(const Value &other) const override;
146   bool operator==(const KeywordArg &other) const;
147 
148   std::string ToString() const override;
149 
150   abstract::AbstractBasePtr ToAbstract() override;
151   std::string DumpText() const override { return ToString(); }
152 
153  private:
154   std::string key_;
155   ValuePtr value_;
156 };
157 using KeywordArgPtr = std::shared_ptr<KeywordArg>;
158 
159 class MS_CORE_API ValueDictionary : public Value {
160  public:
161   explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {}
162   ~ValueDictionary() override = default;
163   MS_DECLARE_PARENT(ValueDictionary, Value)
164   std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
165   std::size_t size() const { return key_values_.size(); }
166   const ValuePtr operator[](const std::string &key) const;
167   const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; }
168   bool operator==(const Value &other) const override;
169   bool operator==(const ValueDictionary &other) const;
170 
171   std::string ToString() const override {
172     std::ostringstream buffer;
173     std::vector<std::string> keys;
174     std::vector<ValuePtr> values;
175     for (const auto &kv : key_values_) {
176       keys.push_back(kv.first);
177       values.push_back(kv.second);
178     }
179     buffer << "dict: {keys: (";
180     for (size_t i = 0; i < keys.size(); i++) {
181       buffer << keys[i];
182       if (i != keys.size() - 1) {
183         buffer << ", ";
184       }
185     }
186     buffer << "), values: (";
187     for (size_t i = 0; i < values.size(); i++) {
188       const auto &value = values[i];
189       MS_EXCEPTION_IF_NULL(value);
190       buffer << value->ToString();
191       if (i != values.size() - 1) {
192         buffer << ", ";
193       }
194     }
195     buffer << ")}";
196     return buffer.str();
197   }
198   abstract::AbstractBasePtr ToAbstract() override;
199   std::string DumpText() const override { return ToString(); }
200 
201  private:
202   std::vector<std::pair<std::string, ValuePtr>> key_values_;
203 };
204 using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>;
205 
206 class MS_CORE_API StringImm : public Value {
207  public:
208   explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
209 
210   ~StringImm() override = default;
211   MS_DECLARE_PARENT(StringImm, Value)
212   std::size_t hash() const override { return hash_; }
213   const std::string &value() const { return str_; }
214   bool operator==(const Value &other) const override;
215   bool operator==(const StringImm &other) const;
216   abstract::AbstractBasePtr ToAbstract() override;
217   std::string ToString() const override { return str_; }
218 
219   std::string DumpText() const override {
220     std::ostringstream oss;
221     oss << "\"" << str_ << "\"";
222     return oss.str();
223   }
224 
225  private:
226   std::string str_;
227   std::size_t hash_ = 0;
228 };
229 using StringImmPtr = std::shared_ptr<StringImm>;
230 IMM_TRAITS(StringImmPtr, std::string)
231 IMM_TRAITS(StringImmPtr, const char *)
232 
233 class MS_CORE_API RefKey : public Named {
234  public:
235   explicit RefKey(const std::string &tag) : Named(tag) {}
236 
237   ~RefKey() override = default;
238   MS_DECLARE_PARENT(RefKey, Named)
239   const std::string &tag() const { return name(); }
240   abstract::AbstractBasePtr ToAbstract() override;
241   std::string ToString() const override { return "RefKey[" + name() + "]"; }
242 
243   std::string DumpText() const override {
244     std::ostringstream oss;
245     oss << "RefKey[\"" << name() << "\"]";
246     return oss.str();
247   }
248 };
249 using RefKeyPtr = std::shared_ptr<RefKey>;
250 
251 class MS_CORE_API AnyValue : public Value {
252  public:
253   AnyValue() = default;
254   ~AnyValue() override = default;
255   MS_DECLARE_PARENT(AnyValue, Value)
256   std::size_t hash() const override { return tid(); }
257   bool operator==(const Value &other) const override;
258   abstract::AbstractBasePtr ToAbstract() override;
259 };
260 
261 inline const ValuePtr kAnyValue = std::make_shared<AnyValue>();
262 
263 class MS_CORE_API Monad : public Value {
264  public:
265   ~Monad() override = default;
266   MS_DECLARE_PARENT(Monad, Value)
267   abstract::AbstractBasePtr ToAbstract() override = 0;
268 
269  protected:
270   explicit Monad(TypePtr type) : Value(type) {}
271 };
272 
273 class MS_CORE_API UMonad : public Monad {
274  public:
275   UMonad() : Monad(kUMonadType) {}
276   ~UMonad() override = default;
277   MS_DECLARE_PARENT(UMonad, Monad)
278   std::size_t hash() const override { return tid(); }
279   bool operator==(const Value &other) const override;
280   abstract::AbstractBasePtr ToAbstract() override;
281   std::string ToString() const override { return "U"; }
282 };
283 using UMonadPtr = std::shared_ptr<UMonad>;
284 extern const ValuePtr kUMonad;
285 
286 class MS_CORE_API IOMonad : public Monad {
287  public:
288   IOMonad() : Monad(kIOMonadType) {}
289   ~IOMonad() override = default;
290   MS_DECLARE_PARENT(IOMonad, Monad)
291   std::size_t hash() const override { return tid(); }
292   bool operator==(const Value &other) const override;
293   abstract::AbstractBasePtr ToAbstract() override;
294   std::string ToString() const override { return "IO"; }
295 };
296 using IOMonadPtr = std::shared_ptr<IOMonad>;
297 extern const ValuePtr kIOMonad;
298 
299 template <>
300 inline const char *GetValue(const ValuePtr &value) {
301   if (value == nullptr) {
302     MS_LOG(EXCEPTION) << "Value is nullptr";
303   }
304   auto imm = value->cast<StringImmPtr>();
305   if (imm == nullptr) {
306     MS_LOG(EXCEPTION) << "GetValue:" << value->ToString() << ", Type:" << value->type_name();
307   }
308   return common::SafeCStr(imm->value());
309 }
310 
311 template <typename T, typename S = typename std::decay<T>::type,
312           typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
313 std::vector<U> GetValue(const ValuePtr &value) {
314   if (value == nullptr) {
315     MS_LOG(EXCEPTION) << "Value is nullptr";
316   }
317 
318   if (!value->isa<ValueSequeue>()) {
319     MS_LOG(EXCEPTION) << "Error GetValue for value: " << value->ToString() << ", type: vector<" << typeid(U).name()
320                       << ">";
321   }
322   std::vector<U> rets;
323   const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value();
324   (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
325                        [](const ValuePtr &v) { return GetValue<U>(v); });
326   return rets;
327 }
328 
329 inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); }
330 
331 template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
332 inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) {
333   return NewValueNode(MakeValue(x));
334 }
335 
336 template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
337 inline ValueNodePtr NewValueNode(const T &x) {
338   return NewValueNode(MakeValue(x));
339 }
340 }  // namespace mindspore
341 
342 #endif  // MINDSPORE_CORE_IR_VALUE_H_
343