• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
17 #define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 #include "tensorflow/contrib/nearest_neighbor/kernels/heap.h"
22 
23 namespace tensorflow {
24 namespace nearest_neighbor {
25 
26 // This class implements hyperplane multiprobe LSH as described in the
27 // following paper:
28 //
29 //   Multi-probe LSH: efficient indexing for high-dimensional similarity search
30 //   Qin Lv, William Josephson, Zhe Wang, Moses Charikar, Kai Li
31 //
32 // The class is only responsible for generating the probing sequence of given
33 // length for a given batch of points. The actual hash table lookups are
34 // implemented in other classes.
35 template <typename CoordinateType, typename HashType>
36 class HyperplaneMultiprobe {
37  public:
38   using Matrix = Eigen::Matrix<CoordinateType, Eigen::Dynamic, Eigen::Dynamic,
39                                Eigen::RowMajor>;
40   using ConstMatrixMap = Eigen::Map<const Matrix>;
41   using MatrixMap = Eigen::Map<Matrix>;
42   using Vector =
43       Eigen::Matrix<CoordinateType, Eigen::Dynamic, 1, Eigen::ColMajor>;
44 
HyperplaneMultiprobe(int num_hyperplanes_per_table,int num_tables)45   HyperplaneMultiprobe(int num_hyperplanes_per_table, int num_tables)
46       : num_hyperplanes_per_table_(num_hyperplanes_per_table),
47         num_tables_(num_tables),
48         num_probes_(0),
49         cur_probe_counter_(0),
50         sorted_hyperplane_indices_(0),
51         main_table_probe_(num_tables) {}
52 
53   // The first input hash_vector is the matrix-vector product between the
54   // hyperplane matrix and the vector for which we want to generate a probing
55   // sequence. We assume that each index in hash_vector is proportional to the
56   // distance between vector and hyperplane (i.e., the hyperplane vectors should
57   // all have the same norm).
58   //
59   // The second input is the number of probes we want to retrieve. If this
60   // number is fixed in advance, it should be passed in here in order to enable
61   // some (minor) internal optimizations. If the number of probes it not known
62   // in advance, the multiprobe sequence can still produce an arbitrary length
63   // probing sequence (up to the maximum number of probes) by calling
64   // get_next_probe multiple times.
65   //
66   // If num_probes is at most num_tables, it is not necessary to generate an
67   // actual multiprobe sequence and the multiprobe object will simply return
68   // the "standard" LSH probes without incurring any multiprobe overhead.
SetupProbing(const Vector & hash_vector,int_fast64_t num_probes)69   void SetupProbing(const Vector& hash_vector, int_fast64_t num_probes) {
70     // We accept a copy here for now.
71     hash_vector_ = hash_vector;
72     num_probes_ = num_probes;
73     cur_probe_counter_ = -1;
74 
75     // Compute the initial probes for each table, i.e., the "true" hash
76     // locations LSH without multiprobe would give.
77     for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
78       main_table_probe_[ii] = 0;
79       for (int_fast32_t jj = 0; jj < num_hyperplanes_per_table_; ++jj) {
80         main_table_probe_[ii] = main_table_probe_[ii] << 1;
81         main_table_probe_[ii] =
82             main_table_probe_[ii] |
83             (hash_vector_[ii * num_hyperplanes_per_table_ + jj] >= 0.0);
84       }
85     }
86 
87     if (num_probes_ >= 0 && num_probes_ <= num_tables_) {
88       return;
89     }
90 
91     if (sorted_hyperplane_indices_.size() == 0) {
92       sorted_hyperplane_indices_.resize(num_tables_);
93       for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
94         sorted_hyperplane_indices_[ii].resize(num_hyperplanes_per_table_);
95         for (int_fast32_t jj = 0; jj < num_hyperplanes_per_table_; ++jj) {
96           sorted_hyperplane_indices_[ii][jj] = jj;
97         }
98       }
99     }
100 
101     for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
102       HyperplaneComparator comp(hash_vector_, ii * num_hyperplanes_per_table_);
103       std::sort(sorted_hyperplane_indices_[ii].begin(),
104                 sorted_hyperplane_indices_[ii].end(), comp);
105     }
106 
107     if (num_probes_ >= 0) {
108       heap_.Resize(2 * num_probes_);
109     }
110     heap_.Reset();
111     for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
112       int_fast32_t best_index = sorted_hyperplane_indices_[ii][0];
113       CoordinateType score =
114           hash_vector_[ii * num_hyperplanes_per_table_ + best_index];
115       score = score * score;
116       HashType hash_mask = 1;
117       hash_mask = hash_mask << (num_hyperplanes_per_table_ - best_index - 1);
118       heap_.InsertUnsorted(score, ProbeCandidate(ii, hash_mask, 0));
119     }
120     heap_.Heapify();
121   }
122 
123   // This method stores the current probe (= hash table location) and
124   // corresponding table in the output parameters. The return value indicates
125   // whether this succeeded (true) or the current probing sequence is exhausted
126   // (false). Here, we say a probing sequence is exhausted if one of the
127   // following two conditions occurs:
128   // - We have used a non-negative value for num_probes in setup_probing, and
129   //   we have produced this many number of probes in the current sequence.
130   // - We have used a negative value for num_probes in setup_probing, and we
131   //   have produced all possible probes in the probing sequence.
GetNextProbe(HashType * cur_probe,int_fast32_t * cur_table)132   bool GetNextProbe(HashType* cur_probe, int_fast32_t* cur_table) {
133     cur_probe_counter_ += 1;
134 
135     if (num_probes_ >= 0 && cur_probe_counter_ >= num_probes_) {
136       // We are out of probes in the current probing sequence.
137       return false;
138     }
139 
140     // For the first num_tables_ probes, we directly return the "standard LSH"
141     // probes to guarantee that they always come first and we avoid any
142     // multiprobe overhead.
143     if (cur_probe_counter_ < num_tables_) {
144       *cur_probe = main_table_probe_[cur_probe_counter_];
145       *cur_table = cur_probe_counter_;
146       return true;
147     }
148 
149     // If the heap is empty, the current probing sequence is exhausted.
150     if (heap_.IsEmpty()) {
151       return false;
152     }
153 
154     CoordinateType cur_score;
155     ProbeCandidate cur_candidate;
156     heap_.ExtractMin(&cur_score, &cur_candidate);
157     *cur_table = cur_candidate.table_;
158     int_fast32_t cur_index =
159         sorted_hyperplane_indices_[*cur_table][cur_candidate.last_index_];
160     *cur_probe = main_table_probe_[*cur_table] ^ cur_candidate.hash_mask_;
161 
162     if (cur_candidate.last_index_ != num_hyperplanes_per_table_ - 1) {
163       // swapping out the last flipped index
164       int_fast32_t next_index =
165           sorted_hyperplane_indices_[*cur_table][cur_candidate.last_index_ + 1];
166 
167       // xor out previous bit, xor in new bit.
168       HashType next_mask =
169           cur_candidate.hash_mask_ ^
170           (HashType(1) << (num_hyperplanes_per_table_ - cur_index - 1)) ^
171           (HashType(1) << (num_hyperplanes_per_table_ - next_index - 1));
172 
173       CoordinateType cur_coord =
174           hash_vector_[*cur_table * num_hyperplanes_per_table_ + cur_index];
175       CoordinateType next_coord =
176           hash_vector_[*cur_table * num_hyperplanes_per_table_ + next_index];
177       CoordinateType next_score =
178           cur_score - cur_coord * cur_coord + next_coord * next_coord;
179 
180       heap_.Insert(next_score, ProbeCandidate(*cur_table, next_mask,
181                                               cur_candidate.last_index_ + 1));
182 
183       // adding a new flipped index
184       next_mask =
185           cur_candidate.hash_mask_ ^
186           (HashType(1) << (num_hyperplanes_per_table_ - next_index - 1));
187       next_score = cur_score + next_coord * next_coord;
188 
189       heap_.Insert(next_score, ProbeCandidate(*cur_table, next_mask,
190                                               cur_candidate.last_index_ + 1));
191     }
192 
193     return true;
194   }
195 
196  private:
197   class ProbeCandidate {
198    public:
199     ProbeCandidate(int_fast32_t table = 0, HashType hash_mask = 0,
200                    int_fast32_t last_index = 0)
table_(table)201         : table_(table), hash_mask_(hash_mask), last_index_(last_index) {}
202 
203     int_fast32_t table_;
204     HashType hash_mask_;
205     int_fast32_t last_index_;
206   };
207 
208   class HyperplaneComparator {
209    public:
HyperplaneComparator(const Vector & values,int_fast32_t offset)210     HyperplaneComparator(const Vector& values, int_fast32_t offset)
211         : values_(values), offset_(offset) {}
212 
operator()213     bool operator()(int_fast32_t ii, int_fast32_t jj) const {
214       return std::abs(values_[offset_ + ii]) < std::abs(values_[offset_ + jj]);
215     }
216 
217    private:
218     const Vector& values_;
219     int_fast32_t offset_;
220   };
221 
222   int_fast32_t num_hyperplanes_per_table_;
223   int_fast32_t num_tables_;
224   int_fast64_t num_probes_;
225   int_fast64_t cur_probe_counter_;
226   std::vector<std::vector<int_fast32_t>> sorted_hyperplane_indices_;
227   std::vector<HashType> main_table_probe_;
228   SimpleHeap<CoordinateType, ProbeCandidate> heap_;
229   Vector hash_vector_;
230 };
231 
232 }  // namespace nearest_neighbor
233 }  // namespace tensorflow
234 
235 #endif  // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
236