1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_PTR_LIST_REF_H 17 #define MINDSPORE_PI_JIT_PTR_LIST_REF_H 18 #include <iterator> 19 #include "utils/log_adapter.h" 20 21 namespace mindspore { 22 namespace pijit { 23 template <typename T> 24 class PtrListNodeBase { 25 public: 26 PtrListNodeBase() = default; 27 ~PtrListNodeBase() = default; GetPrev()28 T *GetPrev() const { return prev; } 29 GetNext()30 T *GetNext() const { return next; } 31 SetPrev(T * ptr)32 void SetPrev(T *ptr) { prev = ptr; } 33 SetNext(T * ptr)34 void SetNext(T *ptr) { next = ptr; } 35 36 private: 37 T *prev = nullptr; 38 T *next = nullptr; 39 }; 40 41 // wrap iterator to run it backwards 42 template <typename T> 43 class ReversePtrListRefIterator { 44 public: 45 using iterator_category = typename std::iterator_traits<T>::iterator_category; 46 using value_type = typename std::iterator_traits<T>::value_type; 47 using difference_type = typename std::iterator_traits<T>::difference_type; 48 using pointer = typename std::iterator_traits<T>::pointer; 49 using reference = typename std::iterator_traits<T>::reference; 50 51 using iterator_type = T; 52 ReversePtrListRefIterator()53 ReversePtrListRefIterator() : current() {} 54 ReversePtrListRefIterator(T right)55 explicit ReversePtrListRefIterator(T right) : current(right) {} 56 57 template <class Other> ReversePtrListRefIterator(const ReversePtrListRefIterator<Other> & right)58 ReversePtrListRefIterator(const ReversePtrListRefIterator<Other> &right) : current(right.base()) {} 59 60 template <class Other> 61 ReversePtrListRefIterator &operator=(const ReversePtrListRefIterator<Other> &right) { 62 current = right.base(); 63 return (*this); 64 } 65 66 ~ReversePtrListRefIterator() = default; 67 base()68 T base() const { return current; } 69 70 reference operator*() const { return *current; } 71 72 pointer operator->() const { return &(operator*()); } 73 74 ReversePtrListRefIterator &operator++() { 75 --current; 76 return (*this); 77 } 78 79 ReversePtrListRefIterator operator++(int) { 80 ReversePtrListRefIterator tmp = *this; 81 --current; 82 return (tmp); 83 } 84 85 ReversePtrListRefIterator &operator--() { 86 ++current; 87 return (*this); 88 } 89 90 ReversePtrListRefIterator operator--(int) { 91 ReversePtrListRefIterator tmp = *this; 92 ++current; 93 return (tmp); 94 } 95 96 bool operator==(const ReversePtrListRefIterator &Iterator) const { return this->base() == Iterator.base(); } 97 98 bool operator!=(const ReversePtrListRefIterator &Iterator) const { return !(*this == Iterator); } 99 100 protected: 101 T current; 102 }; 103 104 template <typename T> 105 class PtrListRefIterator { 106 public: 107 using iterator_category = std::bidirectional_iterator_tag; 108 using value_type = T; 109 using difference_type = std::ptrdiff_t; 110 using pointer = T *; 111 using reference = T &; 112 using const_pointer = const T *; 113 using const_reference = const T &; 114 115 PtrListRefIterator() = default; 116 PtrListRefIterator(pointer list_ref_iterator_ptr)117 explicit PtrListRefIterator(pointer list_ref_iterator_ptr) : ptr(list_ref_iterator_ptr) {} 118 119 template <typename U, typename = std::enable_if_t<std::is_same<U, std::remove_const_t<T>>::value>> PtrListRefIterator(const PtrListRefIterator<U> & ListIter_)120 PtrListRefIterator(const PtrListRefIterator<U> &ListIter_) : ptr(ListIter_.d()) {} 121 122 ~PtrListRefIterator() = default; 123 d()124 pointer d() const { return ptr; } 125 126 reference operator*() const { return *ptr; } 127 128 pointer operator->() const { return ptr; } 129 130 PtrListRefIterator &operator++() { 131 this->ptr = this->ptr->GetNext(); 132 return *this; 133 } 134 135 PtrListRefIterator &operator--() { 136 this->ptr = this->ptr->GetPrev(); 137 return *this; 138 } 139 140 PtrListRefIterator operator++(int) { 141 PtrListRefIterator it = *this; 142 ++(*this); 143 return it; 144 } 145 146 PtrListRefIterator operator--(int) { 147 PtrListRefIterator it = *this; 148 --(*this); 149 return it; 150 } 151 152 bool operator==(const PtrListRefIterator &Iterator) const { return this->ptr == Iterator.ptr; } 153 154 bool operator!=(const PtrListRefIterator &Iterator) const { return !(*this == Iterator); } 155 156 private: 157 pointer ptr = nullptr; 158 }; 159 160 template <typename T> 161 class PtrListRef { 162 public: 163 using value_type = T; 164 using size_type = size_t; 165 using difference_type = std::ptrdiff_t; 166 using pointer = T *; 167 using const_pointer = const T *; 168 using reference = T &; 169 using const_reference = const T &; 170 171 using iterator = PtrListRefIterator<T>; 172 using const_iterator = PtrListRefIterator<const T>; 173 using reverse_iterator = ReversePtrListRefIterator<iterator>; 174 using const_reverse_iterator = ReversePtrListRefIterator<const_iterator>; 175 176 PtrListRef() = default; PtrListRef(pointer list_value)177 explicit PtrListRef(pointer list_value) : first(list_value), last(list_value) {} 178 PtrListRef(pointer FirstList_,pointer LastList_)179 PtrListRef(pointer FirstList_, pointer LastList_) 180 : first(FirstList_), last(LastList_ == nullptr ? FirstList_ : LastList_) {} 181 182 ~PtrListRef() = default; 183 begin()184 iterator begin() { return iterator(this->first); } 185 begin()186 const_iterator begin() const { return const_iterator(this->first); } 187 cbegin()188 const_iterator cbegin() const { return const_iterator(this->first); } 189 end()190 iterator end() { return iterator(this->last == nullptr ? nullptr : this->last->GetNext()); } 191 end()192 const_iterator end() const { return const_iterator(this->last == nullptr ? nullptr : this->last->GetNext()); } 193 cend()194 const_iterator cend() const { return const_iterator(this->last == nullptr ? nullptr : this->last->GetNext()); } 195 rbegin()196 reverse_iterator rbegin() { return reverse_iterator(iterator(this->last)); } 197 rbegin()198 const_reverse_iterator rbegin() const { return const_reverse_iterator(const_iterator(this->last)); } 199 crbegin()200 const_reverse_iterator crbegin() const { return const_reverse_iterator(const_iterator(this->last)); } 201 rend()202 reverse_iterator rend() { 203 return reverse_iterator(iterator(this->first == nullptr ? nullptr : this->first->GetPrev())); 204 } 205 rend()206 const_reverse_iterator rend() const { 207 return const_reverse_iterator(const_iterator(this->first == nullptr ? nullptr : this->first->GetPrev())); 208 } 209 crend()210 const_reverse_iterator crend() const { 211 return const_reverse_iterator(const_iterator(this->first == nullptr ? nullptr : this->first->GetPrev())); 212 } 213 front()214 reference front() { return *(this->first); } 215 back()216 reference back() { return *(this->last); } 217 front()218 const_reference front() const { return *(this->first); } 219 back()220 const_reference back() const { return *(this->last); } 221 empty()222 bool empty() const { return first == nullptr; } 223 update_front(pointer list_value)224 void update_front(pointer list_value) { 225 if (list_value != nullptr) { 226 list_value->SetPrev(nullptr); 227 } 228 this->first = list_value; 229 } 230 push_front(pointer list_value)231 void push_front(pointer list_value) { 232 if (this->last == nullptr) { 233 this->first = list_value; 234 this->last = list_value; 235 list_value->SetPrev(nullptr); 236 list_value->SetNext(nullptr); 237 } else { 238 MS_ASSERT(this->first != nullptr); 239 this->first->SetPrev(list_value); 240 list_value->SetPrev(nullptr); 241 list_value->SetNext(this->first); 242 this->first = list_value; 243 } 244 } 245 pop_front()246 void pop_front() { 247 if (this->first == nullptr) { 248 return; 249 } 250 251 this->first = this->first->GetNext(); 252 if (this->first != nullptr) { 253 this->first->SetPrev(nullptr); 254 } 255 } 256 update_back(pointer list_value)257 void update_back(pointer list_value) { 258 if (list_value != nullptr) { 259 list_value->SetNext(nullptr); 260 } 261 this->last = list_value; 262 } 263 push_back(pointer list_value)264 void push_back(pointer list_value) { 265 if (this->last == nullptr) { 266 this->first = list_value; 267 this->last = list_value; 268 list_value->SetPrev(nullptr); 269 } else { 270 this->last->SetNext(list_value); 271 list_value->SetPrev(this->last); 272 this->last = list_value; 273 } 274 list_value->SetNext(nullptr); 275 } 276 pop_back()277 void pop_back() { 278 if (this->last == nullptr) { 279 return; 280 } 281 282 if (this->last->GetPrev() == nullptr) { 283 this->first = nullptr; 284 this->last = nullptr; 285 } else { 286 this->last = this->last->GetPrev(); 287 this->last->SetNext(nullptr); 288 } 289 } 290 insert(const_iterator list_where,pointer list_value)291 void insert(const_iterator list_where, pointer list_value) { 292 if (list_where == const_iterator(this->first)) { 293 this->push_front(list_value); 294 } else if (list_where == this->cend()) { 295 this->push_back(list_value); 296 } else { 297 // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed. 298 auto *ptr = const_cast<T *>(&*list_where); 299 list_value->SetPrev(ptr->GetPrev()); 300 list_value->SetNext(ptr); 301 list_value->GetPrev()->SetNext(list_value); 302 ptr->SetPrev(list_value); 303 } 304 } 305 insert(const_pointer list_where,pointer list_value)306 void insert(const_pointer list_where, pointer list_value) { this->insert(const_iterator(list_where), list_value); } 307 308 // cut list two half, list_where is head of second half CutList(pointer list_where)309 PtrListRef CutList(pointer list_where) { 310 MS_ASSERT(!list_where || list_where == this->first || this->first == this->last); 311 PtrListRef other = {const_cast<T *>(list_where), this->last}; 312 this->last = list_where->GetPrev(); 313 other.front().SetPrev(nullptr); 314 this->last->SetNext(nullptr); 315 return other; 316 } 317 CutList(iterator list_where)318 PtrListRef CutList(iterator list_where) { return CutList(*list_where); } 319 insertAfter(const_iterator list_where,pointer list_value)320 void insertAfter(const_iterator list_where, pointer list_value) { 321 if (list_where == const_iterator(nullptr)) { 322 this->push_front(list_value); 323 } else if (list_where == const_iterator(this->last)) { 324 this->push_back(list_value); 325 } else { 326 // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed. 327 auto *ptr = const_cast<T *>(&*list_where); 328 list_value->SetPrev(ptr); 329 list_value->SetNext(ptr->GetNext()); 330 list_value->GetNext()->SetPrev(list_value); 331 ptr->SetNext(list_value); 332 } 333 } 334 insertAfter(const_pointer list_where,pointer list_value)335 void insertAfter(const_pointer list_where, pointer list_value) { 336 this->insertAfter(const_iterator(list_where), list_value); 337 } 338 339 // clear other splice(const_iterator list_where,PtrListRef * other)340 void splice(const_iterator list_where, PtrListRef *other) { 341 if (other->empty()) { 342 return; 343 } 344 MS_ASSERT(other->first && !other->first->GetPrev() && other->last && !other->last->GetNext()); 345 if (empty()) { 346 this->first = other->first; 347 this->last = other->last; 348 other->clear(); 349 return; 350 } 351 if (list_where == this->end()) { 352 this->last->SetNext(other->first); 353 other->first->SetPrev(this->first); 354 this->last = other->last; 355 other->clear(); 356 return; 357 } 358 auto *ptr = const_cast<T *>(&*list_where); 359 if (list_where == this->begin()) { 360 this->first = other->first; 361 } else { 362 list_where->GetPrev()->SetNext(other->first); 363 other->first->SetPrev(list_where->GetPrev()); 364 } 365 ptr->SetPrev(other->last); 366 other->last->SetNext(ptr); 367 other->clear(); 368 } 369 splice(const_pointer list_where,PtrListRef * other)370 void splice(const_pointer list_where, PtrListRef *other) { listSplice(const_iterator(list_where), other); } 371 clear()372 void clear() { 373 this->first = nullptr; 374 this->last = nullptr; 375 } 376 erase(const_iterator list_where)377 iterator erase(const_iterator list_where) { 378 if (list_where == this->cbegin() && list_where == this->rbegin().base()) { 379 this->first = nullptr; 380 this->last = nullptr; 381 } else if (list_where == this->cbegin()) { 382 // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed. 383 auto *ptr = const_cast<T *>(&*list_where); 384 this->first = ptr->GetNext(); 385 MS_ASSERT(this->first != nullptr); 386 this->first->SetPrev(nullptr); 387 } else if (list_where == this->rbegin().base()) { 388 pop_back(); 389 } else { 390 MS_ASSERT(list_where->GetPrev() != nullptr); 391 // `list_where` stands for the position, however we made the data and node combined, so a const_cast is needed. 392 auto *ptr = const_cast<T *>(&*list_where); 393 ptr->GetPrev()->SetNext(ptr->GetNext()); 394 if (ptr->GetNext()) { 395 ptr->GetNext()->SetPrev(ptr->GetPrev()); 396 } 397 } 398 return iterator(nullptr); 399 } 400 erase(const_pointer list_where)401 iterator erase(const_pointer list_where) { return this->erase(const_iterator(list_where)); } 402 set_first(T * f)403 void set_first(T *f) { this->first = f; } 404 set_last(T * f)405 void set_last(T *f) { this->last = f; } 406 407 private: 408 T *first = nullptr; 409 T *last = nullptr; 410 }; 411 412 template <typename Iterator> 413 auto to_ptr(Iterator it) -> typename std::iterator_traits<Iterator>::pointer { 414 return it.d(); 415 } 416 417 template <typename Iterator> 418 auto to_ptr(ReversePtrListRefIterator<Iterator> it) -> typename std::iterator_traits<Iterator>::pointer { 419 return it.base().d(); 420 } 421 } // namespace pijit 422 } // namespace mindspore 423 #endif // MINDSPORE_PI_JIT_PTR_LIST_REF_H 424