• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License.  You may obtain a copy
5 // of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <tuple>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/top_n.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/simple_philox.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/cpu_info.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace tensorflow {
40 namespace {
41 using errors::InvalidArgument;
42 
43 template <typename Scalar>
44 using RowMajorMatrix =
45     Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
46 
47 using MatrixXfRowMajor = RowMajorMatrix<float>;
48 using MatrixXi64RowMajor = RowMajorMatrix<int64>;
49 
50 // Ideally this should be computed by dividing L3 cache size by the number of
51 // physical CPUs. Since there isn't a portable method to do this, we are using
52 // a conservative estimate here.
53 const int64 kDefaultL3CachePerCpu = 1 << 20;
54 
55 // These values were determined by performing a parameter sweep on the
56 // NearestNeighborsOp benchmark.
57 const int64 kNearestNeighborsCentersMaxBlockSize = 1024;
58 const int64 kNearestNeighborsPointsMinBlockSize = 16;
59 
60 // Returns the smallest multiple of a that is not smaller than b.
NextMultiple(int64 a,int64 b)61 int64 NextMultiple(int64 a, int64 b) {
62   const int64 remainder = b % a;
63   return remainder == 0 ? b : (b + a - remainder);
64 }
65 
66 // Returns a / b rounded up to the next higher integer.
CeilOfRatio(int64 a,int64 b)67 int64 CeilOfRatio(int64 a, int64 b) { return (a + b - 1) / b; }
68 
69 }  // namespace
70 
71 // Implementation of K-means++ initialization. Samples points iteratively in
72 // proportion to the squared distances from selected points.
73 // TODO(ands): Add support for other distance metrics.
74 class KmeansPlusPlusInitializationOp : public OpKernel {
75  public:
KmeansPlusPlusInitializationOp(OpKernelConstruction * context)76   explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
77       : OpKernel(context) {
78     OP_REQUIRES_OK(context,
79                    context->MatchSignature(
80                        {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
81   }
82 
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {
84     const Tensor& points_tensor = context->input(0);
85     const Tensor& num_to_sample_tensor = context->input(1);
86     const Tensor& seed_tensor = context->input(2);
87     const Tensor& num_retries_per_sample_tensor = context->input(3);
88 
89     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
90                 InvalidArgument("Input points should be a matrix."));
91     OP_REQUIRES(context,
92                 TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
93                 InvalidArgument("Input num_to_sample should be a scalar."));
94     OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
95                 InvalidArgument("Input seed should be a scalar."));
96     OP_REQUIRES(
97         context,
98         TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
99         InvalidArgument("Input num_retries_per_sample should be a scalar."));
100 
101     const int64 num_points = points_tensor.dim_size(0);
102     const int64 point_dimensions = points_tensor.dim_size(1);
103     const int64 num_to_sample = num_to_sample_tensor.scalar<int64>()();
104     const int64 seed = seed_tensor.scalar<int64>()();
105     const int64 num_retries_per_sample = [&]() {
106       const int64 value = num_retries_per_sample_tensor.scalar<int64>()();
107       return value >= 0 ? value
108                         : 2 + static_cast<int64>(std::log(num_to_sample));
109     }();
110 
111     OP_REQUIRES(context, num_points > 0,
112                 InvalidArgument("Expected points.rows() > 0."));
113     OP_REQUIRES(context, num_to_sample > 0,
114                 InvalidArgument("Expected num_to_sample > 0."));
115     OP_REQUIRES(context, num_to_sample <= num_points,
116                 InvalidArgument("Expected num_to_sample <= points.rows(). ",
117                                 num_to_sample, " vs ", num_points, "."));
118 
119     Tensor* output_sampled_points_tensor;
120     OP_REQUIRES_OK(context,
121                    context->allocate_output(
122                        0, TensorShape({num_to_sample, point_dimensions}),
123                        &output_sampled_points_tensor));
124 
125     const Eigen::Map<const MatrixXfRowMajor> points(
126         points_tensor.matrix<float>().data(), num_points, point_dimensions);
127     const Eigen::VectorXf points_half_squared_norm =
128         0.5 * points.rowwise().squaredNorm();
129 
130     Eigen::Map<MatrixXfRowMajor> sampled_points(
131         output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
132         point_dimensions);
133     std::unordered_set<int64> sampled_indices;
134 
135     random::PhiloxRandom random(seed);
136     random::SimplePhilox rng(&random);
137 
138     auto add_one_point = [&](int64 from, int64 to) {
139       from = std::min(from, num_points - 1);
140       sampled_points.row(to) = points.row(from);
141       sampled_indices.insert(from);
142     };
143 
144     // Distances from all points to nearest selected point. Initialize with
145     // distances to first selected point.
146     Eigen::VectorXf min_distances(num_points);
147     min_distances.fill(std::numeric_limits<float>::infinity());
148     Eigen::VectorXf min_distances_cumsum(num_points);
149 
150     auto draw_one_sample = [&]() -> int64 {
151       if (sampled_indices.empty()) return rng.Uniform64(num_points);
152       int64 index = 0;
153       do {
154         // If v is drawn from Uniform[0, distances.sum()), then
155         // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
156         // proportional to distances(i).
157         index = std::upper_bound(
158                     min_distances_cumsum.data(),
159                     min_distances_cumsum.data() + num_points,
160                     rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
161                 min_distances_cumsum.data();
162       } while (sampled_indices.find(index) != sampled_indices.end());
163       return index;
164     };
165 
166     auto sample_one_point = [&]() {
167       const int64 sampled_index = draw_one_sample();
168       min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
169           points, points_half_squared_norm, points.row(sampled_index),
170           points_half_squared_norm(sampled_index)));
171       return sampled_index;
172     };
173 
174     auto sample_one_point_with_retries = [&]() {
175       Eigen::VectorXf best_new_min_distances(num_points);
176       float best_potential = std::numeric_limits<float>::infinity();
177       int64 best_sampled_index = 0;
178       for (int i = 1 + num_retries_per_sample; i > 0; --i) {
179         const int64 sampled_index = draw_one_sample();
180         Eigen::VectorXf new_min_distances =
181             min_distances.cwiseMin(GetHalfSquaredDistancesToY(
182                 points, points_half_squared_norm, points.row(sampled_index),
183                 points_half_squared_norm(sampled_index)));
184         const float potential = new_min_distances.sum();
185         if (potential < best_potential) {
186           best_potential = potential;
187           best_sampled_index = sampled_index;
188           best_new_min_distances.swap(new_min_distances);
189         }
190       }
191       min_distances.swap(best_new_min_distances);
192       return best_sampled_index;
193     };
194 
195     for (int64 i = 0; i < num_to_sample; ++i) {
196       if (i > 0) {
197         std::partial_sum(min_distances.data(),
198                          min_distances.data() + num_points,
199                          min_distances_cumsum.data());
200       }
201       int64 next = num_retries_per_sample == 0
202                        ? sample_one_point()
203                        : sample_one_point_with_retries();
204       add_one_point(next, i);
205     }
206   }
207 
208  private:
209   // Returns a column vector with the i-th element set to half the squared
210   // euclidean distance between the i-th row of xs, and y. Precomputed norms for
211   // each row of xs and y must be provided for efficiency.
212   // TODO(ands): Parallelize this for large xs.
GetHalfSquaredDistancesToY(const Eigen::Ref<const MatrixXfRowMajor> & xs,const Eigen::Ref<const Eigen::VectorXf> & xs_half_squared_norm,const Eigen::Ref<const Eigen::RowVectorXf> & y,float y_half_squared_norm)213   static Eigen::VectorXf GetHalfSquaredDistancesToY(
214       const Eigen::Ref<const MatrixXfRowMajor>& xs,
215       const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
216       const Eigen::Ref<const Eigen::RowVectorXf>& y,
217       float y_half_squared_norm) {
218     // Squared distance between points xs_i and y is:
219     //   || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
220     return (xs_half_squared_norm - xs * y.transpose()).array() +
221            y_half_squared_norm;
222   }
223 };
224 
225 REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
226                         KmeansPlusPlusInitializationOp);
227 
228 // Implementation of one single Markov Chain for the k-MC^2 algorithm
229 class KMC2ChainInitializationOp : public OpKernel {
230  public:
KMC2ChainInitializationOp(OpKernelConstruction * context)231   explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
232       : OpKernel(context) {
233     OP_REQUIRES_OK(context,
234                    context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
235   }
236 
Compute(OpKernelContext * context)237   void Compute(OpKernelContext* context) override {
238     const Tensor& distances_tensor = context->input(0);
239     const Tensor& seed_tensor = context->input(1);
240     OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
241                 InvalidArgument("Input distances should be a vector."));
242     OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
243                 InvalidArgument("Input seed should be a scalar."));
244     const int64 num_points = distances_tensor.dim_size(0);
245     const int64 seed = seed_tensor.scalar<int64>()();
246     OP_REQUIRES(context, num_points > 0,
247                 InvalidArgument("Expected distances_tensor.size() > 0."));
248 
249     random::PhiloxRandom random(seed);
250     random::SimplePhilox rng(&random);
251 
252     auto distances = distances_tensor.flat<float>();
253     // Set the initial state of the Markov chain to be the first candidate.
254     int64 selected_index = 0;
255     float selected_distance = distances(selected_index);
256     // Build a Markov chain of length num_points.
257     for (int64 i = 1; i < num_points; ++i) {
258       const float candidate_distance = distances(i);
259       // Set the next state of the Markov chain to be the candidate with
260       // probability min(1, candidate_distance/selected_distance).
261       if (candidate_distance > rng.RandFloat() * selected_distance) {
262         selected_index = i;
263         selected_distance = candidate_distance;
264       }
265     }
266 
267     Tensor* output_sampled_index_tensor;
268     OP_REQUIRES_OK(context,
269                    context->allocate_output(0, TensorShape({}),
270                                             &output_sampled_index_tensor));
271     auto output = output_sampled_index_tensor->scalar<int64>();
272     // Return the last state of the Markov chain as the new center.
273     output() = selected_index;
274   }
275 };
276 
277 REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
278                         KMC2ChainInitializationOp);
279 
280 // Operator for computing the nearest neighbors for a set of points.
281 class NearestNeighborsOp : public OpKernel {
282  public:
NearestNeighborsOp(OpKernelConstruction * context)283   explicit NearestNeighborsOp(OpKernelConstruction* context)
284       : OpKernel(context) {
285     OP_REQUIRES_OK(context,
286                    context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
287                                            {DT_INT64, DT_FLOAT}));
288   }
289 
Compute(OpKernelContext * context)290   void Compute(OpKernelContext* context) override {
291     const Tensor& points_tensor = context->input(0);
292     const Tensor& centers_tensor = context->input(1);
293     const Tensor& k_tensor = context->input(2);
294 
295     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
296                 InvalidArgument("Input points should be a matrix."));
297     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
298                 InvalidArgument("Input centers should be a matrix."));
299     OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
300                 InvalidArgument("Input k should be a scalar."));
301 
302     const int64 num_points = points_tensor.dim_size(0);
303     const int64 point_dimensions = points_tensor.dim_size(1);
304     const int64 num_centers = centers_tensor.dim_size(0);
305     const int64 center_dimensions = centers_tensor.dim_size(1);
306 
307     OP_REQUIRES(context, num_points > 0,
308                 InvalidArgument("Expected points.rows() > 0."));
309     OP_REQUIRES(
310         context, point_dimensions == center_dimensions,
311         InvalidArgument("Expected point_dimensions == center_dimensions: ",
312                         point_dimensions, " vs ", center_dimensions, "."));
313 
314     const Eigen::Map<const MatrixXfRowMajor> points(
315         points_tensor.matrix<float>().data(), num_points, point_dimensions);
316     const Eigen::Map<const MatrixXfRowMajor> centers(
317         centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
318     const int64 k = std::min<int64>(num_centers, k_tensor.scalar<int64>()());
319 
320     Tensor* output_nearest_center_indices_tensor;
321     Tensor* output_nearest_center_distances_tensor;
322     OP_REQUIRES_OK(context, context->allocate_output(
323                                 0, TensorShape({num_points, k}),
324                                 &output_nearest_center_indices_tensor));
325     OP_REQUIRES_OK(context, context->allocate_output(
326                                 1, TensorShape({num_points, k}),
327                                 &output_nearest_center_distances_tensor));
328 
329     if (k == 0) return;
330 
331     Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
332         output_nearest_center_indices_tensor->matrix<int64>().data(),
333         num_points, k);
334     Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
335         output_nearest_center_distances_tensor->matrix<float>().data(),
336         num_points, k);
337 
338     const Eigen::VectorXf centers_half_squared_norm =
339         0.5 * centers.rowwise().squaredNorm();
340 
341     // The distance computation is sharded to take advantage of multiple cores
342     // and to allow intermediate values to reside in L3 cache. This is done by
343     // sharding the points and centers as follows:
344     //
345     // 1. Centers are sharded such that each block of centers has at most
346     //    kNearestNeighborsCentersMaxBlockSize rows.
347     // 2. Points are sharded, and each block of points is multiplied with each
348     //    block of centers. The block size of points is chosen such that the
349     //    point coordinates (point_dimensions) and the matrix of distances to
350     //    each center in one block -- the intermediate data -- fits in L3 cache.
351     // 3. After performing each block-block distance computation, the results
352     //    are reduced to a set of k nearest centers as soon as possible. This
353     //    decreases total memory I/O.
354     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
355     const int64 num_threads = worker_threads.num_threads;
356     // This kernel might be configured to use fewer than the total number of
357     // available CPUs on the host machine. To avoid destructive interference
358     // with other jobs running on the host machine, we must only use a fraction
359     // of total available L3 cache. Unfortunately, we cannot query the host
360     // machine to get the number of physical CPUs. So, we use a fixed per-CPU
361     // budget and scale it by the number of CPUs available to this operation.
362     const int64 total_memory_budget =
363         kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
364     // Compute the number of blocks into which rows of points must be split so
365     // that the distance matrix and the block of points can fit in cache. One
366     // row of points will yield a vector of distances to each center in a block.
367     const int64 bytes_per_row =
368         (std::min(kNearestNeighborsCentersMaxBlockSize,
369                   num_centers) /* centers in a block */
370          + point_dimensions /* coordinates of one point */) *
371         sizeof(float);
372     // The memory needed for storing the centers being processed. This is shared
373     // by all workers. Adding slack to the number of threads to avoid incorrect
374     // cache eviction when a new block of centers is loaded.
375     const int64 bytes_for_centers =
376         std::min(num_centers,
377                  (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
378         point_dimensions * sizeof(float);
379     // The memory budget available for workers to store their distance matrices.
380     const int64 available_memory_budget =
381         total_memory_budget - bytes_for_centers;
382     // That memory budget is shared by all threads.
383     const int64 rows_per_block =
384         std::max<int64>(kNearestNeighborsPointsMinBlockSize,
385                         available_memory_budget / num_threads / bytes_per_row);
386     // Divide rows into almost uniformly-sized units of work that are small
387     // enough for the memory budget (rows_per_block). Round up to a multiple of
388     // the number of threads.
389     const int64 num_units =
390         NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
391     auto work = [&](int64 start, int64 limit) {
392       for (; start < limit; ++start) {
393         const int64 start_row = num_points * start / num_units;
394         const int64 limit_row = num_points * (start + 1) / num_units;
395         DCHECK_LE(limit_row, num_points);
396         const int64 num_rows = limit_row - start_row;
397         auto points_shard = points.middleRows(start_row, num_rows);
398         const Eigen::VectorXf points_half_squared_norm =
399             0.5 * points_shard.rowwise().squaredNorm();
400         auto nearest_center_indices_shard =
401             nearest_center_indices.middleRows(start_row, num_rows);
402         auto nearest_center_distances_shard =
403             nearest_center_distances.middleRows(start_row, num_rows);
404         FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
405                             centers_half_squared_norm,
406                             nearest_center_indices_shard,
407                             nearest_center_distances_shard);
408       }
409     };
410 
411     const int64 units_per_thread = num_units / num_threads;
412     BlockingCounter counter(num_threads - 1);
413     for (int64 i = 1; i < num_threads; ++i) {
414       const int64 start = i * units_per_thread;
415       const int64 limit = start + units_per_thread;
416       worker_threads.workers->Schedule([work, &counter, start, limit]() {
417         work(start, limit);
418         counter.DecrementCount();
419       });
420     }
421     work(0, units_per_thread);
422     counter.Wait();
423   }
424 
425  private:
FindKNearestCenters(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,const Eigen::Ref<MatrixXi64RowMajor> & nearest_center_indices,const Eigen::Ref<MatrixXfRowMajor> & nearest_center_distances)426   static void FindKNearestCenters(
427       int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
428       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
429       const Eigen::Ref<const MatrixXfRowMajor>& centers,
430       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
431       const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
432       const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
433     DCHECK_LE(k, centers.rows());
434     if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
435       FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
436                                   centers_half_squared_norm,
437                                   nearest_center_indices,
438                                   nearest_center_distances);
439     } else {
440       FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
441                                    centers_half_squared_norm,
442                                    nearest_center_indices,
443                                    nearest_center_distances);
444     }
445   }
446 
FindKNearestCentersOneBlock(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)447   static void FindKNearestCentersOneBlock(
448       int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
449       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
450       const Eigen::Ref<const MatrixXfRowMajor>& centers,
451       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
452       Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
453       Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
454     DCHECK_LE(k, centers.rows());
455     const int64 num_points = points.rows();
456     const MatrixXfRowMajor inner_product = points * centers.transpose();
457     // Find nearest neighbors.
458     if (k == 1) {
459       for (int i = 0; i < num_points; ++i) {
460         int64 index;
461         nearest_center_distances(i, 0) =
462             2.0 *
463             (points_half_squared_norm(i) +
464              (centers_half_squared_norm.transpose() - inner_product.row(i))
465                  .minCoeff(&index));
466         nearest_center_indices(i, 0) = index;
467       }
468     } else {
469       // Select k nearest centers for each point.
470       using Center = std::pair<float, int64>;
471       const int64 num_centers = centers.rows();
472       gtl::TopN<Center, std::less<Center>> selector(k);
473       std::unique_ptr<std::vector<Center>> nearest_centers;
474       for (int i = 0; i < num_points; ++i) {
475         selector.reserve(num_centers);
476         for (int j = 0; j < num_centers; ++j) {
477           const float partial_distance =
478               centers_half_squared_norm(j) - inner_product(i, j);
479           selector.push(Center(partial_distance, j));
480         }
481         nearest_centers.reset(selector.Extract());
482         selector.Reset();
483         const float point_half_squared_norm = points_half_squared_norm(i);
484         for (int j = 0; j < k; ++j) {
485           const Center& center = (*nearest_centers)[j];
486           nearest_center_distances(i, j) =
487               2.0 * (point_half_squared_norm + center.first);
488           nearest_center_indices(i, j) = center.second;
489         }
490       }
491     }
492   }
493 
FindKNearestCentersBlockwise(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)494   static void FindKNearestCentersBlockwise(
495       int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
496       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
497       const Eigen::Ref<const MatrixXfRowMajor>& centers,
498       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
499       Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
500       Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
501     const int64 num_points = points.rows();
502     const int64 num_centers = centers.rows();
503     DCHECK_LE(k, num_centers);
504     DCHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
505     // Store nearest neighbors with first block of centers directly into the
506     // output matrices.
507     int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
508     FindKNearestCentersOneBlock(
509         out_k, points, points_half_squared_norm,
510         centers.topRows(kNearestNeighborsCentersMaxBlockSize),
511         centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
512         nearest_center_indices, nearest_center_distances);
513     // Iteratively compute nearest neighbors with other blocks of centers, and
514     // update the output matrices.
515     MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
516     MatrixXfRowMajor block_nearest_center_distances(num_points, k);
517     Eigen::Matrix<int64, 1, Eigen::Dynamic> merged_indices(k);
518     Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
519     for (int64 centers_start = kNearestNeighborsCentersMaxBlockSize;
520          centers_start < num_centers;
521          centers_start += kNearestNeighborsCentersMaxBlockSize) {
522       const int64 centers_block_size = std::min(
523           kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
524       const int64 block_k = std::min(k, centers_block_size);
525       FindKNearestCentersOneBlock(
526           block_k, points, points_half_squared_norm,
527           centers.middleRows(centers_start, centers_block_size),
528           centers_half_squared_norm.segment(centers_start, centers_block_size),
529           block_nearest_center_indices, block_nearest_center_distances);
530       if (k == 1) {
531         for (int i = 0; i < num_points; ++i) {
532           if (block_nearest_center_distances(i, 0) <
533               nearest_center_distances(i, 0)) {
534             nearest_center_indices(i, 0) =
535                 block_nearest_center_indices(i, 0) + centers_start;
536             nearest_center_distances(i, 0) =
537                 block_nearest_center_distances(i, 0);
538           }
539         }
540       } else {
541         for (int i = 0; i < num_points; ++i) {
542           // Merge and accumulate top-k list from block_nearest_center_indices
543           // into nearest_center_indices.
544           for (int64 j_out = 0, j_block = 0, j_merged = 0;
545                (j_out < out_k || j_block < block_k) && j_merged < k;
546                ++j_merged) {
547             const float distance_out =
548                 j_out < out_k ? nearest_center_distances(i, j_out)
549                               : std::numeric_limits<float>::infinity();
550             const float distance_block =
551                 j_block < block_k ? block_nearest_center_distances(i, j_block)
552                                   : std::numeric_limits<float>::infinity();
553             if (distance_out <= distance_block) {
554               merged_indices(j_merged) = nearest_center_indices(i, j_out);
555               merged_distances(j_merged) = distance_out;
556               ++j_out;
557             } else {
558               merged_indices(j_merged) =
559                   block_nearest_center_indices(i, j_block) + centers_start;
560               merged_distances(j_merged) = distance_block;
561               ++j_block;
562             }
563           }
564           nearest_center_indices.row(i) = merged_indices;
565           nearest_center_distances.row(i) = merged_distances;
566           out_k = std::min(k, out_k + block_k);
567         }
568       }
569     }
570   }
571 };
572 
573 REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
574                         NearestNeighborsOp);
575 
576 }  // namespace tensorflow
577