1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CORE_UTILS_BASE_REF_H_
17 #define MINDSPORE_CORE_UTILS_BASE_REF_H_
18
19 #include <type_traits>
20 #include <algorithm>
21 #include <vector>
22 #include <set>
23 #include <string>
24 #include <memory>
25 #include <sstream>
26 #include <utility>
27 #include <iterator>
28
29 #include "ir/value.h"
30
31 namespace mindspore {
32 class BaseRef;
33 class VectorRef;
34 class SetRef;
35 class RunFunctionRef;
36
37 using iterator = std::vector<BaseRef>::iterator;
38 using const_iterator = std::vector<BaseRef>::const_iterator;
39 using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
40
41 using RunFunc = std::function<VectorRef(const VectorRef &args)>;
42 using RunFuncPtr = std::shared_ptr<RunFunc>;
43
44 template <typename T>
45 using remove_reference_t = typename std::remove_reference<T>::type;
46 template <typename T>
47 using remove_const_t = typename std::remove_const<T>::type;
48 template <typename T>
49 using is_base = std::is_base_of<Base, remove_reference_t<T>>;
50 template <typename T>
51 using is_value = std::is_base_of<Value, remove_reference_t<T>>;
52 template <typename T>
53 using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
54
55 iterator ConstIteratorCast(std::vector<BaseRef> *v, const_iterator iter);
56
MakeNode(const std::vector<BaseRef> & elements)57 inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
58 return std::make_shared<VectorRef>(elements);
59 }
60
MakeNode(std::initializer_list<BaseRef> elements)61 inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
62 return std::make_shared<VectorRef>(elements);
63 }
64
65 // Anfnode, Funcgraph and some not value node class
66 template <typename T,
67 typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
68 int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)69 inline BasePtr MakeNode(const T &v) {
70 return v;
71 }
72
73 template <typename T, typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value,
74 int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)75 inline BasePtr MakeNode(const T &v) {
76 return MakeValue(v);
77 }
78
MakeNode(const VectorRef & a)79 inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(std::move(a)); }
MakeNode(const AnfNodePtrList & a)80 inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
81 std::vector<BaseRef> ret;
82 (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
83 return std::make_shared<VectorRef>(ret);
84 }
MakeNode(const SetRef & a)85 inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(std::move(a)); }
MakeNode(const RunFuncPtr & a)86 inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
87
88 class MS_CORE_API BaseRef : public Base {
89 public:
BaseRef()90 BaseRef() : m_ptr(nullptr) {}
91 BaseRef(const BaseRef &other);
copy()92 virtual std::shared_ptr<Base> copy() const { return m_ptr; }
93
BaseRef(BaseRef && other)94 BaseRef(BaseRef &&other) : Base(other) {
95 m_ptr = other.m_ptr;
96 other.m_ptr = nullptr;
97 }
98
99 // right reference constructor
100 template <class T,
101 class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
BaseRef(T && t)102 BaseRef(T &&t) { // NOLINT
103 m_ptr = MakeNode(t);
104 }
105
~BaseRef()106 ~BaseRef() override { m_ptr = nullptr; }
107
108 MS_DECLARE_PARENT(BaseRef, Base)
109
110 bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
111
112 virtual bool operator==(const BaseRef &other) const;
113
114 // left reference
115 virtual BaseRef &operator=(const BaseRef &other);
116 // right reference
117 virtual BaseRef &operator=(BaseRef &&other);
118
hash()119 std::size_t hash() const override {
120 if (m_ptr == nullptr) {
121 MS_LOG(ERROR) << "Invalid m_ptr";
122 return 0;
123 }
124 return m_ptr->hash();
125 }
126
127 std::string ToString() const override;
128
is_null()129 bool is_null() const { return m_ptr == nullptr; }
130
131 virtual uint32_t type() const;
132
133 BasePtr m_ptr; // point to real data
134 };
135 using BaseRefPtr = std::shared_ptr<BaseRef>;
136
137 struct BaseRefHash {
operatorBaseRefHash138 std::size_t operator()(const BaseRef &c) const { return c.hash(); }
139 };
140
141 struct BaseRefLess {
operatorBaseRefLess142 bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
143 };
144
145 namespace utils {
146 // judge isa relation
147 // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
148 template <typename T,
149 typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)150 bool isa(const BaseRef &handle) {
151 if (!handle.m_ptr) {
152 return false;
153 }
154 return handle.m_ptr->isa<T>();
155 }
156
157 // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
158 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
159 typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)160 bool isa(const BaseRef &handle) {
161 if (handle.m_ptr == nullptr) {
162 return typeid(handle.m_ptr) == typeid(T);
163 }
164
165 if (handle.m_ptr->isa<U>()) {
166 return true;
167 }
168
169 // constptr isa<anfnodeptr> can be true
170 return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
171 }
172
173 // isa<int32>(handle)
174 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
isa(const BaseRef & handle)175 bool isa(const BaseRef &handle) {
176 if (handle.m_ptr == nullptr) {
177 return false;
178 }
179 return handle.m_ptr->isa<U>();
180 }
181
182 // isa<BaseRef>(handle), judge reference or ptr
183 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)184 bool isa(const BaseRef &handle) {
185 static const uint32_t tid = Base::GetTypeId(typeid(T).name());
186 return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa<T>());
187 }
188
189 // valueref -> C++ type
190 // cast<int64_t>(handle)
191 template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int64_t>::type =
192 static_cast<int64_t>(0)>
cast(const BaseRef & handle)193 T cast(const BaseRef &handle) {
194 T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
195 return std::move(ret);
196 }
197
198 // valueref -> valueref type
199 // cast<VectorRef>(handle)
200 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)201 const T &cast(const BaseRef &handle) {
202 if (handle.m_ptr) {
203 return static_cast<const T &>(*handle.m_ptr);
204 }
205
206 return std::move(static_cast<const T &>(handle));
207 }
208
209 // valueref -> nodeptr type
210 // cast<FuncGraphPtr>(handle)
211 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
212 typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
213 int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)214 T cast(const BaseRef &handle) {
215 if (!handle.m_ptr) {
216 MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
217 }
218
219 auto m = handle.m_ptr->cast<T>();
220 if (nullptr != m) {
221 return m;
222 }
223 return std::static_pointer_cast<U>(handle.m_ptr);
224 }
225 } // namespace utils
226
227 class VectorRef : public BaseRef {
228 public:
229 using value_type = BaseRef;
230
VectorRef()231 VectorRef() {}
VectorRef(const std::vector<BaseRef> & elements)232 explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
VectorRef(const const_iterator & begin,const const_iterator & end)233 VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
234
235 // left reference
236 virtual VectorRef &operator=(const VectorRef &other);
237
238 ~VectorRef() override = default;
239
copy()240 std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
241
empty()242 bool empty() const { return (elements_.size() == 0); }
243
size()244 std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(VectorRef,BaseRef)245 MS_DECLARE_PARENT(VectorRef, BaseRef)
246
247 const BaseRef &operator[](const std::size_t &dim) const {
248 if (dim >= size()) {
249 MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
250 }
251 return elements_[dim];
252 }
253
254 BaseRef &operator[](const std::size_t &dim) {
255 if (dim >= size()) {
256 MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
257 }
258 return elements_[dim];
259 }
260
type()261 uint32_t type() const override { return tid(); }
262 std::string ToString() const override;
elements()263 std::vector<BaseRef> &elements() { return elements_; }
clear()264 void clear() { elements_.clear(); }
265
266 bool operator==(const BaseRef &other) const override;
267 bool operator==(const VectorRef &other) const;
268
push_back(const BaseRef & value)269 void push_back(const BaseRef &value) { elements_.push_back(value); }
push_back(BaseRef && value)270 void push_back(BaseRef &&value) { elements_.push_back(value); }
271
emplace_back(const BaseRef & value)272 void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
emplace_back(BaseRef && value)273 void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
274
275 template <class InputIt>
insert(const iterator pos,const InputIt first,const InputIt last)276 void insert(const iterator pos, const InputIt first, const InputIt last) {
277 (void)elements_.insert(pos, first, last);
278 }
279
280 template <class InputIt>
insert(const const_iterator cpos,const InputIt first,const InputIt last)281 void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
282 auto pos = ConstIteratorCast(&elements_, cpos);
283 (void)elements_.insert(pos, first, last);
284 }
285
begin()286 const_iterator begin() const { return elements_.begin(); }
end()287 const_iterator end() const { return elements_.end(); }
288
rbegin()289 const_reverse_iterator rbegin() const { return elements_.rbegin(); }
rend()290 const_reverse_iterator rend() const { return elements_.rend(); }
291
erase(const const_iterator cpos)292 iterator erase(const const_iterator cpos) {
293 auto pos = ConstIteratorCast(&elements_, cpos);
294 return elements_.erase(pos);
295 }
296
erase(const const_iterator cfirst,const const_iterator clast)297 iterator erase(const const_iterator cfirst, const const_iterator clast) {
298 auto first = ConstIteratorCast(&elements_, cfirst);
299 auto last = ConstIteratorCast(&elements_, clast);
300 return elements_.erase(first, last);
301 }
302
hash()303 std::size_t hash() const override {
304 std::stringstream buffer;
305 buffer << ToString();
306 return std::hash<std::string>()(buffer.str());
307 }
308
309 std::vector<BaseRef> elements_;
310 };
311
312 using VectorRefPtr = std::shared_ptr<VectorRef>;
313
314 using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
315 using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
316
317 struct VectorRefHash {
operatorVectorRefHash318 std::size_t operator()(const VectorRef &c) const { return c.hash(); }
319 };
320
321 class SetRef : public BaseRef {
322 public:
SetRef()323 SetRef() {}
SetRef(const std::set<BaseRef,BaseRefLess> & elements)324 explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
SetRef(const std::initializer_list<BaseRef> elements)325 SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
SetRef(const const_set_iterator & begin,const const_set_iterator & end)326 SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
327
328 // left reference
329 virtual SetRef &operator=(const SetRef &other);
330
331 bool operator==(const BaseRef &other) const override;
332 bool operator==(const SetRef &other) const;
333
334 ~SetRef() override = default;
335
copy()336 std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
337
empty()338 bool empty() const { return (elements_.size() == 0); }
339
size()340 std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(SetRef,BaseRef)341 MS_DECLARE_PARENT(SetRef, BaseRef)
342
343 uint32_t type() const override { return tid(); }
344 std::string ToString() const override;
elements()345 std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
clear()346 void clear() { elements_.clear(); }
347
insert(const BaseRef & elem)348 void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
349
begin()350 const_set_iterator begin() const { return elements_.begin(); }
end()351 const_set_iterator end() const { return elements_.end(); }
352
353 template <class InputIt>
insert(const InputIt first,const InputIt last)354 void insert(const InputIt first, const InputIt last) {
355 (void)elements_.insert(first, last);
356 }
357
count(const BaseRef & elem)358 std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
find(const BaseRef & elem)359 const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
360
361 std::set<BaseRef, BaseRefLess> elements_;
362 };
363
364 using SetRefPtr = std::shared_ptr<SetRef>;
365
366 class RunFunctionRef : public BaseRef {
367 public:
RunFunctionRef()368 RunFunctionRef() {}
RunFunctionRef(const RunFuncPtr & ref_func)369 explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
370
371 ~RunFunctionRef() override = default;
MS_DECLARE_PARENT(RunFunctionRef,BaseRef)372 MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
373
374 uint32_t type() const override { return tid(); }
ToString()375 std::string ToString() const override { return std::string("RunFunctionRef"); }
376 bool operator==(const BaseRef &other) const override;
377 bool operator==(const RunFunctionRef &other) const;
378
379 RunFuncPtr func_;
380 };
381 } // namespace mindspore
382
383 #endif // MINDSPORE_CORE_UTILS_BASE_REF_H_
384