• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_UTILS_BASE_REF_H_
17 #define MINDSPORE_CORE_UTILS_BASE_REF_H_
18 
19 #include <type_traits>
20 #include <algorithm>
21 #include <vector>
22 #include <set>
23 #include <string>
24 #include <memory>
25 #include <sstream>
26 #include <utility>
27 #include <iterator>
28 
29 #include "ir/value.h"
30 
31 namespace mindspore {
32 class BaseRef;
33 class VectorRef;
34 class SetRef;
35 class RunFunctionRef;
36 
37 using iterator = std::vector<BaseRef>::iterator;
38 using const_iterator = std::vector<BaseRef>::const_iterator;
39 using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
40 
41 using RunFunc = std::function<VectorRef(const VectorRef &args)>;
42 using RunFuncPtr = std::shared_ptr<RunFunc>;
43 
44 template <typename T>
45 using remove_reference_t = typename std::remove_reference<T>::type;
46 template <typename T>
47 using remove_const_t = typename std::remove_const<T>::type;
48 template <typename T>
49 using is_base = std::is_base_of<Base, remove_reference_t<T>>;
50 template <typename T>
51 using is_value = std::is_base_of<Value, remove_reference_t<T>>;
52 template <typename T>
53 using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
54 
55 MS_CORE_API iterator ConstIteratorCast(std::vector<BaseRef> *v, const const_iterator iter);
56 
MakeNode(const std::vector<BaseRef> & elements)57 inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
58   return std::make_shared<VectorRef>(elements);
59 }
60 
MakeNode(std::initializer_list<BaseRef> elements)61 inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
62   return std::make_shared<VectorRef>(elements);
63 }
64 
65 // Anfnode, Funcgraph and some not value node class
66 template <typename T,
67           typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
68                                   int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)69 inline BasePtr MakeNode(const T &v) {
70   return v;
71 }
72 
73 template <typename T, typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value,
74                                               int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)75 inline BasePtr MakeNode(const T &v) {
76   return MakeValue(v);
77 }
78 
MakeNode(const VectorRef & a)79 inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(a); }
MakeNode(const AnfNodePtrList & a)80 inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
81   std::vector<BaseRef> ret;
82   (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
83   return std::make_shared<VectorRef>(ret);
84 }
MakeNode(const SetRef & a)85 inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(a); }
MakeNode(const RunFuncPtr & a)86 inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
87 
88 /// \brief BaseRef is a base class which store a Base pointer to some real data.
89 class MS_CORE_API BaseRef : public Base {
90  public:
91   /// \brief The Constructor of BaseRef.
92   ///
93   /// \return The instance of BaseRef.
BaseRef()94   BaseRef() : m_ptr(nullptr) {}
95 
96   /// \brief The copy constructor of BaseRef.
97   ///
98   /// \param[in] other Define another instance of BaseRef.
99   ///
100   /// \return The instance of BaseRef.
101   BaseRef(const BaseRef &other);
102 
103   /// \brief Get the Base pointer to some real data.
104   ///
105   /// \return The Base pointer.
copy()106   virtual std::shared_ptr<Base> copy() const { return m_ptr; }
107 
108   /// \brief The move constructor of BaseRef.
109   ///
110   /// \param[in] other Define another instance of BaseRef.
111   ///
112   /// \return The instance of BaseRef.
BaseRef(BaseRef && other)113   BaseRef(BaseRef &&other) : Base(other) {
114     m_ptr = other.m_ptr;
115     other.m_ptr = nullptr;
116   }
117 
118   /// \brief The move constructor of BaseRef with template.
119   ///
120   /// \param[in] t Define an instance of T.
121   ///
122   /// \return The instance of BaseRef.
123   template <class T,
124             class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
BaseRef(T && t)125   BaseRef(T &&t) {  // NOLINT
126     m_ptr = MakeNode(t);
127   }
128 
129   /// \brief The destructor of BaseRef.
~BaseRef()130   ~BaseRef() override { m_ptr = nullptr; }
131 
132   MS_DECLARE_PARENT(BaseRef, Base)
133 
134   /// \brief The operator overloading for "!=".
135   ///
136   /// \param[in] other Define the right operand of "!=".
137   ///
138   /// \return The comparison result.
139   bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
140 
141   /// \brief The operator overloading for "==".
142   ///
143   /// \param[in] other Define the right operand of "==".
144   ///
145   /// \return The comparison result.
146   virtual bool operator==(const BaseRef &other) const;
147 
148   /// \brief The copy assignment operator of BaseRef.
149   ///
150   /// \param[in] other Define another instance of BaseRef.
151   ///
152   /// \return The instance of BaseRef.
153   BaseRef &operator=(const BaseRef &other);
154 
155   /// \brief The move assignment operator of BaseRef.
156   ///
157   /// \param[in] other Define another instance of BaseRef.
158   ///
159   /// \return The instance of BaseRef.
160   virtual BaseRef &operator=(BaseRef &&other);
161 
hash()162   std::size_t hash() const override {
163     if (m_ptr == nullptr) {
164       MS_LOG(ERROR) << "Invalid m_ptr";
165       return 0;
166     }
167     return m_ptr->hash();
168   }
169 
170   std::string ToString() const override;
171 
172   /// \brief Judge whether the real data is null.
173   ///
174   /// \return The result of the judgment.
is_null()175   bool is_null() const { return m_ptr == nullptr; }
176 
177   /// \brief Get the type id of the real data.
178   ///
179   /// \return The type id of the real data.
180   virtual uint32_t type() const;
181 
182   BasePtr m_ptr; /**< pointer to the real data */
183 };
184 using BaseRefPtr = std::shared_ptr<BaseRef>;
185 
186 struct BaseRefHash {
operatorBaseRefHash187   std::size_t operator()(const BaseRef &c) const { return c.hash(); }
188 };
189 
190 struct BaseRefLess {
operatorBaseRefLess191   bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
192 };
193 
194 namespace utils {
195 // judge isa relation
196 // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
197 template <typename T,
198           typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)199 bool isa(const BaseRef &handle) {
200   if (!handle.m_ptr) {
201     return false;
202   }
203   return handle.m_ptr->isa<T>();
204 }
205 
206 // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
207 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
208           typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)209 bool isa(const BaseRef &handle) {
210   if (handle.m_ptr == nullptr) {
211     return typeid(handle.m_ptr) == typeid(T);
212   }
213 
214   if (handle.m_ptr->isa<U>()) {
215     return true;
216   }
217 
218   // constptr isa<anfnodeptr> can be true
219   return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
220 }
221 
222 // isa<int32>(handle)
223 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
isa(const BaseRef & handle)224 bool isa(const BaseRef &handle) {
225   if (handle.m_ptr == nullptr) {
226     return false;
227   }
228   return handle.m_ptr->isa<U>();
229 }
230 
231 // isa<BaseRef>(handle), judge reference or ptr
232 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)233 bool isa(const BaseRef &handle) {
234   return handle.isa<T>() || (handle.m_ptr && handle.m_ptr->isa<T>());
235 }
236 
237 // valueref -> C++ type
238 // cast<int64_t>(handle)
239 template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int64_t>::type =
240                         static_cast<int64_t>(0)>
cast(const BaseRef & handle)241 T cast(const BaseRef &handle) {
242   T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
243   return std::move(ret);
244 }
245 
246 // valueref -> valueref type
247 // cast<VectorRef>(handle)
248 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)249 const T &cast(const BaseRef &handle) {
250   if (handle.m_ptr) {
251     return static_cast<const T &>(*handle.m_ptr);
252   }
253 
254   return static_cast<const T &>(handle);
255 }
256 
257 // valueref -> nodeptr type
258 // cast<FuncGraphPtr>(handle)
259 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
260           typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
261                                   int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)262 T cast(const BaseRef &handle) {
263   if (!handle.m_ptr) {
264     MS_LOG(INTERNAL_EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
265   }
266 
267   auto m = handle.m_ptr->cast<T>();
268   if (nullptr != m) {
269     return m;
270   }
271   return std::static_pointer_cast<U>(handle.m_ptr);
272 }
273 }  // namespace utils
274 
275 class MS_CORE_API VectorRef : public BaseRef {
276  public:
277   using value_type = BaseRef;
278 
VectorRef()279   VectorRef() {}
VectorRef(const std::vector<BaseRef> & elements)280   explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
VectorRef(const const_iterator & begin,const const_iterator & end)281   VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
282 
283   // left reference
284   VectorRef(const VectorRef &other);
285   VectorRef &operator=(const VectorRef &other);
286 
287   ~VectorRef() override = default;
288 
copy()289   std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
290 
empty()291   bool empty() const { return (elements_.size() == 0); }
292 
size()293   std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(VectorRef,BaseRef)294   MS_DECLARE_PARENT(VectorRef, BaseRef)
295 
296   const BaseRef &operator[](const std::size_t &dim) const {
297     if (dim >= size()) {
298       MS_LOG(INTERNAL_EXCEPTION) << "Out of the size of the tuple.";
299     }
300     return elements_[dim];
301   }
302 
303   BaseRef &operator[](const std::size_t &dim) {
304     if (dim >= size()) {
305       MS_LOG(INTERNAL_EXCEPTION) << "Out of the size of the tuple.";
306     }
307     return elements_[dim];
308   }
309 
type()310   uint32_t type() const override { return tid(); }
311   std::string ToString() const override;
elements()312   std::vector<BaseRef> &elements() { return elements_; }
clear()313   void clear() { elements_.clear(); }
314 
315   bool operator==(const BaseRef &other) const override;
316   bool operator==(const VectorRef &other) const;
317 
push_back(const BaseRef & value)318   void push_back(const BaseRef &value) { elements_.push_back(value); }
push_back(BaseRef && value)319   void push_back(BaseRef &&value) { elements_.push_back(value); }
320 
emplace_back(const BaseRef & value)321   void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
emplace_back(BaseRef && value)322   void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
323 
324   template <class InputIt>
insert(const iterator pos,const InputIt first,const InputIt last)325   void insert(const iterator pos, const InputIt first, const InputIt last) {
326     (void)elements_.insert(pos, first, last);
327   }
328 
329   template <class InputIt>
insert(const const_iterator cpos,const InputIt first,const InputIt last)330   void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
331     auto pos = ConstIteratorCast(&elements_, cpos);
332     (void)elements_.insert(pos, first, last);
333   }
334 
begin()335   const_iterator begin() const { return elements_.begin(); }
end()336   const_iterator end() const { return elements_.end(); }
337 
rbegin()338   const_reverse_iterator rbegin() const { return elements_.rbegin(); }
rend()339   const_reverse_iterator rend() const { return elements_.rend(); }
340 
erase(const const_iterator cpos)341   iterator erase(const const_iterator cpos) {
342     auto pos = ConstIteratorCast(&elements_, cpos);
343     return elements_.erase(pos);
344   }
345 
erase(const const_iterator cfirst,const const_iterator clast)346   iterator erase(const const_iterator cfirst, const const_iterator clast) {
347     auto first = ConstIteratorCast(&elements_, cfirst);
348     auto last = ConstIteratorCast(&elements_, clast);
349     return elements_.erase(first, last);
350   }
351 
hash()352   std::size_t hash() const override {
353     std::stringstream buffer;
354     buffer << ToString();
355     return std::hash<std::string>()(buffer.str());
356   }
357 
358   std::vector<BaseRef> elements_;
359 };
360 
361 using VectorRefPtr = std::shared_ptr<VectorRef>;
362 
363 using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
364 using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
365 
366 struct VectorRefHash {
operatorVectorRefHash367   std::size_t operator()(const VectorRef &c) const { return c.hash(); }
368 };
369 
370 class MS_CORE_API SetRef : public BaseRef {
371  public:
SetRef()372   SetRef() {}
SetRef(const std::set<BaseRef,BaseRefLess> & elements)373   explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
SetRef(const std::initializer_list<BaseRef> elements)374   SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
SetRef(const const_set_iterator & begin,const const_set_iterator & end)375   SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
376 
377   // left reference
378   SetRef(const SetRef &other);
379   SetRef &operator=(const SetRef &other);
380 
381   bool operator==(const BaseRef &other) const override;
382   bool operator==(const SetRef &other) const;
383 
384   ~SetRef() override = default;
385 
copy()386   std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
387 
empty()388   bool empty() const { return (elements_.size() == 0); }
389 
size()390   std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(SetRef,BaseRef)391   MS_DECLARE_PARENT(SetRef, BaseRef)
392 
393   uint32_t type() const override { return tid(); }
394   std::string ToString() const override;
elements()395   std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
clear()396   void clear() { elements_.clear(); }
397 
insert(const BaseRef & elem)398   void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
399 
begin()400   const_set_iterator begin() const { return elements_.begin(); }
end()401   const_set_iterator end() const { return elements_.end(); }
402 
403   template <class InputIt>
insert(const InputIt first,const InputIt last)404   void insert(const InputIt first, const InputIt last) {
405     (void)elements_.insert(first, last);
406   }
407 
count(const BaseRef & elem)408   std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
find(const BaseRef & elem)409   const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
410 
411   std::set<BaseRef, BaseRefLess> elements_;
412 };
413 
414 using SetRefPtr = std::shared_ptr<SetRef>;
415 
416 class MS_CORE_API RunFunctionRef : public BaseRef {
417  public:
RunFunctionRef()418   RunFunctionRef() {}
RunFunctionRef(const RunFuncPtr & ref_func)419   explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
420 
421   ~RunFunctionRef() override = default;
MS_DECLARE_PARENT(RunFunctionRef,BaseRef)422   MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
423 
424   uint32_t type() const override { return tid(); }
ToString()425   std::string ToString() const override { return std::string("RunFunctionRef"); }
426   bool operator==(const BaseRef &other) const override;
427   bool operator==(const RunFunctionRef &other) const;
428 
429   RunFuncPtr func_;
430 };
431 }  // namespace mindspore
432 
433 #endif  // MINDSPORE_CORE_UTILS_BASE_REF_H_
434