1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ 17 #define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ 18 19 #include <stddef.h> 20 #include <functional> 21 #include <initializer_list> 22 #include <iterator> 23 #include <utility> 24 #include "tensorflow/core/lib/gtl/flatrep.h" 25 #include "tensorflow/core/lib/hash/hash.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace gtl { 31 32 // FlatMap<K,V,...> provides a map from K to V. 33 // 34 // The map is implemented using an open-addressed hash table. A 35 // single array holds entire map contents and collisions are resolved 36 // by probing at a sequence of locations in the array. 37 template <typename Key, typename Val, class Hash = hash<Key>, 38 class Eq = std::equal_to<Key>> 39 class FlatMap { 40 private: 41 // Forward declare some internal types needed in public section. 42 struct Bucket; 43 44 // We cannot use std::pair<> since internal representation stores 45 // keys and values in separate arrays, so we make a custom struct 46 // that holds references to the internal key, value elements. 47 // 48 // We define the struct as private ValueType, and typedef it as public 49 // value_type, to work around a gcc bug when compiling the iterators. 50 struct ValueType { 51 typedef Key first_type; 52 typedef Val second_type; 53 54 const Key& first; 55 Val& second; ValueTypeValueType56 ValueType(const Key& k, Val& v) : first(k), second(v) {} 57 }; 58 59 public: 60 typedef Key key_type; 61 typedef Val mapped_type; 62 typedef Hash hasher; 63 typedef Eq key_equal; 64 typedef size_t size_type; 65 typedef ptrdiff_t difference_type; 66 typedef ValueType value_type; 67 typedef value_type* pointer; 68 typedef const value_type* const_pointer; 69 typedef value_type& reference; 70 typedef const value_type& const_reference; 71 FlatMap()72 FlatMap() : FlatMap(1) {} 73 74 explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) rep_(N,hf,eq)75 : rep_(N, hf, eq) {} 76 FlatMap(const FlatMap & src)77 FlatMap(const FlatMap& src) : rep_(src.rep_) {} 78 79 template <typename InputIter> 80 FlatMap(InputIter first, InputIter last, size_t N = 1, 81 const Hash& hf = Hash(), const Eq& eq = Eq()) FlatMap(N,hf,eq)82 : FlatMap(N, hf, eq) { 83 insert(first, last); 84 } 85 86 FlatMap(std::initializer_list<std::pair<const Key, Val>> init, size_t N = 1, 87 const Hash& hf = Hash(), const Eq& eq = Eq()) 88 : FlatMap(init.begin(), init.end(), N, hf, eq) {} 89 90 FlatMap& operator=(const FlatMap& src) { 91 rep_.CopyFrom(src.rep_); 92 return *this; 93 } 94 ~FlatMap()95 ~FlatMap() {} 96 swap(FlatMap & x)97 void swap(FlatMap& x) { rep_.swap(x.rep_); } clear_no_resize()98 void clear_no_resize() { rep_.clear_no_resize(); } clear()99 void clear() { rep_.clear(); } reserve(size_t N)100 void reserve(size_t N) { rep_.Resize(std::max(N, size())); } rehash(size_t N)101 void rehash(size_t N) { rep_.Resize(std::max(N, size())); } resize(size_t N)102 void resize(size_t N) { rep_.Resize(std::max(N, size())); } size()103 size_t size() const { return rep_.size(); } empty()104 bool empty() const { return size() == 0; } bucket_count()105 size_t bucket_count() const { return rep_.bucket_count(); } hash_function()106 hasher hash_function() const { return rep_.hash_function(); } key_eq()107 key_equal key_eq() const { return rep_.key_eq(); } 108 109 class iterator { 110 public: 111 typedef typename FlatMap::difference_type difference_type; 112 typedef typename FlatMap::value_type value_type; 113 typedef typename FlatMap::pointer pointer; 114 typedef typename FlatMap::reference reference; 115 typedef ::std::forward_iterator_tag iterator_category; 116 iterator()117 iterator() : b_(nullptr), end_(nullptr), i_(0) {} 118 119 // Make iterator pointing at first element at or after b. iterator(Bucket * b,Bucket * end)120 iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); } 121 122 // Make iterator pointing exactly at ith element in b, which must exist. iterator(Bucket * b,Bucket * end,uint32 i)123 iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) { 124 FillValue(); 125 } 126 127 reference operator*() { return *val(); } 128 pointer operator->() { return val(); } 129 bool operator==(const iterator& x) const { 130 return b_ == x.b_ && i_ == x.i_; 131 } 132 bool operator!=(const iterator& x) const { return !(*this == x); } 133 iterator& operator++() { 134 DCHECK(b_ != end_); 135 i_++; 136 SkipUnused(); 137 return *this; 138 } 139 iterator operator++(int /*indicates postfix*/) { 140 iterator tmp(*this); 141 ++*this; 142 return tmp; 143 } 144 145 private: 146 friend class FlatMap; 147 Bucket* b_; 148 Bucket* end_; 149 char space_ alignas(value_type)[sizeof(value_type)]; 150 uint32 i_; 151 val()152 pointer val() { return reinterpret_cast<pointer>(space_); } FillValue()153 void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); } SkipUnused()154 void SkipUnused() { 155 while (b_ < end_) { 156 if (i_ >= Rep::kWidth) { 157 i_ = 0; 158 b_++; 159 } else if (b_->marker[i_] < 2) { 160 i_++; 161 } else { 162 FillValue(); 163 break; 164 } 165 } 166 } 167 }; 168 169 class const_iterator { 170 private: 171 mutable iterator rep_; // Share state and logic with non-const iterator. 172 public: 173 typedef typename FlatMap::difference_type difference_type; 174 typedef typename FlatMap::value_type value_type; 175 typedef typename FlatMap::const_pointer pointer; 176 typedef typename FlatMap::const_reference reference; 177 typedef ::std::forward_iterator_tag iterator_category; 178 const_iterator()179 const_iterator() : rep_() {} const_iterator(Bucket * start,Bucket * end)180 const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} const_iterator(Bucket * b,Bucket * end,uint32 i)181 const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} 182 183 reference operator*() const { return *rep_.val(); } 184 pointer operator->() const { return rep_.val(); } 185 bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } 186 bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } 187 const_iterator& operator++() { 188 ++rep_; 189 return *this; 190 } 191 const_iterator operator++(int /*indicates postfix*/) { 192 const_iterator tmp(*this); 193 ++*this; 194 return tmp; 195 } 196 }; 197 begin()198 iterator begin() { return iterator(rep_.start(), rep_.limit()); } end()199 iterator end() { return iterator(rep_.limit(), rep_.limit()); } begin()200 const_iterator begin() const { 201 return const_iterator(rep_.start(), rep_.limit()); 202 } end()203 const_iterator end() const { 204 return const_iterator(rep_.limit(), rep_.limit()); 205 } 206 count(const Key & k)207 size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } find(const Key & k)208 iterator find(const Key& k) { 209 auto r = rep_.Find(k); 210 return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); 211 } find(const Key & k)212 const_iterator find(const Key& k) const { 213 auto r = rep_.Find(k); 214 return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); 215 } 216 at(const Key & k)217 Val& at(const Key& k) { 218 auto r = rep_.Find(k); 219 DCHECK(r.found); 220 return r.b->val(r.index); 221 } at(const Key & k)222 const Val& at(const Key& k) const { 223 auto r = rep_.Find(k); 224 DCHECK(r.found); 225 return r.b->val(r.index); 226 } 227 228 template <typename P> insert(const P & p)229 std::pair<iterator, bool> insert(const P& p) { 230 return Insert(p.first, p.second); 231 } insert(const std::pair<const Key,Val> & p)232 std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) { 233 return Insert(p.first, p.second); 234 } 235 template <typename InputIter> insert(InputIter first,InputIter last)236 void insert(InputIter first, InputIter last) { 237 for (; first != last; ++first) { 238 insert(*first); 239 } 240 } 241 242 Val& operator[](const Key& k) { return IndexOp(k); } 243 Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); } 244 245 template <typename... Args> emplace(Args &&...args)246 std::pair<iterator, bool> emplace(Args&&... args) { 247 return InsertPair(std::make_pair(std::forward<Args>(args)...)); 248 } 249 erase(const Key & k)250 size_t erase(const Key& k) { 251 auto r = rep_.Find(k); 252 if (!r.found) return 0; 253 rep_.Erase(r.b, r.index); 254 return 1; 255 } erase(iterator pos)256 iterator erase(iterator pos) { 257 rep_.Erase(pos.b_, pos.i_); 258 ++pos; 259 return pos; 260 } erase(iterator pos,iterator last)261 iterator erase(iterator pos, iterator last) { 262 for (; pos != last; ++pos) { 263 rep_.Erase(pos.b_, pos.i_); 264 } 265 return pos; 266 } 267 equal_range(const Key & k)268 std::pair<iterator, iterator> equal_range(const Key& k) { 269 auto pos = find(k); 270 if (pos == end()) { 271 return std::make_pair(pos, pos); 272 } else { 273 auto next = pos; 274 ++next; 275 return std::make_pair(pos, next); 276 } 277 } equal_range(const Key & k)278 std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { 279 auto pos = find(k); 280 if (pos == end()) { 281 return std::make_pair(pos, pos); 282 } else { 283 auto next = pos; 284 ++next; 285 return std::make_pair(pos, next); 286 } 287 } 288 289 bool operator==(const FlatMap& x) const { 290 if (size() != x.size()) return false; 291 for (auto& p : x) { 292 auto i = find(p.first); 293 if (i == end()) return false; 294 if (i->second != p.second) return false; 295 } 296 return true; 297 } 298 bool operator!=(const FlatMap& x) const { return !(*this == x); } 299 300 // If key exists in the table, prefetch the associated value. This 301 // is a hint, and may have no effect. prefetch_value(const Key & key)302 void prefetch_value(const Key& key) const { rep_.Prefetch(key); } 303 304 private: 305 using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; 306 307 // Bucket stores kWidth <marker, key, value> triples. 308 // The data is organized as three parallel arrays to reduce padding. 309 struct Bucket { 310 uint8 marker[Rep::kWidth]; 311 312 // Wrap keys and values in union to control construction and destruction. 313 union Storage { 314 struct { 315 Key key[Rep::kWidth]; 316 Val val[Rep::kWidth]; 317 }; Storage()318 Storage() {} ~Storage()319 ~Storage() {} 320 } storage; 321 keyBucket322 Key& key(uint32 i) { 323 DCHECK_GE(marker[i], 2); 324 return storage.key[i]; 325 } valBucket326 Val& val(uint32 i) { 327 DCHECK_GE(marker[i], 2); 328 return storage.val[i]; 329 } 330 template <typename V> InitValBucket331 void InitVal(uint32 i, V&& v) { 332 new (&storage.val[i]) Val(std::forward<V>(v)); 333 } DestroyBucket334 void Destroy(uint32 i) { 335 storage.key[i].Key::~Key(); 336 storage.val[i].Val::~Val(); 337 } MoveFromBucket338 void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { 339 new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); 340 new (&storage.val[i]) Val(std::move(src->storage.val[src_index])); 341 } CopyFromBucket342 void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { 343 new (&storage.key[i]) Key(src->storage.key[src_index]); 344 new (&storage.val[i]) Val(src->storage.val[src_index]); 345 } 346 }; 347 348 template <typename Pair> InsertPair(Pair && p)349 std::pair<iterator, bool> InsertPair(Pair&& p) { 350 return Insert(std::forward<decltype(p.first)>(p.first), 351 std::forward<decltype(p.second)>(p.second)); 352 } 353 354 template <typename K, typename V> Insert(K && k,V && v)355 std::pair<iterator, bool> Insert(K&& k, V&& v) { 356 rep_.MaybeResize(); 357 auto r = rep_.FindOrInsert(std::forward<K>(k)); 358 const bool inserted = !r.found; 359 if (inserted) { 360 r.b->InitVal(r.index, std::forward<V>(v)); 361 } 362 return {iterator(r.b, rep_.limit(), r.index), inserted}; 363 } 364 365 template <typename K> IndexOp(K && k)366 Val& IndexOp(K&& k) { 367 rep_.MaybeResize(); 368 auto r = rep_.FindOrInsert(std::forward<K>(k)); 369 Val* vptr = &r.b->val(r.index); 370 if (!r.found) { 371 new (vptr) Val(); // Initialize value in new slot. 372 } 373 return *vptr; 374 } 375 376 Rep rep_; 377 }; 378 379 } // namespace gtl 380 } // namespace tensorflow 381 382 #endif // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ 383