1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CORE_UTILS_BASE_REF_H_
17 #define MINDSPORE_CORE_UTILS_BASE_REF_H_
18
19 #include <type_traits>
20 #include <algorithm>
21 #include <vector>
22 #include <set>
23 #include <string>
24 #include <memory>
25 #include <sstream>
26 #include <utility>
27 #include <iterator>
28
29 #include "ir/value.h"
30
31 namespace mindspore {
32 class BaseRef;
33 class VectorRef;
34 class SetRef;
35 class RunFunctionRef;
36
37 using iterator = std::vector<BaseRef>::iterator;
38 using const_iterator = std::vector<BaseRef>::const_iterator;
39 using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
40
41 using RunFunc = std::function<VectorRef(const VectorRef &args)>;
42 using RunFuncPtr = std::shared_ptr<RunFunc>;
43
44 template <typename T>
45 using remove_reference_t = typename std::remove_reference<T>::type;
46 template <typename T>
47 using remove_const_t = typename std::remove_const<T>::type;
48 template <typename T>
49 using is_base = std::is_base_of<Base, remove_reference_t<T>>;
50 template <typename T>
51 using is_value = std::is_base_of<Value, remove_reference_t<T>>;
52 template <typename T>
53 using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
54
55 MS_CORE_API iterator ConstIteratorCast(std::vector<BaseRef> *v, const const_iterator iter);
56
MakeNode(const std::vector<BaseRef> & elements)57 inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
58 return std::make_shared<VectorRef>(elements);
59 }
60
MakeNode(std::initializer_list<BaseRef> elements)61 inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
62 return std::make_shared<VectorRef>(elements);
63 }
64
65 // Anfnode, Funcgraph and some not value node class
66 template <typename T,
67 typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
68 int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)69 inline BasePtr MakeNode(const T &v) {
70 return v;
71 }
72
73 template <typename T, typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value,
74 int64_t>::type = static_cast<int64_t>(0)>
MakeNode(const T & v)75 inline BasePtr MakeNode(const T &v) {
76 return MakeValue(v);
77 }
78
MakeNode(const VectorRef & a)79 inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(a); }
MakeNode(const AnfNodePtrList & a)80 inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
81 std::vector<BaseRef> ret;
82 (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
83 return std::make_shared<VectorRef>(ret);
84 }
MakeNode(const SetRef & a)85 inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(a); }
MakeNode(const RunFuncPtr & a)86 inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
87
88 /// \brief BaseRef is a base class which store a Base pointer to some real data.
89 class MS_CORE_API BaseRef : public Base {
90 public:
91 /// \brief The Constructor of BaseRef.
92 ///
93 /// \return The instance of BaseRef.
BaseRef()94 BaseRef() : m_ptr(nullptr) {}
95
96 /// \brief The copy constructor of BaseRef.
97 ///
98 /// \param[in] other Define another instance of BaseRef.
99 ///
100 /// \return The instance of BaseRef.
101 BaseRef(const BaseRef &other);
102
103 /// \brief Get the Base pointer to some real data.
104 ///
105 /// \return The Base pointer.
copy()106 virtual std::shared_ptr<Base> copy() const { return m_ptr; }
107
108 /// \brief The move constructor of BaseRef.
109 ///
110 /// \param[in] other Define another instance of BaseRef.
111 ///
112 /// \return The instance of BaseRef.
BaseRef(BaseRef && other)113 BaseRef(BaseRef &&other) : Base(other) {
114 m_ptr = other.m_ptr;
115 other.m_ptr = nullptr;
116 }
117
118 /// \brief The move constructor of BaseRef with template.
119 ///
120 /// \param[in] t Define an instance of T.
121 ///
122 /// \return The instance of BaseRef.
123 template <class T,
124 class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
BaseRef(T && t)125 BaseRef(T &&t) { // NOLINT
126 m_ptr = MakeNode(t);
127 }
128
129 /// \brief The destructor of BaseRef.
~BaseRef()130 ~BaseRef() override { m_ptr = nullptr; }
131
132 MS_DECLARE_PARENT(BaseRef, Base)
133
134 /// \brief The operator overloading for "!=".
135 ///
136 /// \param[in] other Define the right operand of "!=".
137 ///
138 /// \return The comparison result.
139 bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
140
141 /// \brief The operator overloading for "==".
142 ///
143 /// \param[in] other Define the right operand of "==".
144 ///
145 /// \return The comparison result.
146 virtual bool operator==(const BaseRef &other) const;
147
148 /// \brief The copy assignment operator of BaseRef.
149 ///
150 /// \param[in] other Define another instance of BaseRef.
151 ///
152 /// \return The instance of BaseRef.
153 BaseRef &operator=(const BaseRef &other);
154
155 /// \brief The move assignment operator of BaseRef.
156 ///
157 /// \param[in] other Define another instance of BaseRef.
158 ///
159 /// \return The instance of BaseRef.
160 virtual BaseRef &operator=(BaseRef &&other);
161
hash()162 std::size_t hash() const override {
163 if (m_ptr == nullptr) {
164 MS_LOG(ERROR) << "Invalid m_ptr";
165 return 0;
166 }
167 return m_ptr->hash();
168 }
169
170 std::string ToString() const override;
171
172 /// \brief Judge whether the real data is null.
173 ///
174 /// \return The result of the judgment.
is_null()175 bool is_null() const { return m_ptr == nullptr; }
176
177 /// \brief Get the type id of the real data.
178 ///
179 /// \return The type id of the real data.
180 virtual uint32_t type() const;
181
182 BasePtr m_ptr; /**< pointer to the real data */
183 };
184 using BaseRefPtr = std::shared_ptr<BaseRef>;
185
186 struct BaseRefHash {
operatorBaseRefHash187 std::size_t operator()(const BaseRef &c) const { return c.hash(); }
188 };
189
190 struct BaseRefLess {
operatorBaseRefLess191 bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
192 };
193
194 namespace utils {
195 // judge isa relation
196 // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
197 template <typename T,
198 typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)199 bool isa(const BaseRef &handle) {
200 if (!handle.m_ptr) {
201 return false;
202 }
203 return handle.m_ptr->isa<T>();
204 }
205
206 // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
207 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
208 typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)209 bool isa(const BaseRef &handle) {
210 if (handle.m_ptr == nullptr) {
211 return typeid(handle.m_ptr) == typeid(T);
212 }
213
214 if (handle.m_ptr->isa<U>()) {
215 return true;
216 }
217
218 // constptr isa<anfnodeptr> can be true
219 return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
220 }
221
222 // isa<int32>(handle)
223 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
isa(const BaseRef & handle)224 bool isa(const BaseRef &handle) {
225 if (handle.m_ptr == nullptr) {
226 return false;
227 }
228 return handle.m_ptr->isa<U>();
229 }
230
231 // isa<BaseRef>(handle), judge reference or ptr
232 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
isa(const BaseRef & handle)233 bool isa(const BaseRef &handle) {
234 return handle.isa<T>() || (handle.m_ptr && handle.m_ptr->isa<T>());
235 }
236
237 // valueref -> C++ type
238 // cast<int64_t>(handle)
239 template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int64_t>::type =
240 static_cast<int64_t>(0)>
cast(const BaseRef & handle)241 T cast(const BaseRef &handle) {
242 T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
243 return std::move(ret);
244 }
245
246 // valueref -> valueref type
247 // cast<VectorRef>(handle)
248 template <typename T, typename std::enable_if<is_base_ref<T>::value, int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)249 const T &cast(const BaseRef &handle) {
250 if (handle.m_ptr) {
251 return static_cast<const T &>(*handle.m_ptr);
252 }
253
254 return static_cast<const T &>(handle);
255 }
256
257 // valueref -> nodeptr type
258 // cast<FuncGraphPtr>(handle)
259 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
260 typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
261 int64_t>::type = static_cast<int64_t>(0)>
cast(const BaseRef & handle)262 T cast(const BaseRef &handle) {
263 if (!handle.m_ptr) {
264 MS_LOG(INTERNAL_EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
265 }
266
267 auto m = handle.m_ptr->cast<T>();
268 if (nullptr != m) {
269 return m;
270 }
271 return std::static_pointer_cast<U>(handle.m_ptr);
272 }
273 } // namespace utils
274
275 class MS_CORE_API VectorRef : public BaseRef {
276 public:
277 using value_type = BaseRef;
278
VectorRef()279 VectorRef() {}
VectorRef(const std::vector<BaseRef> & elements)280 explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
VectorRef(const const_iterator & begin,const const_iterator & end)281 VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
282
283 // left reference
284 VectorRef(const VectorRef &other);
285 VectorRef &operator=(const VectorRef &other);
286
287 ~VectorRef() override = default;
288
copy()289 std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
290
empty()291 bool empty() const { return (elements_.size() == 0); }
292
size()293 std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(VectorRef,BaseRef)294 MS_DECLARE_PARENT(VectorRef, BaseRef)
295
296 const BaseRef &operator[](const std::size_t &dim) const {
297 if (dim >= size()) {
298 MS_LOG(INTERNAL_EXCEPTION) << "Out of the size of the tuple.";
299 }
300 return elements_[dim];
301 }
302
303 BaseRef &operator[](const std::size_t &dim) {
304 if (dim >= size()) {
305 MS_LOG(INTERNAL_EXCEPTION) << "Out of the size of the tuple.";
306 }
307 return elements_[dim];
308 }
309
type()310 uint32_t type() const override { return tid(); }
311 std::string ToString() const override;
elements()312 std::vector<BaseRef> &elements() { return elements_; }
clear()313 void clear() { elements_.clear(); }
314
315 bool operator==(const BaseRef &other) const override;
316 bool operator==(const VectorRef &other) const;
317
push_back(const BaseRef & value)318 void push_back(const BaseRef &value) { elements_.push_back(value); }
push_back(BaseRef && value)319 void push_back(BaseRef &&value) { elements_.push_back(value); }
320
emplace_back(const BaseRef & value)321 void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
emplace_back(BaseRef && value)322 void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
323
324 template <class InputIt>
insert(const iterator pos,const InputIt first,const InputIt last)325 void insert(const iterator pos, const InputIt first, const InputIt last) {
326 (void)elements_.insert(pos, first, last);
327 }
328
329 template <class InputIt>
insert(const const_iterator cpos,const InputIt first,const InputIt last)330 void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
331 auto pos = ConstIteratorCast(&elements_, cpos);
332 (void)elements_.insert(pos, first, last);
333 }
334
begin()335 const_iterator begin() const { return elements_.begin(); }
end()336 const_iterator end() const { return elements_.end(); }
337
rbegin()338 const_reverse_iterator rbegin() const { return elements_.rbegin(); }
rend()339 const_reverse_iterator rend() const { return elements_.rend(); }
340
erase(const const_iterator cpos)341 iterator erase(const const_iterator cpos) {
342 auto pos = ConstIteratorCast(&elements_, cpos);
343 return elements_.erase(pos);
344 }
345
erase(const const_iterator cfirst,const const_iterator clast)346 iterator erase(const const_iterator cfirst, const const_iterator clast) {
347 auto first = ConstIteratorCast(&elements_, cfirst);
348 auto last = ConstIteratorCast(&elements_, clast);
349 return elements_.erase(first, last);
350 }
351
hash()352 std::size_t hash() const override {
353 std::stringstream buffer;
354 buffer << ToString();
355 return std::hash<std::string>()(buffer.str());
356 }
357
358 std::vector<BaseRef> elements_;
359 };
360
361 using VectorRefPtr = std::shared_ptr<VectorRef>;
362
363 using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
364 using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
365
366 struct VectorRefHash {
operatorVectorRefHash367 std::size_t operator()(const VectorRef &c) const { return c.hash(); }
368 };
369
370 class MS_CORE_API SetRef : public BaseRef {
371 public:
SetRef()372 SetRef() {}
SetRef(const std::set<BaseRef,BaseRefLess> & elements)373 explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
SetRef(const std::initializer_list<BaseRef> elements)374 SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
SetRef(const const_set_iterator & begin,const const_set_iterator & end)375 SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
376
377 // left reference
378 SetRef(const SetRef &other);
379 SetRef &operator=(const SetRef &other);
380
381 bool operator==(const BaseRef &other) const override;
382 bool operator==(const SetRef &other) const;
383
384 ~SetRef() override = default;
385
copy()386 std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
387
empty()388 bool empty() const { return (elements_.size() == 0); }
389
size()390 std::size_t size() const { return elements_.size(); }
MS_DECLARE_PARENT(SetRef,BaseRef)391 MS_DECLARE_PARENT(SetRef, BaseRef)
392
393 uint32_t type() const override { return tid(); }
394 std::string ToString() const override;
elements()395 std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
clear()396 void clear() { elements_.clear(); }
397
insert(const BaseRef & elem)398 void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
399
begin()400 const_set_iterator begin() const { return elements_.begin(); }
end()401 const_set_iterator end() const { return elements_.end(); }
402
403 template <class InputIt>
insert(const InputIt first,const InputIt last)404 void insert(const InputIt first, const InputIt last) {
405 (void)elements_.insert(first, last);
406 }
407
count(const BaseRef & elem)408 std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
find(const BaseRef & elem)409 const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
410
411 std::set<BaseRef, BaseRefLess> elements_;
412 };
413
414 using SetRefPtr = std::shared_ptr<SetRef>;
415
416 class MS_CORE_API RunFunctionRef : public BaseRef {
417 public:
RunFunctionRef()418 RunFunctionRef() {}
RunFunctionRef(const RunFuncPtr & ref_func)419 explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
420
421 ~RunFunctionRef() override = default;
MS_DECLARE_PARENT(RunFunctionRef,BaseRef)422 MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
423
424 uint32_t type() const override { return tid(); }
ToString()425 std::string ToString() const override { return std::string("RunFunctionRef"); }
426 bool operator==(const BaseRef &other) const override;
427 bool operator==(const RunFunctionRef &other) const;
428
429 RunFuncPtr func_;
430 };
431 } // namespace mindspore
432
433 #endif // MINDSPORE_CORE_UTILS_BASE_REF_H_
434