1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ 17 #define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ 18 19 #include <stddef.h> 20 #include <functional> 21 #include <initializer_list> 22 #include <iterator> 23 #include <utility> 24 #include "tensorflow/core/lib/gtl/flatrep.h" 25 #include "tensorflow/core/lib/hash/hash.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace gtl { 31 32 // FlatSet<K,...> provides a set of K. 33 // 34 // The map is implemented using an open-addressed hash table. A 35 // single array holds entire map contents and collisions are resolved 36 // by probing at a sequence of locations in the array. 37 template <typename Key, class Hash = hash<Key>, class Eq = std::equal_to<Key>> 38 class FlatSet { 39 private: 40 // Forward declare some internal types needed in public section. 41 struct Bucket; 42 43 public: 44 typedef Key key_type; 45 typedef Key value_type; 46 typedef Hash hasher; 47 typedef Eq key_equal; 48 typedef size_t size_type; 49 typedef ptrdiff_t difference_type; 50 typedef value_type* pointer; 51 typedef const value_type* const_pointer; 52 typedef value_type& reference; 53 typedef const value_type& const_reference; 54 FlatSet()55 FlatSet() : FlatSet(1) {} 56 57 explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) rep_(N,hf,eq)58 : rep_(N, hf, eq) {} 59 FlatSet(const FlatSet & src)60 FlatSet(const FlatSet& src) : rep_(src.rep_) {} 61 62 // Move constructor leaves src in a valid but unspecified state (same as 63 // std::unordered_set). FlatSet(FlatSet && src)64 FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {} 65 66 template <typename InputIter> 67 FlatSet(InputIter first, InputIter last, size_t N = 1, 68 const Hash& hf = Hash(), const Eq& eq = Eq()) FlatSet(N,hf,eq)69 : FlatSet(N, hf, eq) { 70 insert(first, last); 71 } 72 73 FlatSet(std::initializer_list<value_type> init, size_t N = 1, 74 const Hash& hf = Hash(), const Eq& eq = Eq()) 75 : FlatSet(init.begin(), init.end(), N, hf, eq) {} 76 77 FlatSet& operator=(const FlatSet& src) { 78 rep_.CopyFrom(src.rep_); 79 return *this; 80 } 81 82 // Move-assignment operator leaves src in a valid but unspecified state (same 83 // as std::unordered_set). 84 FlatSet& operator=(FlatSet&& src) { 85 rep_.MoveFrom(std::move(src.rep_)); 86 return *this; 87 } 88 ~FlatSet()89 ~FlatSet() {} 90 swap(FlatSet & x)91 void swap(FlatSet& x) { rep_.swap(x.rep_); } clear_no_resize()92 void clear_no_resize() { rep_.clear_no_resize(); } clear()93 void clear() { rep_.clear(); } reserve(size_t N)94 void reserve(size_t N) { rep_.Resize(std::max(N, size())); } rehash(size_t N)95 void rehash(size_t N) { rep_.Resize(std::max(N, size())); } resize(size_t N)96 void resize(size_t N) { rep_.Resize(std::max(N, size())); } size()97 size_t size() const { return rep_.size(); } empty()98 bool empty() const { return size() == 0; } bucket_count()99 size_t bucket_count() const { return rep_.bucket_count(); } hash_function()100 hasher hash_function() const { return rep_.hash_function(); } key_eq()101 key_equal key_eq() const { return rep_.key_eq(); } 102 103 class const_iterator { 104 public: 105 typedef typename FlatSet::difference_type difference_type; 106 typedef typename FlatSet::value_type value_type; 107 typedef typename FlatSet::const_pointer pointer; 108 typedef typename FlatSet::const_reference reference; 109 typedef ::std::forward_iterator_tag iterator_category; 110 const_iterator()111 const_iterator() : b_(nullptr), end_(nullptr), i_(0) {} 112 113 // Make iterator pointing at first element at or after b. const_iterator(Bucket * b,Bucket * end)114 const_iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { 115 SkipUnused(); 116 } 117 118 // Make iterator pointing exactly at ith element in b, which must exist. const_iterator(Bucket * b,Bucket * end,uint32 i)119 const_iterator(Bucket* b, Bucket* end, uint32 i) 120 : b_(b), end_(end), i_(i) {} 121 122 reference operator*() const { return key(); } 123 pointer operator->() const { return &key(); } 124 bool operator==(const const_iterator& x) const { 125 return b_ == x.b_ && i_ == x.i_; 126 } 127 bool operator!=(const const_iterator& x) const { return !(*this == x); } 128 const_iterator& operator++() { 129 DCHECK(b_ != end_); 130 i_++; 131 SkipUnused(); 132 return *this; 133 } 134 const_iterator operator++(int /*indicates postfix*/) { 135 const_iterator tmp(*this); 136 ++*this; 137 return tmp; 138 } 139 140 private: 141 friend class FlatSet; 142 Bucket* b_; 143 Bucket* end_; 144 uint32 i_; 145 key()146 reference key() const { return b_->key(i_); } SkipUnused()147 void SkipUnused() { 148 while (b_ < end_) { 149 if (i_ >= Rep::kWidth) { 150 i_ = 0; 151 b_++; 152 } else if (b_->marker[i_] < 2) { 153 i_++; 154 } else { 155 break; 156 } 157 } 158 } 159 }; 160 161 typedef const_iterator iterator; 162 begin()163 iterator begin() { return iterator(rep_.start(), rep_.limit()); } end()164 iterator end() { return iterator(rep_.limit(), rep_.limit()); } begin()165 const_iterator begin() const { 166 return const_iterator(rep_.start(), rep_.limit()); 167 } end()168 const_iterator end() const { 169 return const_iterator(rep_.limit(), rep_.limit()); 170 } 171 count(const Key & k)172 size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } find(const Key & k)173 iterator find(const Key& k) { 174 auto r = rep_.Find(k); 175 return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); 176 } find(const Key & k)177 const_iterator find(const Key& k) const { 178 auto r = rep_.Find(k); 179 return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); 180 } 181 insert(const Key & k)182 std::pair<iterator, bool> insert(const Key& k) { return Insert(k); } insert(Key && k)183 std::pair<iterator, bool> insert(Key&& k) { return Insert(std::move(k)); } 184 template <typename InputIter> insert(InputIter first,InputIter last)185 void insert(InputIter first, InputIter last) { 186 for (; first != last; ++first) { 187 insert(*first); 188 } 189 } 190 191 template <typename... Args> emplace(Args &&...args)192 std::pair<iterator, bool> emplace(Args&&... args) { 193 rep_.MaybeResize(); 194 auto r = rep_.FindOrInsert(std::forward<Args>(args)...); 195 const bool inserted = !r.found; 196 return {iterator(r.b, rep_.limit(), r.index), inserted}; 197 } 198 erase(const Key & k)199 size_t erase(const Key& k) { 200 auto r = rep_.Find(k); 201 if (!r.found) return 0; 202 rep_.Erase(r.b, r.index); 203 return 1; 204 } erase(iterator pos)205 iterator erase(iterator pos) { 206 rep_.Erase(pos.b_, pos.i_); 207 ++pos; 208 return pos; 209 } erase(iterator pos,iterator last)210 iterator erase(iterator pos, iterator last) { 211 for (; pos != last; ++pos) { 212 rep_.Erase(pos.b_, pos.i_); 213 } 214 return pos; 215 } 216 equal_range(const Key & k)217 std::pair<iterator, iterator> equal_range(const Key& k) { 218 auto pos = find(k); 219 if (pos == end()) { 220 return std::make_pair(pos, pos); 221 } else { 222 auto next = pos; 223 ++next; 224 return std::make_pair(pos, next); 225 } 226 } equal_range(const Key & k)227 std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { 228 auto pos = find(k); 229 if (pos == end()) { 230 return std::make_pair(pos, pos); 231 } else { 232 auto next = pos; 233 ++next; 234 return std::make_pair(pos, next); 235 } 236 } 237 238 bool operator==(const FlatSet& x) const { 239 if (size() != x.size()) return false; 240 for (const auto& elem : x) { 241 auto i = find(elem); 242 if (i == end()) return false; 243 } 244 return true; 245 } 246 bool operator!=(const FlatSet& x) const { return !(*this == x); } 247 248 // If key exists in the table, prefetch it. This is a hint, and may 249 // have no effect. prefetch_value(const Key & key)250 void prefetch_value(const Key& key) const { rep_.Prefetch(key); } 251 252 private: 253 using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; 254 255 // Bucket stores kWidth <marker, key, value> triples. 256 // The data is organized as three parallel arrays to reduce padding. 257 struct Bucket { 258 uint8 marker[Rep::kWidth]; 259 260 // Wrap keys in union to control construction and destruction. 261 union Storage { 262 Key key[Rep::kWidth]; Storage()263 Storage() {} ~Storage()264 ~Storage() {} 265 } storage; 266 keyBucket267 Key& key(uint32 i) { 268 DCHECK_GE(marker[i], 2); 269 return storage.key[i]; 270 } DestroyBucket271 void Destroy(uint32 i) { storage.key[i].Key::~Key(); } MoveFromBucket272 void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { 273 new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); 274 } CopyFromBucket275 void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { 276 new (&storage.key[i]) Key(src->storage.key[src_index]); 277 } 278 }; 279 280 template <typename K> Insert(K && k)281 std::pair<iterator, bool> Insert(K&& k) { 282 rep_.MaybeResize(); 283 auto r = rep_.FindOrInsert(std::forward<K>(k)); 284 const bool inserted = !r.found; 285 return {iterator(r.b, rep_.limit(), r.index), inserted}; 286 } 287 288 Rep rep_; 289 }; 290 291 } // namespace gtl 292 } // namespace tensorflow 293 294 #endif // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ 295