• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
17 #define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
18 
19 #include <stddef.h>
20 #include <functional>
21 #include <initializer_list>
22 #include <iterator>
23 #include <utility>
24 #include "tensorflow/core/lib/gtl/flatrep.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace gtl {
31 
32 // FlatMap<K,V,...> provides a map from K to V.
33 //
34 // The map is implemented using an open-addressed hash table.  A
35 // single array holds entire map contents and collisions are resolved
36 // by probing at a sequence of locations in the array.
37 template <typename Key, typename Val, class Hash = hash<Key>,
38           class Eq = std::equal_to<Key>>
39 class FlatMap {
40  private:
41   // Forward declare some internal types needed in public section.
42   struct Bucket;
43 
44   // We cannot use std::pair<> since internal representation stores
45   // keys and values in separate arrays, so we make a custom struct
46   // that holds references to the internal key, value elements.
47   //
48   // We define the struct as private ValueType, and typedef it as public
49   // value_type, to work around a gcc bug when compiling the iterators.
50   struct ValueType {
51     typedef Key first_type;
52     typedef Val second_type;
53 
54     const Key& first;
55     Val& second;
ValueTypeValueType56     ValueType(const Key& k, Val& v) : first(k), second(v) {}
57   };
58 
59  public:
60   typedef Key key_type;
61   typedef Val mapped_type;
62   typedef Hash hasher;
63   typedef Eq key_equal;
64   typedef size_t size_type;
65   typedef ptrdiff_t difference_type;
66   typedef ValueType value_type;
67   typedef value_type* pointer;
68   typedef const value_type* const_pointer;
69   typedef value_type& reference;
70   typedef const value_type& const_reference;
71 
FlatMap()72   FlatMap() : FlatMap(1) {}
73 
74   explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
rep_(N,hf,eq)75       : rep_(N, hf, eq) {}
76 
FlatMap(const FlatMap & src)77   FlatMap(const FlatMap& src) : rep_(src.rep_) {}
78 
79   template <typename InputIter>
80   FlatMap(InputIter first, InputIter last, size_t N = 1,
81           const Hash& hf = Hash(), const Eq& eq = Eq())
FlatMap(N,hf,eq)82       : FlatMap(N, hf, eq) {
83     insert(first, last);
84   }
85 
86   FlatMap(std::initializer_list<std::pair<const Key, Val>> init, size_t N = 1,
87           const Hash& hf = Hash(), const Eq& eq = Eq())
88       : FlatMap(init.begin(), init.end(), N, hf, eq) {}
89 
90   FlatMap& operator=(const FlatMap& src) {
91     rep_.CopyFrom(src.rep_);
92     return *this;
93   }
94 
~FlatMap()95   ~FlatMap() {}
96 
swap(FlatMap & x)97   void swap(FlatMap& x) { rep_.swap(x.rep_); }
clear_no_resize()98   void clear_no_resize() { rep_.clear_no_resize(); }
clear()99   void clear() { rep_.clear(); }
reserve(size_t N)100   void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
rehash(size_t N)101   void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
resize(size_t N)102   void resize(size_t N) { rep_.Resize(std::max(N, size())); }
size()103   size_t size() const { return rep_.size(); }
empty()104   bool empty() const { return size() == 0; }
bucket_count()105   size_t bucket_count() const { return rep_.bucket_count(); }
hash_function()106   hasher hash_function() const { return rep_.hash_function(); }
key_eq()107   key_equal key_eq() const { return rep_.key_eq(); }
108 
109   class iterator {
110    public:
111     typedef typename FlatMap::difference_type difference_type;
112     typedef typename FlatMap::value_type value_type;
113     typedef typename FlatMap::pointer pointer;
114     typedef typename FlatMap::reference reference;
115     typedef ::std::forward_iterator_tag iterator_category;
116 
iterator()117     iterator() : b_(nullptr), end_(nullptr), i_(0) {}
118 
119     // Make iterator pointing at first element at or after b.
iterator(Bucket * b,Bucket * end)120     iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); }
121 
122     // Make iterator pointing exactly at ith element in b, which must exist.
iterator(Bucket * b,Bucket * end,uint32 i)123     iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {
124       FillValue();
125     }
126 
127     reference operator*() { return *val(); }
128     pointer operator->() { return val(); }
129     bool operator==(const iterator& x) const {
130       return b_ == x.b_ && i_ == x.i_;
131     }
132     bool operator!=(const iterator& x) const { return !(*this == x); }
133     iterator& operator++() {
134       DCHECK(b_ != end_);
135       i_++;
136       SkipUnused();
137       return *this;
138     }
139     iterator operator++(int /*indicates postfix*/) {
140       iterator tmp(*this);
141       ++*this;
142       return tmp;
143     }
144 
145    private:
146     friend class FlatMap;
147     Bucket* b_;
148     Bucket* end_;
149     char space_ alignas(value_type)[sizeof(value_type)];
150     uint32 i_;
151 
val()152     pointer val() { return reinterpret_cast<pointer>(space_); }
FillValue()153     void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); }
SkipUnused()154     void SkipUnused() {
155       while (b_ < end_) {
156         if (i_ >= Rep::kWidth) {
157           i_ = 0;
158           b_++;
159         } else if (b_->marker[i_] < 2) {
160           i_++;
161         } else {
162           FillValue();
163           break;
164         }
165       }
166     }
167   };
168 
169   class const_iterator {
170    private:
171     mutable iterator rep_;  // Share state and logic with non-const iterator.
172    public:
173     typedef typename FlatMap::difference_type difference_type;
174     typedef typename FlatMap::value_type value_type;
175     typedef typename FlatMap::const_pointer pointer;
176     typedef typename FlatMap::const_reference reference;
177     typedef ::std::forward_iterator_tag iterator_category;
178 
const_iterator()179     const_iterator() : rep_() {}
const_iterator(Bucket * start,Bucket * end)180     const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
const_iterator(Bucket * b,Bucket * end,uint32 i)181     const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
182 
183     reference operator*() const { return *rep_.val(); }
184     pointer operator->() const { return rep_.val(); }
185     bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
186     bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
187     const_iterator& operator++() {
188       ++rep_;
189       return *this;
190     }
191     const_iterator operator++(int /*indicates postfix*/) {
192       const_iterator tmp(*this);
193       ++*this;
194       return tmp;
195     }
196   };
197 
begin()198   iterator begin() { return iterator(rep_.start(), rep_.limit()); }
end()199   iterator end() { return iterator(rep_.limit(), rep_.limit()); }
begin()200   const_iterator begin() const {
201     return const_iterator(rep_.start(), rep_.limit());
202   }
end()203   const_iterator end() const {
204     return const_iterator(rep_.limit(), rep_.limit());
205   }
206 
count(const Key & k)207   size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
find(const Key & k)208   iterator find(const Key& k) {
209     auto r = rep_.Find(k);
210     return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
211   }
find(const Key & k)212   const_iterator find(const Key& k) const {
213     auto r = rep_.Find(k);
214     return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
215   }
216 
at(const Key & k)217   Val& at(const Key& k) {
218     auto r = rep_.Find(k);
219     DCHECK(r.found);
220     return r.b->val(r.index);
221   }
at(const Key & k)222   const Val& at(const Key& k) const {
223     auto r = rep_.Find(k);
224     DCHECK(r.found);
225     return r.b->val(r.index);
226   }
227 
228   template <typename P>
insert(const P & p)229   std::pair<iterator, bool> insert(const P& p) {
230     return Insert(p.first, p.second);
231   }
insert(const std::pair<const Key,Val> & p)232   std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) {
233     return Insert(p.first, p.second);
234   }
235   template <typename InputIter>
insert(InputIter first,InputIter last)236   void insert(InputIter first, InputIter last) {
237     for (; first != last; ++first) {
238       insert(*first);
239     }
240   }
241 
242   Val& operator[](const Key& k) { return IndexOp(k); }
243   Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); }
244 
245   template <typename... Args>
emplace(Args &&...args)246   std::pair<iterator, bool> emplace(Args&&... args) {
247     return InsertPair(std::make_pair(std::forward<Args>(args)...));
248   }
249 
erase(const Key & k)250   size_t erase(const Key& k) {
251     auto r = rep_.Find(k);
252     if (!r.found) return 0;
253     rep_.Erase(r.b, r.index);
254     return 1;
255   }
erase(iterator pos)256   iterator erase(iterator pos) {
257     rep_.Erase(pos.b_, pos.i_);
258     ++pos;
259     return pos;
260   }
erase(iterator pos,iterator last)261   iterator erase(iterator pos, iterator last) {
262     for (; pos != last; ++pos) {
263       rep_.Erase(pos.b_, pos.i_);
264     }
265     return pos;
266   }
267 
equal_range(const Key & k)268   std::pair<iterator, iterator> equal_range(const Key& k) {
269     auto pos = find(k);
270     if (pos == end()) {
271       return std::make_pair(pos, pos);
272     } else {
273       auto next = pos;
274       ++next;
275       return std::make_pair(pos, next);
276     }
277   }
equal_range(const Key & k)278   std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
279     auto pos = find(k);
280     if (pos == end()) {
281       return std::make_pair(pos, pos);
282     } else {
283       auto next = pos;
284       ++next;
285       return std::make_pair(pos, next);
286     }
287   }
288 
289   bool operator==(const FlatMap& x) const {
290     if (size() != x.size()) return false;
291     for (auto& p : x) {
292       auto i = find(p.first);
293       if (i == end()) return false;
294       if (i->second != p.second) return false;
295     }
296     return true;
297   }
298   bool operator!=(const FlatMap& x) const { return !(*this == x); }
299 
300   // If key exists in the table, prefetch the associated value.  This
301   // is a hint, and may have no effect.
prefetch_value(const Key & key)302   void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
303 
304  private:
305   using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
306 
307   // Bucket stores kWidth <marker, key, value> triples.
308   // The data is organized as three parallel arrays to reduce padding.
309   struct Bucket {
310     uint8 marker[Rep::kWidth];
311 
312     // Wrap keys and values in union to control construction and destruction.
313     union Storage {
314       struct {
315         Key key[Rep::kWidth];
316         Val val[Rep::kWidth];
317       };
Storage()318       Storage() {}
~Storage()319       ~Storage() {}
320     } storage;
321 
keyBucket322     Key& key(uint32 i) {
323       DCHECK_GE(marker[i], 2);
324       return storage.key[i];
325     }
valBucket326     Val& val(uint32 i) {
327       DCHECK_GE(marker[i], 2);
328       return storage.val[i];
329     }
330     template <typename V>
InitValBucket331     void InitVal(uint32 i, V&& v) {
332       new (&storage.val[i]) Val(std::forward<V>(v));
333     }
DestroyBucket334     void Destroy(uint32 i) {
335       storage.key[i].Key::~Key();
336       storage.val[i].Val::~Val();
337     }
MoveFromBucket338     void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
339       new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
340       new (&storage.val[i]) Val(std::move(src->storage.val[src_index]));
341     }
CopyFromBucket342     void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
343       new (&storage.key[i]) Key(src->storage.key[src_index]);
344       new (&storage.val[i]) Val(src->storage.val[src_index]);
345     }
346   };
347 
348   template <typename Pair>
InsertPair(Pair && p)349   std::pair<iterator, bool> InsertPair(Pair&& p) {
350     return Insert(std::forward<decltype(p.first)>(p.first),
351                   std::forward<decltype(p.second)>(p.second));
352   }
353 
354   template <typename K, typename V>
Insert(K && k,V && v)355   std::pair<iterator, bool> Insert(K&& k, V&& v) {
356     rep_.MaybeResize();
357     auto r = rep_.FindOrInsert(std::forward<K>(k));
358     const bool inserted = !r.found;
359     if (inserted) {
360       r.b->InitVal(r.index, std::forward<V>(v));
361     }
362     return {iterator(r.b, rep_.limit(), r.index), inserted};
363   }
364 
365   template <typename K>
IndexOp(K && k)366   Val& IndexOp(K&& k) {
367     rep_.MaybeResize();
368     auto r = rep_.FindOrInsert(std::forward<K>(k));
369     Val* vptr = &r.b->val(r.index);
370     if (!r.found) {
371       new (vptr) Val();  // Initialize value in new slot.
372     }
373     return *vptr;
374   }
375 
376   Rep rep_;
377 };
378 
379 }  // namespace gtl
380 }  // namespace tensorflow
381 
382 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
383