• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License.  You may obtain a copy
5 // of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <tuple>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/top_n.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/simple_philox.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace tensorflow {
39 namespace {
40 using errors::InvalidArgument;
41 
42 template <typename Scalar>
43 using RowMajorMatrix =
44     Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
45 
46 using MatrixXfRowMajor = RowMajorMatrix<float>;
47 using MatrixXi64RowMajor = RowMajorMatrix<int64>;
48 
49 // Ideally this should be computed by dividing L3 cache size by the number of
50 // physical CPUs. Since there isn't a portable method to do this, we are using
51 // a conservative estimate here.
52 const int64 kDefaultL3CachePerCpu = 1 << 20;
53 
54 // These values were determined by performing a parameter sweep on the
55 // NearestNeighborsOp benchmark.
56 const int64 kNearestNeighborsCentersMaxBlockSize = 1024;
57 const int64 kNearestNeighborsPointsMinBlockSize = 16;
58 
59 // Returns the smallest multiple of a that is not smaller than b.
NextMultiple(int64 a,int64 b)60 int64 NextMultiple(int64 a, int64 b) {
61   const int64 remainder = b % a;
62   return remainder == 0 ? b : (b + a - remainder);
63 }
64 
65 // Returns a / b rounded up to the next higher integer.
CeilOfRatio(int64 a,int64 b)66 int64 CeilOfRatio(int64 a, int64 b) { return (a + b - 1) / b; }
67 
68 }  // namespace
69 
70 // Implementation of K-means++ initialization. Samples points iteratively in
71 // proportion to the squared distances from selected points.
72 // TODO(ands): Add support for other distance metrics.
73 class KmeansPlusPlusInitializationOp : public OpKernel {
74  public:
KmeansPlusPlusInitializationOp(OpKernelConstruction * context)75   explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
76       : OpKernel(context) {
77     OP_REQUIRES_OK(context,
78                    context->MatchSignature(
79                        {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
80   }
81 
Compute(OpKernelContext * context)82   void Compute(OpKernelContext* context) override {
83     const Tensor& points_tensor = context->input(0);
84     const Tensor& num_to_sample_tensor = context->input(1);
85     const Tensor& seed_tensor = context->input(2);
86     const Tensor& num_retries_per_sample_tensor = context->input(3);
87 
88     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
89                 InvalidArgument("Input points should be a matrix."));
90     OP_REQUIRES(context,
91                 TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
92                 InvalidArgument("Input num_to_sample should be a scalar."));
93     OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
94                 InvalidArgument("Input seed should be a scalar."));
95     OP_REQUIRES(
96         context,
97         TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
98         InvalidArgument("Input num_retries_per_sample should be a scalar."));
99 
100     const int64 num_points = points_tensor.dim_size(0);
101     const int64 point_dimensions = points_tensor.dim_size(1);
102     const int64 num_to_sample = num_to_sample_tensor.scalar<int64>()();
103     const int64 seed = seed_tensor.scalar<int64>()();
104     const int64 num_retries_per_sample = [&]() {
105       const int64 value = num_retries_per_sample_tensor.scalar<int64>()();
106       return value >= 0 ? value
107                         : 2 + static_cast<int64>(std::log(num_to_sample));
108     }();
109 
110     OP_REQUIRES(context, num_points > 0,
111                 InvalidArgument("Expected points.rows() > 0."));
112     OP_REQUIRES(context, num_to_sample > 0,
113                 InvalidArgument("Expected num_to_sample > 0."));
114     OP_REQUIRES(context, num_to_sample <= num_points,
115                 InvalidArgument("Expected num_to_sample <= points.rows(). ",
116                                 num_to_sample, " vs ", num_points, "."));
117 
118     Tensor* output_sampled_points_tensor;
119     OP_REQUIRES_OK(context,
120                    context->allocate_output(
121                        0, TensorShape({num_to_sample, point_dimensions}),
122                        &output_sampled_points_tensor));
123 
124     const Eigen::Map<const MatrixXfRowMajor> points(
125         points_tensor.matrix<float>().data(), num_points, point_dimensions);
126     const Eigen::VectorXf points_half_squared_norm =
127         0.5 * points.rowwise().squaredNorm();
128 
129     Eigen::Map<MatrixXfRowMajor> sampled_points(
130         output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
131         point_dimensions);
132     std::unordered_set<int64> sampled_indices;
133 
134     random::PhiloxRandom random(seed);
135     random::SimplePhilox rng(&random);
136 
137     auto add_one_point = [&](int64 from, int64 to) {
138       from = std::min(from, num_points - 1);
139       sampled_points.row(to) = points.row(from);
140       sampled_indices.insert(from);
141     };
142 
143     // Distances from all points to nearest selected point. Initialize with
144     // distances to first selected point.
145     Eigen::VectorXf min_distances(num_points);
146     min_distances.fill(std::numeric_limits<float>::infinity());
147     Eigen::VectorXf min_distances_cumsum(num_points);
148 
149     auto draw_one_sample = [&]() -> int64 {
150       if (sampled_indices.empty()) return rng.Uniform64(num_points);
151       int64 index = 0;
152       do {
153         // If v is drawn from Uniform[0, distances.sum()), then
154         // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
155         // proportional to distances(i).
156         index = std::upper_bound(
157                     min_distances_cumsum.data(),
158                     min_distances_cumsum.data() + num_points,
159                     rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
160                 min_distances_cumsum.data();
161       } while (sampled_indices.find(index) != sampled_indices.end());
162       return index;
163     };
164 
165     auto sample_one_point = [&]() {
166       const int64 sampled_index = draw_one_sample();
167       min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
168           points, points_half_squared_norm, points.row(sampled_index),
169           points_half_squared_norm(sampled_index)));
170       return sampled_index;
171     };
172 
173     auto sample_one_point_with_retries = [&]() {
174       Eigen::VectorXf best_new_min_distances(num_points);
175       float best_potential = std::numeric_limits<float>::infinity();
176       int64 best_sampled_index = 0;
177       for (int i = 1 + num_retries_per_sample; i > 0; --i) {
178         const int64 sampled_index = draw_one_sample();
179         Eigen::VectorXf new_min_distances =
180             min_distances.cwiseMin(GetHalfSquaredDistancesToY(
181                 points, points_half_squared_norm, points.row(sampled_index),
182                 points_half_squared_norm(sampled_index)));
183         const float potential = new_min_distances.sum();
184         if (potential < best_potential) {
185           best_potential = potential;
186           best_sampled_index = sampled_index;
187           best_new_min_distances.swap(new_min_distances);
188         }
189       }
190       min_distances.swap(best_new_min_distances);
191       return best_sampled_index;
192     };
193 
194     for (int64 i = 0; i < num_to_sample; ++i) {
195       if (i > 0) {
196         std::partial_sum(min_distances.data(),
197                          min_distances.data() + num_points,
198                          min_distances_cumsum.data());
199       }
200       int64 next = num_retries_per_sample == 0
201                        ? sample_one_point()
202                        : sample_one_point_with_retries();
203       add_one_point(next, i);
204     }
205   }
206 
207  private:
208   // Returns a column vector with the i-th element set to half the squared
209   // euclidean distance between the i-th row of xs, and y. Precomputed norms for
210   // each row of xs and y must be provided for efficiency.
211   // TODO(ands): Parallelize this for large xs.
GetHalfSquaredDistancesToY(const Eigen::Ref<const MatrixXfRowMajor> & xs,const Eigen::Ref<const Eigen::VectorXf> & xs_half_squared_norm,const Eigen::Ref<const Eigen::RowVectorXf> & y,float y_half_squared_norm)212   static Eigen::VectorXf GetHalfSquaredDistancesToY(
213       const Eigen::Ref<const MatrixXfRowMajor>& xs,
214       const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
215       const Eigen::Ref<const Eigen::RowVectorXf>& y,
216       float y_half_squared_norm) {
217     // Squared distance between points xs_i and y is:
218     //   || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
219     return (xs_half_squared_norm - xs * y.transpose()).array() +
220            y_half_squared_norm;
221   }
222 };
223 
224 REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
225                         KmeansPlusPlusInitializationOp);
226 
227 // Implementation of one single Markov Chain for the k-MC^2 algorithm
228 class KMC2ChainInitializationOp : public OpKernel {
229  public:
KMC2ChainInitializationOp(OpKernelConstruction * context)230   explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
231       : OpKernel(context) {
232     OP_REQUIRES_OK(context,
233                    context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
234   }
235 
Compute(OpKernelContext * context)236   void Compute(OpKernelContext* context) override {
237     const Tensor& distances_tensor = context->input(0);
238     const Tensor& seed_tensor = context->input(1);
239     OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
240                 InvalidArgument("Input distances should be a vector."));
241     OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
242                 InvalidArgument("Input seed should be a scalar."));
243     const int64 num_points = distances_tensor.dim_size(0);
244     const int64 seed = seed_tensor.scalar<int64>()();
245     OP_REQUIRES(context, num_points > 0,
246                 InvalidArgument("Expected distances_tensor.size() > 0."));
247 
248     random::PhiloxRandom random(seed);
249     random::SimplePhilox rng(&random);
250 
251     auto distances = distances_tensor.flat<float>();
252     // Set the initial state of the Markov chain to be the first candidate.
253     int64 selected_index = 0;
254     float selected_distance = distances(selected_index);
255     // Build a Markov chain of length num_points.
256     for (int64 i = 1; i < num_points; ++i) {
257       const float candidate_distance = distances(i);
258       // Set the next state of the Markov chain to be the candidate with
259       // probability min(1, candidate_distance/selected_distance).
260       if (candidate_distance > rng.RandFloat() * selected_distance) {
261         selected_index = i;
262         selected_distance = candidate_distance;
263       }
264     }
265 
266     Tensor* output_sampled_index_tensor;
267     OP_REQUIRES_OK(context,
268                    context->allocate_output(0, TensorShape({}),
269                                             &output_sampled_index_tensor));
270     auto output = output_sampled_index_tensor->scalar<int64>();
271     // Return the last state of the Markov chain as the new center.
272     output() = selected_index;
273   }
274 };
275 
276 REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
277                         KMC2ChainInitializationOp);
278 
279 // Operator for computing the nearest neighbors for a set of points.
280 class NearestNeighborsOp : public OpKernel {
281  public:
NearestNeighborsOp(OpKernelConstruction * context)282   explicit NearestNeighborsOp(OpKernelConstruction* context)
283       : OpKernel(context) {
284     OP_REQUIRES_OK(context,
285                    context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
286                                            {DT_INT64, DT_FLOAT}));
287   }
288 
Compute(OpKernelContext * context)289   void Compute(OpKernelContext* context) override {
290     const Tensor& points_tensor = context->input(0);
291     const Tensor& centers_tensor = context->input(1);
292     const Tensor& k_tensor = context->input(2);
293 
294     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
295                 InvalidArgument("Input points should be a matrix."));
296     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
297                 InvalidArgument("Input centers should be a matrix."));
298     OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
299                 InvalidArgument("Input k should be a scalar."));
300 
301     const int64 num_points = points_tensor.dim_size(0);
302     const int64 point_dimensions = points_tensor.dim_size(1);
303     const int64 num_centers = centers_tensor.dim_size(0);
304     const int64 center_dimensions = centers_tensor.dim_size(1);
305 
306     OP_REQUIRES(context, num_points > 0,
307                 InvalidArgument("Expected points.rows() > 0."));
308     OP_REQUIRES(
309         context, point_dimensions == center_dimensions,
310         InvalidArgument("Expected point_dimensions == center_dimensions: ",
311                         point_dimensions, " vs ", center_dimensions, "."));
312 
313     const Eigen::Map<const MatrixXfRowMajor> points(
314         points_tensor.matrix<float>().data(), num_points, point_dimensions);
315     const Eigen::Map<const MatrixXfRowMajor> centers(
316         centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
317     const int64 k = std::min<int64>(num_centers, k_tensor.scalar<int64>()());
318 
319     Tensor* output_nearest_center_indices_tensor;
320     Tensor* output_nearest_center_distances_tensor;
321     OP_REQUIRES_OK(context, context->allocate_output(
322                                 0, TensorShape({num_points, k}),
323                                 &output_nearest_center_indices_tensor));
324     OP_REQUIRES_OK(context, context->allocate_output(
325                                 1, TensorShape({num_points, k}),
326                                 &output_nearest_center_distances_tensor));
327 
328     if (k == 0) return;
329 
330     Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
331         output_nearest_center_indices_tensor->matrix<int64>().data(),
332         num_points, k);
333     Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
334         output_nearest_center_distances_tensor->matrix<float>().data(),
335         num_points, k);
336 
337     const Eigen::VectorXf centers_half_squared_norm =
338         0.5 * centers.rowwise().squaredNorm();
339 
340     // The distance computation is sharded to take advantage of multiple cores
341     // and to allow intermediate values to reside in L3 cache. This is done by
342     // sharding the points and centers as follows:
343     //
344     // 1. Centers are sharded such that each block of centers has at most
345     //    kNearestNeighborsCentersMaxBlockSize rows.
346     // 2. Points are sharded, and each block of points is multiplied with each
347     //    block of centers. The block size of points is chosen such that the
348     //    point coordinates (point_dimensions) and the matrix of distances to
349     //    each center in one block -- the intermediate data -- fits in L3 cache.
350     // 3. After performing each block-block distance computation, the results
351     //    are reduced to a set of k nearest centers as soon as possible. This
352     //    decreases total memory I/O.
353     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
354     const int64 num_threads = worker_threads.num_threads;
355     // This kernel might be configured to use fewer than the total number of
356     // available CPUs on the host machine. To avoid descructive interference
357     // with other jobs running on the host machine, we must only use a fraction
358     // of total available L3 cache. Unfortunately, we cannot query the host
359     // machine to get the number of physical CPUs. So, we use a fixed per-CPU
360     // budget and scale it by the number of CPUs available to this operation.
361     const int64 total_memory_budget =
362         kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
363     // Compute the number of blocks into which rows of points must be split so
364     // that the distance matrix and the block of points can fit in cache. One
365     // row of points will yield a vector of distances to each center in a block.
366     const int64 bytes_per_row =
367         (std::min(kNearestNeighborsCentersMaxBlockSize,
368                   num_centers) /* centers in a block */
369          + point_dimensions /* coordinates of one point */) *
370         sizeof(float);
371     // The memory needed for storing the centers being processed. This is shared
372     // by all workers. Adding slack to the number of threads to avoid incorrect
373     // cache eviction when a new block of centers is loaded.
374     const int64 bytes_for_centers =
375         std::min(num_centers,
376                  (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
377         point_dimensions * sizeof(float);
378     // The memory budget available for workers to store their distance matrices.
379     const int64 available_memory_budget =
380         total_memory_budget - bytes_for_centers;
381     // That memory budget is shared by all threads.
382     const int64 rows_per_block =
383         std::max<int64>(kNearestNeighborsPointsMinBlockSize,
384                         available_memory_budget / num_threads / bytes_per_row);
385     // Divide rows into almost uniformly-sized units of work that are small
386     // enough for the memory budget (rows_per_block). Round up to a multiple of
387     // the number of threads.
388     const int64 num_units =
389         NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
390     auto work = [&](int64 start, int64 limit) {
391       for (; start < limit; ++start) {
392         const int64 start_row = num_points * start / num_units;
393         const int64 limit_row = num_points * (start + 1) / num_units;
394         CHECK_LE(limit_row, num_points);
395         const int64 num_rows = limit_row - start_row;
396         auto points_shard = points.middleRows(start_row, num_rows);
397         const Eigen::VectorXf points_half_squared_norm =
398             0.5 * points_shard.rowwise().squaredNorm();
399         auto nearest_center_indices_shard =
400             nearest_center_indices.middleRows(start_row, num_rows);
401         auto nearest_center_distances_shard =
402             nearest_center_distances.middleRows(start_row, num_rows);
403         FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
404                             centers_half_squared_norm,
405                             nearest_center_indices_shard,
406                             nearest_center_distances_shard);
407       }
408     };
409 
410     const int64 units_per_thread = num_units / num_threads;
411     BlockingCounter counter(num_threads - 1);
412     for (int64 i = 1; i < num_threads; ++i) {
413       const int64 start = i * units_per_thread;
414       const int64 limit = start + units_per_thread;
415       worker_threads.workers->Schedule([work, &counter, start, limit]() {
416         work(start, limit);
417         counter.DecrementCount();
418       });
419     }
420     work(0, units_per_thread);
421     counter.Wait();
422   }
423 
424  private:
FindKNearestCenters(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,const Eigen::Ref<MatrixXi64RowMajor> & nearest_center_indices,const Eigen::Ref<MatrixXfRowMajor> & nearest_center_distances)425   static void FindKNearestCenters(
426       int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
427       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
428       const Eigen::Ref<const MatrixXfRowMajor>& centers,
429       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
430       const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
431       const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
432     CHECK_LE(k, centers.rows());
433     if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
434       FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
435                                   centers_half_squared_norm,
436                                   nearest_center_indices,
437                                   nearest_center_distances);
438     } else {
439       FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
440                                    centers_half_squared_norm,
441                                    nearest_center_indices,
442                                    nearest_center_distances);
443     }
444   }
445 
FindKNearestCentersOneBlock(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)446   static void FindKNearestCentersOneBlock(
447       int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
448       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
449       const Eigen::Ref<const MatrixXfRowMajor>& centers,
450       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
451       Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
452       Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
453     CHECK_LE(k, centers.rows());
454     const int64 num_points = points.rows();
455     const MatrixXfRowMajor inner_product = points * centers.transpose();
456     // Find nearest neighbors.
457     if (k == 1) {
458       for (int i = 0; i < num_points; ++i) {
459         int64 index;
460         nearest_center_distances(i, 0) =
461             2.0 *
462             (points_half_squared_norm(i) +
463              (centers_half_squared_norm.transpose() - inner_product.row(i))
464                  .minCoeff(&index));
465         nearest_center_indices(i, 0) = index;
466       }
467     } else {
468       // Select k nearest centers for each point.
469       using Center = std::pair<float, int64>;
470       const int64 num_centers = centers.rows();
471       gtl::TopN<Center, std::less<Center>> selector(k);
472       std::unique_ptr<std::vector<Center>> nearest_centers;
473       for (int i = 0; i < num_points; ++i) {
474         selector.reserve(num_centers);
475         for (int j = 0; j < num_centers; ++j) {
476           const float partial_distance =
477               centers_half_squared_norm(j) - inner_product(i, j);
478           selector.push(Center(partial_distance, j));
479         }
480         nearest_centers.reset(selector.Extract());
481         selector.Reset();
482         const float point_half_squared_norm = points_half_squared_norm(i);
483         for (int j = 0; j < k; ++j) {
484           const Center& center = (*nearest_centers)[j];
485           nearest_center_distances(i, j) =
486               2.0 * (point_half_squared_norm + center.first);
487           nearest_center_indices(i, j) = center.second;
488         }
489       }
490     }
491   }
492 
FindKNearestCentersBlockwise(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)493   static void FindKNearestCentersBlockwise(
494       int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
495       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
496       const Eigen::Ref<const MatrixXfRowMajor>& centers,
497       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
498       Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
499       Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
500     const int64 num_points = points.rows();
501     const int64 num_centers = centers.rows();
502     CHECK_LE(k, num_centers);
503     CHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
504     // Store nearest neighbors with first block of centers directly into the
505     // output matrices.
506     int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
507     FindKNearestCentersOneBlock(
508         out_k, points, points_half_squared_norm,
509         centers.topRows(kNearestNeighborsCentersMaxBlockSize),
510         centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
511         nearest_center_indices, nearest_center_distances);
512     // Iteratively compute nearest neighbors with other blocks of centers, and
513     // update the output matrices.
514     MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
515     MatrixXfRowMajor block_nearest_center_distances(num_points, k);
516     Eigen::Matrix<int64, 1, Eigen::Dynamic> merged_indices(k);
517     Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
518     for (int64 centers_start = kNearestNeighborsCentersMaxBlockSize;
519          centers_start < num_centers;
520          centers_start += kNearestNeighborsCentersMaxBlockSize) {
521       const int64 centers_block_size = std::min(
522           kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
523       const int64 block_k = std::min(k, centers_block_size);
524       FindKNearestCentersOneBlock(
525           block_k, points, points_half_squared_norm,
526           centers.middleRows(centers_start, centers_block_size),
527           centers_half_squared_norm.segment(centers_start, centers_block_size),
528           block_nearest_center_indices, block_nearest_center_distances);
529       if (k == 1) {
530         for (int i = 0; i < num_points; ++i) {
531           if (block_nearest_center_distances(i, 0) <
532               nearest_center_distances(i, 0)) {
533             nearest_center_indices(i, 0) =
534                 block_nearest_center_indices(i, 0) + centers_start;
535             nearest_center_distances(i, 0) =
536                 block_nearest_center_distances(i, 0);
537           }
538         }
539       } else {
540         for (int i = 0; i < num_points; ++i) {
541           // Merge and accumulate top-k list from block_nearest_center_indices
542           // into nearest_center_indices.
543           for (int64 j_out = 0, j_block = 0, j_merged = 0;
544                (j_out < out_k || j_block < block_k) && j_merged < k;
545                ++j_merged) {
546             const float distance_out =
547                 j_out < out_k ? nearest_center_distances(i, j_out)
548                               : std::numeric_limits<float>::infinity();
549             const float distance_block =
550                 j_block < block_k ? block_nearest_center_distances(i, j_block)
551                                   : std::numeric_limits<float>::infinity();
552             if (distance_out <= distance_block) {
553               merged_indices(j_merged) = nearest_center_indices(i, j_out);
554               merged_distances(j_merged) = distance_out;
555               ++j_out;
556             } else {
557               merged_indices(j_merged) =
558                   block_nearest_center_indices(i, j_block) + centers_start;
559               merged_distances(j_merged) = distance_block;
560               ++j_block;
561             }
562           }
563           nearest_center_indices.row(i) = merged_indices;
564           nearest_center_distances.row(i) = merged_distances;
565           out_k = std::min(k, out_k + block_k);
566         }
567       }
568     }
569   }
570 };
571 
572 REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
573                         NearestNeighborsOp);
574 
575 }  // namespace tensorflow
576