• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 #include <iterator>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
23 #include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h"
24 #include "tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/work_sharder.h"
35 
36 namespace tensorflow {
37 
38 using ::boosted_trees::QuantileConfig;
39 using boosted_trees::QuantileStreamResource;
40 using boosted_trees::utils::TensorUtils;
41 
42 namespace {
43 const char* const kExampleWeightsName = "example_weights";
44 const char* const kMaxElementsName = "max_elements";
45 const char* const kNextStampTokenName = "next_stamp_token";
46 const char* const kStampTokenName = "stamp_token";
47 const char* const kAreBucketsReadyName = "are_buckets_ready";
48 const char* const kGenerateQuantiles = "generate_quantiles";
49 // Names for sparse arguments.
50 const char* const kNumSparseFeaturesName = "num_sparse_features";
51 const char* const kSparseBucketsName = "sparse_buckets";
52 const char* const kSparseValuesName = "sparse_values";
53 const char* const kSparseIndicesName = "sparse_indices";
54 const char* const kSparseSummariesName = "sparse_summaries";
55 const char* const kSparseConfigName = "sparse_config";
56 const char* const kSparseOutputTensorName = "sparse_quantiles";
57 // Names for dense arguments.
58 const char* const kDenseBucketsName = "dense_buckets";
59 const char* const kDenseConfigName = "dense_config";
60 const char* const kDenseOutputTensorName = "dense_quantiles";
61 const char* const kDenseSummariesName = "dense_summaries";
62 const char* const kDenseValuesName = "dense_values";
63 const char* const kNumDenseFeaturesName = "num_dense_features";
64 const char* const kResourceHandlesName = "quantile_accumulator_handles";
65 const char* const kNumQuantilesName = "num_quantiles";
66 const char* const kEpsilonName = "epsilon";
67 const char* const kBucketsName = "buckets";
68 const char* const kStreamStateName = "stream_state";
69 const char* const kSummariesName = "summaries";
70 
71 using QuantileStream =
72     boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
73 using QuantileSummary =
74     boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
75 using QuantileSummaryEntry =
76     boosted_trees::quantiles::WeightedQuantilesSummary<float,
77                                                        float>::SummaryEntry;
78 
GetBuckets(const int32 feature,const OpInputList & buckets_list)79 std::vector<float> GetBuckets(const int32 feature,
80                               const OpInputList& buckets_list) {
81   const auto& buckets = buckets_list[feature].flat<float>();
82   std::vector<float> buckets_vector(buckets.data(),
83                                     buckets.data() + buckets.size());
84   return buckets_vector;
85 }
86 
GetFeatureDimension(const int32 feature_index,const int64 instance,const OpInputList * const indices_list)87 int32 GetFeatureDimension(const int32 feature_index, const int64 instance,
88                           const OpInputList* const indices_list) {
89   if (indices_list != nullptr) {
90     // Sparse multidimensional.
91     return (*indices_list)[feature_index].matrix<int64>()(instance, 1);
92   }
93   // No indices, assume one-dimensional tensor.
94   return 0;
95 }
96 
97 // Allows quantization for each of multiple dimensions of a sparse feature.
QuantizeFeatures(const string & output_name,const OpInputList & values_list,const OpInputList & buckets_list,const OpInputList * const indices_list,OpKernelContext * const context)98 void QuantizeFeatures(
99     const string& output_name, const OpInputList& values_list,
100     const OpInputList& buckets_list,
101     const OpInputList* const
102         indices_list /** Optional, provide for sparse features **/,
103     OpKernelContext* const context) {
104   if (values_list.size() == 0) {
105     return;
106   }
107   OpOutputList output_list;
108   OP_REQUIRES_OK(context, context->output_list(output_name, &output_list));
109 
110   for (int32 feature_index = 0; feature_index < values_list.size();
111        ++feature_index) {
112     const Tensor& values_tensor = values_list[feature_index];
113     const int64 num_values = values_tensor.dim_size(0);
114 
115     Tensor* output_t = nullptr;
116     // Output will have bucket id and dimension of the features for that bucket.
117     OP_REQUIRES_OK(
118         context, output_list.allocate(feature_index,
119                                       TensorShape({num_values, 2}), &output_t));
120 
121     auto output = output_t->matrix<int32>();
122 
123     const std::vector<float>& buckets_vector =
124         GetBuckets(feature_index, buckets_list);
125     auto flat_values = values_tensor.flat<float>();
126     for (int64 instance = 0; instance < num_values; ++instance) {
127       const float value = flat_values(instance);
128       CHECK(!buckets_vector.empty())
129           << "Got empty buckets for feature " << feature_index;
130       auto bucket_iter =
131           std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value);
132       if (bucket_iter == buckets_vector.end()) {
133         --bucket_iter;
134       }
135       const int32 bucket =
136           static_cast<int32>(bucket_iter - buckets_vector.begin());
137       // Bucket id.
138       output(instance, 0) = bucket;
139       // Dimension.
140       output(instance, 1) =
141           GetFeatureDimension(feature_index, instance, indices_list);
142     }
143   }
144 }
145 
146 // Validates attributes for the quantile ops.
ReadAndValidateAttributes(OpKernelConstruction * const context,int * num_dense_features,int * num_sparse_features)147 Status ReadAndValidateAttributes(OpKernelConstruction* const context,
148                                  int* num_dense_features,
149                                  int* num_sparse_features) {
150   TF_RETURN_IF_ERROR(
151       context->GetAttr(kNumDenseFeaturesName, num_dense_features));
152   TF_RETURN_IF_ERROR(
153       context->GetAttr(kNumSparseFeaturesName, num_sparse_features));
154   if ((*num_dense_features) + (*num_sparse_features) == 0) {
155     return errors::InvalidArgument(
156         "Please provide at least sparse or dense features.");
157   }
158   return Status::OK();
159 }
160 
ParseConfig(OpKernelConstruction * const context,const string & name,std::vector<QuantileConfig> * output)161 void ParseConfig(OpKernelConstruction* const context, const string& name,
162                  std::vector<QuantileConfig>* output) {
163   std::vector<string> serialized_config;
164   OP_REQUIRES_OK(context, context->GetAttr(name, &serialized_config));
165   output->reserve(serialized_config.size());
166   QuantileConfig tmp;
167   for (const auto& serialized_string : serialized_config) {
168     OP_REQUIRES(context, tmp.ParseFromString(serialized_string),
169                 errors::InvalidArgument("Malformed QuantileConfig passed in."));
170     output->push_back(tmp);
171   }
172 }
173 
174 // Generates quantiles on a finalized QuantileStream.
GenerateBoundaries(const QuantileStream & stream,int num_boundaries)175 std::vector<float> GenerateBoundaries(const QuantileStream& stream,
176                                       int num_boundaries) {
177   std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
178 
179   // Uniquify elements as we may get dupes.
180   auto end_it = std::unique(boundaries.begin(), boundaries.end());
181   boundaries.resize(std::distance(boundaries.begin(), end_it));
182   return boundaries;
183 }
184 
185 // Generates quantiles on a finalized QuantileStream.
GenerateQuantiles(const QuantileStream & stream,int num_quantiles)186 std::vector<float> GenerateQuantiles(const QuantileStream& stream,
187                                      int num_quantiles) {
188   // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
189   // will be returned.
190   std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles);
191   CHECK_EQ(boundaries.size(), num_quantiles + 1);
192   return boundaries;
193 }
194 
195 // Copies quantiles to output list.
CopyBoundaries(OpKernelContext * const context,const std::vector<float> & boundaries,const int64 index,OpOutputList * output_list)196 void CopyBoundaries(OpKernelContext* const context,
197                     const std::vector<float>& boundaries, const int64 index,
198                     OpOutputList* output_list) {
199   // Output to tensor.
200   Tensor* output_t = nullptr;
201   OP_REQUIRES_OK(
202       context, output_list->allocate(
203                    index, {static_cast<int64>(boundaries.size())}, &output_t));
204   auto* quantiles_flat = output_t->flat<float>().data();
205   memcpy(quantiles_flat, boundaries.data(), sizeof(float) * boundaries.size());
206 }
207 
CopySummaryToProto(const QuantileSummary & summary,::boosted_trees::QuantileSummaryState * summary_proto)208 void CopySummaryToProto(const QuantileSummary& summary,
209                         ::boosted_trees::QuantileSummaryState* summary_proto) {
210   summary_proto->mutable_entries()->Reserve(summary.Size());
211   for (const auto& entry : summary.GetEntryList()) {
212     auto* new_entry = summary_proto->add_entries();
213     new_entry->set_value(entry.value);
214     new_entry->set_weight(entry.weight);
215     new_entry->set_min_rank(entry.min_rank);
216     new_entry->set_max_rank(entry.max_rank);
217   }
218 }
219 
220 }  // namespace
221 
222 // Accumulator for Quantile Summaries.
223 REGISTER_RESOURCE_HANDLE_KERNEL(QuantileStreamResource);
224 
225 REGISTER_KERNEL_BUILDER(
226     Name("QuantileAccumulatorIsInitialized").Device(DEVICE_CPU),
227     IsResourceInitialized<QuantileStreamResource>);
228 
229 class CreateQuantileAccumulatorOp : public OpKernel {
230  public:
CreateQuantileAccumulatorOp(OpKernelConstruction * const context)231   explicit CreateQuantileAccumulatorOp(OpKernelConstruction* const context)
232       : OpKernel(context) {
233     OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
234     OP_REQUIRES_OK(context,
235                    context->GetAttr(kNumQuantilesName, &num_quantiles_));
236     OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
237     OP_REQUIRES_OK(context,
238                    context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
239   }
240 
Compute(OpKernelContext * context)241   void Compute(OpKernelContext* context) override {
242     // Only create one, if one does not exist already. Report status for all
243     // other exceptions. If one already exists, it unrefs the new one.
244     const Tensor* stamp_token_t;
245     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
246     // An epsilon value of zero could cause perfoamance issues and is therefore,
247     // disallowed.
248     OP_REQUIRES(
249         context, epsilon_ > 0,
250         errors::InvalidArgument("An epsilon value of zero is not allowed."));
251     auto result = new QuantileStreamResource(epsilon_, num_quantiles_,
252                                              max_elements_, generate_quantiles_,
253                                              stamp_token_t->scalar<int64>()());
254     auto status = CreateResource(context, HandleFromInput(context, 0), result);
255     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
256       OP_REQUIRES(context, false, status);
257     }
258   }
259 
260  private:
261   float epsilon_;
262   int32 num_quantiles_;
263   // An upper bound on the number of entries that the summaries might have
264   // for a feature.
265   int64 max_elements_;
266   bool generate_quantiles_;
267 };
268 
269 REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU),
270                         CreateQuantileAccumulatorOp);
271 
272 // Adds a summary to the quantile summary stream.
273 class QuantileAccumulatorAddSummariesOp : public OpKernel {
274  public:
QuantileAccumulatorAddSummariesOp(OpKernelConstruction * const context)275   explicit QuantileAccumulatorAddSummariesOp(
276       OpKernelConstruction* const context)
277       : OpKernel(context) {}
278 
Compute(OpKernelContext * context)279   void Compute(OpKernelContext* context) override {
280     OpInputList resource_handle_list;
281     OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
282                                                 &resource_handle_list));
283     OpInputList summary_list;
284     OP_REQUIRES_OK(context, context->input_list(kSummariesName, &summary_list));
285 
286     const Tensor* stamp_token_t;
287     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
288     int64 stamp_token = stamp_token_t->scalar<int64>()();
289 
290     thread::ThreadPool* const worker_threads =
291         context->device()->tensorflow_cpu_worker_threads()->workers;
292     boosted_trees::utils::ParallelFor(
293         resource_handle_list.size(), worker_threads->NumThreads(),
294         worker_threads,
295         [&context, &resource_handle_list, &summary_list, stamp_token](
296             int64 start, int64 end) {
297           for (int resource_handle_idx = start; resource_handle_idx < end;
298                ++resource_handle_idx) {
299             const ResourceHandle& handle =
300                 resource_handle_list[resource_handle_idx]
301                     .flat<ResourceHandle>()(0);
302             QuantileStreamResource* streams_resource;
303             // Create a reference to the underlying resource using the handle.
304             OP_REQUIRES_OK(context,
305                            LookupResource(context, handle, &streams_resource));
306             // Remove the reference at the end of this scope.
307             mutex_lock l(*streams_resource->mutex());
308             core::ScopedUnref unref_me(streams_resource);
309 
310             // If the stamp is invalid we drop the update.
311             if (!streams_resource->is_stamp_valid(stamp_token)) {
312               VLOG(1)
313                   << "Invalid stamp token in QuantileAccumulatorAddSummariesOp."
314                   << " Passed stamp token: " << stamp_token << " "
315                   << "Current token: " << streams_resource->stamp();
316               return;
317             }
318 
319             protobuf::Arena arena;
320             ::boosted_trees::QuantileSummaryState* summary_proto =
321                 protobuf::Arena::CreateMessage<
322                     ::boosted_trees::QuantileSummaryState>(&arena);
323             OP_REQUIRES(
324                 context,
325                 ParseProtoUnlimited(
326                     summary_proto,
327                     summary_list[resource_handle_idx].scalar<string>()()),
328                 errors::InvalidArgument("Unable to parse quantile summary."));
329             std::vector<QuantileSummaryEntry> entries;
330             entries.reserve(summary_proto->entries_size());
331             for (const auto& entry : summary_proto->entries()) {
332               entries.emplace_back(entry.value(), entry.weight(),
333                                    entry.min_rank(), entry.max_rank());
334             }
335 
336             // Add the summary to the quantile stream.
337             streams_resource->stream(stamp_token)->PushSummary(entries);
338           }
339         });
340   }
341 };
342 
343 REGISTER_KERNEL_BUILDER(
344     Name("QuantileAccumulatorAddSummaries").Device(DEVICE_CPU),
345     QuantileAccumulatorAddSummariesOp);
346 
347 // Generates summaries for given set of float values, and the given config.
348 class MakeQuantileSummariesOp : public OpKernel {
349  public:
MakeQuantileSummariesOp(OpKernelConstruction * const context)350   explicit MakeQuantileSummariesOp(OpKernelConstruction* const context)
351       : OpKernel(context) {
352     OP_REQUIRES_OK(context,
353                    ReadAndValidateAttributes(context, &num_dense_features_,
354                                              &num_sparse_features_));
355     OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
356   }
357 
Compute(OpKernelContext * const context)358   void Compute(OpKernelContext* const context) override {
359     // Read dense float features list;
360     OpInputList dense_float_features_list;
361     OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
362                                 context, &dense_float_features_list));
363 
364     // Read sparse float features list;
365     OpInputList sparse_float_feature_indices_list;
366     OpInputList sparse_float_feature_values_list;
367     OpInputList sparse_float_feature_shapes_list;
368     OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
369                                 context, &sparse_float_feature_indices_list,
370                                 &sparse_float_feature_values_list,
371                                 &sparse_float_feature_shapes_list));
372 
373     // Parse example weights and get batch size.
374     const Tensor* example_weights_t;
375     OP_REQUIRES_OK(context,
376                    context->input(kExampleWeightsName, &example_weights_t));
377     auto example_weights = example_weights_t->flat<float>();
378     const int64 batch_size = example_weights.size();
379 
380     OpOutputList sparse_summaries_output_list;
381     OP_REQUIRES_OK(context,
382                    context->output_list(kSparseSummariesName,
383                                         &sparse_summaries_output_list));
384     OpOutputList dense_summaries_output_list;
385     OP_REQUIRES_OK(context, context->output_list(kDenseSummariesName,
386                                                  &dense_summaries_output_list));
387 
388     auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
389       auto copy_over_summaries = [&](const QuantileStream& stream,
390                                      const int64 index,
391                                      OpOutputList* output_list) {
392         protobuf::Arena arena;
393         ::boosted_trees::QuantileSummaryState* summary_proto =
394             protobuf::Arena::CreateMessage<
395                 ::boosted_trees::QuantileSummaryState>(&arena);
396         const auto& summary = stream.GetFinalSummary();
397         CopySummaryToProto(summary, summary_proto);
398         // Output to tensor.
399         Tensor* output_t = nullptr;
400         OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t));
401         summary_proto->SerializeToString(&output_t->scalar<string>()());
402       };
403 
404       // These are blocks of ranges. We are iterating over both sparse and
405       // dense features i.e. [0, sparse_features.size() + dense_features.size()]
406       for (int64 i = begin; i < end; ++i) {
407         if (i < num_dense_features_) {
408           const int64 dense_index = i;
409           const auto dense_values =
410               dense_float_features_list[dense_index].flat<float>();
411           QuantileStream stream(epsilon_, batch_size + 1);
412           // Run quantile summary generation.
413           for (int64 j = 0; j < batch_size; ++j) {
414             stream.PushEntry(dense_values(j), example_weights(j));
415           }
416           stream.Finalize();
417           // Copy summaries to output.
418           copy_over_summaries(stream, dense_index,
419                               &dense_summaries_output_list);
420         } else {
421           const int64 sparse_index = i - num_dense_features_;
422           const auto sparse_values =
423               sparse_float_feature_values_list[sparse_index].flat<float>();
424           const auto sparse_indices =
425               sparse_float_feature_indices_list[sparse_index].matrix<int64>();
426           const auto dense_shape =
427               sparse_float_feature_shapes_list[sparse_index].flat<int64>();
428           OP_REQUIRES(context, batch_size == dense_shape(0),
429                       errors::InvalidArgument(
430                           "Sparse column shape doesn't match the batch size."));
431           QuantileStream stream(epsilon_, batch_size + 1);
432           // Run quantile summary generation.
433           const int64 num_sparse_rows =
434               sparse_float_feature_indices_list[sparse_index].dim_size(0);
435           for (int64 j = 0; j < num_sparse_rows; ++j) {
436             const int64 example_id = sparse_indices(j, 0);
437             stream.PushEntry(sparse_values(j), example_weights(example_id));
438           }
439           stream.Finalize();
440           // Copy summaries to output.
441           copy_over_summaries(stream, sparse_index,
442                               &sparse_summaries_output_list);
443         }
444       }
445     };
446     const int64 kCostPerUnit = 500 * batch_size;
447     const int64 num_features = num_sparse_features_ + num_dense_features_;
448     const DeviceBase::CpuWorkerThreads& worker_threads =
449         *context->device()->tensorflow_cpu_worker_threads();
450     Shard(worker_threads.num_threads, worker_threads.workers, num_features,
451           kCostPerUnit, do_quantile_summary_gen);
452   }
453 
454  private:
455   int num_dense_features_;
456   int num_sparse_features_;
457   float epsilon_;
458 };
459 
460 REGISTER_KERNEL_BUILDER(Name("MakeQuantileSummaries").Device(DEVICE_CPU),
461                         MakeQuantileSummariesOp);
462 
463 // Serializes the state of streams.
464 class QuantileAccumulatorSerializeOp : public OpKernel {
465  public:
QuantileAccumulatorSerializeOp(OpKernelConstruction * const context)466   explicit QuantileAccumulatorSerializeOp(OpKernelConstruction* const context)
467       : OpKernel(context) {}
468 
Compute(OpKernelContext * context)469   void Compute(OpKernelContext* context) override {
470     QuantileStreamResource* streams_resource;
471     // Create a reference to the underlying resource using the handle.
472     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
473                                            &streams_resource));
474     // Remove the reference at the end of this scope.
475     mutex_lock l(*streams_resource->mutex());
476     core::ScopedUnref unref_me(streams_resource);
477 
478     int64 stamp_token = streams_resource->stamp();
479     Tensor* stream_state_t;
480     OP_REQUIRES_OK(context,
481                    context->allocate_output(kStreamStateName, TensorShape({}),
482                                             &stream_state_t));
483     bool are_buckets_ready = streams_resource->are_buckets_ready();
484 
485     // We are iterating over both dense and sparse features. First we go
486     // through the dense features and then the sparse features.
487     const QuantileStream& stream = *streams_resource->stream(stamp_token);
488     const std::vector<float>& boundaries =
489         are_buckets_ready ? streams_resource->boundaries(stamp_token)
490                           : std::vector<float>();
491     protobuf::Arena arena;
492     ::boosted_trees::QuantileStreamState* stream_proto =
493         protobuf::Arena::CreateMessage<::boosted_trees::QuantileStreamState>(
494             &arena);
495     for (const auto& summary : stream.SerializeInternalSummaries()) {
496       CopySummaryToProto(summary, stream_proto->add_summaries());
497     }
498     stream_proto->SerializeToString(&stream_state_t->scalar<string>()());
499     Tensor* buckets_t = nullptr;
500     OP_REQUIRES_OK(
501         context,
502         context->allocate_output(
503             kBucketsName, {static_cast<int64>(boundaries.size())}, &buckets_t));
504     auto* quantiles_flat = buckets_t->flat<float>().data();
505     memcpy(quantiles_flat, boundaries.data(),
506            sizeof(float) * boundaries.size());
507     Tensor* stamp_token_t = nullptr;
508     OP_REQUIRES_OK(context,
509                    context->allocate_output(kStampTokenName, TensorShape({}),
510                                             &stamp_token_t));
511     stamp_token_t->scalar<int64>()() = stamp_token;
512     Tensor* are_buckets_ready_t = nullptr;
513     OP_REQUIRES_OK(context, context->allocate_output(kAreBucketsReadyName, {},
514                                                      &are_buckets_ready_t));
515     are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
516   }
517 };
518 
519 REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorSerialize").Device(DEVICE_CPU),
520                         QuantileAccumulatorSerializeOp);
521 
522 // Serializes the state of streams.
523 class QuantileAccumulatorDeserializeOp : public OpKernel {
524  public:
QuantileAccumulatorDeserializeOp(OpKernelConstruction * const context)525   explicit QuantileAccumulatorDeserializeOp(OpKernelConstruction* const context)
526       : OpKernel(context) {}
527 
Compute(OpKernelContext * context)528   void Compute(OpKernelContext* context) override {
529     QuantileStreamResource* streams_resource;
530     // Create a reference to the underlying resource using the handle.
531     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
532                                            &streams_resource));
533     // Remove the reference at the end of this scope.
534     mutex_lock l(*streams_resource->mutex());
535     core::ScopedUnref unref_me(streams_resource);
536 
537     int64 old_stamp_token = streams_resource->stamp();
538 
539     const Tensor* stream_state_t;
540     OP_REQUIRES_OK(context, context->input(kStreamStateName, &stream_state_t));
541     const Tensor* buckets_t;
542     OP_REQUIRES_OK(context, context->input(kBucketsName, &buckets_t));
543 
544     QuantileStream* stream = streams_resource->stream(old_stamp_token);
545     ::boosted_trees::QuantileStreamState state_proto;
546     OP_REQUIRES(
547         context,
548         ParseProtoUnlimited(&state_proto, stream_state_t->scalar<string>()()),
549         errors::InvalidArgument("Unabnle to parse quantile stream state."));
550     std::vector<QuantileSummary> summaries;
551     summaries.reserve(state_proto.summaries_size());
552     std::vector<QuantileSummaryEntry> entries;
553     for (const auto& summary : state_proto.summaries()) {
554       entries.clear();
555       entries.reserve(summary.entries_size());
556       for (const auto& entry : summary.entries()) {
557         entries.emplace_back(entry.value(), entry.weight(), entry.min_rank(),
558                              entry.max_rank());
559       }
560       summaries.emplace_back();
561       summaries[summaries.size() - 1].BuildFromSummaryEntries(entries);
562     }
563     stream->DeserializeInternalSummaries(summaries);
564 
565     const auto& buckets = buckets_t->vec<float>();
566     std::vector<float> result;
567     result.reserve(buckets.size());
568 
569     for (size_t i = 0; i < buckets.size(); ++i) {
570       result.push_back(buckets(i));
571     }
572     streams_resource->set_boundaries(old_stamp_token, result);
573 
574     // Reset the stamp token.
575     const Tensor* stamp_token_t = nullptr;
576     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
577     int64 stamp_token = stamp_token_t->scalar<int64>()();
578     streams_resource->set_stamp(stamp_token);
579 
580     const Tensor* are_buckets_ready_t = nullptr;
581     OP_REQUIRES_OK(context,
582                    context->input(kAreBucketsReadyName, &are_buckets_ready_t));
583     streams_resource->set_buckets_ready(are_buckets_ready_t->scalar<bool>()());
584   }
585 };
586 
587 REGISTER_KERNEL_BUILDER(
588     Name("QuantileAccumulatorDeserialize").Device(DEVICE_CPU),
589     QuantileAccumulatorDeserializeOp);
590 
591 // Flushes the quantile summary stream resource.
592 class QuantileAccumulatorFlushOp : public OpKernel {
593  public:
QuantileAccumulatorFlushOp(OpKernelConstruction * const context)594   explicit QuantileAccumulatorFlushOp(OpKernelConstruction* const context)
595       : OpKernel(context) {}
596 
Compute(OpKernelContext * context)597   void Compute(OpKernelContext* context) override {
598     QuantileStreamResource* streams_resource;
599     // Create a reference to the underlying resource using the handle.
600     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
601                                            &streams_resource));
602     // Remove the reference at the end of this scope.
603     mutex_lock l(*streams_resource->mutex());
604     core::ScopedUnref unref_me(streams_resource);
605 
606     const Tensor* next_stamp_token_t;
607     OP_REQUIRES_OK(context,
608                    context->input(kNextStampTokenName, &next_stamp_token_t));
609     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
610 
611     const Tensor* stamp_token_t;
612     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
613     int64 stamp_token = stamp_token_t->scalar<int64>()();
614     CHECK(streams_resource->is_stamp_valid(stamp_token))
615         << "Invalid stamp token in QuantileAccumulatorFlushOp. "
616         << "Passed stamp token: " << stamp_token << " "
617         << "Current token: " << streams_resource->stamp();
618     QuantileStream* stream = streams_resource->stream(stamp_token);
619     bool generate_quantiles = streams_resource->generate_quantiles();
620     stream->Finalize();
621 
622     streams_resource->set_boundaries(
623         stamp_token,
624         generate_quantiles
625             ? GenerateQuantiles(*stream, streams_resource->num_quantiles())
626             : GenerateBoundaries(*stream, streams_resource->num_quantiles()));
627 
628     streams_resource->Reset(next_stamp_token);
629   }
630 };
631 
632 REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorFlush").Device(DEVICE_CPU),
633                         QuantileAccumulatorFlushOp);
634 
635 // Flushes the quantile summary stream resource. This version computes the
636 // summary.
637 class QuantileAccumulatorFlushSummaryOp : public OpKernel {
638  public:
QuantileAccumulatorFlushSummaryOp(OpKernelConstruction * const context)639   explicit QuantileAccumulatorFlushSummaryOp(
640       OpKernelConstruction* const context)
641       : OpKernel(context) {}
642 
Compute(OpKernelContext * context)643   void Compute(OpKernelContext* context) override {
644     QuantileStreamResource* streams_resource;
645     // Create a reference to the underlying resource using the handle.
646     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
647                                            &streams_resource));
648     // Remove the reference at the end of this scope.
649     mutex_lock l(*streams_resource->mutex());
650     core::ScopedUnref unref_me(streams_resource);
651 
652     const Tensor* next_stamp_token_t;
653     OP_REQUIRES_OK(context,
654                    context->input(kNextStampTokenName, &next_stamp_token_t));
655     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
656 
657     const Tensor* stamp_token_t;
658     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
659     int64 stamp_token = stamp_token_t->scalar<int64>()();
660     CHECK(streams_resource->is_stamp_valid(stamp_token))
661         << "Invalid stamp token in QuantileAccumulatorFlushSummaryOp. "
662         << "Passed stamp token: " << stamp_token << " "
663         << "Current token: " << streams_resource->stamp();
664     QuantileStream* stream = streams_resource->stream(stamp_token);
665     stream->Finalize();
666     protobuf::Arena arena;
667     ::boosted_trees::QuantileSummaryState* summary_proto =
668         protobuf::Arena::CreateMessage<::boosted_trees::QuantileSummaryState>(
669             &arena);
670     const auto& summary = stream->GetFinalSummary();
671     CopySummaryToProto(summary, summary_proto);
672     // Output to tensor.
673     Tensor* output_t = nullptr;
674     OP_REQUIRES_OK(context,
675                    context->allocate_output(0, TensorShape({}), &output_t));
676     summary_proto->SerializeToString(&output_t->scalar<string>()());
677     streams_resource->Reset(next_stamp_token);
678   }
679 };
680 
681 REGISTER_KERNEL_BUILDER(
682     Name("QuantileAccumulatorFlushSummary").Device(DEVICE_CPU),
683     QuantileAccumulatorFlushSummaryOp);
684 
685 // Get bucket boundaries from summaries.
686 class QuantileAccumulatorGetBucketsOp : public OpKernel {
687  public:
QuantileAccumulatorGetBucketsOp(OpKernelConstruction * const context)688   explicit QuantileAccumulatorGetBucketsOp(OpKernelConstruction* const context)
689       : OpKernel(context) {}
690 
Compute(OpKernelContext * const context)691   void Compute(OpKernelContext* const context) override {
692     OpInputList resource_handle_list;
693     OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
694                                                 &resource_handle_list));
695     OpOutputList are_buckets_ready_list;
696     OP_REQUIRES_OK(context, context->output_list(kAreBucketsReadyName,
697                                                  &are_buckets_ready_list));
698     OpOutputList buckets_list;
699     OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
700     const Tensor* stamp_token_t;
701     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
702     int64 stamp_token = stamp_token_t->scalar<int64>()();
703 
704     thread::ThreadPool* const worker_threads =
705         context->device()->tensorflow_cpu_worker_threads()->workers;
706     boosted_trees::utils::ParallelFor(
707         resource_handle_list.size(), worker_threads->NumThreads(),
708         worker_threads,
709         [&context, &resource_handle_list, &are_buckets_ready_list,
710          &buckets_list, stamp_token](int64 start, int64 end) {
711           for (int resource_handle_idx = start; resource_handle_idx < end;
712                ++resource_handle_idx) {
713             const ResourceHandle& handle =
714                 resource_handle_list[resource_handle_idx]
715                     .flat<ResourceHandle>()(0);
716             QuantileStreamResource* streams_resource;
717             OP_REQUIRES_OK(context,
718                            LookupResource(context, handle, &streams_resource));
719             // Remove the reference at the end of this scope.
720             mutex_lock l(*streams_resource->mutex());
721             core::ScopedUnref unref_me(streams_resource);
722 
723             bool are_buckets_ready =
724                 streams_resource->is_stamp_valid(stamp_token) &&
725                 streams_resource->are_buckets_ready();
726 
727             Tensor* are_buckets_ready_t = nullptr;
728             OP_REQUIRES_OK(context,
729                            are_buckets_ready_list.allocate(
730                                resource_handle_idx, {}, &are_buckets_ready_t));
731             are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
732 
733             const std::vector<float>& boundaries =
734                 are_buckets_ready ? streams_resource->boundaries(stamp_token)
735                                   : std::vector<float>();
736             Tensor* output_t = nullptr;
737             OP_REQUIRES_OK(context, buckets_list.allocate(
738                                         resource_handle_idx,
739                                         {static_cast<int64>(boundaries.size())},
740                                         &output_t));
741             auto* quantiles_flat = output_t->flat<float>().data();
742             memcpy(quantiles_flat, boundaries.data(),
743                    sizeof(float) * boundaries.size());
744           }
745         });
746   }
747 };
748 
749 REGISTER_KERNEL_BUILDER(
750     Name("QuantileAccumulatorGetBuckets").Device(DEVICE_CPU),
751     QuantileAccumulatorGetBucketsOp);
752 
753 // Generates buckets for given set of float values, and the given config.
754 class QuantileBucketsOp : public OpKernel {
755  public:
QuantileBucketsOp(OpKernelConstruction * const context)756   explicit QuantileBucketsOp(OpKernelConstruction* const context)
757       : OpKernel(context) {
758     OP_REQUIRES_OK(context,
759                    ReadAndValidateAttributes(context, &num_dense_features_,
760                                              &num_sparse_features_));
761 
762     ParseConfig(context, kDenseConfigName, &dense_configs_);
763     OP_REQUIRES(context, dense_configs_.size() == num_dense_features_,
764                 errors::InvalidArgument(
765                     "Mismatch in number of dense quantile configs."));
766     ParseConfig(context, kSparseConfigName, &sparse_configs_);
767     OP_REQUIRES(context, sparse_configs_.size() == num_sparse_features_,
768                 errors::InvalidArgument(
769                     "Mismatch in number of sparse quantile configs."));
770   }
771 
Compute(OpKernelContext * const context)772   void Compute(OpKernelContext* const context) override {
773     // Read dense float features list;
774     OpInputList dense_float_features_list;
775     OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
776                                 context, &dense_float_features_list));
777 
778     // Read sparse float features list;
779     OpInputList sparse_float_feature_indices_list;
780     OpInputList sparse_float_feature_values_list;
781     OpInputList sparse_float_feature_shapes_list;
782     OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
783                                 context, &sparse_float_feature_indices_list,
784                                 &sparse_float_feature_values_list,
785                                 &sparse_float_feature_shapes_list));
786 
787     // Parse example weights and get batch size.
788     const Tensor* example_weights_t;
789     OP_REQUIRES_OK(context,
790                    context->input(kExampleWeightsName, &example_weights_t));
791     auto example_weights = example_weights_t->flat<float>();
792     const int64 batch_size = example_weights.size();
793 
794     OpOutputList sparse_buckets_output_list;
795     OP_REQUIRES_OK(context, context->output_list(kSparseBucketsName,
796                                                  &sparse_buckets_output_list));
797     OpOutputList dense_buckets_output_list;
798     OP_REQUIRES_OK(context, context->output_list(kDenseBucketsName,
799                                                  &dense_buckets_output_list));
800 
801     auto do_quantile_bucket_gen = [&](const int64 begin, const int64 end) {
802       // These are blocks of ranges. We are iterating over both sparse and
803       // dense features i.e. [0, sparse_features.size() + dense_features.size()]
804       for (int64 i = begin; i < end; ++i) {
805         if (i < sparse_configs_.size()) {
806           const int64 sparse_index = i;
807           const auto sparse_values =
808               sparse_float_feature_values_list[sparse_index].flat<float>();
809           const auto sparse_indices =
810               sparse_float_feature_indices_list[sparse_index].matrix<int64>();
811           QuantileStream stream(sparse_configs_[sparse_index].eps(),
812                                 batch_size);
813           // Run quantile summary generation.
814           const int64 num_sparse_rows =
815               sparse_float_feature_indices_list[sparse_index].dim_size(0);
816           for (int64 j = 0; j < num_sparse_rows; ++j) {
817             const int64 example_id = sparse_indices(j, 0);
818             stream.PushEntry(sparse_values(j), example_weights(example_id));
819           }
820           stream.Finalize();
821           // Create buckets.
822           const auto boundaries = GenerateBoundaries(
823               stream, sparse_configs_[sparse_index].num_quantiles());
824           CopyBoundaries(context, boundaries, sparse_index,
825                          &sparse_buckets_output_list);
826 
827         } else {
828           const int64 dense_index = i - sparse_configs_.size();
829           const auto dense_values =
830               dense_float_features_list[dense_index].flat<float>();
831           QuantileStream stream(dense_configs_[dense_index].eps(), batch_size);
832           // Run quantile summary generation.
833           for (int64 j = 0; j < batch_size; ++j) {
834             stream.PushEntry(dense_values(j), example_weights(j));
835           }
836           stream.Finalize();
837           // Create buckets.
838           const auto boundaries = GenerateBoundaries(
839               stream, dense_configs_[dense_index].num_quantiles());
840           CopyBoundaries(context, boundaries, dense_index,
841                          &dense_buckets_output_list);
842         }
843       }
844     };
845 
846     const int64 kCostPerUnit = 500 * batch_size;
847     const int64 num_features = sparse_configs_.size() + dense_configs_.size();
848     const DeviceBase::CpuWorkerThreads& worker_threads =
849         *context->device()->tensorflow_cpu_worker_threads();
850     Shard(worker_threads.num_threads, worker_threads.workers, num_features,
851           kCostPerUnit, do_quantile_bucket_gen);
852   }
853 
854  private:
855   int num_dense_features_;
856   int num_sparse_features_;
857   std::vector<QuantileConfig> dense_configs_;
858   std::vector<QuantileConfig> sparse_configs_;
859 };
860 
861 REGISTER_KERNEL_BUILDER(Name("QuantileBuckets").Device(DEVICE_CPU),
862                         QuantileBucketsOp);
863 
864 // Given the calculated quantiles thresholds and input data, this operation
865 // converts the input features into the buckets (categorical values), depending
866 // on which quantile they fall into.
867 class QuantilesOp : public OpKernel {
868  public:
QuantilesOp(OpKernelConstruction * const context)869   explicit QuantilesOp(OpKernelConstruction* const context)
870       : OpKernel(context) {
871     int num_dense_features;
872     int num_sparse_features;
873     OP_REQUIRES_OK(context,
874                    ReadAndValidateAttributes(context, &num_dense_features,
875                                              &num_sparse_features));
876   }
877 
Compute(OpKernelContext * const context)878   void Compute(OpKernelContext* const context) override {
879     // Dense features inputs
880     OpInputList dense_float_features_list;
881     OP_REQUIRES_OK(context, context->input_list(kDenseValuesName,
882                                                 &dense_float_features_list));
883     OpInputList dense_buckets_list;
884     OP_REQUIRES_OK(context,
885                    context->input_list(kDenseBucketsName, &dense_buckets_list));
886 
887     if (dense_buckets_list.size() > 0) {
888       // Check the first tensor to make sure it is the right shape
889       OP_REQUIRES(
890           context,
891           tensorflow::TensorShapeUtils::IsVector(dense_buckets_list[0].shape()),
892           errors::InvalidArgument(
893               strings::Printf("Dense buckets should be flat vectors")));
894     }
895 
896     // Sparse features inputs
897     OpInputList sparse_float_feature_values_list;
898     OP_REQUIRES_OK(context,
899                    context->input_list(kSparseValuesName,
900                                        &sparse_float_feature_values_list));
901 
902     OpInputList sparse_float_indices_list;
903     OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName,
904                                                 &sparse_float_indices_list));
905 
906     OpInputList sparse_buckets_list;
907     OP_REQUIRES_OK(
908         context, context->input_list(kSparseBucketsName, &sparse_buckets_list));
909 
910     if (sparse_buckets_list.size() > 0) {
911       OP_REQUIRES(
912           context,
913           tensorflow::TensorShapeUtils::IsVector(
914               sparse_buckets_list[0].shape()),
915           errors::InvalidArgument("Sparse buckets should be flat vectors"));
916     }
917 
918     // Quantize the feature values
919     QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list,
920                      dense_buckets_list, nullptr, context);
921 
922     QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list,
923                      sparse_buckets_list, &sparse_float_indices_list, context);
924   }
925 };
926 
927 REGISTER_KERNEL_BUILDER(Name("Quantiles").Device(DEVICE_CPU), QuantilesOp);
928 
929 template <typename T>
930 class BucketizeWithInputBoundariesOp : public OpKernel {
931  public:
BucketizeWithInputBoundariesOp(OpKernelConstruction * context)932   explicit BucketizeWithInputBoundariesOp(OpKernelConstruction* context)
933       : OpKernel(context) {}
934 
Compute(OpKernelContext * context)935   void Compute(OpKernelContext* context) override {
936     const Tensor& boundaries_tensor = context->input(1);
937     VLOG(1) << "boundaries has shape: "
938             << boundaries_tensor.shape().DebugString();
939     auto boundaries = boundaries_tensor.flat<float>();
940     std::vector<T> boundaries_vector;
941     boundaries_vector.reserve(boundaries.size());
942     for (size_t i = 0; i < boundaries.size(); i++) {
943       boundaries_vector.push_back(boundaries(i));
944       VLOG(1) << "boundaries(" << i << ") : " << boundaries(i);
945     }
946     OP_REQUIRES(
947         context,
948         std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()),
949         errors::InvalidArgument("Expected sorted boundaries"));
950 
951     const Tensor& input_tensor = context->input(0);
952     VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString()
953             << " Dtype: " << tensorflow::DataTypeString(input_tensor.dtype());
954     auto input = input_tensor.flat<T>();
955 
956     Tensor* output_tensor = nullptr;
957     OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
958                                                      &output_tensor));
959     auto output = output_tensor->template flat<int32>();
960 
961     for (size_t i = 0; i < input.size(); i++) {
962       output(i) = CalculateBucketIndex(input(i), boundaries_vector);
963     }
964   }
965 
966  private:
CalculateBucketIndex(const T value,std::vector<T> & boundaries_vector)967   int32 CalculateBucketIndex(const T value, std::vector<T>& boundaries_vector) {
968     auto first_bigger_it = std::upper_bound(boundaries_vector.begin(),
969                                             boundaries_vector.end(), value);
970     int32 index = first_bigger_it - boundaries_vector.begin();
971     CHECK(index >= 0 && index <= boundaries_vector.size())
972         << "Invalid bucket index: " << index
973         << " boundaries_vector.size(): " << boundaries_vector.size();
974     return index;
975   }
976 };
977 
978 #define REGISTER_KERNEL(T)                                     \
979   REGISTER_KERNEL_BUILDER(Name("BucketizeWithInputBoundaries") \
980                               .Device(DEVICE_CPU)              \
981                               .TypeConstraint<T>("T"),         \
982                           BucketizeWithInputBoundariesOp<T>);
983 
984 REGISTER_KERNEL(int32);
985 REGISTER_KERNEL(int64);
986 REGISTER_KERNEL(float);
987 REGISTER_KERNEL(double);
988 #undef REGISTER_KERNEL
989 
990 }  // namespace tensorflow
991