• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 #include <iterator>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
27 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
28 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/refcount.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 
39 const char* const kExampleWeightsName = "example_weights";
40 const char* const kMaxElementsName = "max_elements";
41 const char* const kGenerateQuantiles = "generate_quantiles";
42 const char* const kNumBucketsName = "num_buckets";
43 const char* const kEpsilonName = "epsilon";
44 const char* const kBucketBoundariesName = "bucket_boundaries";
45 const char* const kBucketsName = "buckets";
46 const char* const kSummariesName = "summaries";
47 const char* const kNumStreamsName = "num_streams";
48 const char* const kNumFeaturesName = "num_features";
49 const char* const kFloatFeaturesName = "float_values";
50 const char* const kResourceHandleName = "quantile_stream_resource_handle";
51 
52 using QuantileStreamResource = BoostedTreesQuantileStreamResource;
53 using QuantileStream =
54     boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
55 using QuantileSummary =
56     boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
57 using QuantileSummaryEntry =
58     boosted_trees::quantiles::WeightedQuantilesSummary<float,
59                                                        float>::SummaryEntry;
60 
61 // Generates quantiles on a finalized QuantileStream.
GenerateBoundaries(const QuantileStream & stream,const int64_t num_boundaries)62 std::vector<float> GenerateBoundaries(const QuantileStream& stream,
63                                       const int64_t num_boundaries) {
64   std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
65 
66   // Uniquify elements as we may get dupes.
67   auto end_it = std::unique(boundaries.begin(), boundaries.end());
68   boundaries.resize(std::distance(boundaries.begin(), end_it));
69   return boundaries;
70 }
71 
72 // Generates quantiles on a finalized QuantileStream.
GenerateQuantiles(const QuantileStream & stream,const int64_t num_quantiles)73 std::vector<float> GenerateQuantiles(const QuantileStream& stream,
74                                      const int64_t num_quantiles) {
75   // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
76   // will be returned.
77   std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
78   CHECK_EQ(boundaries.size(), num_quantiles);
79   return boundaries;
80 }
81 
GetBuckets(const int32_t feature,const OpInputList & buckets_list)82 std::vector<float> GetBuckets(const int32_t feature,
83                               const OpInputList& buckets_list) {
84   const auto& buckets = buckets_list[feature].flat<float>();
85   std::vector<float> buckets_vector(buckets.data(),
86                                     buckets.data() + buckets.size());
87   return buckets_vector;
88 }
89 
90 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
91 
92 REGISTER_KERNEL_BUILDER(
93     Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
94     IsResourceInitialized<BoostedTreesQuantileStreamResource>);
95 
96 class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
97  public:
BoostedTreesCreateQuantileStreamResourceOp(OpKernelConstruction * const context)98   explicit BoostedTreesCreateQuantileStreamResourceOp(
99       OpKernelConstruction* const context)
100       : OpKernel(context) {
101     OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
102   }
103 
Compute(OpKernelContext * context)104   void Compute(OpKernelContext* context) override {
105     // Only create one, if one does not exist already. Report status for all
106     // other exceptions. If one already exists, it unrefs the new one.
107     // An epsilon value of zero could cause performance issues and is therefore,
108     // disallowed.
109     const Tensor* epsilon_t;
110     OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
111     float epsilon = epsilon_t->scalar<float>()();
112     OP_REQUIRES(
113         context, epsilon > 0,
114         errors::InvalidArgument("An epsilon value of zero is not allowed."));
115 
116     const Tensor* num_streams_t;
117     OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
118     int64_t num_streams = num_streams_t->scalar<int64>()();
119     OP_REQUIRES(context, num_streams >= 0,
120                 errors::InvalidArgument(
121                     "Num_streams input cannot be a negative integer"));
122 
123     auto result =
124         new QuantileStreamResource(epsilon, max_elements_, num_streams);
125     auto status = CreateResource(context, HandleFromInput(context, 0), result);
126     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
127       OP_REQUIRES(context, false, status);
128     }
129   }
130 
131  private:
132   // An upper bound on the number of entries that the summaries might have
133   // for a feature.
134   int64 max_elements_;
135 };
136 
137 REGISTER_KERNEL_BUILDER(
138     Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
139     BoostedTreesCreateQuantileStreamResourceOp);
140 
141 class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
142  public:
BoostedTreesMakeQuantileSummariesOp(OpKernelConstruction * const context)143   explicit BoostedTreesMakeQuantileSummariesOp(
144       OpKernelConstruction* const context)
145       : OpKernel(context) {
146     OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
147   }
148 
Compute(OpKernelContext * const context)149   void Compute(OpKernelContext* const context) override {
150     // Read float features list;
151     OpInputList float_features_list;
152     OP_REQUIRES_OK(
153         context, context->input_list(kFloatFeaturesName, &float_features_list));
154 
155     // Parse example weights and get batch size.
156     const Tensor* example_weights_t;
157     OP_REQUIRES_OK(context,
158                    context->input(kExampleWeightsName, &example_weights_t));
159     DCHECK(float_features_list.size() > 0) << "Got empty feature list";
160     auto example_weights = example_weights_t->flat<float>();
161     const int64_t weight_size = example_weights.size();
162     const int64_t batch_size = float_features_list[0].flat<float>().size();
163     OP_REQUIRES(
164         context, weight_size == 1 || weight_size == batch_size,
165         errors::InvalidArgument(strings::Printf(
166             "Weights should be a single value or same size as features.")));
167     const Tensor* epsilon_t;
168     OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
169     float epsilon = epsilon_t->scalar<float>()();
170 
171     OpOutputList summaries_output_list;
172     OP_REQUIRES_OK(
173         context, context->output_list(kSummariesName, &summaries_output_list));
174 
175     auto do_quantile_summary_gen = [&](const int64_t begin, const int64_t end) {
176       // Iterating features.
177       for (int64_t index = begin; index < end; index++) {
178         const auto feature_values = float_features_list[index].flat<float>();
179         QuantileStream stream(epsilon, batch_size + 1);
180         // Run quantile summary generation.
181         for (int64_t j = 0; j < batch_size; j++) {
182           stream.PushEntry(feature_values(j), (weight_size > 1)
183                                                   ? example_weights(j)
184                                                   : example_weights(0));
185         }
186         stream.Finalize();
187         const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
188         Tensor* output_t;
189         OP_REQUIRES_OK(
190             context,
191             summaries_output_list.allocate(
192                 index,
193                 TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
194                 &output_t));
195         auto output = output_t->matrix<float>();
196         for (auto row = 0; row < summary_entry_list.size(); row++) {
197           const auto& entry = summary_entry_list[row];
198           output(row, 0) = entry.value;
199           output(row, 1) = entry.weight;
200           output(row, 2) = entry.min_rank;
201           output(row, 3) = entry.max_rank;
202         }
203       }
204     };
205     // TODO(tanzheny): comment on the magic number.
206     const int64_t kCostPerUnit = 500 * batch_size;
207     const DeviceBase::CpuWorkerThreads& worker_threads =
208         *context->device()->tensorflow_cpu_worker_threads();
209     Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
210           kCostPerUnit, do_quantile_summary_gen);
211   }
212 
213  private:
214   int64 num_features_;
215 };
216 
217 REGISTER_KERNEL_BUILDER(
218     Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
219     BoostedTreesMakeQuantileSummariesOp);
220 
221 class BoostedTreesFlushQuantileSummariesOp : public OpKernel {
222  public:
BoostedTreesFlushQuantileSummariesOp(OpKernelConstruction * const context)223   explicit BoostedTreesFlushQuantileSummariesOp(
224       OpKernelConstruction* const context)
225       : OpKernel(context) {
226     OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
227   }
228 
Compute(OpKernelContext * const context)229   void Compute(OpKernelContext* const context) override {
230     ResourceHandle handle;
231     OP_REQUIRES_OK(context,
232                    HandleFromInput(context, kResourceHandleName, &handle));
233     core::RefCountPtr<QuantileStreamResource> stream_resource;
234     OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
235     // Remove the reference at the end of this scope.
236     mutex_lock l(*stream_resource->mutex());
237 
238     OpOutputList summaries_output_list;
239     OP_REQUIRES_OK(
240         context, context->output_list(kSummariesName, &summaries_output_list));
241 
242     auto do_quantile_summary_gen = [&](const int64_t begin, const int64_t end) {
243       // Iterating features.
244       for (int64_t index = begin; index < end; index++) {
245         QuantileStream* stream = stream_resource->stream(index);
246         stream->Finalize();
247 
248         const auto summary_list = stream->GetFinalSummary().GetEntryList();
249         Tensor* output_t;
250         const int64_t summary_list_size =
251             static_cast<int64>(summary_list.size());
252         OP_REQUIRES_OK(context, summaries_output_list.allocate(
253                                     index, TensorShape({summary_list_size, 4}),
254                                     &output_t));
255         auto output = output_t->matrix<float>();
256         for (auto row = 0; row < summary_list_size; row++) {
257           const auto& entry = summary_list[row];
258           output(row, 0) = entry.value;
259           output(row, 1) = entry.weight;
260           output(row, 2) = entry.min_rank;
261           output(row, 3) = entry.max_rank;
262         }
263       }
264     };
265     // TODO(tanzheny): comment on the magic number.
266     const int64_t kCostPerUnit = 500 * num_features_;
267     const DeviceBase::CpuWorkerThreads& worker_threads =
268         *context->device()->tensorflow_cpu_worker_threads();
269     Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
270           kCostPerUnit, do_quantile_summary_gen);
271     stream_resource->ResetStreams();
272   }
273 
274  private:
275   int64 num_features_;
276 };
277 
278 REGISTER_KERNEL_BUILDER(
279     Name("BoostedTreesFlushQuantileSummaries").Device(DEVICE_CPU),
280     BoostedTreesFlushQuantileSummariesOp);
281 
282 class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
283  public:
BoostedTreesQuantileStreamResourceAddSummariesOp(OpKernelConstruction * const context)284   explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
285       OpKernelConstruction* const context)
286       : OpKernel(context) {}
287 
Compute(OpKernelContext * context)288   void Compute(OpKernelContext* context) override {
289     ResourceHandle handle;
290     OP_REQUIRES_OK(context,
291                    HandleFromInput(context, kResourceHandleName, &handle));
292     core::RefCountPtr<QuantileStreamResource> stream_resource;
293     // Create a reference to the underlying resource using the handle.
294     OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
295     // Remove the reference at the end of this scope.
296     mutex_lock l(*stream_resource->mutex());
297 
298     OpInputList summaries_list;
299     OP_REQUIRES_OK(context,
300                    context->input_list(kSummariesName, &summaries_list));
301     int32_t num_streams = stream_resource->num_streams();
302     CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
303 
304     auto do_quantile_add_summary = [&](const int64_t begin, const int64_t end) {
305       // Iterating all features.
306       for (int64_t feature_idx = begin; feature_idx < end; ++feature_idx) {
307         QuantileStream* stream = stream_resource->stream(feature_idx);
308         if (stream->IsFinalized()) {
309           VLOG(1) << "QuantileStream has already been finalized for feature"
310                   << feature_idx << ".";
311           continue;
312         }
313         const Tensor& summaries = summaries_list[feature_idx];
314         const auto summary_values = summaries.matrix<float>();
315         const auto& tensor_shape = summaries.shape();
316         const int64_t entries_size = tensor_shape.dim_size(0);
317         CHECK_EQ(tensor_shape.dim_size(1), 4);
318         std::vector<QuantileSummaryEntry> summary_entries;
319         summary_entries.reserve(entries_size);
320         for (int64_t i = 0; i < entries_size; i++) {
321           float value = summary_values(i, 0);
322           float weight = summary_values(i, 1);
323           float min_rank = summary_values(i, 2);
324           float max_rank = summary_values(i, 3);
325           QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
326           summary_entries.push_back(entry);
327         }
328         stream_resource->stream(feature_idx)->PushSummary(summary_entries);
329       }
330     };
331 
332     // TODO(tanzheny): comment on the magic number.
333     const int64_t kCostPerUnit = 500 * num_streams;
334     const DeviceBase::CpuWorkerThreads& worker_threads =
335         *context->device()->tensorflow_cpu_worker_threads();
336     Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
337           kCostPerUnit, do_quantile_add_summary);
338   }
339 };
340 
341 REGISTER_KERNEL_BUILDER(
342     Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
343     BoostedTreesQuantileStreamResourceAddSummariesOp);
344 
345 class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
346  public:
BoostedTreesQuantileStreamResourceDeserializeOp(OpKernelConstruction * const context)347   explicit BoostedTreesQuantileStreamResourceDeserializeOp(
348       OpKernelConstruction* const context)
349       : OpKernel(context) {
350     OP_REQUIRES_OK(context, context->GetAttr(kNumStreamsName, &num_features_));
351   }
352 
Compute(OpKernelContext * context)353   void Compute(OpKernelContext* context) override {
354     core::RefCountPtr<QuantileStreamResource> streams_resource;
355     // Create a reference to the underlying resource using the handle.
356     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
357                                            &streams_resource));
358     // Remove the reference at the end of this scope.
359     mutex_lock l(*streams_resource->mutex());
360 
361     OpInputList bucket_boundaries_list;
362     OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
363                                                 &bucket_boundaries_list));
364 
365     auto do_quantile_deserialize = [&](const int64_t begin, const int64_t end) {
366       // Iterating over all streams.
367       for (int64_t stream_idx = begin; stream_idx < end; stream_idx++) {
368         const Tensor& bucket_boundaries_t = bucket_boundaries_list[stream_idx];
369         const auto& bucket_boundaries = bucket_boundaries_t.vec<float>();
370         std::vector<float> result;
371         result.reserve(bucket_boundaries.size());
372         for (size_t i = 0; i < bucket_boundaries.size(); ++i) {
373           result.push_back(bucket_boundaries(i));
374         }
375         streams_resource->set_boundaries(result, stream_idx);
376       }
377     };
378 
379     // TODO(tanzheny): comment on the magic number.
380     const int64_t kCostPerUnit = 500 * num_features_;
381     const DeviceBase::CpuWorkerThreads& worker_threads =
382         *context->device()->tensorflow_cpu_worker_threads();
383     Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
384           kCostPerUnit, do_quantile_deserialize);
385   }
386 
387  private:
388   int64 num_features_;
389 };
390 
391 REGISTER_KERNEL_BUILDER(
392     Name("BoostedTreesQuantileStreamResourceDeserialize").Device(DEVICE_CPU),
393     BoostedTreesQuantileStreamResourceDeserializeOp);
394 
395 class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
396  public:
BoostedTreesQuantileStreamResourceFlushOp(OpKernelConstruction * const context)397   explicit BoostedTreesQuantileStreamResourceFlushOp(
398       OpKernelConstruction* const context)
399       : OpKernel(context) {
400     OP_REQUIRES_OK(context,
401                    context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
402   }
403 
Compute(OpKernelContext * context)404   void Compute(OpKernelContext* context) override {
405     ResourceHandle handle;
406     OP_REQUIRES_OK(context,
407                    HandleFromInput(context, kResourceHandleName, &handle));
408     core::RefCountPtr<QuantileStreamResource> stream_resource;
409     // Create a reference to the underlying resource using the handle.
410     OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
411     // Remove the reference at the end of this scope.
412     mutex_lock l(*stream_resource->mutex());
413 
414     const Tensor* num_buckets_t;
415     OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
416     const int64_t num_buckets = num_buckets_t->scalar<int64>()();
417     const int64_t num_streams = stream_resource->num_streams();
418 
419     auto do_quantile_flush = [&](const int64_t begin, const int64_t end) {
420       // Iterating over all streams.
421       for (int64_t stream_idx = begin; stream_idx < end; ++stream_idx) {
422         QuantileStream* stream = stream_resource->stream(stream_idx);
423         stream->Finalize();
424         stream_resource->set_boundaries(
425             generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
426                                 : GenerateBoundaries(*stream, num_buckets),
427             stream_idx);
428       }
429     };
430 
431     // TODO(tanzheny): comment on the magic number.
432     const int64_t kCostPerUnit = 500 * num_streams;
433     const DeviceBase::CpuWorkerThreads& worker_threads =
434         *context->device()->tensorflow_cpu_worker_threads();
435     Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
436           kCostPerUnit, do_quantile_flush);
437 
438     stream_resource->ResetStreams();
439     stream_resource->set_buckets_ready(true);
440   }
441 
442  private:
443   bool generate_quantiles_;
444 };
445 
446 REGISTER_KERNEL_BUILDER(
447     Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
448     BoostedTreesQuantileStreamResourceFlushOp);
449 
450 class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
451     : public OpKernel {
452  public:
BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(OpKernelConstruction * const context)453   explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
454       OpKernelConstruction* const context)
455       : OpKernel(context) {
456     OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
457   }
458 
Compute(OpKernelContext * const context)459   void Compute(OpKernelContext* const context) override {
460     ResourceHandle handle;
461     OP_REQUIRES_OK(context,
462                    HandleFromInput(context, kResourceHandleName, &handle));
463     core::RefCountPtr<QuantileStreamResource> stream_resource;
464     // Create a reference to the underlying resource using the handle.
465     OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
466     // Remove the reference at the end of this scope.
467     mutex_lock l(*stream_resource->mutex());
468 
469     const int64_t num_streams = stream_resource->num_streams();
470     CHECK_EQ(num_features_, num_streams);
471     OpOutputList bucket_boundaries_list;
472     OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
473                                                  &bucket_boundaries_list));
474 
475     auto do_quantile_get_buckets = [&](const int64_t begin, const int64_t end) {
476       // Iterating over all streams.
477       for (int64_t stream_idx = begin; stream_idx < end; stream_idx++) {
478         const auto& boundaries = stream_resource->boundaries(stream_idx);
479         Tensor* bucket_boundaries_t = nullptr;
480         OP_REQUIRES_OK(context,
481                        bucket_boundaries_list.allocate(
482                            stream_idx, {static_cast<int64>(boundaries.size())},
483                            &bucket_boundaries_t));
484         auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
485         memcpy(quantiles_flat, boundaries.data(),
486                sizeof(float) * boundaries.size());
487       }
488     };
489 
490     // TODO(tanzheny): comment on the magic number.
491     const int64_t kCostPerUnit = 500 * num_streams;
492     const DeviceBase::CpuWorkerThreads& worker_threads =
493         *context->device()->tensorflow_cpu_worker_threads();
494     Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
495           kCostPerUnit, do_quantile_get_buckets);
496   }
497 
498  private:
499   int64 num_features_;
500 };
501 
502 REGISTER_KERNEL_BUILDER(
503     Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
504         .Device(DEVICE_CPU),
505     BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
506 
507 // Given the calculated quantiles thresholds and input data, this operation
508 // converts the input features into the buckets (categorical values), depending
509 // on which quantile they fall into.
510 class BoostedTreesBucketizeOp : public OpKernel {
511  public:
BoostedTreesBucketizeOp(OpKernelConstruction * const context)512   explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
513       : OpKernel(context) {
514     OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
515   }
516 
Compute(OpKernelContext * const context)517   void Compute(OpKernelContext* const context) override {
518     // Read float features list;
519     OpInputList float_features_list;
520     OP_REQUIRES_OK(
521         context, context->input_list(kFloatFeaturesName, &float_features_list));
522     OpInputList bucket_boundaries_list;
523     OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
524                                                 &bucket_boundaries_list));
525     OP_REQUIRES(context,
526                 tensorflow::TensorShapeUtils::IsVector(
527                     bucket_boundaries_list[0].shape()),
528                 errors::InvalidArgument(
529                     strings::Printf("Buckets should be flat vectors.")));
530     OpOutputList buckets_list;
531     OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
532 
533     auto do_quantile_get_quantiles = [&](const int64_t begin,
534                                          const int64_t end) {
535       // Iterating over all resources
536       for (int64_t feature_idx = begin; feature_idx < end; feature_idx++) {
537         const Tensor& values_tensor = float_features_list[feature_idx];
538         const int64_t num_values = values_tensor.dim_size(0);
539 
540         Tensor* output_t = nullptr;
541         OP_REQUIRES_OK(context,
542                        buckets_list.allocate(
543                            feature_idx, TensorShape({num_values}), &output_t));
544         auto output = output_t->flat<int32>();
545 
546         const std::vector<float>& bucket_boundaries_vector =
547             GetBuckets(feature_idx, bucket_boundaries_list);
548         auto flat_values = values_tensor.flat<float>();
549         const auto& iter_begin = bucket_boundaries_vector.begin();
550         const auto& iter_end = bucket_boundaries_vector.end();
551         for (int64_t instance = 0; instance < num_values; instance++) {
552           if (iter_begin == iter_end) {
553             output(instance) = 0;
554             continue;
555           }
556           const float value = flat_values(instance);
557           auto bucket_iter = std::lower_bound(iter_begin, iter_end, value);
558           if (bucket_iter == iter_end) {
559             --bucket_iter;
560           }
561           const int32_t bucket = static_cast<int32>(bucket_iter - iter_begin);
562           // Bucket id.
563           output(instance) = bucket;
564         }
565       }
566     };
567 
568     // TODO(tanzheny): comment on the magic number.
569     const int64_t kCostPerUnit = 500 * num_features_;
570     const DeviceBase::CpuWorkerThreads& worker_threads =
571         *context->device()->tensorflow_cpu_worker_threads();
572     Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
573           kCostPerUnit, do_quantile_get_quantiles);
574   }
575 
576  private:
577   int64 num_features_;
578 };
579 
580 REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
581                         BoostedTreesBucketizeOp);
582 
583 }  // namespace tensorflow
584