1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy
5 // of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15
16 #define EIGEN_USE_THREADS
17
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <tuple>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/top_n.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/simple_philox.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 namespace tensorflow {
39 namespace {
40 using errors::InvalidArgument;
41
42 template <typename Scalar>
43 using RowMajorMatrix =
44 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
45
46 using MatrixXfRowMajor = RowMajorMatrix<float>;
47 using MatrixXi64RowMajor = RowMajorMatrix<int64>;
48
49 // Ideally this should be computed by dividing L3 cache size by the number of
50 // physical CPUs. Since there isn't a portable method to do this, we are using
51 // a conservative estimate here.
52 const int64 kDefaultL3CachePerCpu = 1 << 20;
53
54 // These values were determined by performing a parameter sweep on the
55 // NearestNeighborsOp benchmark.
56 const int64 kNearestNeighborsCentersMaxBlockSize = 1024;
57 const int64 kNearestNeighborsPointsMinBlockSize = 16;
58
59 // Returns the smallest multiple of a that is not smaller than b.
NextMultiple(int64 a,int64 b)60 int64 NextMultiple(int64 a, int64 b) {
61 const int64 remainder = b % a;
62 return remainder == 0 ? b : (b + a - remainder);
63 }
64
65 // Returns a / b rounded up to the next higher integer.
CeilOfRatio(int64 a,int64 b)66 int64 CeilOfRatio(int64 a, int64 b) { return (a + b - 1) / b; }
67
68 } // namespace
69
70 // Implementation of K-means++ initialization. Samples points iteratively in
71 // proportion to the squared distances from selected points.
72 // TODO(ands): Add support for other distance metrics.
73 class KmeansPlusPlusInitializationOp : public OpKernel {
74 public:
KmeansPlusPlusInitializationOp(OpKernelConstruction * context)75 explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
76 : OpKernel(context) {
77 OP_REQUIRES_OK(context,
78 context->MatchSignature(
79 {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
80 }
81
Compute(OpKernelContext * context)82 void Compute(OpKernelContext* context) override {
83 const Tensor& points_tensor = context->input(0);
84 const Tensor& num_to_sample_tensor = context->input(1);
85 const Tensor& seed_tensor = context->input(2);
86 const Tensor& num_retries_per_sample_tensor = context->input(3);
87
88 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
89 InvalidArgument("Input points should be a matrix."));
90 OP_REQUIRES(context,
91 TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
92 InvalidArgument("Input num_to_sample should be a scalar."));
93 OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
94 InvalidArgument("Input seed should be a scalar."));
95 OP_REQUIRES(
96 context,
97 TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
98 InvalidArgument("Input num_retries_per_sample should be a scalar."));
99
100 const int64 num_points = points_tensor.dim_size(0);
101 const int64 point_dimensions = points_tensor.dim_size(1);
102 const int64 num_to_sample = num_to_sample_tensor.scalar<int64>()();
103 const int64 seed = seed_tensor.scalar<int64>()();
104 const int64 num_retries_per_sample = [&]() {
105 const int64 value = num_retries_per_sample_tensor.scalar<int64>()();
106 return value >= 0 ? value
107 : 2 + static_cast<int64>(std::log(num_to_sample));
108 }();
109
110 OP_REQUIRES(context, num_points > 0,
111 InvalidArgument("Expected points.rows() > 0."));
112 OP_REQUIRES(context, num_to_sample > 0,
113 InvalidArgument("Expected num_to_sample > 0."));
114 OP_REQUIRES(context, num_to_sample <= num_points,
115 InvalidArgument("Expected num_to_sample <= points.rows(). ",
116 num_to_sample, " vs ", num_points, "."));
117
118 Tensor* output_sampled_points_tensor;
119 OP_REQUIRES_OK(context,
120 context->allocate_output(
121 0, TensorShape({num_to_sample, point_dimensions}),
122 &output_sampled_points_tensor));
123
124 const Eigen::Map<const MatrixXfRowMajor> points(
125 points_tensor.matrix<float>().data(), num_points, point_dimensions);
126 const Eigen::VectorXf points_half_squared_norm =
127 0.5 * points.rowwise().squaredNorm();
128
129 Eigen::Map<MatrixXfRowMajor> sampled_points(
130 output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
131 point_dimensions);
132 std::unordered_set<int64> sampled_indices;
133
134 random::PhiloxRandom random(seed);
135 random::SimplePhilox rng(&random);
136
137 auto add_one_point = [&](int64 from, int64 to) {
138 from = std::min(from, num_points - 1);
139 sampled_points.row(to) = points.row(from);
140 sampled_indices.insert(from);
141 };
142
143 // Distances from all points to nearest selected point. Initialize with
144 // distances to first selected point.
145 Eigen::VectorXf min_distances(num_points);
146 min_distances.fill(std::numeric_limits<float>::infinity());
147 Eigen::VectorXf min_distances_cumsum(num_points);
148
149 auto draw_one_sample = [&]() -> int64 {
150 if (sampled_indices.empty()) return rng.Uniform64(num_points);
151 int64 index = 0;
152 do {
153 // If v is drawn from Uniform[0, distances.sum()), then
154 // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
155 // proportional to distances(i).
156 index = std::upper_bound(
157 min_distances_cumsum.data(),
158 min_distances_cumsum.data() + num_points,
159 rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
160 min_distances_cumsum.data();
161 } while (sampled_indices.find(index) != sampled_indices.end());
162 return index;
163 };
164
165 auto sample_one_point = [&]() {
166 const int64 sampled_index = draw_one_sample();
167 min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
168 points, points_half_squared_norm, points.row(sampled_index),
169 points_half_squared_norm(sampled_index)));
170 return sampled_index;
171 };
172
173 auto sample_one_point_with_retries = [&]() {
174 Eigen::VectorXf best_new_min_distances(num_points);
175 float best_potential = std::numeric_limits<float>::infinity();
176 int64 best_sampled_index = 0;
177 for (int i = 1 + num_retries_per_sample; i > 0; --i) {
178 const int64 sampled_index = draw_one_sample();
179 Eigen::VectorXf new_min_distances =
180 min_distances.cwiseMin(GetHalfSquaredDistancesToY(
181 points, points_half_squared_norm, points.row(sampled_index),
182 points_half_squared_norm(sampled_index)));
183 const float potential = new_min_distances.sum();
184 if (potential < best_potential) {
185 best_potential = potential;
186 best_sampled_index = sampled_index;
187 best_new_min_distances.swap(new_min_distances);
188 }
189 }
190 min_distances.swap(best_new_min_distances);
191 return best_sampled_index;
192 };
193
194 for (int64 i = 0; i < num_to_sample; ++i) {
195 if (i > 0) {
196 std::partial_sum(min_distances.data(),
197 min_distances.data() + num_points,
198 min_distances_cumsum.data());
199 }
200 int64 next = num_retries_per_sample == 0
201 ? sample_one_point()
202 : sample_one_point_with_retries();
203 add_one_point(next, i);
204 }
205 }
206
207 private:
208 // Returns a column vector with the i-th element set to half the squared
209 // euclidean distance between the i-th row of xs, and y. Precomputed norms for
210 // each row of xs and y must be provided for efficiency.
211 // TODO(ands): Parallelize this for large xs.
GetHalfSquaredDistancesToY(const Eigen::Ref<const MatrixXfRowMajor> & xs,const Eigen::Ref<const Eigen::VectorXf> & xs_half_squared_norm,const Eigen::Ref<const Eigen::RowVectorXf> & y,float y_half_squared_norm)212 static Eigen::VectorXf GetHalfSquaredDistancesToY(
213 const Eigen::Ref<const MatrixXfRowMajor>& xs,
214 const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
215 const Eigen::Ref<const Eigen::RowVectorXf>& y,
216 float y_half_squared_norm) {
217 // Squared distance between points xs_i and y is:
218 // || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
219 return (xs_half_squared_norm - xs * y.transpose()).array() +
220 y_half_squared_norm;
221 }
222 };
223
224 REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
225 KmeansPlusPlusInitializationOp);
226
227 // Implementation of one single Markov Chain for the k-MC^2 algorithm
228 class KMC2ChainInitializationOp : public OpKernel {
229 public:
KMC2ChainInitializationOp(OpKernelConstruction * context)230 explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
231 : OpKernel(context) {
232 OP_REQUIRES_OK(context,
233 context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
234 }
235
Compute(OpKernelContext * context)236 void Compute(OpKernelContext* context) override {
237 const Tensor& distances_tensor = context->input(0);
238 const Tensor& seed_tensor = context->input(1);
239 OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
240 InvalidArgument("Input distances should be a vector."));
241 OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
242 InvalidArgument("Input seed should be a scalar."));
243 const int64 num_points = distances_tensor.dim_size(0);
244 const int64 seed = seed_tensor.scalar<int64>()();
245 OP_REQUIRES(context, num_points > 0,
246 InvalidArgument("Expected distances_tensor.size() > 0."));
247
248 random::PhiloxRandom random(seed);
249 random::SimplePhilox rng(&random);
250
251 auto distances = distances_tensor.flat<float>();
252 // Set the initial state of the Markov chain to be the first candidate.
253 int64 selected_index = 0;
254 float selected_distance = distances(selected_index);
255 // Build a Markov chain of length num_points.
256 for (int64 i = 1; i < num_points; ++i) {
257 const float candidate_distance = distances(i);
258 // Set the next state of the Markov chain to be the candidate with
259 // probability min(1, candidate_distance/selected_distance).
260 if (candidate_distance > rng.RandFloat() * selected_distance) {
261 selected_index = i;
262 selected_distance = candidate_distance;
263 }
264 }
265
266 Tensor* output_sampled_index_tensor;
267 OP_REQUIRES_OK(context,
268 context->allocate_output(0, TensorShape({}),
269 &output_sampled_index_tensor));
270 auto output = output_sampled_index_tensor->scalar<int64>();
271 // Return the last state of the Markov chain as the new center.
272 output() = selected_index;
273 }
274 };
275
276 REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
277 KMC2ChainInitializationOp);
278
279 // Operator for computing the nearest neighbors for a set of points.
280 class NearestNeighborsOp : public OpKernel {
281 public:
NearestNeighborsOp(OpKernelConstruction * context)282 explicit NearestNeighborsOp(OpKernelConstruction* context)
283 : OpKernel(context) {
284 OP_REQUIRES_OK(context,
285 context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
286 {DT_INT64, DT_FLOAT}));
287 }
288
Compute(OpKernelContext * context)289 void Compute(OpKernelContext* context) override {
290 const Tensor& points_tensor = context->input(0);
291 const Tensor& centers_tensor = context->input(1);
292 const Tensor& k_tensor = context->input(2);
293
294 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
295 InvalidArgument("Input points should be a matrix."));
296 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
297 InvalidArgument("Input centers should be a matrix."));
298 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
299 InvalidArgument("Input k should be a scalar."));
300
301 const int64 num_points = points_tensor.dim_size(0);
302 const int64 point_dimensions = points_tensor.dim_size(1);
303 const int64 num_centers = centers_tensor.dim_size(0);
304 const int64 center_dimensions = centers_tensor.dim_size(1);
305
306 OP_REQUIRES(context, num_points > 0,
307 InvalidArgument("Expected points.rows() > 0."));
308 OP_REQUIRES(
309 context, point_dimensions == center_dimensions,
310 InvalidArgument("Expected point_dimensions == center_dimensions: ",
311 point_dimensions, " vs ", center_dimensions, "."));
312
313 const Eigen::Map<const MatrixXfRowMajor> points(
314 points_tensor.matrix<float>().data(), num_points, point_dimensions);
315 const Eigen::Map<const MatrixXfRowMajor> centers(
316 centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
317 const int64 k = std::min<int64>(num_centers, k_tensor.scalar<int64>()());
318
319 Tensor* output_nearest_center_indices_tensor;
320 Tensor* output_nearest_center_distances_tensor;
321 OP_REQUIRES_OK(context, context->allocate_output(
322 0, TensorShape({num_points, k}),
323 &output_nearest_center_indices_tensor));
324 OP_REQUIRES_OK(context, context->allocate_output(
325 1, TensorShape({num_points, k}),
326 &output_nearest_center_distances_tensor));
327
328 if (k == 0) return;
329
330 Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
331 output_nearest_center_indices_tensor->matrix<int64>().data(),
332 num_points, k);
333 Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
334 output_nearest_center_distances_tensor->matrix<float>().data(),
335 num_points, k);
336
337 const Eigen::VectorXf centers_half_squared_norm =
338 0.5 * centers.rowwise().squaredNorm();
339
340 // The distance computation is sharded to take advantage of multiple cores
341 // and to allow intermediate values to reside in L3 cache. This is done by
342 // sharding the points and centers as follows:
343 //
344 // 1. Centers are sharded such that each block of centers has at most
345 // kNearestNeighborsCentersMaxBlockSize rows.
346 // 2. Points are sharded, and each block of points is multiplied with each
347 // block of centers. The block size of points is chosen such that the
348 // point coordinates (point_dimensions) and the matrix of distances to
349 // each center in one block -- the intermediate data -- fits in L3 cache.
350 // 3. After performing each block-block distance computation, the results
351 // are reduced to a set of k nearest centers as soon as possible. This
352 // decreases total memory I/O.
353 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
354 const int64 num_threads = worker_threads.num_threads;
355 // This kernel might be configured to use fewer than the total number of
356 // available CPUs on the host machine. To avoid descructive interference
357 // with other jobs running on the host machine, we must only use a fraction
358 // of total available L3 cache. Unfortunately, we cannot query the host
359 // machine to get the number of physical CPUs. So, we use a fixed per-CPU
360 // budget and scale it by the number of CPUs available to this operation.
361 const int64 total_memory_budget =
362 kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
363 // Compute the number of blocks into which rows of points must be split so
364 // that the distance matrix and the block of points can fit in cache. One
365 // row of points will yield a vector of distances to each center in a block.
366 const int64 bytes_per_row =
367 (std::min(kNearestNeighborsCentersMaxBlockSize,
368 num_centers) /* centers in a block */
369 + point_dimensions /* coordinates of one point */) *
370 sizeof(float);
371 // The memory needed for storing the centers being processed. This is shared
372 // by all workers. Adding slack to the number of threads to avoid incorrect
373 // cache eviction when a new block of centers is loaded.
374 const int64 bytes_for_centers =
375 std::min(num_centers,
376 (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
377 point_dimensions * sizeof(float);
378 // The memory budget available for workers to store their distance matrices.
379 const int64 available_memory_budget =
380 total_memory_budget - bytes_for_centers;
381 // That memory budget is shared by all threads.
382 const int64 rows_per_block =
383 std::max<int64>(kNearestNeighborsPointsMinBlockSize,
384 available_memory_budget / num_threads / bytes_per_row);
385 // Divide rows into almost uniformly-sized units of work that are small
386 // enough for the memory budget (rows_per_block). Round up to a multiple of
387 // the number of threads.
388 const int64 num_units =
389 NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
390 auto work = [&](int64 start, int64 limit) {
391 for (; start < limit; ++start) {
392 const int64 start_row = num_points * start / num_units;
393 const int64 limit_row = num_points * (start + 1) / num_units;
394 CHECK_LE(limit_row, num_points);
395 const int64 num_rows = limit_row - start_row;
396 auto points_shard = points.middleRows(start_row, num_rows);
397 const Eigen::VectorXf points_half_squared_norm =
398 0.5 * points_shard.rowwise().squaredNorm();
399 auto nearest_center_indices_shard =
400 nearest_center_indices.middleRows(start_row, num_rows);
401 auto nearest_center_distances_shard =
402 nearest_center_distances.middleRows(start_row, num_rows);
403 FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
404 centers_half_squared_norm,
405 nearest_center_indices_shard,
406 nearest_center_distances_shard);
407 }
408 };
409
410 const int64 units_per_thread = num_units / num_threads;
411 BlockingCounter counter(num_threads - 1);
412 for (int64 i = 1; i < num_threads; ++i) {
413 const int64 start = i * units_per_thread;
414 const int64 limit = start + units_per_thread;
415 worker_threads.workers->Schedule([work, &counter, start, limit]() {
416 work(start, limit);
417 counter.DecrementCount();
418 });
419 }
420 work(0, units_per_thread);
421 counter.Wait();
422 }
423
424 private:
FindKNearestCenters(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,const Eigen::Ref<MatrixXi64RowMajor> & nearest_center_indices,const Eigen::Ref<MatrixXfRowMajor> & nearest_center_distances)425 static void FindKNearestCenters(
426 int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
427 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
428 const Eigen::Ref<const MatrixXfRowMajor>& centers,
429 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
430 const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
431 const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
432 CHECK_LE(k, centers.rows());
433 if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
434 FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
435 centers_half_squared_norm,
436 nearest_center_indices,
437 nearest_center_distances);
438 } else {
439 FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
440 centers_half_squared_norm,
441 nearest_center_indices,
442 nearest_center_distances);
443 }
444 }
445
FindKNearestCentersOneBlock(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)446 static void FindKNearestCentersOneBlock(
447 int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
448 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
449 const Eigen::Ref<const MatrixXfRowMajor>& centers,
450 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
451 Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
452 Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
453 CHECK_LE(k, centers.rows());
454 const int64 num_points = points.rows();
455 const MatrixXfRowMajor inner_product = points * centers.transpose();
456 // Find nearest neighbors.
457 if (k == 1) {
458 for (int i = 0; i < num_points; ++i) {
459 int64 index;
460 nearest_center_distances(i, 0) =
461 2.0 *
462 (points_half_squared_norm(i) +
463 (centers_half_squared_norm.transpose() - inner_product.row(i))
464 .minCoeff(&index));
465 nearest_center_indices(i, 0) = index;
466 }
467 } else {
468 // Select k nearest centers for each point.
469 using Center = std::pair<float, int64>;
470 const int64 num_centers = centers.rows();
471 gtl::TopN<Center, std::less<Center>> selector(k);
472 std::unique_ptr<std::vector<Center>> nearest_centers;
473 for (int i = 0; i < num_points; ++i) {
474 selector.reserve(num_centers);
475 for (int j = 0; j < num_centers; ++j) {
476 const float partial_distance =
477 centers_half_squared_norm(j) - inner_product(i, j);
478 selector.push(Center(partial_distance, j));
479 }
480 nearest_centers.reset(selector.Extract());
481 selector.Reset();
482 const float point_half_squared_norm = points_half_squared_norm(i);
483 for (int j = 0; j < k; ++j) {
484 const Center& center = (*nearest_centers)[j];
485 nearest_center_distances(i, j) =
486 2.0 * (point_half_squared_norm + center.first);
487 nearest_center_indices(i, j) = center.second;
488 }
489 }
490 }
491 }
492
FindKNearestCentersBlockwise(int64 k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)493 static void FindKNearestCentersBlockwise(
494 int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
495 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
496 const Eigen::Ref<const MatrixXfRowMajor>& centers,
497 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
498 Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
499 Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
500 const int64 num_points = points.rows();
501 const int64 num_centers = centers.rows();
502 CHECK_LE(k, num_centers);
503 CHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
504 // Store nearest neighbors with first block of centers directly into the
505 // output matrices.
506 int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
507 FindKNearestCentersOneBlock(
508 out_k, points, points_half_squared_norm,
509 centers.topRows(kNearestNeighborsCentersMaxBlockSize),
510 centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
511 nearest_center_indices, nearest_center_distances);
512 // Iteratively compute nearest neighbors with other blocks of centers, and
513 // update the output matrices.
514 MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
515 MatrixXfRowMajor block_nearest_center_distances(num_points, k);
516 Eigen::Matrix<int64, 1, Eigen::Dynamic> merged_indices(k);
517 Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
518 for (int64 centers_start = kNearestNeighborsCentersMaxBlockSize;
519 centers_start < num_centers;
520 centers_start += kNearestNeighborsCentersMaxBlockSize) {
521 const int64 centers_block_size = std::min(
522 kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
523 const int64 block_k = std::min(k, centers_block_size);
524 FindKNearestCentersOneBlock(
525 block_k, points, points_half_squared_norm,
526 centers.middleRows(centers_start, centers_block_size),
527 centers_half_squared_norm.segment(centers_start, centers_block_size),
528 block_nearest_center_indices, block_nearest_center_distances);
529 if (k == 1) {
530 for (int i = 0; i < num_points; ++i) {
531 if (block_nearest_center_distances(i, 0) <
532 nearest_center_distances(i, 0)) {
533 nearest_center_indices(i, 0) =
534 block_nearest_center_indices(i, 0) + centers_start;
535 nearest_center_distances(i, 0) =
536 block_nearest_center_distances(i, 0);
537 }
538 }
539 } else {
540 for (int i = 0; i < num_points; ++i) {
541 // Merge and accumulate top-k list from block_nearest_center_indices
542 // into nearest_center_indices.
543 for (int64 j_out = 0, j_block = 0, j_merged = 0;
544 (j_out < out_k || j_block < block_k) && j_merged < k;
545 ++j_merged) {
546 const float distance_out =
547 j_out < out_k ? nearest_center_distances(i, j_out)
548 : std::numeric_limits<float>::infinity();
549 const float distance_block =
550 j_block < block_k ? block_nearest_center_distances(i, j_block)
551 : std::numeric_limits<float>::infinity();
552 if (distance_out <= distance_block) {
553 merged_indices(j_merged) = nearest_center_indices(i, j_out);
554 merged_distances(j_merged) = distance_out;
555 ++j_out;
556 } else {
557 merged_indices(j_merged) =
558 block_nearest_center_indices(i, j_block) + centers_start;
559 merged_distances(j_merged) = distance_block;
560 ++j_block;
561 }
562 }
563 nearest_center_indices.row(i) = merged_indices;
564 nearest_center_distances.row(i) = merged_distances;
565 out_k = std::min(k, out_k + block_k);
566 }
567 }
568 }
569 }
570 };
571
572 REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
573 NearestNeighborsOp);
574
575 } // namespace tensorflow
576