1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ 17 #define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ 18 19 #include <cassert> 20 #include <cstdint> 21 #include <cstdlib> 22 #include <vector> 23 24 namespace tensorflow { 25 namespace nearest_neighbor { 26 27 // A simple binary heap. We use our own implementation because multiprobe for 28 // the cross-polytope hash interacts with the heap in a way so that about half 29 // of the insertion operations are guaranteed to be on top of the heap. We make 30 // use of this fact in the AugmentedHeap below. 31 32 // HeapBase is a base class for both the SimpleHeap and AugmentedHeap below. 33 template <typename KeyType, typename DataType> 34 class HeapBase { 35 public: 36 class Item { 37 public: 38 KeyType key; 39 DataType data; 40 Item()41 Item() {} Item(const KeyType & k,const DataType & d)42 Item(const KeyType& k, const DataType& d) : key(k), data(d) {} 43 44 bool operator<(const Item& i2) const { return key < i2.key; } 45 }; 46 ExtractMin(KeyType * key,DataType * data)47 void ExtractMin(KeyType* key, DataType* data) { 48 *key = v_[0].key; 49 *data = v_[0].data; 50 num_elements_ -= 1; 51 v_[0] = v_[num_elements_]; 52 HeapDown(0); 53 } 54 IsEmpty()55 bool IsEmpty() { return num_elements_ == 0; } 56 57 // This method adds an element at the end of the internal array without 58 // "heapifying" the array afterwards. This is useful for setting up a heap 59 // where a single call to heapify at the end of the initial insertion 60 // operations suffices. InsertUnsorted(const KeyType & key,const DataType & data)61 void InsertUnsorted(const KeyType& key, const DataType& data) { 62 if (v_.size() == static_cast<size_t>(num_elements_)) { 63 v_.push_back(Item(key, data)); 64 } else { 65 v_[num_elements_].key = key; 66 v_[num_elements_].data = data; 67 } 68 num_elements_ += 1; 69 } 70 Insert(const KeyType & key,const DataType & data)71 void Insert(const KeyType& key, const DataType& data) { 72 if (v_.size() == static_cast<size_t>(num_elements_)) { 73 v_.push_back(Item(key, data)); 74 } else { 75 v_[num_elements_].key = key; 76 v_[num_elements_].data = data; 77 } 78 num_elements_ += 1; 79 HeapUp(num_elements_ - 1); 80 } 81 Heapify()82 void Heapify() { 83 int_fast32_t rightmost = parent(num_elements_ - 1); 84 for (int_fast32_t cur_loc = rightmost; cur_loc >= 0; --cur_loc) { 85 HeapDown(cur_loc); 86 } 87 } 88 Reset()89 void Reset() { num_elements_ = 0; } 90 Resize(size_t new_size)91 void Resize(size_t new_size) { v_.resize(new_size); } 92 93 protected: lchild(int_fast32_t x)94 int_fast32_t lchild(int_fast32_t x) { return 2 * x + 1; } 95 rchild(int_fast32_t x)96 int_fast32_t rchild(int_fast32_t x) { return 2 * x + 2; } 97 parent(int_fast32_t x)98 int_fast32_t parent(int_fast32_t x) { return (x - 1) / 2; } 99 SwapEntries(int_fast32_t a,int_fast32_t b)100 void SwapEntries(int_fast32_t a, int_fast32_t b) { 101 Item tmp = v_[a]; 102 v_[a] = v_[b]; 103 v_[b] = tmp; 104 } 105 HeapUp(int_fast32_t cur_loc)106 void HeapUp(int_fast32_t cur_loc) { 107 int_fast32_t p = parent(cur_loc); 108 while (cur_loc > 0 && v_[p].key > v_[cur_loc].key) { 109 SwapEntries(p, cur_loc); 110 cur_loc = p; 111 p = parent(cur_loc); 112 } 113 } 114 HeapDown(int_fast32_t cur_loc)115 void HeapDown(int_fast32_t cur_loc) { 116 while (true) { 117 int_fast32_t lc = lchild(cur_loc); 118 int_fast32_t rc = rchild(cur_loc); 119 if (lc >= num_elements_) { 120 return; 121 } 122 123 if (v_[cur_loc].key <= v_[lc].key) { 124 if (rc >= num_elements_ || v_[cur_loc].key <= v_[rc].key) { 125 return; 126 } else { 127 SwapEntries(cur_loc, rc); 128 cur_loc = rc; 129 } 130 } else { 131 if (rc >= num_elements_ || v_[lc].key <= v_[rc].key) { 132 SwapEntries(cur_loc, lc); 133 cur_loc = lc; 134 } else { 135 SwapEntries(cur_loc, rc); 136 cur_loc = rc; 137 } 138 } 139 } 140 } 141 142 std::vector<Item> v_; 143 int_fast32_t num_elements_ = 0; 144 }; 145 146 // A "simple" binary heap. 147 template <typename KeyType, typename DataType> 148 class SimpleHeap : public HeapBase<KeyType, DataType> { 149 public: ReplaceTop(const KeyType & key,const DataType & data)150 void ReplaceTop(const KeyType& key, const DataType& data) { 151 this->v_[0].key = key; 152 this->v_[0].data = data; 153 this->HeapDown(0); 154 } 155 MinKey()156 KeyType MinKey() { return this->v_[0].key; } 157 GetData()158 std::vector<typename HeapBase<KeyType, DataType>::Item>& GetData() { 159 return this->v_; 160 } 161 }; 162 163 // An "augmented" heap that can hold an extra element that is guaranteed to 164 // be at the top of the heap. This is useful if a significant fraction of the 165 // insertion operations are guaranteed insertions at the top. However, the heap 166 // only stores at most one such special top element, i.e., the heap assumes 167 // that extract_min() is called at least once between successive calls to 168 // insert_guaranteed_top(). 169 template <typename KeyType, typename DataType> 170 class AugmentedHeap : public HeapBase<KeyType, DataType> { 171 public: ExtractMin(KeyType * key,DataType * data)172 void ExtractMin(KeyType* key, DataType* data) { 173 if (has_guaranteed_top_) { 174 has_guaranteed_top_ = false; 175 *key = guaranteed_top_.key; 176 *data = guaranteed_top_.data; 177 } else { 178 *key = this->v_[0].key; 179 *data = this->v_[0].data; 180 this->num_elements_ -= 1; 181 this->v_[0] = this->v_[this->num_elements_]; 182 this->HeapDown(0); 183 } 184 } 185 IsEmpty()186 bool IsEmpty() { return this->num_elements_ == 0 && !has_guaranteed_top_; } 187 InsertGuaranteedTop(const KeyType & key,const DataType & data)188 void InsertGuaranteedTop(const KeyType& key, const DataType& data) { 189 assert(!has_guaranteed_top_); 190 has_guaranteed_top_ = true; 191 guaranteed_top_.key = key; 192 guaranteed_top_.data = data; 193 } 194 Reset()195 void Reset() { 196 this->num_elements_ = 0; 197 has_guaranteed_top_ = false; 198 } 199 200 protected: 201 typename HeapBase<KeyType, DataType>::Item guaranteed_top_; 202 bool has_guaranteed_top_ = false; 203 }; 204 205 } // namespace nearest_neighbor 206 } // namespace tensorflow 207 208 #endif // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ 209