1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 #include <iterator>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
27 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
28 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/refcount.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36
37 namespace tensorflow {
38
39 const char* const kExampleWeightsName = "example_weights";
40 const char* const kMaxElementsName = "max_elements";
41 const char* const kGenerateQuantiles = "generate_quantiles";
42 const char* const kNumBucketsName = "num_buckets";
43 const char* const kEpsilonName = "epsilon";
44 const char* const kBucketBoundariesName = "bucket_boundaries";
45 const char* const kBucketsName = "buckets";
46 const char* const kSummariesName = "summaries";
47 const char* const kNumStreamsName = "num_streams";
48 const char* const kNumFeaturesName = "num_features";
49 const char* const kFloatFeaturesName = "float_values";
50 const char* const kResourceHandleName = "quantile_stream_resource_handle";
51
52 using QuantileStreamResource = BoostedTreesQuantileStreamResource;
53 using QuantileStream =
54 boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
55 using QuantileSummary =
56 boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
57 using QuantileSummaryEntry =
58 boosted_trees::quantiles::WeightedQuantilesSummary<float,
59 float>::SummaryEntry;
60
61 // Generates quantiles on a finalized QuantileStream.
GenerateBoundaries(const QuantileStream & stream,const int64_t num_boundaries)62 std::vector<float> GenerateBoundaries(const QuantileStream& stream,
63 const int64_t num_boundaries) {
64 std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
65
66 // Uniquify elements as we may get dupes.
67 auto end_it = std::unique(boundaries.begin(), boundaries.end());
68 boundaries.resize(std::distance(boundaries.begin(), end_it));
69 return boundaries;
70 }
71
72 // Generates quantiles on a finalized QuantileStream.
GenerateQuantiles(const QuantileStream & stream,const int64_t num_quantiles)73 std::vector<float> GenerateQuantiles(const QuantileStream& stream,
74 const int64_t num_quantiles) {
75 // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
76 // will be returned.
77 std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
78 CHECK_EQ(boundaries.size(), num_quantiles);
79 return boundaries;
80 }
81
GetBuckets(const int32_t feature,const OpInputList & buckets_list)82 std::vector<float> GetBuckets(const int32_t feature,
83 const OpInputList& buckets_list) {
84 const auto& buckets = buckets_list[feature].flat<float>();
85 std::vector<float> buckets_vector(buckets.data(),
86 buckets.data() + buckets.size());
87 return buckets_vector;
88 }
89
90 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
91
92 REGISTER_KERNEL_BUILDER(
93 Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
94 IsResourceInitialized<BoostedTreesQuantileStreamResource>);
95
96 class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
97 public:
BoostedTreesCreateQuantileStreamResourceOp(OpKernelConstruction * const context)98 explicit BoostedTreesCreateQuantileStreamResourceOp(
99 OpKernelConstruction* const context)
100 : OpKernel(context) {
101 OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
102 }
103
Compute(OpKernelContext * context)104 void Compute(OpKernelContext* context) override {
105 // Only create one, if one does not exist already. Report status for all
106 // other exceptions. If one already exists, it unrefs the new one.
107 // An epsilon value of zero could cause performance issues and is therefore,
108 // disallowed.
109 const Tensor* epsilon_t;
110 OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
111 float epsilon = epsilon_t->scalar<float>()();
112 OP_REQUIRES(
113 context, epsilon > 0,
114 errors::InvalidArgument("An epsilon value of zero is not allowed."));
115
116 const Tensor* num_streams_t;
117 OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
118 int64_t num_streams = num_streams_t->scalar<int64>()();
119 OP_REQUIRES(context, num_streams >= 0,
120 errors::InvalidArgument(
121 "Num_streams input cannot be a negative integer"));
122
123 auto result =
124 new QuantileStreamResource(epsilon, max_elements_, num_streams);
125 auto status = CreateResource(context, HandleFromInput(context, 0), result);
126 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
127 OP_REQUIRES(context, false, status);
128 }
129 }
130
131 private:
132 // An upper bound on the number of entries that the summaries might have
133 // for a feature.
134 int64 max_elements_;
135 };
136
137 REGISTER_KERNEL_BUILDER(
138 Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
139 BoostedTreesCreateQuantileStreamResourceOp);
140
141 class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
142 public:
BoostedTreesMakeQuantileSummariesOp(OpKernelConstruction * const context)143 explicit BoostedTreesMakeQuantileSummariesOp(
144 OpKernelConstruction* const context)
145 : OpKernel(context) {
146 OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
147 }
148
Compute(OpKernelContext * const context)149 void Compute(OpKernelContext* const context) override {
150 // Read float features list;
151 OpInputList float_features_list;
152 OP_REQUIRES_OK(
153 context, context->input_list(kFloatFeaturesName, &float_features_list));
154
155 // Parse example weights and get batch size.
156 const Tensor* example_weights_t;
157 OP_REQUIRES_OK(context,
158 context->input(kExampleWeightsName, &example_weights_t));
159 DCHECK(float_features_list.size() > 0) << "Got empty feature list";
160 auto example_weights = example_weights_t->flat<float>();
161 const int64_t weight_size = example_weights.size();
162 const int64_t batch_size = float_features_list[0].flat<float>().size();
163 OP_REQUIRES(
164 context, weight_size == 1 || weight_size == batch_size,
165 errors::InvalidArgument(strings::Printf(
166 "Weights should be a single value or same size as features.")));
167 const Tensor* epsilon_t;
168 OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
169 float epsilon = epsilon_t->scalar<float>()();
170
171 OpOutputList summaries_output_list;
172 OP_REQUIRES_OK(
173 context, context->output_list(kSummariesName, &summaries_output_list));
174
175 auto do_quantile_summary_gen = [&](const int64_t begin, const int64_t end) {
176 // Iterating features.
177 for (int64_t index = begin; index < end; index++) {
178 const auto feature_values = float_features_list[index].flat<float>();
179 QuantileStream stream(epsilon, batch_size + 1);
180 // Run quantile summary generation.
181 for (int64_t j = 0; j < batch_size; j++) {
182 stream.PushEntry(feature_values(j), (weight_size > 1)
183 ? example_weights(j)
184 : example_weights(0));
185 }
186 stream.Finalize();
187 const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
188 Tensor* output_t;
189 OP_REQUIRES_OK(
190 context,
191 summaries_output_list.allocate(
192 index,
193 TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
194 &output_t));
195 auto output = output_t->matrix<float>();
196 for (auto row = 0; row < summary_entry_list.size(); row++) {
197 const auto& entry = summary_entry_list[row];
198 output(row, 0) = entry.value;
199 output(row, 1) = entry.weight;
200 output(row, 2) = entry.min_rank;
201 output(row, 3) = entry.max_rank;
202 }
203 }
204 };
205 // TODO(tanzheny): comment on the magic number.
206 const int64_t kCostPerUnit = 500 * batch_size;
207 const DeviceBase::CpuWorkerThreads& worker_threads =
208 *context->device()->tensorflow_cpu_worker_threads();
209 Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
210 kCostPerUnit, do_quantile_summary_gen);
211 }
212
213 private:
214 int64 num_features_;
215 };
216
217 REGISTER_KERNEL_BUILDER(
218 Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
219 BoostedTreesMakeQuantileSummariesOp);
220
221 class BoostedTreesFlushQuantileSummariesOp : public OpKernel {
222 public:
BoostedTreesFlushQuantileSummariesOp(OpKernelConstruction * const context)223 explicit BoostedTreesFlushQuantileSummariesOp(
224 OpKernelConstruction* const context)
225 : OpKernel(context) {
226 OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
227 }
228
Compute(OpKernelContext * const context)229 void Compute(OpKernelContext* const context) override {
230 ResourceHandle handle;
231 OP_REQUIRES_OK(context,
232 HandleFromInput(context, kResourceHandleName, &handle));
233 core::RefCountPtr<QuantileStreamResource> stream_resource;
234 OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
235 // Remove the reference at the end of this scope.
236 mutex_lock l(*stream_resource->mutex());
237
238 OpOutputList summaries_output_list;
239 OP_REQUIRES_OK(
240 context, context->output_list(kSummariesName, &summaries_output_list));
241
242 auto do_quantile_summary_gen = [&](const int64_t begin, const int64_t end) {
243 // Iterating features.
244 for (int64_t index = begin; index < end; index++) {
245 QuantileStream* stream = stream_resource->stream(index);
246 stream->Finalize();
247
248 const auto summary_list = stream->GetFinalSummary().GetEntryList();
249 Tensor* output_t;
250 const int64_t summary_list_size =
251 static_cast<int64>(summary_list.size());
252 OP_REQUIRES_OK(context, summaries_output_list.allocate(
253 index, TensorShape({summary_list_size, 4}),
254 &output_t));
255 auto output = output_t->matrix<float>();
256 for (auto row = 0; row < summary_list_size; row++) {
257 const auto& entry = summary_list[row];
258 output(row, 0) = entry.value;
259 output(row, 1) = entry.weight;
260 output(row, 2) = entry.min_rank;
261 output(row, 3) = entry.max_rank;
262 }
263 }
264 };
265 // TODO(tanzheny): comment on the magic number.
266 const int64_t kCostPerUnit = 500 * num_features_;
267 const DeviceBase::CpuWorkerThreads& worker_threads =
268 *context->device()->tensorflow_cpu_worker_threads();
269 Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
270 kCostPerUnit, do_quantile_summary_gen);
271 stream_resource->ResetStreams();
272 }
273
274 private:
275 int64 num_features_;
276 };
277
278 REGISTER_KERNEL_BUILDER(
279 Name("BoostedTreesFlushQuantileSummaries").Device(DEVICE_CPU),
280 BoostedTreesFlushQuantileSummariesOp);
281
282 class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
283 public:
BoostedTreesQuantileStreamResourceAddSummariesOp(OpKernelConstruction * const context)284 explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
285 OpKernelConstruction* const context)
286 : OpKernel(context) {}
287
Compute(OpKernelContext * context)288 void Compute(OpKernelContext* context) override {
289 ResourceHandle handle;
290 OP_REQUIRES_OK(context,
291 HandleFromInput(context, kResourceHandleName, &handle));
292 core::RefCountPtr<QuantileStreamResource> stream_resource;
293 // Create a reference to the underlying resource using the handle.
294 OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
295 // Remove the reference at the end of this scope.
296 mutex_lock l(*stream_resource->mutex());
297
298 OpInputList summaries_list;
299 OP_REQUIRES_OK(context,
300 context->input_list(kSummariesName, &summaries_list));
301 int32_t num_streams = stream_resource->num_streams();
302 CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
303
304 auto do_quantile_add_summary = [&](const int64_t begin, const int64_t end) {
305 // Iterating all features.
306 for (int64_t feature_idx = begin; feature_idx < end; ++feature_idx) {
307 QuantileStream* stream = stream_resource->stream(feature_idx);
308 if (stream->IsFinalized()) {
309 VLOG(1) << "QuantileStream has already been finalized for feature"
310 << feature_idx << ".";
311 continue;
312 }
313 const Tensor& summaries = summaries_list[feature_idx];
314 const auto summary_values = summaries.matrix<float>();
315 const auto& tensor_shape = summaries.shape();
316 const int64_t entries_size = tensor_shape.dim_size(0);
317 CHECK_EQ(tensor_shape.dim_size(1), 4);
318 std::vector<QuantileSummaryEntry> summary_entries;
319 summary_entries.reserve(entries_size);
320 for (int64_t i = 0; i < entries_size; i++) {
321 float value = summary_values(i, 0);
322 float weight = summary_values(i, 1);
323 float min_rank = summary_values(i, 2);
324 float max_rank = summary_values(i, 3);
325 QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
326 summary_entries.push_back(entry);
327 }
328 stream_resource->stream(feature_idx)->PushSummary(summary_entries);
329 }
330 };
331
332 // TODO(tanzheny): comment on the magic number.
333 const int64_t kCostPerUnit = 500 * num_streams;
334 const DeviceBase::CpuWorkerThreads& worker_threads =
335 *context->device()->tensorflow_cpu_worker_threads();
336 Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
337 kCostPerUnit, do_quantile_add_summary);
338 }
339 };
340
341 REGISTER_KERNEL_BUILDER(
342 Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
343 BoostedTreesQuantileStreamResourceAddSummariesOp);
344
345 class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
346 public:
BoostedTreesQuantileStreamResourceDeserializeOp(OpKernelConstruction * const context)347 explicit BoostedTreesQuantileStreamResourceDeserializeOp(
348 OpKernelConstruction* const context)
349 : OpKernel(context) {
350 OP_REQUIRES_OK(context, context->GetAttr(kNumStreamsName, &num_features_));
351 }
352
Compute(OpKernelContext * context)353 void Compute(OpKernelContext* context) override {
354 core::RefCountPtr<QuantileStreamResource> streams_resource;
355 // Create a reference to the underlying resource using the handle.
356 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
357 &streams_resource));
358 // Remove the reference at the end of this scope.
359 mutex_lock l(*streams_resource->mutex());
360
361 OpInputList bucket_boundaries_list;
362 OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
363 &bucket_boundaries_list));
364
365 auto do_quantile_deserialize = [&](const int64_t begin, const int64_t end) {
366 // Iterating over all streams.
367 for (int64_t stream_idx = begin; stream_idx < end; stream_idx++) {
368 const Tensor& bucket_boundaries_t = bucket_boundaries_list[stream_idx];
369 const auto& bucket_boundaries = bucket_boundaries_t.vec<float>();
370 std::vector<float> result;
371 result.reserve(bucket_boundaries.size());
372 for (size_t i = 0; i < bucket_boundaries.size(); ++i) {
373 result.push_back(bucket_boundaries(i));
374 }
375 streams_resource->set_boundaries(result, stream_idx);
376 }
377 };
378
379 // TODO(tanzheny): comment on the magic number.
380 const int64_t kCostPerUnit = 500 * num_features_;
381 const DeviceBase::CpuWorkerThreads& worker_threads =
382 *context->device()->tensorflow_cpu_worker_threads();
383 Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
384 kCostPerUnit, do_quantile_deserialize);
385 }
386
387 private:
388 int64 num_features_;
389 };
390
391 REGISTER_KERNEL_BUILDER(
392 Name("BoostedTreesQuantileStreamResourceDeserialize").Device(DEVICE_CPU),
393 BoostedTreesQuantileStreamResourceDeserializeOp);
394
395 class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
396 public:
BoostedTreesQuantileStreamResourceFlushOp(OpKernelConstruction * const context)397 explicit BoostedTreesQuantileStreamResourceFlushOp(
398 OpKernelConstruction* const context)
399 : OpKernel(context) {
400 OP_REQUIRES_OK(context,
401 context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
402 }
403
Compute(OpKernelContext * context)404 void Compute(OpKernelContext* context) override {
405 ResourceHandle handle;
406 OP_REQUIRES_OK(context,
407 HandleFromInput(context, kResourceHandleName, &handle));
408 core::RefCountPtr<QuantileStreamResource> stream_resource;
409 // Create a reference to the underlying resource using the handle.
410 OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
411 // Remove the reference at the end of this scope.
412 mutex_lock l(*stream_resource->mutex());
413
414 const Tensor* num_buckets_t;
415 OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
416 const int64_t num_buckets = num_buckets_t->scalar<int64>()();
417 const int64_t num_streams = stream_resource->num_streams();
418
419 auto do_quantile_flush = [&](const int64_t begin, const int64_t end) {
420 // Iterating over all streams.
421 for (int64_t stream_idx = begin; stream_idx < end; ++stream_idx) {
422 QuantileStream* stream = stream_resource->stream(stream_idx);
423 stream->Finalize();
424 stream_resource->set_boundaries(
425 generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
426 : GenerateBoundaries(*stream, num_buckets),
427 stream_idx);
428 }
429 };
430
431 // TODO(tanzheny): comment on the magic number.
432 const int64_t kCostPerUnit = 500 * num_streams;
433 const DeviceBase::CpuWorkerThreads& worker_threads =
434 *context->device()->tensorflow_cpu_worker_threads();
435 Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
436 kCostPerUnit, do_quantile_flush);
437
438 stream_resource->ResetStreams();
439 stream_resource->set_buckets_ready(true);
440 }
441
442 private:
443 bool generate_quantiles_;
444 };
445
446 REGISTER_KERNEL_BUILDER(
447 Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
448 BoostedTreesQuantileStreamResourceFlushOp);
449
450 class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
451 : public OpKernel {
452 public:
BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(OpKernelConstruction * const context)453 explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
454 OpKernelConstruction* const context)
455 : OpKernel(context) {
456 OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
457 }
458
Compute(OpKernelContext * const context)459 void Compute(OpKernelContext* const context) override {
460 ResourceHandle handle;
461 OP_REQUIRES_OK(context,
462 HandleFromInput(context, kResourceHandleName, &handle));
463 core::RefCountPtr<QuantileStreamResource> stream_resource;
464 // Create a reference to the underlying resource using the handle.
465 OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
466 // Remove the reference at the end of this scope.
467 mutex_lock l(*stream_resource->mutex());
468
469 const int64_t num_streams = stream_resource->num_streams();
470 CHECK_EQ(num_features_, num_streams);
471 OpOutputList bucket_boundaries_list;
472 OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
473 &bucket_boundaries_list));
474
475 auto do_quantile_get_buckets = [&](const int64_t begin, const int64_t end) {
476 // Iterating over all streams.
477 for (int64_t stream_idx = begin; stream_idx < end; stream_idx++) {
478 const auto& boundaries = stream_resource->boundaries(stream_idx);
479 Tensor* bucket_boundaries_t = nullptr;
480 OP_REQUIRES_OK(context,
481 bucket_boundaries_list.allocate(
482 stream_idx, {static_cast<int64>(boundaries.size())},
483 &bucket_boundaries_t));
484 auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
485 memcpy(quantiles_flat, boundaries.data(),
486 sizeof(float) * boundaries.size());
487 }
488 };
489
490 // TODO(tanzheny): comment on the magic number.
491 const int64_t kCostPerUnit = 500 * num_streams;
492 const DeviceBase::CpuWorkerThreads& worker_threads =
493 *context->device()->tensorflow_cpu_worker_threads();
494 Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
495 kCostPerUnit, do_quantile_get_buckets);
496 }
497
498 private:
499 int64 num_features_;
500 };
501
502 REGISTER_KERNEL_BUILDER(
503 Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
504 .Device(DEVICE_CPU),
505 BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
506
507 // Given the calculated quantiles thresholds and input data, this operation
508 // converts the input features into the buckets (categorical values), depending
509 // on which quantile they fall into.
510 class BoostedTreesBucketizeOp : public OpKernel {
511 public:
BoostedTreesBucketizeOp(OpKernelConstruction * const context)512 explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
513 : OpKernel(context) {
514 OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
515 }
516
Compute(OpKernelContext * const context)517 void Compute(OpKernelContext* const context) override {
518 // Read float features list;
519 OpInputList float_features_list;
520 OP_REQUIRES_OK(
521 context, context->input_list(kFloatFeaturesName, &float_features_list));
522 OpInputList bucket_boundaries_list;
523 OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
524 &bucket_boundaries_list));
525 OP_REQUIRES(context,
526 tensorflow::TensorShapeUtils::IsVector(
527 bucket_boundaries_list[0].shape()),
528 errors::InvalidArgument(
529 strings::Printf("Buckets should be flat vectors.")));
530 OpOutputList buckets_list;
531 OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
532
533 auto do_quantile_get_quantiles = [&](const int64_t begin,
534 const int64_t end) {
535 // Iterating over all resources
536 for (int64_t feature_idx = begin; feature_idx < end; feature_idx++) {
537 const Tensor& values_tensor = float_features_list[feature_idx];
538 const int64_t num_values = values_tensor.dim_size(0);
539
540 Tensor* output_t = nullptr;
541 OP_REQUIRES_OK(context,
542 buckets_list.allocate(
543 feature_idx, TensorShape({num_values}), &output_t));
544 auto output = output_t->flat<int32>();
545
546 const std::vector<float>& bucket_boundaries_vector =
547 GetBuckets(feature_idx, bucket_boundaries_list);
548 auto flat_values = values_tensor.flat<float>();
549 const auto& iter_begin = bucket_boundaries_vector.begin();
550 const auto& iter_end = bucket_boundaries_vector.end();
551 for (int64_t instance = 0; instance < num_values; instance++) {
552 if (iter_begin == iter_end) {
553 output(instance) = 0;
554 continue;
555 }
556 const float value = flat_values(instance);
557 auto bucket_iter = std::lower_bound(iter_begin, iter_end, value);
558 if (bucket_iter == iter_end) {
559 --bucket_iter;
560 }
561 const int32_t bucket = static_cast<int32>(bucket_iter - iter_begin);
562 // Bucket id.
563 output(instance) = bucket;
564 }
565 }
566 };
567
568 // TODO(tanzheny): comment on the magic number.
569 const int64_t kCostPerUnit = 500 * num_features_;
570 const DeviceBase::CpuWorkerThreads& worker_threads =
571 *context->device()->tensorflow_cpu_worker_threads();
572 Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
573 kCostPerUnit, do_quantile_get_quantiles);
574 }
575
576 private:
577 int64 num_features_;
578 };
579
580 REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
581 BoostedTreesBucketizeOp);
582
583 } // namespace tensorflow
584