• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_UTILS_BASE_REF_H_
17 #define MINDSPORE_CORE_UTILS_BASE_REF_H_
18 
19 #include <type_traits>
20 #include <algorithm>
21 #include <vector>
22 #include <set>
23 #include <string>
24 #include <memory>
25 #include <sstream>
26 #include <utility>
27 #include <iterator>
28 
29 #include "ir/value.h"
30 
31 namespace mindspore {
32 class BaseRef;
33 class VectorRef;
34 class SetRef;
35 class RunFunctionRef;
36 
37 using iterator = std::vector<BaseRef>::iterator;
38 using const_iterator = std::vector<BaseRef>::const_iterator;
39 using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
40 
41 using RunFunc = std::function<VectorRef(const VectorRef &args)>;
42 using RunFuncPtr = std::shared_ptr<RunFunc>;
43 
44 template <typename T>
45 using remove_reference_t = typename std::remove_reference<T>::type;
46 template <typename T>
47 using remove_const_t = typename std::remove_const<T>::type;
48 template <typename T>
49 using is_base = std::is_base_of<Base, remove_reference_t<T>>;
50 template <typename T>
51 using is_value = std::is_base_of<Value, remove_reference_t<T>>;
52 template <typename T>
53 using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
54 
55 iterator ConstIteratorCast(std::vector<BaseRef> *v, const_iterator iter);
56 
MakeNode(const std::vector<BaseRef> & elements)57 inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
58   return std::make_shared<VectorRef>(elements);
59 }
60 
MakeNode(std::initializer_list<BaseRef> elements)61 inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
62   return std::make_shared<VectorRef>(elements);
63 }
64 
65 // Anfnode, Funcgraph and some not value node class
66 template <typename T,
67           typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
68                                   int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)69 inline BasePtr MakeNode(const T &v) {
70   return v;
71 }
72 
73 template <typename T, typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value,
74                                               int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)75 inline BasePtr MakeNode(const T &v) {
76   return MakeValue(v);
77 }
78 
MakeNode(const VectorRef & a)79 inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(std::move(a)); }
MakeNode(const AnfNodePtrList & a)80 inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
81   std::vector<BaseRef> ret;
82   (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
83   return std::make_shared<VectorRef>(ret);
84 }
MakeNode(const SetRef & a)85 inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(std::move(a)); }
MakeNode(const RunFuncPtr & a)86 inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
87 
88 class MS_CORE_API BaseRef : public Base {
89  public:
BaseRef()90   BaseRef() : m_ptr(nullptr) {}
91   BaseRef(const BaseRef &other);
copy()92   virtual std::shared_ptr<Base> copy() const { return m_ptr; }
93 
BaseRef(BaseRef && other)94   BaseRef(BaseRef &&other) : Base(other) {
95     m_ptr = other.m_ptr;
96     other.m_ptr = nullptr;
97   }
98 
99   // right reference constructor
100   template <class T,
101             class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
BaseRef(T && t)102   BaseRef(T &&t) {  // NOLINT
103     m_ptr = MakeNode(t);
104   }
105 
~BaseRef()106   ~BaseRef() override { m_ptr = nullptr; }
107 
108   MS_DECLARE_PARENT(BaseRef, Base)
109 
110   bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
111 
112   virtual bool operator==(const BaseRef &other) const;
113 
114   // left reference
115   virtual BaseRef &operator=(const BaseRef &other);
116   // right reference
117   virtual BaseRef &operator=(BaseRef &&other);
118 
hash()119   std::size_t hash() const override {
120     if (m_ptr == nullptr) {
121       MS_LOG(ERROR) << "Invalid m_ptr";
122       return 0;
123     }
124     return m_ptr->hash();
125   }
126 
127   std::string ToString() const override;
128 
is_null()129   bool is_null() const { return m_ptr == nullptr; }
130 
131   virtual uint32_t type() const;
132 
133   BasePtr m_ptr;  // point to real data
134 };
135 using BaseRefPtr = std::shared_ptr<BaseRef>;
136 
137 struct BaseRefHash {
operatorBaseRefHash138   std::size_t operator()(const BaseRef &c) const { return c.hash(); }
139 };
140 
141 struct BaseRefLess {
operatorBaseRefLess142   bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
143 };
144 
145 namespace utils {
146 // judge isa relation
147 // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
148 template <typename T,
149           typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)150 bool isa(const BaseRef &handle) {
151   if (!handle.m_ptr) {
152     return false;
153   }
154   return handle.m_ptr->isa<T>();
155 }
156 
157 // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
158 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
159           typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)160 bool isa(const BaseRef &handle) {
161   if (handle.m_ptr == nullptr) {
162     return typeid(handle.m_ptr) == typeid(T);
163   }
164 
165   if (handle.m_ptr->isa<U>()) {
166     return true;
167   }
168 
169   // constptr isa<anfnodeptr> can be true
170   return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
171 }
172 
173 // isa<int32>(handle)
174 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
isa(const BaseRef & handle)175 bool isa(const BaseRef &handle) {
176   if (handle.m_ptr == nullptr) {
177     return false;
178   }
179   return handle.m_ptr->isa<U>();
180 }
181 
182 // isa<BaseRef>(handle), judge reference or ptr
183 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)184 bool isa(const BaseRef &handle) {
185   static const uint32_t tid = Base::GetTypeId(typeid(T).name());
186   return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa<T>());
187 }
188 
189 // valueref -> C++ type
190 // cast<int64_t>(handle)
191 template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int64_t>::type =
192                         static_cast<int64_t>(0)>
cast(const BaseRef & handle)193 T cast(const BaseRef &handle) {
194   T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
195   return std::move(ret);
196 }
197 
198 // valueref -> valueref type
199 // cast<VectorRef>(handle)
200 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)201 const T &cast(const BaseRef &handle) {
202   if (handle.m_ptr) {
203     return static_cast<const T &>(*handle.m_ptr);
204   }
205 
206   return std::move(static_cast<const T &>(handle));
207 }
208 
209 // valueref -> nodeptr type
210 // cast<FuncGraphPtr>(handle)
211 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
212           typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
213                                   int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)214 T cast(const BaseRef &handle) {
215   if (!handle.m_ptr) {
216     MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
217   }
218 
219   auto m = handle.m_ptr->cast<T>();
220   if (nullptr != m) {
221     return m;
222   }
223   return std::static_pointer_cast<U>(handle.m_ptr);
224 }
225 }  // namespace utils
226 
227 class VectorRef : public BaseRef {
228  public:
229   using value_type = BaseRef;
230 
VectorRef()231   VectorRef() {}
VectorRef(const std::vector<BaseRef> & elements)232   explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
VectorRef(const const_iterator & begin,const const_iterator & end)233   VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
234 
235   // left reference
236   virtual VectorRef &operator=(const VectorRef &other);
237 
238   ~VectorRef() override = default;
239 
copy()240   std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
241 
empty()242   bool empty() const { return (elements_.size() == 0); }
243 
size()244   std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(VectorRef,BaseRef)245   MS_DECLARE_PARENT(VectorRef, BaseRef)
246 
247   const BaseRef &operator[](const std::size_t &dim) const {
248     if (dim >= size()) {
249       MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
250     }
251     return elements_[dim];
252   }
253 
254   BaseRef &operator[](const std::size_t &dim) {
255     if (dim >= size()) {
256       MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
257     }
258     return elements_[dim];
259   }
260 
type()261   uint32_t type() const override { return tid(); }
262   std::string ToString() const override;
elements()263   std::vector<BaseRef> &elements() { return elements_; }
clear()264   void clear() { elements_.clear(); }
265 
266   bool operator==(const BaseRef &other) const override;
267   bool operator==(const VectorRef &other) const;
268 
push_back(const BaseRef & value)269   void push_back(const BaseRef &value) { elements_.push_back(value); }
push_back(BaseRef && value)270   void push_back(BaseRef &&value) { elements_.push_back(value); }
271 
emplace_back(const BaseRef & value)272   void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
emplace_back(BaseRef && value)273   void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
274 
275   template <class InputIt>
insert(const iterator pos,const InputIt first,const InputIt last)276   void insert(const iterator pos, const InputIt first, const InputIt last) {
277     (void)elements_.insert(pos, first, last);
278   }
279 
280   template <class InputIt>
insert(const const_iterator cpos,const InputIt first,const InputIt last)281   void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
282     auto pos = ConstIteratorCast(&elements_, cpos);
283     (void)elements_.insert(pos, first, last);
284   }
285 
begin()286   const_iterator begin() const { return elements_.begin(); }
end()287   const_iterator end() const { return elements_.end(); }
288 
rbegin()289   const_reverse_iterator rbegin() const { return elements_.rbegin(); }
rend()290   const_reverse_iterator rend() const { return elements_.rend(); }
291 
erase(const const_iterator cpos)292   iterator erase(const const_iterator cpos) {
293     auto pos = ConstIteratorCast(&elements_, cpos);
294     return elements_.erase(pos);
295   }
296 
erase(const const_iterator cfirst,const const_iterator clast)297   iterator erase(const const_iterator cfirst, const const_iterator clast) {
298     auto first = ConstIteratorCast(&elements_, cfirst);
299     auto last = ConstIteratorCast(&elements_, clast);
300     return elements_.erase(first, last);
301   }
302 
hash()303   std::size_t hash() const override {
304     std::stringstream buffer;
305     buffer << ToString();
306     return std::hash<std::string>()(buffer.str());
307   }
308 
309   std::vector<BaseRef> elements_;
310 };
311 
312 using VectorRefPtr = std::shared_ptr<VectorRef>;
313 
314 using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
315 using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
316 
317 struct VectorRefHash {
operatorVectorRefHash318   std::size_t operator()(const VectorRef &c) const { return c.hash(); }
319 };
320 
321 class SetRef : public BaseRef {
322  public:
SetRef()323   SetRef() {}
SetRef(const std::set<BaseRef,BaseRefLess> & elements)324   explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
SetRef(const std::initializer_list<BaseRef> elements)325   SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
SetRef(const const_set_iterator & begin,const const_set_iterator & end)326   SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
327 
328   // left reference
329   virtual SetRef &operator=(const SetRef &other);
330 
331   bool operator==(const BaseRef &other) const override;
332   bool operator==(const SetRef &other) const;
333 
334   ~SetRef() override = default;
335 
copy()336   std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
337 
empty()338   bool empty() const { return (elements_.size() == 0); }
339 
size()340   std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(SetRef,BaseRef)341   MS_DECLARE_PARENT(SetRef, BaseRef)
342 
343   uint32_t type() const override { return tid(); }
344   std::string ToString() const override;
elements()345   std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
clear()346   void clear() { elements_.clear(); }
347 
insert(const BaseRef & elem)348   void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
349 
begin()350   const_set_iterator begin() const { return elements_.begin(); }
end()351   const_set_iterator end() const { return elements_.end(); }
352 
353   template <class InputIt>
insert(const InputIt first,const InputIt last)354   void insert(const InputIt first, const InputIt last) {
355     (void)elements_.insert(first, last);
356   }
357 
count(const BaseRef & elem)358   std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
find(const BaseRef & elem)359   const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
360 
361   std::set<BaseRef, BaseRefLess> elements_;
362 };
363 
364 using SetRefPtr = std::shared_ptr<SetRef>;
365 
366 class RunFunctionRef : public BaseRef {
367  public:
RunFunctionRef()368   RunFunctionRef() {}
RunFunctionRef(const RunFuncPtr & ref_func)369   explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
370 
371   ~RunFunctionRef() override = default;
MS_DECLARE_PARENT(RunFunctionRef,BaseRef)372   MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
373 
374   uint32_t type() const override { return tid(); }
ToString()375   std::string ToString() const override { return std::string("RunFunctionRef"); }
376   bool operator==(const BaseRef &other) const override;
377   bool operator==(const RunFunctionRef &other) const;
378 
379   RunFuncPtr func_;
380 };
381 }  // namespace mindspore
382 
383 #endif  // MINDSPORE_CORE_UTILS_BASE_REF_H_
384