1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ 18 #define MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ 19 20 #include <list> 21 #include <algorithm> 22 #include <functional> 23 #include <utility> 24 #include <memory> 25 #include <type_traits> 26 #include <unordered_map> 27 #include "utils/hashing.h" 28 #include "utils/hash_map.h" 29 30 namespace mindspore { 31 // Implementation of OrderedMap that keeps insertion order 32 // using hash map to improve the performance of find/erase, and use list to keep insertion order 33 template <typename KeyT, typename ValueT, class Hash = std::hash<KeyT>, class Equal = std::equal_to<KeyT>, 34 bool UseStd = false> 35 class OrderedMap { 36 using key_ptr_t = const KeyT *; 37 struct KeyPtrHash { operatorKeyPtrHash38 std::size_t operator()(key_ptr_t ptr) const noexcept { return Hash{}(*ptr); } 39 }; 40 struct KeyPtrEqual { operatorKeyPtrEqual41 bool operator()(key_ptr_t lhs, key_ptr_t rhs) const noexcept { return Equal{}(*lhs, *rhs); } 42 }; 43 44 public: 45 using key_t = KeyT; 46 using value_t = ValueT; 47 using pair_type = std::pair<key_t, value_t>; 48 using sequential_type = std::list<pair_type>; 49 using iterator = typename sequential_type::iterator; 50 using const_iterator = typename sequential_type::const_iterator; 51 using reverse_iterator = typename sequential_type::reverse_iterator; 52 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 53 using map_type = typename std::conditional<UseStd, std::unordered_map<key_ptr_t, iterator, KeyPtrHash, KeyPtrEqual>, 54 mindspore::HashMap<key_ptr_t, iterator, KeyPtrHash, KeyPtrEqual>>::type; 55 using value_type = typename sequential_type::value_type; 56 using size_type = typename sequential_type::size_type; 57 begin()58 iterator begin() { return sequential_data_.begin(); } end()59 iterator end() { return sequential_data_.end(); } begin()60 const_iterator begin() const { return sequential_data_.cbegin(); } end()61 const_iterator end() const { return sequential_data_.cend(); } cbegin()62 const_iterator cbegin() const { return sequential_data_.cbegin(); } cend()63 const_iterator cend() const { return sequential_data_.cend(); } 64 rbegin()65 reverse_iterator rbegin() { return sequential_data_.rbegin(); } rend()66 reverse_iterator rend() { return sequential_data_.rend(); } rbegin()67 const_reverse_iterator rbegin() const { return sequential_data_.rbegin(); } rend()68 const_reverse_iterator rend() const { return sequential_data_.rend(); } 69 front()70 pair_type &front() { return sequential_data_.front(); } front()71 const pair_type &front() const { return sequential_data_.front(); } back()72 pair_type &back() { return sequential_data_.back(); } back()73 const pair_type &back() const { return sequential_data_.back(); } 74 75 OrderedMap() = default; 76 ~OrderedMap() = default; 77 78 OrderedMap(OrderedMap &&other) noexcept = default; 79 OrderedMap &operator=(OrderedMap &&other) noexcept = default; 80 OrderedMap(const sequential_type & other)81 explicit OrderedMap(const sequential_type &other) { 82 reserve(other.size()); 83 for (auto &item : other) { 84 (void)emplace(item.first, item.second); 85 } 86 } 87 OrderedMap(const OrderedMap & other)88 OrderedMap(const OrderedMap &other) : OrderedMap(other.sequential_data_) {} 89 90 OrderedMap &operator=(const OrderedMap &other) { 91 if (this != &other) { 92 clear(); 93 reserve(other.size()); 94 for (auto &item : other.sequential_data_) { 95 (void)emplace(item.first, item.second); 96 } 97 } 98 return *this; 99 } 100 clear()101 void clear() { 102 if (!map_data_.empty()) { 103 map_data_.clear(); 104 sequential_data_.clear(); 105 } 106 } 107 swap(OrderedMap & rhs)108 void swap(OrderedMap &rhs) noexcept { 109 std::swap(map_data_, rhs.map_data_); 110 std::swap(sequential_data_, rhs.sequential_data_); 111 } 112 reserve(size_type num_entries)113 void reserve(size_type num_entries) { map_data_.reserve(num_entries); } 114 115 template <typename... Args> emplace(Args &&...args)116 std::pair<iterator, bool> emplace(Args &&... args) { 117 auto new_iter = sequential_data_.emplace(sequential_data_.end(), std::forward<Args>(args)...); 118 auto [map_iter, inserted] = map_data_.emplace(&(new_iter->first), new_iter); 119 if (!inserted) { 120 sequential_data_.erase(new_iter); 121 } 122 return {map_iter->second, inserted}; 123 } 124 insert(const pair_type & kv)125 std::pair<iterator, bool> insert(const pair_type &kv) { 126 auto iter = map_data_.find(&(kv.first)); 127 if (iter != map_data_.end()) { 128 return {iter->second, false}; 129 } 130 auto new_iter = sequential_data_.emplace(sequential_data_.end(), kv); 131 auto result = map_data_.emplace(&(new_iter->first), new_iter); 132 return {result.first->second, true}; 133 } 134 insert(pair_type && kv)135 std::pair<iterator, bool> insert(pair_type &&kv) { 136 auto iter = map_data_.find(&(kv.first)); 137 if (iter != map_data_.end()) { 138 return {iter->second, false}; 139 } 140 auto new_iter = sequential_data_.emplace(sequential_data_.end(), std::move(kv)); 141 auto result = map_data_.emplace(&(new_iter->first), new_iter); 142 return {result.first->second, true}; 143 } 144 add(const key_t & key)145 std::pair<iterator, bool> add(const key_t &key) { return insert(pair_type{key, ValueT{}}); } 146 147 ValueT &operator[](const key_t &key) { 148 auto iter = map_data_.find(&key); 149 if (iter != map_data_.end()) { 150 return iter->second->second; 151 } 152 auto new_iter = sequential_data_.emplace(sequential_data_.end(), key, ValueT{}); 153 auto result = map_data_.emplace(&(new_iter->first), new_iter); 154 return result.first->second->second; 155 } 156 empty()157 bool empty() const { return sequential_data_.empty(); } 158 size()159 size_type size() const { return sequential_data_.size(); } 160 at(const key_t & key)161 const ValueT &at(const key_t &key) const { 162 auto &list_iter = map_data_.at(&key); 163 return list_iter->second; 164 } 165 count(const key_t & key)166 size_type count(const key_t &key) const { 167 auto pos = map_data_.find(&key); 168 return pos == map_data_.end() ? 0 : 1; 169 } 170 find(const key_t & key)171 iterator find(const key_t &key) { 172 auto pos = map_data_.find(&key); 173 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 174 } 175 find(const key_t & key)176 const_iterator find(const key_t &key) const { 177 auto pos = map_data_.find(&key); 178 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 179 } 180 181 // Remove the last element from the sequential_data_. pop_back()182 void pop_back() { 183 (void)map_data_.erase(&(sequential_data_.back().first)); 184 sequential_data_.pop_back(); 185 } 186 187 // Remove the first element from the sequential_data_. pop_front()188 void pop_front() { 189 (void)map_data_.erase(&(sequential_data_.front().first)); 190 sequential_data_.pop_front(); 191 } 192 193 // Remove the element given by Iterator. erase(const iterator & iter)194 iterator erase(const iterator &iter) { 195 (void)map_data_.erase(&(iter->first)); 196 return sequential_data_.erase(iter); 197 } 198 199 // Remove the element with the given key erase(const key_t & key)200 size_type erase(const key_t &key) { 201 auto itr = find(key); 202 if (itr == end()) { 203 return 0; 204 } 205 (void)erase(itr); 206 return 1; 207 } 208 209 private: 210 map_type map_data_; 211 sequential_type sequential_data_; 212 }; 213 214 // OrderedMap that specially optimized for shared_ptr key type. 215 template <typename T, typename ValueT> 216 class OrderedMap<std::shared_ptr<T>, ValueT> { 217 public: 218 using raw_key_t = const T *; 219 using key_t = std::shared_ptr<T>; 220 using value_t = ValueT; 221 using pair_type = std::pair<key_t, value_t>; 222 using sequential_type = std::list<pair_type>; 223 using iterator = typename sequential_type::iterator; 224 using const_iterator = typename sequential_type::const_iterator; 225 using reverse_iterator = typename sequential_type::reverse_iterator; 226 using const_reverse_iterator = typename sequential_type::const_reverse_iterator; 227 using map_type = mindspore::HashMap<raw_key_t, iterator>; 228 using value_type = typename sequential_type::value_type; 229 using size_type = typename sequential_type::size_type; 230 begin()231 iterator begin() { return sequential_data_.begin(); } end()232 iterator end() { return sequential_data_.end(); } begin()233 const_iterator begin() const { return sequential_data_.cbegin(); } end()234 const_iterator end() const { return sequential_data_.cend(); } cbegin()235 const_iterator cbegin() const { return sequential_data_.cbegin(); } cend()236 const_iterator cend() const { return sequential_data_.cend(); } 237 rbegin()238 reverse_iterator rbegin() { return sequential_data_.rbegin(); } rend()239 reverse_iterator rend() { return sequential_data_.rend(); } rbegin()240 const_reverse_iterator rbegin() const { return sequential_data_.rbegin(); } rend()241 const_reverse_iterator rend() const { return sequential_data_.rend(); } 242 front()243 pair_type &front() { return sequential_data_.front(); } front()244 const pair_type &front() const { return sequential_data_.front(); } back()245 pair_type &back() { return sequential_data_.back(); } back()246 const pair_type &back() const { return sequential_data_.back(); } 247 248 OrderedMap() = default; 249 ~OrderedMap() = default; 250 251 OrderedMap(OrderedMap &&other) noexcept = default; 252 OrderedMap &operator=(OrderedMap &&other) noexcept = default; 253 OrderedMap(const sequential_type & other)254 explicit OrderedMap(const sequential_type &other) { 255 reserve(other.size()); 256 for (auto &item : other) { 257 (void)emplace(item.first, item.second); 258 } 259 } 260 OrderedMap(const OrderedMap & other)261 OrderedMap(const OrderedMap &other) : OrderedMap(other.sequential_data_) {} 262 263 OrderedMap &operator=(const OrderedMap &other) { 264 if (this != &other) { 265 clear(); 266 reserve(other.size()); 267 for (auto &item : other.sequential_data_) { 268 (void)emplace(item.first, item.second); 269 } 270 } 271 return *this; 272 } 273 clear()274 void clear() { 275 if (!map_data_.empty()) { 276 map_data_.clear(); 277 sequential_data_.clear(); 278 } 279 } 280 swap(OrderedMap & rhs)281 void swap(OrderedMap &rhs) noexcept { 282 std::swap(map_data_, rhs.map_data_); 283 std::swap(sequential_data_, rhs.sequential_data_); 284 } 285 reserve(size_type num_entries)286 void reserve(size_type num_entries) { map_data_.reserve(num_entries); } 287 288 template <typename K, typename V> emplace(K && key,V && value)289 std::pair<iterator, bool> emplace(K &&key, V &&value) { 290 auto [map_iter, inserted] = map_data_.emplace(key.get(), iterator{}); 291 if (inserted) { 292 map_iter->second = sequential_data_.emplace(sequential_data_.end(), std::forward<K>(key), std::forward<V>(value)); 293 } 294 return {map_iter->second, inserted}; 295 } 296 insert(const pair_type & kv)297 std::pair<iterator, bool> insert(const pair_type &kv) { 298 auto [map_iter, inserted] = map_data_.emplace(kv.first.get(), iterator{}); 299 if (inserted) { 300 map_iter->second = sequential_data_.emplace(sequential_data_.end(), kv); 301 } 302 return {map_iter->second, inserted}; 303 } 304 insert(pair_type && kv)305 std::pair<iterator, bool> insert(pair_type &&kv) { 306 auto [map_iter, inserted] = map_data_.emplace(kv.first.get(), iterator{}); 307 if (inserted) { 308 map_iter->second = sequential_data_.emplace(sequential_data_.end(), std::move(kv)); 309 } 310 return {map_iter->second, inserted}; 311 } 312 add(const key_t & key)313 std::pair<iterator, bool> add(const key_t &key) { return insert(pair_type{key, ValueT{}}); } 314 315 ValueT &operator[](const key_t &key) { 316 auto result = emplace(key, ValueT{}); 317 return result.first->second; 318 } 319 empty()320 bool empty() const { return sequential_data_.empty(); } 321 size()322 size_type size() const { return sequential_data_.size(); } 323 at(const key_t & key)324 const ValueT &at(const key_t &key) const { 325 auto &list_iter = map_data_.at(key.get()); 326 return list_iter->second; 327 } 328 count(const key_t & key)329 size_type count(const key_t &key) const { 330 auto pos = map_data_.find(key.get()); 331 return pos == map_data_.end() ? 0 : 1; 332 } 333 find(const key_t & key)334 iterator find(const key_t &key) { 335 auto pos = map_data_.find(key.get()); 336 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 337 } 338 find(const key_t & key)339 const_iterator find(const key_t &key) const { 340 auto pos = map_data_.find(key.get()); 341 return pos == map_data_.end() ? sequential_data_.end() : (pos->second); 342 } 343 344 // Remove the last element from the sequential_data_. pop_back()345 void pop_back() { 346 (void)map_data_.erase(sequential_data_.back().first.get()); 347 sequential_data_.pop_back(); 348 } 349 350 // Remove the first element from the sequential_data_. pop_front()351 void pop_front() { 352 (void)map_data_.erase(sequential_data_.front().first.get()); 353 sequential_data_.pop_front(); 354 } 355 356 // Remove the element given by Iterator. erase(const iterator & iter)357 iterator erase(const iterator &iter) { 358 (void)map_data_.erase(iter->first.get()); 359 return sequential_data_.erase(iter); 360 } 361 362 // Remove the element with the given key. erase(const key_t & key)363 size_type erase(const key_t &key) { 364 auto itr = find(key); 365 if (itr == end()) { 366 return 0; 367 } 368 (void)erase(itr); 369 return 1; 370 } 371 372 private: 373 map_type map_data_; 374 sequential_type sequential_data_; 375 }; 376 } // namespace mindspore 377 378 #endif // MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ 379