1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_COUNTER_H_ 18 #define MINDSPORE_CORE_UTILS_COUNTER_H_ 19 #include <list> 20 #include <vector> 21 #include <utility> 22 #include <functional> 23 #include <unordered_map> 24 #include <memory> 25 #include "utils/ordered_map.h" 26 27 namespace mindspore { 28 template <typename T, class Hash = std::hash<T>, class Equal = std::equal_to<T>> 29 class Counter { 30 using counter_type = Counter<T, Hash, Equal>; 31 using key_type = T const *; 32 using item_type = std::pair<T, int>; 33 using list_type = std::list<item_type>; 34 using iterator = typename list_type::iterator; 35 using const_iterator = typename list_type::const_iterator; 36 37 struct KeyHash { operatorKeyHash38 std::size_t operator()(const key_type ptr) const noexcept { return Hash{}(*ptr); } 39 }; 40 41 struct KeyEqual { operatorKeyEqual42 bool operator()(const key_type lhs, const key_type rhs) const noexcept { return Equal{}(*lhs, *rhs); } 43 }; 44 using map_type = std::unordered_map<key_type, iterator, KeyHash, KeyEqual>; 45 46 public: 47 Counter() = default; 48 ~Counter() = default; 49 50 Counter(Counter &&other) noexcept = default; 51 Counter &operator=(Counter &&other) noexcept = default; 52 Counter(const Counter & other)53 Counter(const Counter &other) { *this = other; } 54 Counter &operator=(const Counter &other) { 55 map_.clear(); 56 list_ = other.list_; 57 for (auto iter = list_.begin(); iter != list_.end(); ++iter) { 58 map_.emplace(&(iter->first), iter); 59 } 60 return *this; 61 } 62 63 template <typename... Args> emplace(Args &&...args)64 std::pair<iterator, bool> emplace(Args &&... args) { 65 auto new_iter = list_.emplace(list_.end(), std::forward<Args>(args)...); 66 auto [map_iter, inserted] = map_.emplace(&(new_iter->first), new_iter); 67 if (!inserted) { 68 list_.erase(new_iter); 69 } 70 return {map_iter->second, inserted}; 71 } 72 73 template <typename... Args> add(Args &&...args)74 void add(Args &&... args) { 75 auto [iter, inserted] = emplace(T{std::forward<Args>(args)...}, 1); 76 if (!inserted) { 77 ++(iter->second); 78 } 79 } 80 81 int &operator[](const T &key) { 82 auto map_iter = map_.find(&key); 83 if (map_iter != map_.end()) { 84 return map_iter->second->second; 85 } 86 return emplace(key, 0).first->second; 87 } 88 89 counter_type operator-(const counter_type &other) const { 90 counter_type new_counter; 91 for (const auto &[key, value] : list_) { 92 auto iter = other.find(key); 93 if (iter != other.end()) { 94 int new_value = value - iter->second; 95 if (new_value > 0) { 96 new_counter.emplace(key, new_value); 97 } 98 } else { 99 new_counter.emplace(key, value); 100 } 101 } 102 return new_counter; 103 } 104 105 counter_type operator+(const counter_type &other) const { 106 counter_type new_counter = *this; 107 for (const auto &[key, value] : other.list_) { 108 auto [iter, inserted] = new_counter.emplace(key, value); 109 if (!inserted) { 110 iter->second += value; 111 } 112 } 113 return new_counter; 114 } 115 116 template <typename Func> subtract_by(const counter_type & other,Func && func)117 void subtract_by(const counter_type &other, Func &&func) const { 118 for (const auto &[key, value] : list_) { 119 auto iter = other.find(key); 120 if (iter != other.end()) { 121 if ((value - iter->second) > 0) { 122 func(key); 123 } 124 } else { 125 func(key); 126 } 127 } 128 } 129 subtract(const counter_type & other)130 std::vector<T> subtract(const counter_type &other) const { 131 std::vector<T> result; 132 subtract_by(other, [&result](const T &item) { result.emplace_back(item); }); 133 return result; 134 } 135 size()136 std::size_t size() const { return list_.size(); } 137 contains(const T & key)138 bool contains(const T &key) const { return map_.find(&key) != map_.end(); } 139 find(const T & key)140 const_iterator find(const T &key) const { 141 auto map_iter = map_.find(&key); 142 if (map_iter == map_.end()) { 143 return list_.end(); 144 } 145 return map_iter->second; 146 } 147 begin()148 iterator begin() { return list_.begin(); } end()149 iterator end() { return list_.end(); } 150 begin()151 const_iterator begin() const { return list_.cbegin(); } end()152 const_iterator end() const { return list_.cend(); } 153 cbegin()154 const_iterator cbegin() const { return list_.cbegin(); } cend()155 const_iterator cend() const { return list_.cend(); } 156 157 private: 158 map_type map_; 159 list_type list_; 160 }; 161 162 // Counter for shared_ptr. 163 template <typename T> 164 class Counter<std::shared_ptr<T>> { 165 using key_type = std::shared_ptr<T>; 166 using counter_type = Counter<key_type>; 167 using map_type = OrderedMap<key_type, int>; 168 using item_type = std::pair<std::shared_ptr<T>, int>; 169 using iterator = typename map_type::iterator; 170 using const_iterator = typename map_type::const_iterator; 171 172 public: emplace(const key_type & key,int value)173 std::pair<iterator, bool> emplace(const key_type &key, int value) { return map_.emplace(key, value); } 174 emplace(key_type && key,int value)175 std::pair<iterator, bool> emplace(key_type &&key, int value) { return map_.emplace(std::move(key), value); } 176 add(const key_type & key)177 void add(const key_type &key) { 178 auto [iter, inserted] = map_.emplace(key, 1); 179 if (!inserted) { 180 ++(iter->second); 181 } 182 } 183 add(key_type && key)184 void add(key_type &&key) { 185 auto [iter, inserted] = map_.emplace(std::move(key), 1); 186 if (!inserted) { 187 ++(iter->second); 188 } 189 } 190 191 int &operator[](const T &key) { return map_[key]; } 192 193 counter_type operator-(const counter_type &other) const { 194 counter_type new_counter; 195 for (const auto &[key, value] : map_) { 196 auto iter = other.find(key); 197 if (iter != other.end()) { 198 int new_value = value - iter->second; 199 if (new_value > 0) { 200 new_counter.emplace(key, new_value); 201 } 202 } else { 203 new_counter.emplace(key, value); 204 } 205 } 206 return new_counter; 207 } 208 209 counter_type operator+(const counter_type &other) const { 210 counter_type new_counter = *this; 211 for (const auto &[key, value] : other) { 212 auto [iter, inserted] = new_counter.emplace(key, value); 213 if (!inserted) { 214 iter->second += value; 215 } 216 } 217 return new_counter; 218 } 219 220 template <typename Func> subtract_by(const counter_type & other,Func && func)221 void subtract_by(const counter_type &other, Func &&func) const { 222 for (const auto &[key, value] : map_) { 223 auto iter = other.find(key); 224 if (iter != other.end()) { 225 if ((value - iter->second) > 0) { 226 func(key); 227 } 228 } else { 229 func(key); 230 } 231 } 232 } 233 subtract(const counter_type & other)234 std::vector<key_type> subtract(const counter_type &other) const { 235 std::vector<key_type> result; 236 subtract_by(other, [&result](const key_type &item) { result.emplace_back(item); }); 237 return result; 238 } 239 size()240 std::size_t size() const { return map_.size(); } 241 contains(const key_type & key)242 bool contains(const key_type &key) const { return map_.contains(key); } 243 find(const key_type & key)244 const_iterator find(const key_type &key) const { return map_.find(key); } 245 begin()246 iterator begin() { return map_.begin(); } end()247 iterator end() { return map_.end(); } 248 begin()249 const_iterator begin() const { return map_.cbegin(); } end()250 const_iterator end() const { return map_.cend(); } 251 cbegin()252 const_iterator cbegin() const { return map_.cbegin(); } cend()253 const_iterator cend() const { return map_.cend(); } 254 255 private: 256 map_type map_; 257 }; 258 } // namespace mindspore 259 260 #endif // MINDSPORE_CORE_UTILS_COUNTER_H_ 261