• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_PTR_LIST_REF_H
17 #define MINDSPORE_PI_JIT_PTR_LIST_REF_H
18 #include <iterator>
19 #include "utils/log_adapter.h"
20 
21 namespace mindspore {
22 namespace pijit {
23 template <typename T>
24 class PtrListNodeBase {
25  public:
26   PtrListNodeBase() = default;
27   ~PtrListNodeBase() = default;
GetPrev()28   T *GetPrev() const { return prev; }
29 
GetNext()30   T *GetNext() const { return next; }
31 
SetPrev(T * ptr)32   void SetPrev(T *ptr) { prev = ptr; }
33 
SetNext(T * ptr)34   void SetNext(T *ptr) { next = ptr; }
35 
36  private:
37   T *prev = nullptr;
38   T *next = nullptr;
39 };
40 
41 // wrap iterator to run it backwards
42 template <typename T>
43 class ReversePtrListRefIterator {
44  public:
45   using iterator_category = typename std::iterator_traits<T>::iterator_category;
46   using value_type = typename std::iterator_traits<T>::value_type;
47   using difference_type = typename std::iterator_traits<T>::difference_type;
48   using pointer = typename std::iterator_traits<T>::pointer;
49   using reference = typename std::iterator_traits<T>::reference;
50 
51   using iterator_type = T;
52 
ReversePtrListRefIterator()53   ReversePtrListRefIterator() : current() {}
54 
ReversePtrListRefIterator(T right)55   explicit ReversePtrListRefIterator(T right) : current(right) {}
56 
57   template <class Other>
ReversePtrListRefIterator(const ReversePtrListRefIterator<Other> & right)58   ReversePtrListRefIterator(const ReversePtrListRefIterator<Other> &right) : current(right.base()) {}
59 
60   template <class Other>
61   ReversePtrListRefIterator &operator=(const ReversePtrListRefIterator<Other> &right) {
62     current = right.base();
63     return (*this);
64   }
65 
66   ~ReversePtrListRefIterator() = default;
67 
base()68   T base() const { return current; }
69 
70   reference operator*() const { return *current; }
71 
72   pointer operator->() const { return &(operator*()); }
73 
74   ReversePtrListRefIterator &operator++() {
75     --current;
76     return (*this);
77   }
78 
79   ReversePtrListRefIterator operator++(int) {
80     ReversePtrListRefIterator tmp = *this;
81     --current;
82     return (tmp);
83   }
84 
85   ReversePtrListRefIterator &operator--() {
86     ++current;
87     return (*this);
88   }
89 
90   ReversePtrListRefIterator operator--(int) {
91     ReversePtrListRefIterator tmp = *this;
92     ++current;
93     return (tmp);
94   }
95 
96   bool operator==(const ReversePtrListRefIterator &Iterator) const { return this->base() == Iterator.base(); }
97 
98   bool operator!=(const ReversePtrListRefIterator &Iterator) const { return !(*this == Iterator); }
99 
100  protected:
101   T current;
102 };
103 
104 template <typename T>
105 class PtrListRefIterator {
106  public:
107   using iterator_category = std::bidirectional_iterator_tag;
108   using value_type = T;
109   using difference_type = std::ptrdiff_t;
110   using pointer = T *;
111   using reference = T &;
112   using const_pointer = const T *;
113   using const_reference = const T &;
114 
115   PtrListRefIterator() = default;
116 
PtrListRefIterator(pointer list_ref_iterator_ptr)117   explicit PtrListRefIterator(pointer list_ref_iterator_ptr) : ptr(list_ref_iterator_ptr) {}
118 
119   template <typename U, typename = std::enable_if_t<std::is_same<U, std::remove_const_t<T>>::value>>
PtrListRefIterator(const PtrListRefIterator<U> & ListIter_)120   PtrListRefIterator(const PtrListRefIterator<U> &ListIter_) : ptr(ListIter_.d()) {}
121 
122   ~PtrListRefIterator() = default;
123 
d()124   pointer d() const { return ptr; }
125 
126   reference operator*() const { return *ptr; }
127 
128   pointer operator->() const { return ptr; }
129 
130   PtrListRefIterator &operator++() {
131     this->ptr = this->ptr->GetNext();
132     return *this;
133   }
134 
135   PtrListRefIterator &operator--() {
136     this->ptr = this->ptr->GetPrev();
137     return *this;
138   }
139 
140   PtrListRefIterator operator++(int) {
141     PtrListRefIterator it = *this;
142     ++(*this);
143     return it;
144   }
145 
146   PtrListRefIterator operator--(int) {
147     PtrListRefIterator it = *this;
148     --(*this);
149     return it;
150   }
151 
152   bool operator==(const PtrListRefIterator &Iterator) const { return this->ptr == Iterator.ptr; }
153 
154   bool operator!=(const PtrListRefIterator &Iterator) const { return !(*this == Iterator); }
155 
156  private:
157   pointer ptr = nullptr;
158 };
159 
160 template <typename T>
161 class PtrListRef {
162  public:
163   using value_type = T;
164   using size_type = size_t;
165   using difference_type = std::ptrdiff_t;
166   using pointer = T *;
167   using const_pointer = const T *;
168   using reference = T &;
169   using const_reference = const T &;
170 
171   using iterator = PtrListRefIterator<T>;
172   using const_iterator = PtrListRefIterator<const T>;
173   using reverse_iterator = ReversePtrListRefIterator<iterator>;
174   using const_reverse_iterator = ReversePtrListRefIterator<const_iterator>;
175 
176   PtrListRef() = default;
PtrListRef(pointer list_value)177   explicit PtrListRef(pointer list_value) : first(list_value), last(list_value) {}
178 
PtrListRef(pointer FirstList_,pointer LastList_)179   PtrListRef(pointer FirstList_, pointer LastList_)
180       : first(FirstList_), last(LastList_ == nullptr ? FirstList_ : LastList_) {}
181 
182   ~PtrListRef() = default;
183 
begin()184   iterator begin() { return iterator(this->first); }
185 
begin()186   const_iterator begin() const { return const_iterator(this->first); }
187 
cbegin()188   const_iterator cbegin() const { return const_iterator(this->first); }
189 
end()190   iterator end() { return iterator(this->last == nullptr ? nullptr : this->last->GetNext()); }
191 
end()192   const_iterator end() const { return const_iterator(this->last == nullptr ? nullptr : this->last->GetNext()); }
193 
cend()194   const_iterator cend() const { return const_iterator(this->last == nullptr ? nullptr : this->last->GetNext()); }
195 
rbegin()196   reverse_iterator rbegin() { return reverse_iterator(iterator(this->last)); }
197 
rbegin()198   const_reverse_iterator rbegin() const { return const_reverse_iterator(const_iterator(this->last)); }
199 
crbegin()200   const_reverse_iterator crbegin() const { return const_reverse_iterator(const_iterator(this->last)); }
201 
rend()202   reverse_iterator rend() {
203     return reverse_iterator(iterator(this->first == nullptr ? nullptr : this->first->GetPrev()));
204   }
205 
rend()206   const_reverse_iterator rend() const {
207     return const_reverse_iterator(const_iterator(this->first == nullptr ? nullptr : this->first->GetPrev()));
208   }
209 
crend()210   const_reverse_iterator crend() const {
211     return const_reverse_iterator(const_iterator(this->first == nullptr ? nullptr : this->first->GetPrev()));
212   }
213 
front()214   reference front() { return *(this->first); }
215 
back()216   reference back() { return *(this->last); }
217 
front()218   const_reference front() const { return *(this->first); }
219 
back()220   const_reference back() const { return *(this->last); }
221 
empty()222   bool empty() const { return first == nullptr; }
223 
update_front(pointer list_value)224   void update_front(pointer list_value) {
225     if (list_value != nullptr) {
226       list_value->SetPrev(nullptr);
227     }
228     this->first = list_value;
229   }
230 
push_front(pointer list_value)231   void push_front(pointer list_value) {
232     if (this->last == nullptr) {
233       this->first = list_value;
234       this->last = list_value;
235       list_value->SetPrev(nullptr);
236       list_value->SetNext(nullptr);
237     } else {
238       MS_ASSERT(this->first != nullptr);
239       this->first->SetPrev(list_value);
240       list_value->SetPrev(nullptr);
241       list_value->SetNext(this->first);
242       this->first = list_value;
243     }
244   }
245 
pop_front()246   void pop_front() {
247     if (this->first == nullptr) {
248       return;
249     }
250 
251     this->first = this->first->GetNext();
252     if (this->first != nullptr) {
253       this->first->SetPrev(nullptr);
254     }
255   }
256 
update_back(pointer list_value)257   void update_back(pointer list_value) {
258     if (list_value != nullptr) {
259       list_value->SetNext(nullptr);
260     }
261     this->last = list_value;
262   }
263 
push_back(pointer list_value)264   void push_back(pointer list_value) {
265     if (this->last == nullptr) {
266       this->first = list_value;
267       this->last = list_value;
268       list_value->SetPrev(nullptr);
269     } else {
270       this->last->SetNext(list_value);
271       list_value->SetPrev(this->last);
272       this->last = list_value;
273     }
274     list_value->SetNext(nullptr);
275   }
276 
pop_back()277   void pop_back() {
278     if (this->last == nullptr) {
279       return;
280     }
281 
282     if (this->last->GetPrev() == nullptr) {
283       this->first = nullptr;
284       this->last = nullptr;
285     } else {
286       this->last = this->last->GetPrev();
287       this->last->SetNext(nullptr);
288     }
289   }
290 
insert(const_iterator list_where,pointer list_value)291   void insert(const_iterator list_where, pointer list_value) {
292     if (list_where == const_iterator(this->first)) {
293       this->push_front(list_value);
294     } else if (list_where == this->cend()) {
295       this->push_back(list_value);
296     } else {
297       // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed.
298       auto *ptr = const_cast<T *>(&*list_where);
299       list_value->SetPrev(ptr->GetPrev());
300       list_value->SetNext(ptr);
301       list_value->GetPrev()->SetNext(list_value);
302       ptr->SetPrev(list_value);
303     }
304   }
305 
insert(const_pointer list_where,pointer list_value)306   void insert(const_pointer list_where, pointer list_value) { this->insert(const_iterator(list_where), list_value); }
307 
308   // cut list two half, list_where is head of second half
CutList(pointer list_where)309   PtrListRef CutList(pointer list_where) {
310     MS_ASSERT(!list_where || list_where == this->first || this->first == this->last);
311     PtrListRef other = {const_cast<T *>(list_where), this->last};
312     this->last = list_where->GetPrev();
313     other.front().SetPrev(nullptr);
314     this->last->SetNext(nullptr);
315     return other;
316   }
317 
CutList(iterator list_where)318   PtrListRef CutList(iterator list_where) { return CutList(*list_where); }
319 
insertAfter(const_iterator list_where,pointer list_value)320   void insertAfter(const_iterator list_where, pointer list_value) {
321     if (list_where == const_iterator(nullptr)) {
322       this->push_front(list_value);
323     } else if (list_where == const_iterator(this->last)) {
324       this->push_back(list_value);
325     } else {
326       // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed.
327       auto *ptr = const_cast<T *>(&*list_where);
328       list_value->SetPrev(ptr);
329       list_value->SetNext(ptr->GetNext());
330       list_value->GetNext()->SetPrev(list_value);
331       ptr->SetNext(list_value);
332     }
333   }
334 
insertAfter(const_pointer list_where,pointer list_value)335   void insertAfter(const_pointer list_where, pointer list_value) {
336     this->insertAfter(const_iterator(list_where), list_value);
337   }
338 
339   // clear other
splice(const_iterator list_where,PtrListRef * other)340   void splice(const_iterator list_where, PtrListRef *other) {
341     if (other->empty()) {
342       return;
343     }
344     MS_ASSERT(other->first && !other->first->GetPrev() && other->last && !other->last->GetNext());
345     if (empty()) {
346       this->first = other->first;
347       this->last = other->last;
348       other->clear();
349       return;
350     }
351     if (list_where == this->end()) {
352       this->last->SetNext(other->first);
353       other->first->SetPrev(this->first);
354       this->last = other->last;
355       other->clear();
356       return;
357     }
358     auto *ptr = const_cast<T *>(&*list_where);
359     if (list_where == this->begin()) {
360       this->first = other->first;
361     } else {
362       list_where->GetPrev()->SetNext(other->first);
363       other->first->SetPrev(list_where->GetPrev());
364     }
365     ptr->SetPrev(other->last);
366     other->last->SetNext(ptr);
367     other->clear();
368   }
369 
splice(const_pointer list_where,PtrListRef * other)370   void splice(const_pointer list_where, PtrListRef *other) { listSplice(const_iterator(list_where), other); }
371 
clear()372   void clear() {
373     this->first = nullptr;
374     this->last = nullptr;
375   }
376 
erase(const_iterator list_where)377   iterator erase(const_iterator list_where) {
378     if (list_where == this->cbegin() && list_where == this->rbegin().base()) {
379       this->first = nullptr;
380       this->last = nullptr;
381     } else if (list_where == this->cbegin()) {
382       // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed.
383       auto *ptr = const_cast<T *>(&*list_where);
384       this->first = ptr->GetNext();
385       MS_ASSERT(this->first != nullptr);
386       this->first->SetPrev(nullptr);
387     } else if (list_where == this->rbegin().base()) {
388       pop_back();
389     } else {
390       MS_ASSERT(list_where->GetPrev() != nullptr);
391       // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed.
392       auto *ptr = const_cast<T *>(&*list_where);
393       ptr->GetPrev()->SetNext(ptr->GetNext());
394       if (ptr->GetNext()) {
395         ptr->GetNext()->SetPrev(ptr->GetPrev());
396       }
397     }
398     return iterator(nullptr);
399   }
400 
erase(const_pointer list_where)401   iterator erase(const_pointer list_where) { return this->erase(const_iterator(list_where)); }
402 
set_first(T * f)403   void set_first(T *f) { this->first = f; }
404 
set_last(T * f)405   void set_last(T *f) { this->last = f; }
406 
407  private:
408   T *first = nullptr;
409   T *last = nullptr;
410 };
411 
412 template <typename Iterator>
413 auto to_ptr(Iterator it) -> typename std::iterator_traits<Iterator>::pointer {
414   return it.d();
415 }
416 
417 template <typename Iterator>
418 auto to_ptr(ReversePtrListRefIterator<Iterator> it) -> typename std::iterator_traits<Iterator>::pointer {
419   return it.base().d();
420 }
421 }  // namespace pijit
422 }  // namespace mindspore
423 #endif  // MINDSPORE_PI_JIT_PTR_LIST_REF_H
424