• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_COUNTER_H_
18 #define MINDSPORE_CORE_UTILS_COUNTER_H_
19 #include <list>
20 #include <vector>
21 #include <utility>
22 #include <functional>
23 #include <unordered_map>
24 #include <memory>
25 #include "utils/ordered_map.h"
26 
27 namespace mindspore {
28 template <typename T, class Hash = std::hash<T>, class Equal = std::equal_to<T>>
29 class Counter {
30   using counter_type = Counter<T, Hash, Equal>;
31   using key_type = T const *;
32   using item_type = std::pair<T, int>;
33   using list_type = std::list<item_type>;
34   using iterator = typename list_type::iterator;
35   using const_iterator = typename list_type::const_iterator;
36 
37   struct KeyHash {
operatorKeyHash38     std::size_t operator()(const key_type ptr) const noexcept { return Hash{}(*ptr); }
39   };
40 
41   struct KeyEqual {
operatorKeyEqual42     bool operator()(const key_type lhs, const key_type rhs) const noexcept { return Equal{}(*lhs, *rhs); }
43   };
44   using map_type = std::unordered_map<key_type, iterator, KeyHash, KeyEqual>;
45 
46  public:
47   Counter() = default;
48   ~Counter() = default;
49 
50   Counter(Counter &&other) noexcept = default;
51   Counter &operator=(Counter &&other) noexcept = default;
52 
Counter(const Counter & other)53   Counter(const Counter &other) { *this = other; }
54   Counter &operator=(const Counter &other) {
55     map_.clear();
56     list_ = other.list_;
57     for (auto iter = list_.begin(); iter != list_.end(); ++iter) {
58       map_.emplace(&(iter->first), iter);
59     }
60     return *this;
61   }
62 
63   template <typename... Args>
emplace(Args &&...args)64   std::pair<iterator, bool> emplace(Args &&... args) {
65     auto new_iter = list_.emplace(list_.end(), std::forward<Args>(args)...);
66     auto [map_iter, inserted] = map_.emplace(&(new_iter->first), new_iter);
67     if (!inserted) {
68       list_.erase(new_iter);
69     }
70     return {map_iter->second, inserted};
71   }
72 
73   template <typename... Args>
add(Args &&...args)74   void add(Args &&... args) {
75     auto [iter, inserted] = emplace(T{std::forward<Args>(args)...}, 1);
76     if (!inserted) {
77       ++(iter->second);
78     }
79   }
80 
81   int &operator[](const T &key) {
82     auto map_iter = map_.find(&key);
83     if (map_iter != map_.end()) {
84       return map_iter->second->second;
85     }
86     return emplace(key, 0).first->second;
87   }
88 
89   counter_type operator-(const counter_type &other) const {
90     counter_type new_counter;
91     for (const auto &[key, value] : list_) {
92       auto iter = other.find(key);
93       if (iter != other.end()) {
94         int new_value = value - iter->second;
95         if (new_value > 0) {
96           new_counter.emplace(key, new_value);
97         }
98       } else {
99         new_counter.emplace(key, value);
100       }
101     }
102     return new_counter;
103   }
104 
105   counter_type operator+(const counter_type &other) const {
106     counter_type new_counter = *this;
107     for (const auto &[key, value] : other.list_) {
108       auto [iter, inserted] = new_counter.emplace(key, value);
109       if (!inserted) {
110         iter->second += value;
111       }
112     }
113     return new_counter;
114   }
115 
116   template <typename Func>
subtract_by(const counter_type & other,Func && func)117   void subtract_by(const counter_type &other, Func &&func) const {
118     for (const auto &[key, value] : list_) {
119       auto iter = other.find(key);
120       if (iter != other.end()) {
121         if ((value - iter->second) > 0) {
122           func(key);
123         }
124       } else {
125         func(key);
126       }
127     }
128   }
129 
subtract(const counter_type & other)130   std::vector<T> subtract(const counter_type &other) const {
131     std::vector<T> result;
132     subtract_by(other, [&result](const T &item) { result.emplace_back(item); });
133     return result;
134   }
135 
size()136   std::size_t size() const { return list_.size(); }
137 
contains(const T & key)138   bool contains(const T &key) const { return map_.find(&key) != map_.end(); }
139 
find(const T & key)140   const_iterator find(const T &key) const {
141     auto map_iter = map_.find(&key);
142     if (map_iter == map_.end()) {
143       return list_.end();
144     }
145     return map_iter->second;
146   }
147 
begin()148   iterator begin() { return list_.begin(); }
end()149   iterator end() { return list_.end(); }
150 
begin()151   const_iterator begin() const { return list_.cbegin(); }
end()152   const_iterator end() const { return list_.cend(); }
153 
cbegin()154   const_iterator cbegin() const { return list_.cbegin(); }
cend()155   const_iterator cend() const { return list_.cend(); }
156 
157  private:
158   map_type map_;
159   list_type list_;
160 };
161 
162 // Counter for shared_ptr.
163 template <typename T>
164 class Counter<std::shared_ptr<T>> {
165   using key_type = std::shared_ptr<T>;
166   using counter_type = Counter<key_type>;
167   using map_type = OrderedMap<key_type, int>;
168   using item_type = std::pair<std::shared_ptr<T>, int>;
169   using iterator = typename map_type::iterator;
170   using const_iterator = typename map_type::const_iterator;
171 
172  public:
emplace(const key_type & key,int value)173   std::pair<iterator, bool> emplace(const key_type &key, int value) { return map_.emplace(key, value); }
174 
emplace(key_type && key,int value)175   std::pair<iterator, bool> emplace(key_type &&key, int value) { return map_.emplace(std::move(key), value); }
176 
add(const key_type & key)177   void add(const key_type &key) {
178     auto [iter, inserted] = map_.emplace(key, 1);
179     if (!inserted) {
180       ++(iter->second);
181     }
182   }
183 
add(key_type && key)184   void add(key_type &&key) {
185     auto [iter, inserted] = map_.emplace(std::move(key), 1);
186     if (!inserted) {
187       ++(iter->second);
188     }
189   }
190 
191   int &operator[](const T &key) { return map_[key]; }
192 
193   counter_type operator-(const counter_type &other) const {
194     counter_type new_counter;
195     for (const auto &[key, value] : map_) {
196       auto iter = other.find(key);
197       if (iter != other.end()) {
198         int new_value = value - iter->second;
199         if (new_value > 0) {
200           new_counter.emplace(key, new_value);
201         }
202       } else {
203         new_counter.emplace(key, value);
204       }
205     }
206     return new_counter;
207   }
208 
209   counter_type operator+(const counter_type &other) const {
210     counter_type new_counter = *this;
211     for (const auto &[key, value] : other) {
212       auto [iter, inserted] = new_counter.emplace(key, value);
213       if (!inserted) {
214         iter->second += value;
215       }
216     }
217     return new_counter;
218   }
219 
220   template <typename Func>
subtract_by(const counter_type & other,Func && func)221   void subtract_by(const counter_type &other, Func &&func) const {
222     for (const auto &[key, value] : map_) {
223       auto iter = other.find(key);
224       if (iter != other.end()) {
225         if ((value - iter->second) > 0) {
226           func(key);
227         }
228       } else {
229         func(key);
230       }
231     }
232   }
233 
subtract(const counter_type & other)234   std::vector<key_type> subtract(const counter_type &other) const {
235     std::vector<key_type> result;
236     subtract_by(other, [&result](const key_type &item) { result.emplace_back(item); });
237     return result;
238   }
239 
size()240   std::size_t size() const { return map_.size(); }
241 
contains(const key_type & key)242   bool contains(const key_type &key) const { return map_.contains(key); }
243 
find(const key_type & key)244   const_iterator find(const key_type &key) const { return map_.find(key); }
245 
begin()246   iterator begin() { return map_.begin(); }
end()247   iterator end() { return map_.end(); }
248 
begin()249   const_iterator begin() const { return map_.cbegin(); }
end()250   const_iterator end() const { return map_.cend(); }
251 
cbegin()252   const_iterator cbegin() const { return map_.cbegin(); }
cend()253   const_iterator cend() const { return map_.cend(); }
254 
255  private:
256   map_type map_;
257 };
258 }  // namespace mindspore
259 
260 #endif  // MINDSPORE_CORE_UTILS_COUNTER_H_
261