1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 #include <iterator>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
23 #include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h"
24 #include "tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/work_sharder.h"
35
36 namespace tensorflow {
37
38 using ::boosted_trees::QuantileConfig;
39 using boosted_trees::QuantileStreamResource;
40 using boosted_trees::utils::TensorUtils;
41
42 namespace {
43 const char* const kExampleWeightsName = "example_weights";
44 const char* const kMaxElementsName = "max_elements";
45 const char* const kNextStampTokenName = "next_stamp_token";
46 const char* const kStampTokenName = "stamp_token";
47 const char* const kAreBucketsReadyName = "are_buckets_ready";
48 const char* const kGenerateQuantiles = "generate_quantiles";
49 // Names for sparse arguments.
50 const char* const kNumSparseFeaturesName = "num_sparse_features";
51 const char* const kSparseBucketsName = "sparse_buckets";
52 const char* const kSparseValuesName = "sparse_values";
53 const char* const kSparseIndicesName = "sparse_indices";
54 const char* const kSparseSummariesName = "sparse_summaries";
55 const char* const kSparseConfigName = "sparse_config";
56 const char* const kSparseOutputTensorName = "sparse_quantiles";
57 // Names for dense arguments.
58 const char* const kDenseBucketsName = "dense_buckets";
59 const char* const kDenseConfigName = "dense_config";
60 const char* const kDenseOutputTensorName = "dense_quantiles";
61 const char* const kDenseSummariesName = "dense_summaries";
62 const char* const kDenseValuesName = "dense_values";
63 const char* const kNumDenseFeaturesName = "num_dense_features";
64 const char* const kResourceHandlesName = "quantile_accumulator_handles";
65 const char* const kNumQuantilesName = "num_quantiles";
66 const char* const kEpsilonName = "epsilon";
67 const char* const kBucketsName = "buckets";
68 const char* const kStreamStateName = "stream_state";
69 const char* const kSummariesName = "summaries";
70
71 using QuantileStream =
72 boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
73 using QuantileSummary =
74 boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
75 using QuantileSummaryEntry =
76 boosted_trees::quantiles::WeightedQuantilesSummary<float,
77 float>::SummaryEntry;
78
GetBuckets(const int32 feature,const OpInputList & buckets_list)79 std::vector<float> GetBuckets(const int32 feature,
80 const OpInputList& buckets_list) {
81 const auto& buckets = buckets_list[feature].flat<float>();
82 std::vector<float> buckets_vector(buckets.data(),
83 buckets.data() + buckets.size());
84 return buckets_vector;
85 }
86
GetFeatureDimension(const int32 feature_index,const int64 instance,const OpInputList * const indices_list)87 int32 GetFeatureDimension(const int32 feature_index, const int64 instance,
88 const OpInputList* const indices_list) {
89 if (indices_list != nullptr) {
90 // Sparse multidimensional.
91 return (*indices_list)[feature_index].matrix<int64>()(instance, 1);
92 }
93 // No indices, assume one-dimensional tensor.
94 return 0;
95 }
96
97 // Allows quantization for each of multiple dimensions of a sparse feature.
QuantizeFeatures(const string & output_name,const OpInputList & values_list,const OpInputList & buckets_list,const OpInputList * const indices_list,OpKernelContext * const context)98 void QuantizeFeatures(
99 const string& output_name, const OpInputList& values_list,
100 const OpInputList& buckets_list,
101 const OpInputList* const
102 indices_list /** Optional, provide for sparse features **/,
103 OpKernelContext* const context) {
104 if (values_list.size() == 0) {
105 return;
106 }
107 OpOutputList output_list;
108 OP_REQUIRES_OK(context, context->output_list(output_name, &output_list));
109
110 for (int32 feature_index = 0; feature_index < values_list.size();
111 ++feature_index) {
112 const Tensor& values_tensor = values_list[feature_index];
113 const int64 num_values = values_tensor.dim_size(0);
114
115 Tensor* output_t = nullptr;
116 // Output will have bucket id and dimension of the features for that bucket.
117 OP_REQUIRES_OK(
118 context, output_list.allocate(feature_index,
119 TensorShape({num_values, 2}), &output_t));
120
121 auto output = output_t->matrix<int32>();
122
123 const std::vector<float>& buckets_vector =
124 GetBuckets(feature_index, buckets_list);
125 auto flat_values = values_tensor.flat<float>();
126 for (int64 instance = 0; instance < num_values; ++instance) {
127 const float value = flat_values(instance);
128 CHECK(!buckets_vector.empty())
129 << "Got empty buckets for feature " << feature_index;
130 auto bucket_iter =
131 std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value);
132 if (bucket_iter == buckets_vector.end()) {
133 --bucket_iter;
134 }
135 const int32 bucket =
136 static_cast<int32>(bucket_iter - buckets_vector.begin());
137 // Bucket id.
138 output(instance, 0) = bucket;
139 // Dimension.
140 output(instance, 1) =
141 GetFeatureDimension(feature_index, instance, indices_list);
142 }
143 }
144 }
145
146 // Validates attributes for the quantile ops.
ReadAndValidateAttributes(OpKernelConstruction * const context,int * num_dense_features,int * num_sparse_features)147 Status ReadAndValidateAttributes(OpKernelConstruction* const context,
148 int* num_dense_features,
149 int* num_sparse_features) {
150 TF_RETURN_IF_ERROR(
151 context->GetAttr(kNumDenseFeaturesName, num_dense_features));
152 TF_RETURN_IF_ERROR(
153 context->GetAttr(kNumSparseFeaturesName, num_sparse_features));
154 if ((*num_dense_features) + (*num_sparse_features) == 0) {
155 return errors::InvalidArgument(
156 "Please provide at least sparse or dense features.");
157 }
158 return Status::OK();
159 }
160
ParseConfig(OpKernelConstruction * const context,const string & name,std::vector<QuantileConfig> * output)161 void ParseConfig(OpKernelConstruction* const context, const string& name,
162 std::vector<QuantileConfig>* output) {
163 std::vector<string> serialized_config;
164 OP_REQUIRES_OK(context, context->GetAttr(name, &serialized_config));
165 output->reserve(serialized_config.size());
166 QuantileConfig tmp;
167 for (const auto& serialized_string : serialized_config) {
168 OP_REQUIRES(context, tmp.ParseFromString(serialized_string),
169 errors::InvalidArgument("Malformed QuantileConfig passed in."));
170 output->push_back(tmp);
171 }
172 }
173
174 // Generates quantiles on a finalized QuantileStream.
GenerateBoundaries(const QuantileStream & stream,int num_boundaries)175 std::vector<float> GenerateBoundaries(const QuantileStream& stream,
176 int num_boundaries) {
177 std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
178
179 // Uniquify elements as we may get dupes.
180 auto end_it = std::unique(boundaries.begin(), boundaries.end());
181 boundaries.resize(std::distance(boundaries.begin(), end_it));
182 return boundaries;
183 }
184
185 // Generates quantiles on a finalized QuantileStream.
GenerateQuantiles(const QuantileStream & stream,int num_quantiles)186 std::vector<float> GenerateQuantiles(const QuantileStream& stream,
187 int num_quantiles) {
188 // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
189 // will be returned.
190 std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles);
191 CHECK_EQ(boundaries.size(), num_quantiles + 1);
192 return boundaries;
193 }
194
195 // Copies quantiles to output list.
CopyBoundaries(OpKernelContext * const context,const std::vector<float> & boundaries,const int64 index,OpOutputList * output_list)196 void CopyBoundaries(OpKernelContext* const context,
197 const std::vector<float>& boundaries, const int64 index,
198 OpOutputList* output_list) {
199 // Output to tensor.
200 Tensor* output_t = nullptr;
201 OP_REQUIRES_OK(
202 context, output_list->allocate(
203 index, {static_cast<int64>(boundaries.size())}, &output_t));
204 auto* quantiles_flat = output_t->flat<float>().data();
205 memcpy(quantiles_flat, boundaries.data(), sizeof(float) * boundaries.size());
206 }
207
CopySummaryToProto(const QuantileSummary & summary,::boosted_trees::QuantileSummaryState * summary_proto)208 void CopySummaryToProto(const QuantileSummary& summary,
209 ::boosted_trees::QuantileSummaryState* summary_proto) {
210 summary_proto->mutable_entries()->Reserve(summary.Size());
211 for (const auto& entry : summary.GetEntryList()) {
212 auto* new_entry = summary_proto->add_entries();
213 new_entry->set_value(entry.value);
214 new_entry->set_weight(entry.weight);
215 new_entry->set_min_rank(entry.min_rank);
216 new_entry->set_max_rank(entry.max_rank);
217 }
218 }
219
220 } // namespace
221
222 // Accumulator for Quantile Summaries.
223 REGISTER_RESOURCE_HANDLE_KERNEL(QuantileStreamResource);
224
225 REGISTER_KERNEL_BUILDER(
226 Name("QuantileAccumulatorIsInitialized").Device(DEVICE_CPU),
227 IsResourceInitialized<QuantileStreamResource>);
228
229 class CreateQuantileAccumulatorOp : public OpKernel {
230 public:
CreateQuantileAccumulatorOp(OpKernelConstruction * const context)231 explicit CreateQuantileAccumulatorOp(OpKernelConstruction* const context)
232 : OpKernel(context) {
233 OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
234 OP_REQUIRES_OK(context,
235 context->GetAttr(kNumQuantilesName, &num_quantiles_));
236 OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
237 OP_REQUIRES_OK(context,
238 context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
239 }
240
Compute(OpKernelContext * context)241 void Compute(OpKernelContext* context) override {
242 // Only create one, if one does not exist already. Report status for all
243 // other exceptions. If one already exists, it unrefs the new one.
244 const Tensor* stamp_token_t;
245 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
246 // An epsilon value of zero could cause perfoamance issues and is therefore,
247 // disallowed.
248 OP_REQUIRES(
249 context, epsilon_ > 0,
250 errors::InvalidArgument("An epsilon value of zero is not allowed."));
251 auto result = new QuantileStreamResource(epsilon_, num_quantiles_,
252 max_elements_, generate_quantiles_,
253 stamp_token_t->scalar<int64>()());
254 auto status = CreateResource(context, HandleFromInput(context, 0), result);
255 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
256 OP_REQUIRES(context, false, status);
257 }
258 }
259
260 private:
261 float epsilon_;
262 int32 num_quantiles_;
263 // An upper bound on the number of entries that the summaries might have
264 // for a feature.
265 int64 max_elements_;
266 bool generate_quantiles_;
267 };
268
269 REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU),
270 CreateQuantileAccumulatorOp);
271
272 // Adds a summary to the quantile summary stream.
273 class QuantileAccumulatorAddSummariesOp : public OpKernel {
274 public:
QuantileAccumulatorAddSummariesOp(OpKernelConstruction * const context)275 explicit QuantileAccumulatorAddSummariesOp(
276 OpKernelConstruction* const context)
277 : OpKernel(context) {}
278
Compute(OpKernelContext * context)279 void Compute(OpKernelContext* context) override {
280 OpInputList resource_handle_list;
281 OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
282 &resource_handle_list));
283 OpInputList summary_list;
284 OP_REQUIRES_OK(context, context->input_list(kSummariesName, &summary_list));
285
286 const Tensor* stamp_token_t;
287 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
288 int64 stamp_token = stamp_token_t->scalar<int64>()();
289
290 thread::ThreadPool* const worker_threads =
291 context->device()->tensorflow_cpu_worker_threads()->workers;
292 boosted_trees::utils::ParallelFor(
293 resource_handle_list.size(), worker_threads->NumThreads(),
294 worker_threads,
295 [&context, &resource_handle_list, &summary_list, stamp_token](
296 int64 start, int64 end) {
297 for (int resource_handle_idx = start; resource_handle_idx < end;
298 ++resource_handle_idx) {
299 const ResourceHandle& handle =
300 resource_handle_list[resource_handle_idx]
301 .flat<ResourceHandle>()(0);
302 QuantileStreamResource* streams_resource;
303 // Create a reference to the underlying resource using the handle.
304 OP_REQUIRES_OK(context,
305 LookupResource(context, handle, &streams_resource));
306 // Remove the reference at the end of this scope.
307 mutex_lock l(*streams_resource->mutex());
308 core::ScopedUnref unref_me(streams_resource);
309
310 // If the stamp is invalid we drop the update.
311 if (!streams_resource->is_stamp_valid(stamp_token)) {
312 VLOG(1)
313 << "Invalid stamp token in QuantileAccumulatorAddSummariesOp."
314 << " Passed stamp token: " << stamp_token << " "
315 << "Current token: " << streams_resource->stamp();
316 return;
317 }
318
319 protobuf::Arena arena;
320 ::boosted_trees::QuantileSummaryState* summary_proto =
321 protobuf::Arena::CreateMessage<
322 ::boosted_trees::QuantileSummaryState>(&arena);
323 OP_REQUIRES(
324 context,
325 ParseProtoUnlimited(
326 summary_proto,
327 summary_list[resource_handle_idx].scalar<string>()()),
328 errors::InvalidArgument("Unable to parse quantile summary."));
329 std::vector<QuantileSummaryEntry> entries;
330 entries.reserve(summary_proto->entries_size());
331 for (const auto& entry : summary_proto->entries()) {
332 entries.emplace_back(entry.value(), entry.weight(),
333 entry.min_rank(), entry.max_rank());
334 }
335
336 // Add the summary to the quantile stream.
337 streams_resource->stream(stamp_token)->PushSummary(entries);
338 }
339 });
340 }
341 };
342
343 REGISTER_KERNEL_BUILDER(
344 Name("QuantileAccumulatorAddSummaries").Device(DEVICE_CPU),
345 QuantileAccumulatorAddSummariesOp);
346
347 // Generates summaries for given set of float values, and the given config.
348 class MakeQuantileSummariesOp : public OpKernel {
349 public:
MakeQuantileSummariesOp(OpKernelConstruction * const context)350 explicit MakeQuantileSummariesOp(OpKernelConstruction* const context)
351 : OpKernel(context) {
352 OP_REQUIRES_OK(context,
353 ReadAndValidateAttributes(context, &num_dense_features_,
354 &num_sparse_features_));
355 OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
356 }
357
Compute(OpKernelContext * const context)358 void Compute(OpKernelContext* const context) override {
359 // Read dense float features list;
360 OpInputList dense_float_features_list;
361 OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
362 context, &dense_float_features_list));
363
364 // Read sparse float features list;
365 OpInputList sparse_float_feature_indices_list;
366 OpInputList sparse_float_feature_values_list;
367 OpInputList sparse_float_feature_shapes_list;
368 OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
369 context, &sparse_float_feature_indices_list,
370 &sparse_float_feature_values_list,
371 &sparse_float_feature_shapes_list));
372
373 // Parse example weights and get batch size.
374 const Tensor* example_weights_t;
375 OP_REQUIRES_OK(context,
376 context->input(kExampleWeightsName, &example_weights_t));
377 auto example_weights = example_weights_t->flat<float>();
378 const int64 batch_size = example_weights.size();
379
380 OpOutputList sparse_summaries_output_list;
381 OP_REQUIRES_OK(context,
382 context->output_list(kSparseSummariesName,
383 &sparse_summaries_output_list));
384 OpOutputList dense_summaries_output_list;
385 OP_REQUIRES_OK(context, context->output_list(kDenseSummariesName,
386 &dense_summaries_output_list));
387
388 auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
389 auto copy_over_summaries = [&](const QuantileStream& stream,
390 const int64 index,
391 OpOutputList* output_list) {
392 protobuf::Arena arena;
393 ::boosted_trees::QuantileSummaryState* summary_proto =
394 protobuf::Arena::CreateMessage<
395 ::boosted_trees::QuantileSummaryState>(&arena);
396 const auto& summary = stream.GetFinalSummary();
397 CopySummaryToProto(summary, summary_proto);
398 // Output to tensor.
399 Tensor* output_t = nullptr;
400 OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t));
401 summary_proto->SerializeToString(&output_t->scalar<string>()());
402 };
403
404 // These are blocks of ranges. We are iterating over both sparse and
405 // dense features i.e. [0, sparse_features.size() + dense_features.size()]
406 for (int64 i = begin; i < end; ++i) {
407 if (i < num_dense_features_) {
408 const int64 dense_index = i;
409 const auto dense_values =
410 dense_float_features_list[dense_index].flat<float>();
411 QuantileStream stream(epsilon_, batch_size + 1);
412 // Run quantile summary generation.
413 for (int64 j = 0; j < batch_size; ++j) {
414 stream.PushEntry(dense_values(j), example_weights(j));
415 }
416 stream.Finalize();
417 // Copy summaries to output.
418 copy_over_summaries(stream, dense_index,
419 &dense_summaries_output_list);
420 } else {
421 const int64 sparse_index = i - num_dense_features_;
422 const auto sparse_values =
423 sparse_float_feature_values_list[sparse_index].flat<float>();
424 const auto sparse_indices =
425 sparse_float_feature_indices_list[sparse_index].matrix<int64>();
426 const auto dense_shape =
427 sparse_float_feature_shapes_list[sparse_index].flat<int64>();
428 OP_REQUIRES(context, batch_size == dense_shape(0),
429 errors::InvalidArgument(
430 "Sparse column shape doesn't match the batch size."));
431 QuantileStream stream(epsilon_, batch_size + 1);
432 // Run quantile summary generation.
433 const int64 num_sparse_rows =
434 sparse_float_feature_indices_list[sparse_index].dim_size(0);
435 for (int64 j = 0; j < num_sparse_rows; ++j) {
436 const int64 example_id = sparse_indices(j, 0);
437 stream.PushEntry(sparse_values(j), example_weights(example_id));
438 }
439 stream.Finalize();
440 // Copy summaries to output.
441 copy_over_summaries(stream, sparse_index,
442 &sparse_summaries_output_list);
443 }
444 }
445 };
446 const int64 kCostPerUnit = 500 * batch_size;
447 const int64 num_features = num_sparse_features_ + num_dense_features_;
448 const DeviceBase::CpuWorkerThreads& worker_threads =
449 *context->device()->tensorflow_cpu_worker_threads();
450 Shard(worker_threads.num_threads, worker_threads.workers, num_features,
451 kCostPerUnit, do_quantile_summary_gen);
452 }
453
454 private:
455 int num_dense_features_;
456 int num_sparse_features_;
457 float epsilon_;
458 };
459
460 REGISTER_KERNEL_BUILDER(Name("MakeQuantileSummaries").Device(DEVICE_CPU),
461 MakeQuantileSummariesOp);
462
463 // Serializes the state of streams.
464 class QuantileAccumulatorSerializeOp : public OpKernel {
465 public:
QuantileAccumulatorSerializeOp(OpKernelConstruction * const context)466 explicit QuantileAccumulatorSerializeOp(OpKernelConstruction* const context)
467 : OpKernel(context) {}
468
Compute(OpKernelContext * context)469 void Compute(OpKernelContext* context) override {
470 QuantileStreamResource* streams_resource;
471 // Create a reference to the underlying resource using the handle.
472 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
473 &streams_resource));
474 // Remove the reference at the end of this scope.
475 mutex_lock l(*streams_resource->mutex());
476 core::ScopedUnref unref_me(streams_resource);
477
478 int64 stamp_token = streams_resource->stamp();
479 Tensor* stream_state_t;
480 OP_REQUIRES_OK(context,
481 context->allocate_output(kStreamStateName, TensorShape({}),
482 &stream_state_t));
483 bool are_buckets_ready = streams_resource->are_buckets_ready();
484
485 // We are iterating over both dense and sparse features. First we go
486 // through the dense features and then the sparse features.
487 const QuantileStream& stream = *streams_resource->stream(stamp_token);
488 const std::vector<float>& boundaries =
489 are_buckets_ready ? streams_resource->boundaries(stamp_token)
490 : std::vector<float>();
491 protobuf::Arena arena;
492 ::boosted_trees::QuantileStreamState* stream_proto =
493 protobuf::Arena::CreateMessage<::boosted_trees::QuantileStreamState>(
494 &arena);
495 for (const auto& summary : stream.SerializeInternalSummaries()) {
496 CopySummaryToProto(summary, stream_proto->add_summaries());
497 }
498 stream_proto->SerializeToString(&stream_state_t->scalar<string>()());
499 Tensor* buckets_t = nullptr;
500 OP_REQUIRES_OK(
501 context,
502 context->allocate_output(
503 kBucketsName, {static_cast<int64>(boundaries.size())}, &buckets_t));
504 auto* quantiles_flat = buckets_t->flat<float>().data();
505 memcpy(quantiles_flat, boundaries.data(),
506 sizeof(float) * boundaries.size());
507 Tensor* stamp_token_t = nullptr;
508 OP_REQUIRES_OK(context,
509 context->allocate_output(kStampTokenName, TensorShape({}),
510 &stamp_token_t));
511 stamp_token_t->scalar<int64>()() = stamp_token;
512 Tensor* are_buckets_ready_t = nullptr;
513 OP_REQUIRES_OK(context, context->allocate_output(kAreBucketsReadyName, {},
514 &are_buckets_ready_t));
515 are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
516 }
517 };
518
519 REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorSerialize").Device(DEVICE_CPU),
520 QuantileAccumulatorSerializeOp);
521
522 // Serializes the state of streams.
523 class QuantileAccumulatorDeserializeOp : public OpKernel {
524 public:
QuantileAccumulatorDeserializeOp(OpKernelConstruction * const context)525 explicit QuantileAccumulatorDeserializeOp(OpKernelConstruction* const context)
526 : OpKernel(context) {}
527
Compute(OpKernelContext * context)528 void Compute(OpKernelContext* context) override {
529 QuantileStreamResource* streams_resource;
530 // Create a reference to the underlying resource using the handle.
531 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
532 &streams_resource));
533 // Remove the reference at the end of this scope.
534 mutex_lock l(*streams_resource->mutex());
535 core::ScopedUnref unref_me(streams_resource);
536
537 int64 old_stamp_token = streams_resource->stamp();
538
539 const Tensor* stream_state_t;
540 OP_REQUIRES_OK(context, context->input(kStreamStateName, &stream_state_t));
541 const Tensor* buckets_t;
542 OP_REQUIRES_OK(context, context->input(kBucketsName, &buckets_t));
543
544 QuantileStream* stream = streams_resource->stream(old_stamp_token);
545 ::boosted_trees::QuantileStreamState state_proto;
546 OP_REQUIRES(
547 context,
548 ParseProtoUnlimited(&state_proto, stream_state_t->scalar<string>()()),
549 errors::InvalidArgument("Unabnle to parse quantile stream state."));
550 std::vector<QuantileSummary> summaries;
551 summaries.reserve(state_proto.summaries_size());
552 std::vector<QuantileSummaryEntry> entries;
553 for (const auto& summary : state_proto.summaries()) {
554 entries.clear();
555 entries.reserve(summary.entries_size());
556 for (const auto& entry : summary.entries()) {
557 entries.emplace_back(entry.value(), entry.weight(), entry.min_rank(),
558 entry.max_rank());
559 }
560 summaries.emplace_back();
561 summaries[summaries.size() - 1].BuildFromSummaryEntries(entries);
562 }
563 stream->DeserializeInternalSummaries(summaries);
564
565 const auto& buckets = buckets_t->vec<float>();
566 std::vector<float> result;
567 result.reserve(buckets.size());
568
569 for (size_t i = 0; i < buckets.size(); ++i) {
570 result.push_back(buckets(i));
571 }
572 streams_resource->set_boundaries(old_stamp_token, result);
573
574 // Reset the stamp token.
575 const Tensor* stamp_token_t = nullptr;
576 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
577 int64 stamp_token = stamp_token_t->scalar<int64>()();
578 streams_resource->set_stamp(stamp_token);
579
580 const Tensor* are_buckets_ready_t = nullptr;
581 OP_REQUIRES_OK(context,
582 context->input(kAreBucketsReadyName, &are_buckets_ready_t));
583 streams_resource->set_buckets_ready(are_buckets_ready_t->scalar<bool>()());
584 }
585 };
586
587 REGISTER_KERNEL_BUILDER(
588 Name("QuantileAccumulatorDeserialize").Device(DEVICE_CPU),
589 QuantileAccumulatorDeserializeOp);
590
591 // Flushes the quantile summary stream resource.
592 class QuantileAccumulatorFlushOp : public OpKernel {
593 public:
QuantileAccumulatorFlushOp(OpKernelConstruction * const context)594 explicit QuantileAccumulatorFlushOp(OpKernelConstruction* const context)
595 : OpKernel(context) {}
596
Compute(OpKernelContext * context)597 void Compute(OpKernelContext* context) override {
598 QuantileStreamResource* streams_resource;
599 // Create a reference to the underlying resource using the handle.
600 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
601 &streams_resource));
602 // Remove the reference at the end of this scope.
603 mutex_lock l(*streams_resource->mutex());
604 core::ScopedUnref unref_me(streams_resource);
605
606 const Tensor* next_stamp_token_t;
607 OP_REQUIRES_OK(context,
608 context->input(kNextStampTokenName, &next_stamp_token_t));
609 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
610
611 const Tensor* stamp_token_t;
612 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
613 int64 stamp_token = stamp_token_t->scalar<int64>()();
614 CHECK(streams_resource->is_stamp_valid(stamp_token))
615 << "Invalid stamp token in QuantileAccumulatorFlushOp. "
616 << "Passed stamp token: " << stamp_token << " "
617 << "Current token: " << streams_resource->stamp();
618 QuantileStream* stream = streams_resource->stream(stamp_token);
619 bool generate_quantiles = streams_resource->generate_quantiles();
620 stream->Finalize();
621
622 streams_resource->set_boundaries(
623 stamp_token,
624 generate_quantiles
625 ? GenerateQuantiles(*stream, streams_resource->num_quantiles())
626 : GenerateBoundaries(*stream, streams_resource->num_quantiles()));
627
628 streams_resource->Reset(next_stamp_token);
629 }
630 };
631
632 REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorFlush").Device(DEVICE_CPU),
633 QuantileAccumulatorFlushOp);
634
635 // Flushes the quantile summary stream resource. This version computes the
636 // summary.
637 class QuantileAccumulatorFlushSummaryOp : public OpKernel {
638 public:
QuantileAccumulatorFlushSummaryOp(OpKernelConstruction * const context)639 explicit QuantileAccumulatorFlushSummaryOp(
640 OpKernelConstruction* const context)
641 : OpKernel(context) {}
642
Compute(OpKernelContext * context)643 void Compute(OpKernelContext* context) override {
644 QuantileStreamResource* streams_resource;
645 // Create a reference to the underlying resource using the handle.
646 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
647 &streams_resource));
648 // Remove the reference at the end of this scope.
649 mutex_lock l(*streams_resource->mutex());
650 core::ScopedUnref unref_me(streams_resource);
651
652 const Tensor* next_stamp_token_t;
653 OP_REQUIRES_OK(context,
654 context->input(kNextStampTokenName, &next_stamp_token_t));
655 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
656
657 const Tensor* stamp_token_t;
658 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
659 int64 stamp_token = stamp_token_t->scalar<int64>()();
660 CHECK(streams_resource->is_stamp_valid(stamp_token))
661 << "Invalid stamp token in QuantileAccumulatorFlushSummaryOp. "
662 << "Passed stamp token: " << stamp_token << " "
663 << "Current token: " << streams_resource->stamp();
664 QuantileStream* stream = streams_resource->stream(stamp_token);
665 stream->Finalize();
666 protobuf::Arena arena;
667 ::boosted_trees::QuantileSummaryState* summary_proto =
668 protobuf::Arena::CreateMessage<::boosted_trees::QuantileSummaryState>(
669 &arena);
670 const auto& summary = stream->GetFinalSummary();
671 CopySummaryToProto(summary, summary_proto);
672 // Output to tensor.
673 Tensor* output_t = nullptr;
674 OP_REQUIRES_OK(context,
675 context->allocate_output(0, TensorShape({}), &output_t));
676 summary_proto->SerializeToString(&output_t->scalar<string>()());
677 streams_resource->Reset(next_stamp_token);
678 }
679 };
680
681 REGISTER_KERNEL_BUILDER(
682 Name("QuantileAccumulatorFlushSummary").Device(DEVICE_CPU),
683 QuantileAccumulatorFlushSummaryOp);
684
685 // Get bucket boundaries from summaries.
686 class QuantileAccumulatorGetBucketsOp : public OpKernel {
687 public:
QuantileAccumulatorGetBucketsOp(OpKernelConstruction * const context)688 explicit QuantileAccumulatorGetBucketsOp(OpKernelConstruction* const context)
689 : OpKernel(context) {}
690
Compute(OpKernelContext * const context)691 void Compute(OpKernelContext* const context) override {
692 OpInputList resource_handle_list;
693 OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
694 &resource_handle_list));
695 OpOutputList are_buckets_ready_list;
696 OP_REQUIRES_OK(context, context->output_list(kAreBucketsReadyName,
697 &are_buckets_ready_list));
698 OpOutputList buckets_list;
699 OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
700 const Tensor* stamp_token_t;
701 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
702 int64 stamp_token = stamp_token_t->scalar<int64>()();
703
704 thread::ThreadPool* const worker_threads =
705 context->device()->tensorflow_cpu_worker_threads()->workers;
706 boosted_trees::utils::ParallelFor(
707 resource_handle_list.size(), worker_threads->NumThreads(),
708 worker_threads,
709 [&context, &resource_handle_list, &are_buckets_ready_list,
710 &buckets_list, stamp_token](int64 start, int64 end) {
711 for (int resource_handle_idx = start; resource_handle_idx < end;
712 ++resource_handle_idx) {
713 const ResourceHandle& handle =
714 resource_handle_list[resource_handle_idx]
715 .flat<ResourceHandle>()(0);
716 QuantileStreamResource* streams_resource;
717 OP_REQUIRES_OK(context,
718 LookupResource(context, handle, &streams_resource));
719 // Remove the reference at the end of this scope.
720 mutex_lock l(*streams_resource->mutex());
721 core::ScopedUnref unref_me(streams_resource);
722
723 bool are_buckets_ready =
724 streams_resource->is_stamp_valid(stamp_token) &&
725 streams_resource->are_buckets_ready();
726
727 Tensor* are_buckets_ready_t = nullptr;
728 OP_REQUIRES_OK(context,
729 are_buckets_ready_list.allocate(
730 resource_handle_idx, {}, &are_buckets_ready_t));
731 are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
732
733 const std::vector<float>& boundaries =
734 are_buckets_ready ? streams_resource->boundaries(stamp_token)
735 : std::vector<float>();
736 Tensor* output_t = nullptr;
737 OP_REQUIRES_OK(context, buckets_list.allocate(
738 resource_handle_idx,
739 {static_cast<int64>(boundaries.size())},
740 &output_t));
741 auto* quantiles_flat = output_t->flat<float>().data();
742 memcpy(quantiles_flat, boundaries.data(),
743 sizeof(float) * boundaries.size());
744 }
745 });
746 }
747 };
748
749 REGISTER_KERNEL_BUILDER(
750 Name("QuantileAccumulatorGetBuckets").Device(DEVICE_CPU),
751 QuantileAccumulatorGetBucketsOp);
752
753 // Generates buckets for given set of float values, and the given config.
754 class QuantileBucketsOp : public OpKernel {
755 public:
QuantileBucketsOp(OpKernelConstruction * const context)756 explicit QuantileBucketsOp(OpKernelConstruction* const context)
757 : OpKernel(context) {
758 OP_REQUIRES_OK(context,
759 ReadAndValidateAttributes(context, &num_dense_features_,
760 &num_sparse_features_));
761
762 ParseConfig(context, kDenseConfigName, &dense_configs_);
763 OP_REQUIRES(context, dense_configs_.size() == num_dense_features_,
764 errors::InvalidArgument(
765 "Mismatch in number of dense quantile configs."));
766 ParseConfig(context, kSparseConfigName, &sparse_configs_);
767 OP_REQUIRES(context, sparse_configs_.size() == num_sparse_features_,
768 errors::InvalidArgument(
769 "Mismatch in number of sparse quantile configs."));
770 }
771
Compute(OpKernelContext * const context)772 void Compute(OpKernelContext* const context) override {
773 // Read dense float features list;
774 OpInputList dense_float_features_list;
775 OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
776 context, &dense_float_features_list));
777
778 // Read sparse float features list;
779 OpInputList sparse_float_feature_indices_list;
780 OpInputList sparse_float_feature_values_list;
781 OpInputList sparse_float_feature_shapes_list;
782 OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
783 context, &sparse_float_feature_indices_list,
784 &sparse_float_feature_values_list,
785 &sparse_float_feature_shapes_list));
786
787 // Parse example weights and get batch size.
788 const Tensor* example_weights_t;
789 OP_REQUIRES_OK(context,
790 context->input(kExampleWeightsName, &example_weights_t));
791 auto example_weights = example_weights_t->flat<float>();
792 const int64 batch_size = example_weights.size();
793
794 OpOutputList sparse_buckets_output_list;
795 OP_REQUIRES_OK(context, context->output_list(kSparseBucketsName,
796 &sparse_buckets_output_list));
797 OpOutputList dense_buckets_output_list;
798 OP_REQUIRES_OK(context, context->output_list(kDenseBucketsName,
799 &dense_buckets_output_list));
800
801 auto do_quantile_bucket_gen = [&](const int64 begin, const int64 end) {
802 // These are blocks of ranges. We are iterating over both sparse and
803 // dense features i.e. [0, sparse_features.size() + dense_features.size()]
804 for (int64 i = begin; i < end; ++i) {
805 if (i < sparse_configs_.size()) {
806 const int64 sparse_index = i;
807 const auto sparse_values =
808 sparse_float_feature_values_list[sparse_index].flat<float>();
809 const auto sparse_indices =
810 sparse_float_feature_indices_list[sparse_index].matrix<int64>();
811 QuantileStream stream(sparse_configs_[sparse_index].eps(),
812 batch_size);
813 // Run quantile summary generation.
814 const int64 num_sparse_rows =
815 sparse_float_feature_indices_list[sparse_index].dim_size(0);
816 for (int64 j = 0; j < num_sparse_rows; ++j) {
817 const int64 example_id = sparse_indices(j, 0);
818 stream.PushEntry(sparse_values(j), example_weights(example_id));
819 }
820 stream.Finalize();
821 // Create buckets.
822 const auto boundaries = GenerateBoundaries(
823 stream, sparse_configs_[sparse_index].num_quantiles());
824 CopyBoundaries(context, boundaries, sparse_index,
825 &sparse_buckets_output_list);
826
827 } else {
828 const int64 dense_index = i - sparse_configs_.size();
829 const auto dense_values =
830 dense_float_features_list[dense_index].flat<float>();
831 QuantileStream stream(dense_configs_[dense_index].eps(), batch_size);
832 // Run quantile summary generation.
833 for (int64 j = 0; j < batch_size; ++j) {
834 stream.PushEntry(dense_values(j), example_weights(j));
835 }
836 stream.Finalize();
837 // Create buckets.
838 const auto boundaries = GenerateBoundaries(
839 stream, dense_configs_[dense_index].num_quantiles());
840 CopyBoundaries(context, boundaries, dense_index,
841 &dense_buckets_output_list);
842 }
843 }
844 };
845
846 const int64 kCostPerUnit = 500 * batch_size;
847 const int64 num_features = sparse_configs_.size() + dense_configs_.size();
848 const DeviceBase::CpuWorkerThreads& worker_threads =
849 *context->device()->tensorflow_cpu_worker_threads();
850 Shard(worker_threads.num_threads, worker_threads.workers, num_features,
851 kCostPerUnit, do_quantile_bucket_gen);
852 }
853
854 private:
855 int num_dense_features_;
856 int num_sparse_features_;
857 std::vector<QuantileConfig> dense_configs_;
858 std::vector<QuantileConfig> sparse_configs_;
859 };
860
861 REGISTER_KERNEL_BUILDER(Name("QuantileBuckets").Device(DEVICE_CPU),
862 QuantileBucketsOp);
863
864 // Given the calculated quantiles thresholds and input data, this operation
865 // converts the input features into the buckets (categorical values), depending
866 // on which quantile they fall into.
867 class QuantilesOp : public OpKernel {
868 public:
QuantilesOp(OpKernelConstruction * const context)869 explicit QuantilesOp(OpKernelConstruction* const context)
870 : OpKernel(context) {
871 int num_dense_features;
872 int num_sparse_features;
873 OP_REQUIRES_OK(context,
874 ReadAndValidateAttributes(context, &num_dense_features,
875 &num_sparse_features));
876 }
877
Compute(OpKernelContext * const context)878 void Compute(OpKernelContext* const context) override {
879 // Dense features inputs
880 OpInputList dense_float_features_list;
881 OP_REQUIRES_OK(context, context->input_list(kDenseValuesName,
882 &dense_float_features_list));
883 OpInputList dense_buckets_list;
884 OP_REQUIRES_OK(context,
885 context->input_list(kDenseBucketsName, &dense_buckets_list));
886
887 if (dense_buckets_list.size() > 0) {
888 // Check the first tensor to make sure it is the right shape
889 OP_REQUIRES(
890 context,
891 tensorflow::TensorShapeUtils::IsVector(dense_buckets_list[0].shape()),
892 errors::InvalidArgument(
893 strings::Printf("Dense buckets should be flat vectors")));
894 }
895
896 // Sparse features inputs
897 OpInputList sparse_float_feature_values_list;
898 OP_REQUIRES_OK(context,
899 context->input_list(kSparseValuesName,
900 &sparse_float_feature_values_list));
901
902 OpInputList sparse_float_indices_list;
903 OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName,
904 &sparse_float_indices_list));
905
906 OpInputList sparse_buckets_list;
907 OP_REQUIRES_OK(
908 context, context->input_list(kSparseBucketsName, &sparse_buckets_list));
909
910 if (sparse_buckets_list.size() > 0) {
911 OP_REQUIRES(
912 context,
913 tensorflow::TensorShapeUtils::IsVector(
914 sparse_buckets_list[0].shape()),
915 errors::InvalidArgument("Sparse buckets should be flat vectors"));
916 }
917
918 // Quantize the feature values
919 QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list,
920 dense_buckets_list, nullptr, context);
921
922 QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list,
923 sparse_buckets_list, &sparse_float_indices_list, context);
924 }
925 };
926
927 REGISTER_KERNEL_BUILDER(Name("Quantiles").Device(DEVICE_CPU), QuantilesOp);
928
929 template <typename T>
930 class BucketizeWithInputBoundariesOp : public OpKernel {
931 public:
BucketizeWithInputBoundariesOp(OpKernelConstruction * context)932 explicit BucketizeWithInputBoundariesOp(OpKernelConstruction* context)
933 : OpKernel(context) {}
934
Compute(OpKernelContext * context)935 void Compute(OpKernelContext* context) override {
936 const Tensor& boundaries_tensor = context->input(1);
937 VLOG(1) << "boundaries has shape: "
938 << boundaries_tensor.shape().DebugString();
939 auto boundaries = boundaries_tensor.flat<float>();
940 std::vector<T> boundaries_vector;
941 boundaries_vector.reserve(boundaries.size());
942 for (size_t i = 0; i < boundaries.size(); i++) {
943 boundaries_vector.push_back(boundaries(i));
944 VLOG(1) << "boundaries(" << i << ") : " << boundaries(i);
945 }
946 OP_REQUIRES(
947 context,
948 std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()),
949 errors::InvalidArgument("Expected sorted boundaries"));
950
951 const Tensor& input_tensor = context->input(0);
952 VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString()
953 << " Dtype: " << tensorflow::DataTypeString(input_tensor.dtype());
954 auto input = input_tensor.flat<T>();
955
956 Tensor* output_tensor = nullptr;
957 OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
958 &output_tensor));
959 auto output = output_tensor->template flat<int32>();
960
961 for (size_t i = 0; i < input.size(); i++) {
962 output(i) = CalculateBucketIndex(input(i), boundaries_vector);
963 }
964 }
965
966 private:
CalculateBucketIndex(const T value,std::vector<T> & boundaries_vector)967 int32 CalculateBucketIndex(const T value, std::vector<T>& boundaries_vector) {
968 auto first_bigger_it = std::upper_bound(boundaries_vector.begin(),
969 boundaries_vector.end(), value);
970 int32 index = first_bigger_it - boundaries_vector.begin();
971 CHECK(index >= 0 && index <= boundaries_vector.size())
972 << "Invalid bucket index: " << index
973 << " boundaries_vector.size(): " << boundaries_vector.size();
974 return index;
975 }
976 };
977
978 #define REGISTER_KERNEL(T) \
979 REGISTER_KERNEL_BUILDER(Name("BucketizeWithInputBoundaries") \
980 .Device(DEVICE_CPU) \
981 .TypeConstraint<T>("T"), \
982 BucketizeWithInputBoundariesOp<T>);
983
984 REGISTER_KERNEL(int32);
985 REGISTER_KERNEL(int64);
986 REGISTER_KERNEL(float);
987 REGISTER_KERNEL(double);
988 #undef REGISTER_KERNEL
989
990 } // namespace tensorflow
991