1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy
5 // of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15
16 #define EIGEN_USE_THREADS
17
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <tuple>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/top_n.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/simple_philox.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/cpu_info.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace tensorflow {
40 namespace {
41 using errors::InvalidArgument;
42
43 template <typename Scalar>
44 using RowMajorMatrix =
45 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
46
47 using MatrixXfRowMajor = RowMajorMatrix<float>;
48 using MatrixXi64RowMajor = RowMajorMatrix<int64>;
49
50 // Ideally this should be computed by dividing L3 cache size by the number of
51 // physical CPUs. Since there isn't a portable method to do this, we are using
52 // a conservative estimate here.
53 const int64 kDefaultL3CachePerCpu = 1 << 20;
54
55 // These values were determined by performing a parameter sweep on the
56 // NearestNeighborsOp benchmark.
57 const int64 kNearestNeighborsCentersMaxBlockSize = 1024;
58 const int64 kNearestNeighborsPointsMinBlockSize = 16;
59
60 // Returns the smallest multiple of a that is not smaller than b.
NextMultiple(int64 a,int64 b)61 int64 NextMultiple(int64 a, int64 b) {
62 const int64 remainder = b % a;
63 return remainder == 0 ? b : (b + a - remainder);
64 }
65
66 // Returns a / b rounded up to the next higher integer.
CeilOfRatio(int64 a,int64 b)67 int64 CeilOfRatio(int64 a, int64 b) { return (a + b - 1) / b; }
68
69 } // namespace
70
71 // Implementation of K-means++ initialization. Samples points iteratively in
72 // proportion to the squared distances from selected points.
73 // TODO(ands): Add support for other distance metrics.
74 class KmeansPlusPlusInitializationOp : public OpKernel {
75 public:
KmeansPlusPlusInitializationOp(OpKernelConstruction * context)76 explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
77 : OpKernel(context) {
78 OP_REQUIRES_OK(context,
79 context->MatchSignature(
80 {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
81 }
82
Compute(OpKernelContext * context)83 void Compute(OpKernelContext* context) override {
84 const Tensor& points_tensor = context->input(0);
85 const Tensor& num_to_sample_tensor = context->input(1);
86 const Tensor& seed_tensor = context->input(2);
87 const Tensor& num_retries_per_sample_tensor = context->input(3);
88
89 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
90 InvalidArgument("Input points should be a matrix."));
91 OP_REQUIRES(context,
92 TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
93 InvalidArgument("Input num_to_sample should be a scalar."));
94 OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
95 InvalidArgument("Input seed should be a scalar."));
96 OP_REQUIRES(
97 context,
98 TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
99 InvalidArgument("Input num_retries_per_sample should be a scalar."));
100
101 const int64 num_points = points_tensor.dim_size(0);
102 const int64 point_dimensions = points_tensor.dim_size(1);
103 const int64 num_to_sample = num_to_sample_tensor.scalar<int64>()();
104 const int64 seed = seed_tensor.scalar<int64>()();
105 const int64 num_retries_per_sample = [&]() {
106 const int64 value = num_retries_per_sample_tensor.scalar<int64>()();
107 return value >= 0 ? value
108 : 2 + static_cast<int64>(std::log(num_to_sample));
109 }();
110
111 OP_REQUIRES(context, num_points > 0,
112 InvalidArgument("Expected points.rows() > 0."));
113 OP_REQUIRES(context, num_to_sample > 0,
114 InvalidArgument("Expected num_to_sample > 0."));
115 OP_REQUIRES(context, num_to_sample <= num_points,
116 InvalidArgument("Expected num_to_sample <= points.rows(). ",
117 num_to_sample, " vs ", num_points, "."));
118
119 Tensor* output_sampled_points_tensor;
120 OP_REQUIRES_OK(context,
121 context->allocate_output(
122 0, TensorShape({num_to_sample, point_dimensions}),
123 &output_sampled_points_tensor));
124
125 const Eigen::Map<const MatrixXfRowMajor> points(
126 points_tensor.matrix<float>().data(), num_points, point_dimensions);
127 const Eigen::VectorXf points_half_squared_norm =
128 0.5 * points.rowwise().squaredNorm();
129
130 Eigen::Map<MatrixXfRowMajor> sampled_points(
131 output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
132 point_dimensions);
133 std::unordered_set<int64> sampled_indices;
134
135 random::PhiloxRandom random(seed);
136 random::SimplePhilox rng(&random);
137
138 auto add_one_point = [&](int64 from, int64 to) {
139 from = std::min(from, num_points - 1);
140 sampled_points.row(to) = points.row(from);
141 sampled_indices.insert(from);
142 };
143
144 // Distances from all points to nearest selected point. Initialize with
145 // distances to first selected point.
146 Eigen::VectorXf min_distances(num_points);
147 min_distances.fill(std::numeric_limits<float>::infinity());
148 Eigen::VectorXf min_distances_cumsum(num_points);
149
150 auto draw_one_sample = [&]() -> int64 {
151 if (sampled_indices.empty()) return rng.Uniform64(num_points);
152 int64 index = 0;
153 do {
154 // If v is drawn from Uniform[0, distances.sum()), then
155 // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
156 // proportional to distances(i).
157 index = std::upper_bound(
158 min_distances_cumsum.data(),
159 min_distances_cumsum.data() + num_points,
160 rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
161 min_distances_cumsum.data();
162 } while (sampled_indices.find(index) != sampled_indices.end());
163 return index;
164 };
165
166 auto sample_one_point = [&]() {
167 const int64 sampled_index = draw_one_sample();
168 min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
169 points, points_half_squared_norm, points.row(sampled_index),
170 points_half_squared_norm(sampled_index)));
171 return sampled_index;
172 };
173
174 auto sample_one_point_with_retries = [&]() {
175 Eigen::VectorXf best_new_min_distances(num_points);
176 float best_potential = std::numeric_limits<float>::infinity();
177 int64 best_sampled_index = 0;
178 for (int i = 1 + num_retries_per_sample; i > 0; --i) {
179 const int64 sampled_index = draw_one_sample();
180 Eigen::VectorXf new_min_distances =
181 min_distances.cwiseMin(GetHalfSquaredDistancesToY(
182 points, points_half_squared_norm, points.row(sampled_index),
183 points_half_squared_norm(sampled_index)));
184 const float potential = new_min_distances.sum();
185 if (potential < best_potential) {
186 best_potential = potential;
187 best_sampled_index = sampled_index;
188 best_new_min_distances.swap(new_min_distances);
189 }
190 }
191 min_distances.swap(best_new_min_distances);
192 return best_sampled_index;
193 };
194
195 for (int64 i = 0; i < num_to_sample; ++i) {
196 if (i > 0) {
197 std::partial_sum(min_distances.data(),
198 min_distances.data() + num_points,
199 min_distances_cumsum.data());
200 }
201 int64 next = num_retries_per_sample == 0
202 ? sample_one_point()
203 : sample_one_point_with_retries();
204 add_one_point(next, i);
205 }
206 }
207
208 private:
209 // Returns a column vector with the i-th element set to half the squared
210 // euclidean distance between the i-th row of xs, and y. Precomputed norms for
211 // each row of xs and y must be provided for efficiency.
212 // TODO(ands): Parallelize this for large xs.
GetHalfSquaredDistancesToY(const Eigen::Ref<const MatrixXfRowMajor> & xs,const Eigen::Ref<const Eigen::VectorXf> & xs_half_squared_norm,const Eigen::Ref<const Eigen::RowVectorXf> & y,float y_half_squared_norm)213 static Eigen::VectorXf GetHalfSquaredDistancesToY(
214 const Eigen::Ref<const MatrixXfRowMajor>& xs,
215 const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
216 const Eigen::Ref<const Eigen::RowVectorXf>& y,
217 float y_half_squared_norm) {
218 // Squared distance between points xs_i and y is:
219 // || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
220 return (xs_half_squared_norm - xs * y.transpose()).array() +
221 y_half_squared_norm;
222 }
223 };
224
225 REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
226 KmeansPlusPlusInitializationOp);
227
228 // Implementation of one single Markov Chain for the k-MC^2 algorithm
229 class KMC2ChainInitializationOp : public OpKernel {
230 public:
KMC2ChainInitializationOp(OpKernelConstruction * context)231 explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
232 : OpKernel(context) {
233 OP_REQUIRES_OK(context,
234 context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
235 }
236
Compute(OpKernelContext * context)237 void Compute(OpKernelContext* context) override {
238 const Tensor& distances_tensor = context->input(0);
239 const Tensor& seed_tensor = context->input(1);
240 OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
241 InvalidArgument("Input distances should be a vector."));
242 OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
243 InvalidArgument("Input seed should be a scalar."));
244 const int64 num_points = distances_tensor.dim_size(0);
245 const int64 seed = seed_tensor.scalar<int64>()();
246 OP_REQUIRES(context, num_points > 0,
247 InvalidArgument("Expected distances_tensor.size() > 0."));
248
249 random::PhiloxRandom random(seed);
250 random::SimplePhilox rng(&random);
251
252 auto distances = distances_tensor.flat<float>();
253 // Set the initial state of the Markov chain to be the first candidate.
254 int64 selected_index = 0;
255 float selected_distance = distances(selected_index);
256 // Build a Markov chain of length num_points.
257 for (int64 i = 1; i < num_points; ++i) {
258 const float candidate_distance = distances(i);
259 // Set the next state of the Markov chain to be the candidate with
260 // probability min(1, candidate_distance/selected_distance).
261 if (candidate_distance > rng.RandFloat() * selected_distance) {
262 selected_index = i;
263 selected_distance = candidate_distance;
264 }
265 }
266
267 Tensor* output_sampled_index_tensor;
268 OP_REQUIRES_OK(context,
269 context->allocate_output(0, TensorShape({}),
270 &output_sampled_index_tensor));
271 auto output = output_sampled_index_tensor->scalar<int64>();
272 // Return the last state of the Markov chain as the new center.
273 output() = selected_index;
274 }
275 };
276
277 REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
278 KMC2ChainInitializationOp);
279
280 // Operator for computing the nearest neighbors for a set of points.
281 class NearestNeighborsOp : public OpKernel {
282 public:
NearestNeighborsOp(OpKernelConstruction * context)283 explicit NearestNeighborsOp(OpKernelConstruction* context)
284 : OpKernel(context) {
285 OP_REQUIRES_OK(context,
286 context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
287 {DT_INT64, DT_FLOAT}));
288 }
289
Compute(OpKernelContext * context)290 void Compute(OpKernelContext* context) override {
291 const Tensor& points_tensor = context->input(0);
292 const Tensor& centers_tensor = context->input(1);
293 const Tensor& k_tensor = context->input(2);
294
295 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
296 InvalidArgument("Input points should be a matrix."));
297 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
298 InvalidArgument("Input centers should be a matrix."));
299 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
300 InvalidArgument("Input k should be a scalar."));
301
302 const int64 num_points = points_tensor.dim_size(0);
303 const int64 point_dimensions = points_tensor.dim_size(1);
304 const int64 num_centers = centers_tensor.dim_size(0);
305 const int64 center_dimensions = centers_tensor.dim_size(1);
306
307 OP_REQUIRES(context, num_points > 0,
308 InvalidArgument("Expected points.rows() > 0."));
309 OP_REQUIRES(
310 context, point_dimensions == center_dimensions,
311 InvalidArgument("Expected point_dimensions == center_dimensions: ",
312 point_dimensions, " vs ", center_dimensions, "."));
313
314 const Eigen::Map<const MatrixXfRowMajor> points(
315 points_tensor.matrix<float>().data(), num_points, point_dimensions);
316 const Eigen::Map<const MatrixXfRowMajor> centers(
317 centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
318 const int64 k = std::min<int64>(num_centers, k_tensor.scalar<int64>()());
319
320 Tensor* output_nearest_center_indices_tensor;
321 Tensor* output_nearest_center_distances_tensor;
322 OP_REQUIRES_OK(context, context->allocate_output(
323 0, TensorShape({num_points, k}),
324 &output_nearest_center_indices_tensor));
325 OP_REQUIRES_OK(context, context->allocate_output(
326 1, TensorShape({num_points, k}),
327 &output_nearest_center_distances_tensor));
328
329 if (k == 0) return;
330
331 Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
332 output_nearest_center_indices_tensor->matrix<int64>().data(),
333 num_points, k);
334 Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
335 output_nearest_center_distances_tensor->matrix<float>().data(),
336 num_points, k);
337
338 const Eigen::VectorXf centers_half_squared_norm =
339 0.5 * centers.rowwise().squaredNorm();
340
341 // The distance computation is sharded to take advantage of multiple cores
342 // and to allow intermediate values to reside in L3 cache. This is done by
343 // sharding the points and centers as follows:
344 //
345 // 1. Centers are sharded such that each block of centers has at most
346 // kNearestNeighborsCentersMaxBlockSize rows.
347 // 2. Points are sharded, and each block of points is multiplied with each
348 // block of centers. The block size of points is chosen such that the
349 // point coordinates (point_dimensions) and the matrix of distances to
350 // each center in one block -- the intermediate data -- fits in L3 cache.
351 // 3. After performing each block-block distance computation, the results
352 // are reduced to a set of k nearest centers as soon as possible. This
353 // decreases total memory I/O.
354 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
355 const int64 num_threads = worker_threads.num_threads;
356 // This kernel might be configured to use fewer than the total number of
357 // available CPUs on the host machine. To avoid destructive interference
358 // with other jobs running on the host machine, we must only use a fraction
359 // of total available L3 cache. Unfortunately, we cannot query the host
360 // machine to get the number of physical CPUs. So, we use a fixed per-CPU
361 // budget and scale it by the number of CPUs available to this operation.
362 const int64 total_memory_budget =
363 kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
364 // Compute the number of blocks into which rows of points must be split so
365 // that the distance matrix and the block of points can fit in cache. One
366 // row of points will yield a vector of distances to each center in a block.
367 const int64 bytes_per_row =
368 (std::min(kNearestNeighborsCentersMaxBlockSize,
369 num_centers) /* centers in a block */
370 + point_dimensions /* coordinates of one point */) *
371 sizeof(float);
372 // The memory needed for storing the centers being processed. This is shared
373 // by all workers. Adding slack to the number of threads to avoid incorrect
374 // cache eviction when a new block of centers is loaded.
375 const int64 bytes_for_centers =
376 std::min(num_centers,
377 (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
378 point_dimensions * sizeof(float);
379 // The memory budget available for workers to store their distance matrices.
380 const int64 available_memory_budget =
381 total_memory_budget - bytes_for_centers;
382 // That memory budget is shared by all threads.
383 const int64 rows_per_block =
384 std::max<int64>(kNearestNeighborsPointsMinBlockSize,
385 available_memory_budget / num_threads / bytes_per_row);
386 // Divide rows into almost uniformly-sized units of work that are small
387 // enough for the memory budget (rows_per_block). Round up to a multiple of
388 // the number of threads.
389 const int64 num_units =
390 NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
391 auto work = [&](int64 start, int64 limit) {
392 for (; start < limit; ++start) {
393 const int64 start_row = num_points * start / num_units;
394 const int64 limit_row = num_points * (start + 1) / num_units;
395 DCHECK_LE(limit_row, num_points);
396 const int64 num_rows = limit_row - start_row;
397 auto points_shard = points.middleRows(start_row, num_rows);
398 const Eigen::VectorXf points_half_squared_norm =
399 0.5 * points_shard.rowwise().squaredNorm();
400 auto nearest_center_indices_shard =
401 nearest_center_indices.middleRows(start_row, num_rows);
402 auto nearest_center_distances_shard =
403 nearest_center_distances.middleRows(start_row, num_rows);
404 FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
405 centers_half_squared_norm,
406 nearest_center_indices_shard,
407 nearest_center_distances_shard);
408 }
409 };
410
411 const int64 units_per_thread = num_units / num_threads;
412 BlockingCounter counter(num_threads - 1);
413 for (int64 i = 1; i < num_threads; ++i) {
414 const int64 start = i * units_per_thread;
415 const int64 limit = start + units_per_thread;
416 worker_threads.workers->Schedule([work, &counter, start, limit]() {
417 work(start, limit);
418 counter.DecrementCount();
419 });
420 }
421 work(0, units_per_thread);
422 counter.Wait();
423 }
424
425 private:
FindKNearestCenters(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,const Eigen::Ref<MatrixXi64RowMajor> & nearest_center_indices,const Eigen::Ref<MatrixXfRowMajor> & nearest_center_distances)426 static void FindKNearestCenters(
427 int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
428 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
429 const Eigen::Ref<const MatrixXfRowMajor>& centers,
430 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
431 const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
432 const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
433 DCHECK_LE(k, centers.rows());
434 if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
435 FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
436 centers_half_squared_norm,
437 nearest_center_indices,
438 nearest_center_distances);
439 } else {
440 FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
441 centers_half_squared_norm,
442 nearest_center_indices,
443 nearest_center_distances);
444 }
445 }
446
FindKNearestCentersOneBlock(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)447 static void FindKNearestCentersOneBlock(
448 int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
449 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
450 const Eigen::Ref<const MatrixXfRowMajor>& centers,
451 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
452 Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
453 Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
454 DCHECK_LE(k, centers.rows());
455 const int64 num_points = points.rows();
456 const MatrixXfRowMajor inner_product = points * centers.transpose();
457 // Find nearest neighbors.
458 if (k == 1) {
459 for (int i = 0; i < num_points; ++i) {
460 int64 index;
461 nearest_center_distances(i, 0) =
462 2.0 *
463 (points_half_squared_norm(i) +
464 (centers_half_squared_norm.transpose() - inner_product.row(i))
465 .minCoeff(&index));
466 nearest_center_indices(i, 0) = index;
467 }
468 } else {
469 // Select k nearest centers for each point.
470 using Center = std::pair<float, int64>;
471 const int64 num_centers = centers.rows();
472 gtl::TopN<Center, std::less<Center>> selector(k);
473 std::unique_ptr<std::vector<Center>> nearest_centers;
474 for (int i = 0; i < num_points; ++i) {
475 selector.reserve(num_centers);
476 for (int j = 0; j < num_centers; ++j) {
477 const float partial_distance =
478 centers_half_squared_norm(j) - inner_product(i, j);
479 selector.push(Center(partial_distance, j));
480 }
481 nearest_centers.reset(selector.Extract());
482 selector.Reset();
483 const float point_half_squared_norm = points_half_squared_norm(i);
484 for (int j = 0; j < k; ++j) {
485 const Center& center = (*nearest_centers)[j];
486 nearest_center_distances(i, j) =
487 2.0 * (point_half_squared_norm + center.first);
488 nearest_center_indices(i, j) = center.second;
489 }
490 }
491 }
492 }
493
FindKNearestCentersBlockwise(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)494 static void FindKNearestCentersBlockwise(
495 int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
496 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
497 const Eigen::Ref<const MatrixXfRowMajor>& centers,
498 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
499 Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
500 Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
501 const int64 num_points = points.rows();
502 const int64 num_centers = centers.rows();
503 DCHECK_LE(k, num_centers);
504 DCHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
505 // Store nearest neighbors with first block of centers directly into the
506 // output matrices.
507 int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
508 FindKNearestCentersOneBlock(
509 out_k, points, points_half_squared_norm,
510 centers.topRows(kNearestNeighborsCentersMaxBlockSize),
511 centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
512 nearest_center_indices, nearest_center_distances);
513 // Iteratively compute nearest neighbors with other blocks of centers, and
514 // update the output matrices.
515 MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
516 MatrixXfRowMajor block_nearest_center_distances(num_points, k);
517 Eigen::Matrix<int64, 1, Eigen::Dynamic> merged_indices(k);
518 Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
519 for (int64 centers_start = kNearestNeighborsCentersMaxBlockSize;
520 centers_start < num_centers;
521 centers_start += kNearestNeighborsCentersMaxBlockSize) {
522 const int64 centers_block_size = std::min(
523 kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
524 const int64 block_k = std::min(k, centers_block_size);
525 FindKNearestCentersOneBlock(
526 block_k, points, points_half_squared_norm,
527 centers.middleRows(centers_start, centers_block_size),
528 centers_half_squared_norm.segment(centers_start, centers_block_size),
529 block_nearest_center_indices, block_nearest_center_distances);
530 if (k == 1) {
531 for (int i = 0; i < num_points; ++i) {
532 if (block_nearest_center_distances(i, 0) <
533 nearest_center_distances(i, 0)) {
534 nearest_center_indices(i, 0) =
535 block_nearest_center_indices(i, 0) + centers_start;
536 nearest_center_distances(i, 0) =
537 block_nearest_center_distances(i, 0);
538 }
539 }
540 } else {
541 for (int i = 0; i < num_points; ++i) {
542 // Merge and accumulate top-k list from block_nearest_center_indices
543 // into nearest_center_indices.
544 for (int64 j_out = 0, j_block = 0, j_merged = 0;
545 (j_out < out_k || j_block < block_k) && j_merged < k;
546 ++j_merged) {
547 const float distance_out =
548 j_out < out_k ? nearest_center_distances(i, j_out)
549 : std::numeric_limits<float>::infinity();
550 const float distance_block =
551 j_block < block_k ? block_nearest_center_distances(i, j_block)
552 : std::numeric_limits<float>::infinity();
553 if (distance_out <= distance_block) {
554 merged_indices(j_merged) = nearest_center_indices(i, j_out);
555 merged_distances(j_merged) = distance_out;
556 ++j_out;
557 } else {
558 merged_indices(j_merged) =
559 block_nearest_center_indices(i, j_block) + centers_start;
560 merged_distances(j_merged) = distance_block;
561 ++j_block;
562 }
563 }
564 nearest_center_indices.row(i) = merged_indices;
565 nearest_center_distances.row(i) = merged_distances;
566 out_k = std::min(k, out_k + block_k);
567 }
568 }
569 }
570 }
571 };
572
573 REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
574 NearestNeighborsOp);
575
576 } // namespace tensorflow
577