1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ 18 #define MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ 19 20 #include <list> 21 #include <unordered_map> 22 #include <algorithm> 23 #include <functional> 24 #include <utility> 25 #include <memory> 26 #include "utils/hashing.h" 27 28 namespace mindspore { 29 // Implementation of OrderedMap that keeps insertion order 30 // using unordered_map to improve the performance of find/erase, and use list to keep insertion order 31 template <typename KeyT, typename ValueT, class Hash = std::hash<KeyT>, class Equal = std::equal_to<KeyT>> 32 class OrderedMap { 33 using key_ptr_t = const KeyT *; 34 struct KeyPtrHash { operatorKeyPtrHash35 std::size_t operator()(key_ptr_t ptr) const noexcept { return Hash{}(*ptr); } 36 }; 37 struct KeyPtrEqual { operatorKeyPtrEqual38 bool operator()(key_ptr_t lhs, key_ptr_t rhs) const noexcept { return Equal{}(*lhs, *rhs); } 39 }; 40 41 public: 42 using key_t = KeyT; 43 using value_t = ValueT; 44 using pair_type = std::pair<key_t, value_t>; 45 using sequential_type = std::list<pair_type>; 46 using iterator = typename sequential_type::iterator; 47 using const_iterator = typename sequential_type::const_iterator; 48 using reverse_iterator = typename sequential_type::reverse_iterator; 49 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 50 using map_type = std::unordered_map<key_ptr_t, iterator, KeyPtrHash, KeyPtrEqual>; 51 using value_type = typename sequential_type::value_type; 52 using size_type = typename sequential_type::size_type; 53 begin()54 iterator begin() { return sequential_data_.begin(); } end()55 iterator end() { return sequential_data_.end(); } begin()56 const_iterator begin() const { return sequential_data_.cbegin(); } end()57 const_iterator end() const { return sequential_data_.cend(); } cbegin()58 const_iterator cbegin() const { return sequential_data_.cbegin(); } cend()59 const_iterator cend() const { return sequential_data_.cend(); } 60 rbegin()61 reverse_iterator rbegin() { return sequential_data_.rbegin(); } rend()62 reverse_iterator rend() { return sequential_data_.rend(); } rbegin()63 const_reverse_iterator rbegin() const { return sequential_data_.rbegin(); } rend()64 const_reverse_iterator rend() const { return sequential_data_.rend(); } 65 front()66 pair_type &front() { return sequential_data_.front(); } front()67 const pair_type &front() const { return sequential_data_.front(); } back()68 pair_type &back() { return sequential_data_.back(); } back()69 const pair_type &back() const { return sequential_data_.back(); } 70 71 OrderedMap() = default; 72 ~OrderedMap() = default; 73 74 OrderedMap(OrderedMap &&other) noexcept = default; 75 OrderedMap &operator=(OrderedMap &&other) noexcept = default; 76 OrderedMap(const sequential_type & other)77 explicit OrderedMap(const sequential_type &other) { 78 reserve(other.size()); 79 for (auto &item : other) { 80 (void)emplace(item.first, item.second); 81 } 82 } 83 OrderedMap(const OrderedMap & other)84 OrderedMap(const OrderedMap &other) : OrderedMap(other.sequential_data_) {} 85 86 OrderedMap &operator=(const OrderedMap &other) { 87 if (this != &other) { 88 clear(); 89 reserve(other.size()); 90 for (auto &item : other.sequential_data_) { 91 (void)emplace(item.first, item.second); 92 } 93 } 94 return *this; 95 } 96 clear()97 void clear() { 98 map_data_.clear(); 99 sequential_data_.clear(); 100 } 101 swap(OrderedMap & rhs)102 void swap(OrderedMap &rhs) noexcept { 103 std::swap(map_data_, rhs.map_data_); 104 std::swap(sequential_data_, rhs.sequential_data_); 105 } 106 reserve(size_type num_entries)107 void reserve(size_type num_entries) { map_data_.reserve(num_entries); } 108 109 template <typename... Args> emplace(Args &&...args)110 std::pair<iterator, bool> emplace(Args &&... args) { 111 auto new_iter = sequential_data_.emplace(sequential_data_.end(), std::forward<Args>(args)...); 112 auto [map_iter, inserted] = map_data_.emplace(&(new_iter->first), new_iter); 113 if (!inserted) { 114 sequential_data_.erase(new_iter); 115 } 116 return {map_iter->second, inserted}; 117 } 118 insert(const pair_type & kv)119 std::pair<iterator, bool> insert(const pair_type &kv) { 120 auto iter = map_data_.find(&(kv.first)); 121 if (iter != map_data_.end()) { 122 return {iter->second, false}; 123 } 124 auto new_iter = sequential_data_.emplace(sequential_data_.end(), kv); 125 auto result = map_data_.emplace(&(new_iter->first), new_iter); 126 return {result.first->second, true}; 127 } 128 insert(pair_type && kv)129 std::pair<iterator, bool> insert(pair_type &&kv) { 130 auto iter = map_data_.find(&(kv.first)); 131 if (iter != map_data_.end()) { 132 return {iter->second, false}; 133 } 134 auto new_iter = sequential_data_.emplace(sequential_data_.end(), std::move(kv)); 135 auto result = map_data_.emplace(&(new_iter->first), new_iter); 136 return {result.first->second, true}; 137 } 138 add(const key_t & key)139 std::pair<iterator, bool> add(const key_t &key) { return insert(pair_type{key, ValueT{}}); } 140 141 ValueT &operator[](const key_t &key) { 142 auto iter = map_data_.find(&key); 143 if (iter != map_data_.end()) { 144 return iter->second->second; 145 } 146 auto new_iter = sequential_data_.emplace(sequential_data_.end(), key, ValueT{}); 147 auto result = map_data_.emplace(&(new_iter->first), new_iter); 148 return result.first->second->second; 149 } 150 empty()151 bool empty() const { return sequential_data_.empty(); } 152 size()153 size_type size() const { return sequential_data_.size(); } 154 at(const key_t & key)155 const ValueT &at(const key_t &key) const { 156 auto &list_iter = map_data_.at(&key); 157 return list_iter->second; 158 } 159 count(const key_t & key)160 size_type count(const key_t &key) const { 161 auto pos = map_data_.find(&key); 162 return pos == map_data_.end() ? 0 : 1; 163 } 164 find(const key_t & key)165 iterator find(const key_t &key) { 166 auto pos = map_data_.find(&key); 167 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 168 } 169 find(const key_t & key)170 const_iterator find(const key_t &key) const { 171 auto pos = map_data_.find(&key); 172 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 173 } 174 175 // Remove the last element from the sequential_data_. pop_back()176 void pop_back() { 177 (void)map_data_.erase(&(sequential_data_.back().first)); 178 sequential_data_.pop_back(); 179 } 180 181 // Remove the first element from the sequential_data_. pop_front()182 void pop_front() { 183 (void)map_data_.erase(&(sequential_data_.front().first)); 184 sequential_data_.pop_front(); 185 } 186 187 // Remove the element given by Iterator. erase(iterator iter)188 iterator erase(iterator iter) { 189 (void)map_data_.erase(&(iter->first)); 190 return sequential_data_.erase(iter); 191 } 192 193 // Remove the element with the given key erase(const key_t & key)194 size_type erase(const key_t &key) { 195 auto itr = find(key); 196 if (itr == end()) { 197 return 0; 198 } 199 (void)erase(itr); 200 return 1; 201 } 202 203 private: 204 map_type map_data_; 205 sequential_type sequential_data_; 206 }; 207 208 // OrderedMap that specially optimized for shared_ptr key type. 209 template <typename T, typename ValueT> 210 class OrderedMap<std::shared_ptr<T>, ValueT> { 211 public: 212 using raw_key_t = const T *; 213 using key_t = std::shared_ptr<T>; 214 using hash_t = PointerHash<T>; 215 using value_t = ValueT; 216 using pair_type = std::pair<key_t, value_t>; 217 using sequential_type = std::list<pair_type>; 218 using iterator = typename sequential_type::iterator; 219 using const_iterator = typename sequential_type::const_iterator; 220 using reverse_iterator = typename sequential_type::reverse_iterator; 221 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 222 using map_type = std::unordered_map<raw_key_t, iterator, hash_t>; 223 using value_type = typename sequential_type::value_type; 224 using size_type = typename sequential_type::size_type; 225 begin()226 iterator begin() { return sequential_data_.begin(); } end()227 iterator end() { return sequential_data_.end(); } begin()228 const_iterator begin() const { return sequential_data_.cbegin(); } end()229 const_iterator end() const { return sequential_data_.cend(); } cbegin()230 const_iterator cbegin() const { return sequential_data_.cbegin(); } cend()231 const_iterator cend() const { return sequential_data_.cend(); } 232 rbegin()233 reverse_iterator rbegin() { return sequential_data_.rbegin(); } rend()234 reverse_iterator rend() { return sequential_data_.rend(); } rbegin()235 const_reverse_iterator rbegin() const { return sequential_data_.rbegin(); } rend()236 const_reverse_iterator rend() const { return sequential_data_.rend(); } 237 front()238 pair_type &front() { return sequential_data_.front(); } front()239 const pair_type &front() const { return sequential_data_.front(); } back()240 pair_type &back() { return sequential_data_.back(); } back()241 const pair_type &back() const { return sequential_data_.back(); } 242 243 OrderedMap() = default; 244 ~OrderedMap() = default; 245 246 OrderedMap(OrderedMap &&other) noexcept = default; 247 OrderedMap &operator=(OrderedMap &&other) noexcept = default; 248 OrderedMap(const sequential_type & other)249 explicit OrderedMap(const sequential_type &other) { 250 reserve(other.size()); 251 for (auto &item : other) { 252 (void)emplace(item.first, item.second); 253 } 254 } 255 OrderedMap(const OrderedMap & other)256 OrderedMap(const OrderedMap &other) : OrderedMap(other.sequential_data_) {} 257 258 OrderedMap &operator=(const OrderedMap &other) { 259 if (this != &other) { 260 clear(); 261 reserve(other.size()); 262 for (auto &item : other.sequential_data_) { 263 (void)emplace(item.first, item.second); 264 } 265 } 266 return *this; 267 } 268 clear()269 void clear() { 270 map_data_.clear(); 271 sequential_data_.clear(); 272 } 273 swap(OrderedMap & rhs)274 void swap(OrderedMap &rhs) noexcept { 275 std::swap(map_data_, rhs.map_data_); 276 std::swap(sequential_data_, rhs.sequential_data_); 277 } 278 reserve(size_type num_entries)279 void reserve(size_type num_entries) { map_data_.reserve(num_entries); } 280 281 template <typename K, typename V> emplace(K && key,V && value)282 std::pair<iterator, bool> emplace(K &&key, V &&value) { 283 auto [map_iter, inserted] = map_data_.emplace(key.get(), iterator{}); 284 if (inserted) { 285 map_iter->second = sequential_data_.emplace(sequential_data_.end(), std::forward<K>(key), std::forward<V>(value)); 286 } 287 return {map_iter->second, inserted}; 288 } 289 insert(const pair_type & kv)290 std::pair<iterator, bool> insert(const pair_type &kv) { 291 auto [map_iter, inserted] = map_data_.emplace(kv.first.get(), iterator{}); 292 if (inserted) { 293 map_iter->second = sequential_data_.emplace(sequential_data_.end(), kv); 294 } 295 return {map_iter->second, inserted}; 296 } 297 insert(pair_type && kv)298 std::pair<iterator, bool> insert(pair_type &&kv) { 299 auto [map_iter, inserted] = map_data_.emplace(kv.first.get(), iterator{}); 300 if (inserted) { 301 map_iter->second = sequential_data_.emplace(sequential_data_.end(), std::move(kv)); 302 } 303 return {map_iter->second, inserted}; 304 } 305 add(const key_t & key)306 std::pair<iterator, bool> add(const key_t &key) { return insert(pair_type{key, ValueT{}}); } 307 308 ValueT &operator[](const key_t &key) { 309 auto result = emplace(key, ValueT{}); 310 return result.first->second; 311 } 312 empty()313 bool empty() const { return sequential_data_.empty(); } 314 size()315 size_type size() const { return sequential_data_.size(); } 316 at(const key_t & key)317 const ValueT &at(const key_t &key) const { 318 auto &list_iter = map_data_.at(key.get()); 319 return list_iter->second; 320 } 321 count(const key_t & key)322 size_type count(const key_t &key) const { 323 auto pos = map_data_.find(key.get()); 324 return pos == map_data_.end() ? 0 : 1; 325 } 326 find(const key_t & key)327 iterator find(const key_t &key) { 328 auto pos = map_data_.find(key.get()); 329 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 330 } 331 find(const key_t & key)332 const_iterator find(const key_t &key) const { 333 auto pos = map_data_.find(key.get()); 334 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 335 } 336 337 // Remove the last element from the sequential_data_. pop_back()338 void pop_back() { 339 (void)map_data_.erase(sequential_data_.back().first.get()); 340 sequential_data_.pop_back(); 341 } 342 343 // Remove the first element from the sequential_data_. pop_front()344 void pop_front() { 345 (void)map_data_.erase(sequential_data_.front().first.get()); 346 sequential_data_.pop_front(); 347 } 348 349 // Remove the element given by Iterator. erase(iterator iter)350 iterator erase(iterator iter) { 351 (void)map_data_.erase(iter->first.get()); 352 return sequential_data_.erase(iter); 353 } 354 355 // Remove the element with the given key. erase(const key_t & key)356 size_type erase(const key_t &key) { 357 auto itr = find(key); 358 if (itr == end()) { 359 return 0; 360 } 361 (void)erase(itr); 362 return 1; 363 } 364 365 private: 366 map_type map_data_; 367 sequential_type sequential_data_; 368 }; 369 } // namespace mindspore 370 371 #endif // MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ 372