1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <vector>
16
17 #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
18 #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
19 #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
20 #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
21 #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24
25 namespace tensorflow {
26 using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
27
28 namespace boosted_trees {
29
30 namespace {
31
32 using boosted_trees::learner::LearnerConfig;
33 using boosted_trees::learner::LearningRateConfig;
34 using boosted_trees::trees::Leaf;
35 using boosted_trees::trees::TreeNode;
36 using boosted_trees::trees::TreeNodeMetadata;
37 using boosted_trees::utils::DropoutUtils;
38
39 // SplitCandidate holds the split candidate node along with the stats.
40 struct SplitCandidate {
41 // Id of handler that generated the split candidate.
42 int64 handler_id;
43
44 // Split gain.
45 float gain;
46
47 // Split info.
48 learner::SplitInfo split_info;
49
50 // Oblivious split info.
51 learner::ObliviousSplitInfo oblivious_split_info;
52 };
53
54 // Checks that the leaf is not empty.
IsLeafWellFormed(const Leaf & leaf)55 bool IsLeafWellFormed(const Leaf& leaf) {
56 return leaf.has_sparse_vector() || leaf.has_vector();
57 }
58
59 // Helper method to update the best split per partition given
60 // a current candidate.
UpdateBestSplit(const boosted_trees::learner::LearnerConfig & learner_config,int32 partition_id,SplitCandidate * split,std::map<int32,SplitCandidate> * best_splits)61 void UpdateBestSplit(
62 const boosted_trees::learner::LearnerConfig& learner_config,
63 int32 partition_id, SplitCandidate* split,
64 std::map<int32, SplitCandidate>* best_splits) {
65 // Don't consider nodeless splits.
66 if (TF_PREDICT_FALSE(split->split_info.split_node().node_case() ==
67 TreeNode::NODE_NOT_SET)) {
68 return;
69 }
70
71 // Don't consider negative splits if we're pre-pruning the tree.
72 // Note that zero-gain splits are acceptable as they're mostly doing as well
73 // as what bias centering in that partition would do.
74 if (learner_config.pruning_mode() ==
75 boosted_trees::learner::LearnerConfig::PRE_PRUNE &&
76 split->gain < 0) {
77 return;
78 }
79
80 // If the current node is pure, one of the leafs will be empty, so the split
81 // is meaningless and we should not split.
82 if (!(IsLeafWellFormed(split->split_info.right_child()) &&
83 IsLeafWellFormed(split->split_info.left_child()))) {
84 VLOG(1) << "Split does not actually split anything";
85 return;
86 }
87
88 // Take the split if we don't have a candidate yet.
89 auto best_split_it = best_splits->find(partition_id);
90 if (best_split_it == best_splits->end()) {
91 best_splits->insert(std::make_pair(partition_id, std::move(*split)));
92 return;
93 }
94
95 // Determine if best split so far needs to be replaced.
96 SplitCandidate& best_split = best_split_it->second;
97 if (TF_PREDICT_FALSE(split->gain == best_split.gain)) {
98 // Tie break on node case preferring simpler tree node types.
99 VLOG(2) << "Attempting to tie break with smaller node case. "
100 << "(current split: " << split->split_info.split_node().node_case()
101 << ", best split: "
102 << best_split.split_info.split_node().node_case() << ")";
103 if (split->split_info.split_node().node_case() <
104 best_split.split_info.split_node().node_case()) {
105 best_split = std::move(*split);
106 } else if (split->split_info.split_node().node_case() ==
107 best_split.split_info.split_node().node_case()) {
108 // Tie break on handler Id.
109 VLOG(2) << "Tie breaking with higher handler Id. "
110 << "(current split: " << split->handler_id
111 << ", best split: " << best_split.handler_id << ")";
112 if (split->handler_id > best_split.handler_id) {
113 best_split = std::move(*split);
114 }
115 }
116 } else if (split->gain > best_split.gain) {
117 best_split = std::move(*split);
118 }
119 }
120
121 // Helper method to check whether a node is a terminal node in that it
122 // only has leaf nodes as children.
IsTerminalSplitNode(const size_t node_id,const std::vector<int32> & children,const std::vector<TreeNode> & nodes)123 bool IsTerminalSplitNode(const size_t node_id,
124 const std::vector<int32>& children,
125 const std::vector<TreeNode>& nodes) {
126 for (const int32 child_id : children) {
127 const auto& child_node = nodes[child_id];
128 CHECK(child_node.node_case() != TreeNode::NODE_NOT_SET);
129 if (child_node.node_case() != TreeNode::kLeaf) {
130 return false;
131 }
132 }
133 return true;
134 }
135
136 // Helper method to recursively prune the tree in a depth-first fashion.
RecursivePruneTree(const size_t node_id,std::vector<TreeNode> * nodes)137 void RecursivePruneTree(const size_t node_id, std::vector<TreeNode>* nodes) {
138 // Base case when we reach a leaf.
139 TreeNode& tree_node = (*nodes)[node_id];
140 CHECK(tree_node.node_case() != TreeNode::NODE_NOT_SET);
141 if (tree_node.node_case() == TreeNode::kLeaf) {
142 return;
143 }
144
145 // Traverse node children first and recursively prune their sub-trees.
146 const std::vector<int32> children =
147 boosted_trees::trees::DecisionTree::GetChildren(tree_node);
148 for (const int32 child_id : children) {
149 RecursivePruneTree(child_id, nodes);
150 }
151
152 // Two conditions must be satisfied to prune the node:
153 // 1- The split gain is negative.
154 // 2- After depth-first pruning, the node only has leaf children.
155 TreeNodeMetadata* node_metadata = tree_node.mutable_node_metadata();
156 if (node_metadata->gain() < 0 &&
157 IsTerminalSplitNode(node_id, children, (*nodes))) {
158 // Clear node children.
159 for (const int32 child_id : children) {
160 auto& child_node = (*nodes)[child_id];
161 child_node.Clear();
162 }
163
164 // Change node back into leaf.
165 (*tree_node.mutable_leaf()) = *node_metadata->mutable_original_leaf();
166
167 // Clear gain for leaf node.
168 tree_node.clear_node_metadata();
169 } else {
170 // Clear original leaf as it's no longer needed for back-track pruning.
171 node_metadata->clear_original_leaf();
172 }
173 }
174
175 } // namespace
176
177 class CenterTreeEnsembleBiasOp : public OpKernel {
178 public:
CenterTreeEnsembleBiasOp(OpKernelConstruction * const context)179 explicit CenterTreeEnsembleBiasOp(OpKernelConstruction* const context)
180 : OpKernel(context) {
181 // Read learner config.
182 string serialized_learner_config;
183 OP_REQUIRES_OK(context, context->GetAttr("learner_config",
184 &serialized_learner_config));
185 OP_REQUIRES(context,
186 learner_config_.ParseFromString(serialized_learner_config),
187 errors::InvalidArgument("Unable to parse learner config."));
188
189 // Read centering epsilon.
190 OP_REQUIRES_OK(context,
191 context->GetAttr("centering_epsilon", ¢ering_epsilon_));
192 }
193
Compute(OpKernelContext * const context)194 void Compute(OpKernelContext* const context) override {
195 // Get decision tree ensemble.
196 boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
197 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
198 &ensemble_resource));
199 core::ScopedUnref unref_me(ensemble_resource);
200 mutex_lock l(*ensemble_resource->get_mutex());
201
202 // Get the stamp token.
203 const Tensor* stamp_token_t;
204 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
205 int64 stamp_token = stamp_token_t->scalar<int64>()();
206
207 // Only the Chief should run this Op and it is guaranteed to be in
208 // a consistent state so the stamps must always match.
209 CHECK(ensemble_resource->is_stamp_valid(stamp_token));
210
211 // Get the next stamp token.
212 const Tensor* next_stamp_token_t;
213 OP_REQUIRES_OK(context,
214 context->input("next_stamp_token", &next_stamp_token_t));
215 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
216 CHECK(stamp_token != next_stamp_token);
217
218 // Update the ensemble stamp.
219 ensemble_resource->set_stamp(next_stamp_token);
220
221 // Get the delta updates.
222 const Tensor* delta_updates_t;
223 OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t));
224 auto delta_updates = delta_updates_t->vec<float>();
225 const int64 logits_dimension = delta_updates_t->dim_size(0);
226
227 // Get the bias.
228 boosted_trees::trees::Leaf* const bias =
229 RetrieveBias(ensemble_resource, logits_dimension);
230 CHECK(bias->has_vector());
231
232 // Update the bias.
233 float total_delta = 0;
234 auto* bias_vec = bias->mutable_vector();
235 for (size_t idx = 0; idx < bias->vector().value_size(); ++idx) {
236 float delta = delta_updates(idx);
237 bias_vec->set_value(idx, bias_vec->value(idx) + delta);
238 total_delta += std::abs(delta);
239 }
240
241 // Make a centering continuation decision based on current update.
242 bool continue_centering = total_delta > centering_epsilon_;
243 if (continue_centering) {
244 VLOG(1) << "Continuing to center bias, delta=" << total_delta;
245 } else {
246 VLOG(1) << "Done centering bias, delta=" << total_delta;
247 ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
248 }
249 Tensor* continue_centering_t = nullptr;
250 OP_REQUIRES_OK(
251 context, context->allocate_output("continue_centering", TensorShape({}),
252 &continue_centering_t));
253 continue_centering_t->scalar<bool>()() = continue_centering;
254 }
255
256 private:
257 // Helper method to retrieve the bias from the tree ensemble.
RetrieveBias(boosted_trees::models::DecisionTreeEnsembleResource * ensemble_resource,int64 logits_dimension)258 boosted_trees::trees::Leaf* RetrieveBias(
259 boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource,
260 int64 logits_dimension) {
261 const int32 num_trees = ensemble_resource->num_trees();
262 if (num_trees <= 0) {
263 // Add a new bias leaf.
264 ensemble_resource->IncrementAttempts();
265 boosted_trees::trees::DecisionTreeConfig* const tree_config =
266 ensemble_resource->AddNewTree(1.0);
267 auto* const leaf = tree_config->add_nodes()->mutable_leaf();
268 for (size_t idx = 0; idx < logits_dimension; ++idx) {
269 leaf->mutable_vector()->add_value(0.0);
270 }
271 return leaf;
272 } else if (num_trees == 1) {
273 // Confirms that the only tree is a bias and returns its leaf.
274 boosted_trees::trees::DecisionTreeConfig* const tree_config =
275 ensemble_resource->LastTree();
276 CHECK_EQ(tree_config->nodes_size(), 1);
277 CHECK_EQ(tree_config->nodes(0).node_case(), TreeNode::kLeaf);
278 return tree_config->mutable_nodes(0)->mutable_leaf();
279 } else {
280 LOG(FATAL) << "Unable to center bias on an already grown ensemble";
281 }
282 }
283
284 boosted_trees::learner::LearnerConfig learner_config_;
285 float centering_epsilon_;
286 };
287
288 REGISTER_KERNEL_BUILDER(Name("CenterTreeEnsembleBias").Device(DEVICE_CPU),
289 CenterTreeEnsembleBiasOp);
290
291 class GrowTreeEnsembleOp : public OpKernel {
292 public:
GrowTreeEnsembleOp(OpKernelConstruction * const context)293 explicit GrowTreeEnsembleOp(OpKernelConstruction* const context)
294 : OpKernel(context) {
295 // Read number of handlers, note that this is the static number of
296 // all handlers but any subset of these handlers may be active at a time.
297 OP_REQUIRES_OK(context, context->GetAttr("num_handlers", &num_handlers_));
298
299 OP_REQUIRES_OK(context, context->GetAttr("center_bias", ¢er_bias_));
300
301 // Read learner config.
302 string serialized_learner_config;
303 OP_REQUIRES_OK(context, context->GetAttr("learner_config",
304 &serialized_learner_config));
305 OP_REQUIRES(context,
306 learner_config_.ParseFromString(serialized_learner_config),
307 errors::InvalidArgument("Unable to parse learner config."));
308
309 // Determine whether dropout was used when building this tree.
310 if (learner_config_.has_learning_rate_tuner() &&
311 learner_config_.learning_rate_tuner().tuner_case() ==
312 LearningRateConfig::kDropout) {
313 dropout_config_ = learner_config_.learning_rate_tuner().dropout();
314 dropout_was_applied_ = true;
315 } else {
316 dropout_was_applied_ = false;
317 }
318 }
319
Compute(OpKernelContext * const context)320 void Compute(OpKernelContext* const context) override {
321 // Get decision tree ensemble.
322 boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
323 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
324 &ensemble_resource));
325 core::ScopedUnref unref_me(ensemble_resource);
326 mutex_lock l(*ensemble_resource->get_mutex());
327
328 // Get the stamp token.
329 const Tensor* stamp_token_t;
330 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
331 int64 stamp_token = stamp_token_t->scalar<int64>()();
332
333 // Only the Chief should run this Op and it is guaranteed to be in
334 // a consistent state so the stamps must always match.
335 CHECK(ensemble_resource->is_stamp_valid(stamp_token));
336
337 // Get the next stamp token.
338 const Tensor* next_stamp_token_t;
339 OP_REQUIRES_OK(context,
340 context->input("next_stamp_token", &next_stamp_token_t));
341 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
342 CHECK(stamp_token != next_stamp_token);
343
344 // Update the ensemble stamp regardless of whether a layer
345 // or tree is actually grown.
346 ensemble_resource->set_stamp(next_stamp_token);
347
348 // Read the learning_rate.
349 const Tensor* learning_rate_t;
350 OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
351 float learning_rate = learning_rate_t->scalar<float>()();
352
353 // Read the weak learner type to use.
354 const Tensor* weak_learner_type_t;
355 OP_REQUIRES_OK(context,
356 context->input("weak_learner_type", &weak_learner_type_t));
357 const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
358
359 const Tensor* seed_t;
360 OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
361 // Cast seed to uint64.
362 const uint64 dropout_seed = seed_t->scalar<int64>()();
363
364 // Read partition Ids, gains and split candidates.
365 OpInputList partition_ids_list;
366 OpInputList gains_list;
367 OpInputList splits_list;
368 OP_REQUIRES_OK(context,
369 context->input_list("partition_ids", &partition_ids_list));
370 OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
371 OP_REQUIRES_OK(context, context->input_list("splits", &splits_list));
372
373 // Increment attempt stats.
374 ensemble_resource->IncrementAttempts();
375
376 // Find best splits for each active partition.
377 std::map<int32, SplitCandidate> best_splits;
378 switch (weak_learner_type) {
379 case LearnerConfig::NORMAL_DECISION_TREE: {
380 FindBestSplitsPerPartitionNormal(context, partition_ids_list,
381 gains_list, splits_list, &best_splits);
382 break;
383 }
384 case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
385 FindBestSplitOblivious(context, gains_list, splits_list, &best_splits);
386 break;
387 }
388 }
389 // No-op if no new splits can be considered.
390 if (best_splits.empty()) {
391 LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
392 return;
393 }
394
395 // Get the max tree depth.
396 const Tensor* max_tree_depth_t;
397 OP_REQUIRES_OK(context,
398 context->input("max_tree_depth", &max_tree_depth_t));
399 const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
400 // Update and retrieve the growable tree.
401 // If the tree is fully built and dropout was applied, it also adjusts the
402 // weights of dropped and the last tree.
403 boosted_trees::trees::DecisionTreeConfig* const tree_config =
404 UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
405 dropout_seed, max_tree_depth,
406 weak_learner_type);
407 // Split tree nodes.
408 switch (weak_learner_type) {
409 case LearnerConfig::NORMAL_DECISION_TREE: {
410 for (auto& split_entry : best_splits) {
411 SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
412 ensemble_resource);
413 }
414 break;
415 }
416 case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
417 SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
418 }
419 }
420 // Post-prune finalized tree if needed.
421 if (learner_config_.pruning_mode() ==
422 boosted_trees::learner::LearnerConfig::POST_PRUNE &&
423 ensemble_resource->LastTreeMetadata()->is_finalized()) {
424 VLOG(2) << "Post-pruning finalized tree.";
425 if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
426 LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
427 }
428 PruneTree(tree_config);
429
430 // If after post-pruning the whole tree has no gain, remove the tree
431 // altogether from the ensemble.
432 if (tree_config->nodes_size() <= 0) {
433 ensemble_resource->RemoveLastTree();
434 }
435 }
436 }
437
438 private:
439 // Helper method which effectively does a reduce over all split candidates
440 // and finds the best split for each partition.
FindBestSplitsPerPartitionNormal(OpKernelContext * const context,const OpInputList & partition_ids_list,const OpInputList & gains_list,const OpInputList & splits_list,std::map<int32,SplitCandidate> * best_splits)441 void FindBestSplitsPerPartitionNormal(
442 OpKernelContext* const context, const OpInputList& partition_ids_list,
443 const OpInputList& gains_list, const OpInputList& splits_list,
444 std::map<int32, SplitCandidate>* best_splits) {
445 // Find best split per partition going through every feature candidate.
446 // TODO(salehay): Is this worth parallelizing?
447 for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
448 const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
449 const auto& gains = gains_list[handler_id].vec<float>();
450 const auto& splits = splits_list[handler_id].vec<string>();
451 OP_REQUIRES(context, partition_ids.size() == gains.size(),
452 errors::InvalidArgument(
453 "Inconsistent partition Ids and gains tensors: ",
454 partition_ids.size(), " != ", gains.size()));
455 OP_REQUIRES(context, partition_ids.size() == splits.size(),
456 errors::InvalidArgument(
457 "Inconsistent partition Ids and splits tensors: ",
458 partition_ids.size(), " != ", splits.size()));
459 for (size_t candidate_idx = 0; candidate_idx < splits.size();
460 ++candidate_idx) {
461 // Get current split candidate.
462 const auto& partition_id = partition_ids(candidate_idx);
463 const auto& gain = gains(candidate_idx);
464 const auto& serialized_split = splits(candidate_idx);
465 SplitCandidate split;
466 split.handler_id = handler_id;
467 split.gain = gain;
468 OP_REQUIRES(context, split.split_info.ParseFromString(serialized_split),
469 errors::InvalidArgument("Unable to parse split info."));
470
471 // Update best split for partition based on the current candidate.
472 UpdateBestSplit(learner_config_, partition_id, &split, best_splits);
473 }
474 }
475 }
476
FindBestSplitOblivious(OpKernelContext * const context,const OpInputList & gains_list,const OpInputList & splits_list,std::map<int32,SplitCandidate> * best_splits)477 void FindBestSplitOblivious(OpKernelContext* const context,
478 const OpInputList& gains_list,
479 const OpInputList& splits_list,
480 std::map<int32, SplitCandidate>* best_splits) {
481 // Find best split per partition going through every feature candidate.
482 for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
483 const auto& gains = gains_list[handler_id].vec<float>();
484 const auto& splits = splits_list[handler_id].vec<string>();
485 OP_REQUIRES(context, gains.size() == 1,
486 errors::InvalidArgument(
487 "Gains size must be one for oblivious weak learner: ",
488 gains.size(), " != ", 1));
489 OP_REQUIRES(context, splits.size() == 1,
490 errors::InvalidArgument(
491 "Splits size must be one for oblivious weak learner: ",
492 splits.size(), " != ", 1));
493 // Get current split candidate.
494 const auto& gain = gains(0);
495 const auto& serialized_split = splits(0);
496 SplitCandidate split;
497 split.handler_id = handler_id;
498 split.gain = gain;
499 OP_REQUIRES(
500 context, split.oblivious_split_info.ParseFromString(serialized_split),
501 errors::InvalidArgument("Unable to parse oblivious split info."));
502
503 auto split_info = split.oblivious_split_info;
504 CHECK(split_info.children_size() % 2 == 0)
505 << "The oblivious split should generate an even number of children: "
506 << split_info.children_size();
507
508 // If every node is pure, then we shouldn't split.
509 bool only_pure_nodes = true;
510 for (int idx = 0; idx < split_info.children_size(); idx += 2) {
511 if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
512 IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
513 only_pure_nodes = false;
514 break;
515 }
516 }
517 if (only_pure_nodes) {
518 VLOG(1) << "The oblivious split does not actually split anything.";
519 continue;
520 }
521
522 // Don't consider negative splits if we're pre-pruning the tree.
523 if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
524 gain < 0) {
525 continue;
526 }
527
528 // Take the split if we don't have a candidate yet.
529 auto best_split_it = best_splits->find(0);
530 if (best_split_it == best_splits->end()) {
531 best_splits->insert(std::make_pair(0, std::move(split)));
532 continue;
533 }
534
535 // Determine if we should update best split.
536 SplitCandidate& best_split = best_split_it->second;
537 trees::TreeNode current_node = split_info.split_node();
538 trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
539 if (TF_PREDICT_FALSE(gain == best_split.gain)) {
540 // Tie break on node case preferring simpler tree node types.
541 VLOG(2) << "Attempting to tie break with smaller node case. "
542 << "(current split: " << current_node.node_case()
543 << ", best split: " << best_node.node_case() << ")";
544 if (current_node.node_case() < best_node.node_case()) {
545 best_split = std::move(split);
546 } else if (current_node.node_case() == best_node.node_case()) {
547 // Tie break on handler Id.
548 VLOG(2) << "Tie breaking with higher handler Id. "
549 << "(current split: " << handler_id
550 << ", best split: " << best_split.handler_id << ")";
551 if (handler_id > best_split.handler_id) {
552 best_split = std::move(split);
553 }
554 }
555 } else if (gain > best_split.gain) {
556 best_split = std::move(split);
557 }
558 }
559 }
560
UpdateTreeWeightsIfDropout(boosted_trees::models::DecisionTreeEnsembleResource * const ensemble_resource,const uint64 dropout_seed)561 void UpdateTreeWeightsIfDropout(
562 boosted_trees::models::DecisionTreeEnsembleResource* const
563 ensemble_resource,
564 const uint64 dropout_seed) {
565 // It is possible that the tree was built with dropout. If it is the case,
566 // we need to adjust the tree weight, or bail out.
567 if (!dropout_was_applied_ ||
568 !ensemble_resource->LastTreeMetadata()->is_finalized()) {
569 return;
570 }
571 const int32 num_trees = ensemble_resource->num_trees();
572
573 // Based on seed, figure out what trees were dropped before.
574 std::unordered_set<int32> trees_not_to_drop;
575 if (center_bias_) {
576 trees_not_to_drop.insert(0);
577 }
578 // Last tree is the current tree that is built.
579 const int32 current_tree = num_trees - 1;
580 trees_not_to_drop.insert(current_tree);
581
582 // Since only chief builds the trees, we are sure that the other tree
583 // weights didn't change.
584 std::vector<float> weights = ensemble_resource->GetTreeWeights();
585 std::vector<int32> dropped_trees;
586 std::vector<float> dropped_trees_weights;
587 const auto dropout_status = DropoutUtils::DropOutTrees(
588 dropout_seed, dropout_config_, trees_not_to_drop, weights,
589 &dropped_trees, &dropped_trees_weights);
590 CHECK(dropout_status.ok())
591 << "Can't figure out what trees were dropped out before, error is "
592 << dropout_status.error_message();
593
594 // Now we have dropped trees, update their weights and the current tree
595 // weight.
596 if (!dropped_trees.empty()) {
597 std::vector<int32> increment_num_updates(num_trees, 0);
598 DropoutUtils::GetTreesWeightsForAddingTrees(
599 dropped_trees, dropped_trees_weights, current_tree,
600 1 /* only 1 tree was added */, &weights, &increment_num_updates);
601
602 // Update the weights and num of updates for trees.
603 for (int i = 0; i < num_trees; ++i) {
604 ensemble_resource->SetTreeWeight(i, weights[i],
605 increment_num_updates[i]);
606 }
607 }
608 }
609
610 // Helper method to update the growable tree which is by definition the last
611 // tree in the ensemble.
UpdateAndRetrieveGrowableTree(boosted_trees::models::DecisionTreeEnsembleResource * const ensemble_resource,const float learning_rate,const uint64 dropout_seed,const int32 max_tree_depth,const int32 weak_learner_type)612 boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
613 boosted_trees::models::DecisionTreeEnsembleResource* const
614 ensemble_resource,
615 const float learning_rate, const uint64 dropout_seed,
616 const int32 max_tree_depth, const int32 weak_learner_type) {
617 const auto num_trees = ensemble_resource->num_trees();
618 if (num_trees <= 0 ||
619 ensemble_resource->LastTreeMetadata()->is_finalized()) {
620 // Create a new tree with a no-op leaf.
621 boosted_trees::trees::DecisionTreeConfig* const tree_config =
622 ensemble_resource->AddNewTree(learning_rate);
623 VLOG(1) << "Adding layer #0 to tree #" << num_trees << " of ensemble of "
624 << num_trees + 1 << " trees.";
625 tree_config->add_nodes()->mutable_leaf();
626 boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
627 ensemble_resource->LastTreeMetadata();
628 tree_metadata->set_is_finalized(max_tree_depth <= 1);
629 tree_metadata->set_num_tree_weight_updates(1);
630 } else {
631 // The growable tree is by definition the last tree in the ensemble.
632 boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
633 ensemble_resource->LastTreeMetadata();
634 const auto new_num_layers = tree_metadata->num_layers_grown() + 1;
635 VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
636 << num_trees - 1 << " of ensemble of " << num_trees << " trees.";
637 // Update growable tree metadata.
638 tree_metadata->set_num_layers_grown(new_num_layers);
639 tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth);
640 }
641 UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed);
642 return ensemble_resource->LastTree();
643 }
644
645 // Helper method to merge leaf weights as the tree is being grown.
MergeLeafWeights(const boosted_trees::trees::Leaf & source,boosted_trees::trees::Leaf * dest)646 boosted_trees::trees::Leaf* MergeLeafWeights(
647 const boosted_trees::trees::Leaf& source,
648 boosted_trees::trees::Leaf* dest) {
649 // Resolve leaf merging method based on how the trees are being grown.
650 if (learner_config_.growing_mode() ==
651 boosted_trees::learner::LearnerConfig::WHOLE_TREE) {
652 // No merging occurs when building a whole tree at a time.
653 return dest;
654 }
655
656 if (dest->leaf_case() == boosted_trees::trees::Leaf::LEAF_NOT_SET) {
657 // No merging is required. Just copy the source weights;
658 *dest = source;
659 return dest;
660 }
661
662 // Handle leaf merging based on type.
663 switch (source.leaf_case()) {
664 case boosted_trees::trees::Leaf::kVector: {
665 // No-op if source is empty
666 const auto& src_vec = source.vector();
667 if (src_vec.value_size() == 0) {
668 break;
669 }
670 CHECK(source.leaf_case() == dest->leaf_case());
671
672 // Dense add leaf vectors.
673 auto* dst_vec = dest->mutable_vector();
674 CHECK(src_vec.value_size() == dst_vec->value_size());
675 for (size_t idx = 0; idx < source.vector().value_size(); ++idx) {
676 (*dst_vec->mutable_value()->Mutable(idx)) += src_vec.value(idx);
677 }
678 break;
679 }
680 case boosted_trees::trees::Leaf::kSparseVector: {
681 // No-op if source is empty
682 const auto& src_vec = source.sparse_vector();
683 CHECK(src_vec.value_size() == src_vec.index_size());
684 if (src_vec.value_size() == 0) {
685 break;
686 }
687 CHECK(source.leaf_case() == dest->leaf_case());
688
689 // Get mapping of dimension to value for destination.
690 std::unordered_map<int32, float> dst_map;
691 auto* dst_vec = dest->mutable_sparse_vector();
692 CHECK(dst_vec->value_size() == dst_vec->index_size());
693 dst_map.reserve(dst_vec->value_size());
694 for (size_t idx = 0; idx < dst_vec->value_size(); ++idx) {
695 dst_map[dst_vec->index(idx)] = dst_vec->value(idx);
696 }
697 // Sparse add source vector to destination vector.
698 for (size_t idx = 0; idx < src_vec.value_size(); ++idx) {
699 dst_map[src_vec.index(idx)] += src_vec.value(idx);
700 }
701 // Rebuild merged destination leaf.
702 dst_vec->clear_index();
703 dst_vec->clear_value();
704 for (const auto& entry : dst_map) {
705 dst_vec->add_index(entry.first);
706 dst_vec->add_value(entry.second);
707 }
708 break;
709 }
710 case boosted_trees::trees::Leaf::LEAF_NOT_SET: {
711 // No-op as there is nothing to merge.
712 break;
713 }
714 }
715 return dest;
716 }
717
718 // Helper method to split a tree node and append its respective
719 // leaf children given the split candidate.
SplitTreeNode(const int32 node_id,SplitCandidate * split,boosted_trees::trees::DecisionTreeConfig * tree_config,boosted_trees::models::DecisionTreeEnsembleResource * ensemble_resource)720 void SplitTreeNode(
721 const int32 node_id, SplitCandidate* split,
722 boosted_trees::trees::DecisionTreeConfig* tree_config,
723 boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
724 // No-op if we have no real node.
725 CHECK(node_id < tree_config->nodes_size())
726 << "Invalid node " << node_id << " to split.";
727 // Ensure new split node is valid.
728 CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
729 CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
730 << "Unexpected node type to split "
731 << tree_config->nodes(node_id).node_case() << " for node_id " << node_id
732 << ". Tree config: " << tree_config->DebugString();
733
734 // Add left leaf.
735 int32 left_id = tree_config->nodes_size();
736 (*tree_config->add_nodes()->mutable_leaf()) =
737 *MergeLeafWeights(tree_config->nodes(node_id).leaf(),
738 split->split_info.mutable_left_child());
739
740 // Add right leaf.
741 int32 right_id = tree_config->nodes_size();
742 (*tree_config->add_nodes()->mutable_leaf()) =
743 *MergeLeafWeights(tree_config->nodes(node_id).leaf(),
744 split->split_info.mutable_right_child());
745
746 // Link children and add them as new roots.
747 boosted_trees::trees::DecisionTree::LinkChildren(
748 {left_id, right_id}, split->split_info.mutable_split_node());
749
750 // Add split gain and, if needed, original leaf to node metadata.
751 TreeNodeMetadata* node_metadata =
752 split->split_info.mutable_split_node()->mutable_node_metadata();
753 node_metadata->set_gain(split->gain);
754 if (learner_config_.pruning_mode() ==
755 boosted_trees::learner::LearnerConfig::POST_PRUNE) {
756 (*node_metadata->mutable_original_leaf()) =
757 *tree_config->mutable_nodes(node_id)->mutable_leaf();
758 }
759
760 // Replace node in tree.
761 (*tree_config->mutable_nodes(node_id)) =
762 *split->split_info.mutable_split_node();
763 if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
764 ensemble_resource->MaybeAddUsedHandler(split->handler_id);
765 }
766 }
767
SplitTreeLayer(SplitCandidate * split,boosted_trees::trees::DecisionTreeConfig * tree_config,boosted_trees::models::DecisionTreeEnsembleResource * ensemble_resource)768 void SplitTreeLayer(
769 SplitCandidate* split,
770 boosted_trees::trees::DecisionTreeConfig* tree_config,
771 boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
772 int depth = 0;
773 while (depth < tree_config->nodes_size() &&
774 tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
775 depth++;
776 }
777 CHECK(tree_config->nodes_size() > 0)
778 << "A tree must have at least one dummy leaf.";
779 // The number of new children.
780 int num_children = 1 << (depth + 1);
781 auto split_info = split->oblivious_split_info;
782 CHECK(num_children >= split_info.children_size())
783 << "Too many new children, expected <= " << num_children << " and got "
784 << split_info.children_size();
785 std::vector<trees::Leaf> new_leaves;
786 new_leaves.reserve(num_children);
787 int next_id = 0;
788 for (int idx = 0; idx < num_children / 2; idx++) {
789 trees::Leaf old_leaf =
790 *tree_config->mutable_nodes(depth + idx)->mutable_leaf();
791 // Check if a split was made for this leaf.
792 if (next_id < split_info.children_parent_id_size() &&
793 depth + idx == split_info.children_parent_id(next_id)) {
794 // Add left leaf.
795 new_leaves.push_back(*MergeLeafWeights(
796 old_leaf, split_info.mutable_children(2 * next_id)));
797 // Add right leaf.
798 new_leaves.push_back(*MergeLeafWeights(
799 old_leaf, split_info.mutable_children(2 * next_id + 1)));
800 next_id++;
801 } else {
802 // If there is no split for this leaf, just duplicate it.
803 new_leaves.push_back(old_leaf);
804 new_leaves.push_back(old_leaf);
805 }
806 }
807 CHECK(next_id == split_info.children_parent_id_size());
808 TreeNodeMetadata* split_metadata =
809 split_info.mutable_split_node()->mutable_node_metadata();
810 split_metadata->set_gain(split->gain);
811
812 TreeNode new_split = *split_info.mutable_split_node();
813 // Move old children to metadata.
814 for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
815 *new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
816 *tree_config->mutable_nodes(idx)->mutable_leaf();
817 }
818 // Add the new split to the tree_config in place before the children start.
819 *tree_config->mutable_nodes(depth) = new_split;
820 // Add the new children
821 int nodes_size = tree_config->nodes_size();
822 for (int idx = 0; idx < num_children; idx++) {
823 if (idx + depth + 1 < nodes_size) {
824 // Update leaves that were already there.
825 *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
826 new_leaves[idx];
827 } else {
828 // Add new leaves.
829 *tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
830 }
831 }
832 }
PruneTree(boosted_trees::trees::DecisionTreeConfig * tree_config)833 void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
834 // No-op if tree is empty.
835 if (tree_config->nodes_size() <= 0) {
836 return;
837 }
838
839 // Copy nodes to temp vector and clear original tree.
840 std::vector<TreeNode> tree_nodes;
841 tree_nodes.reserve(tree_config->nodes_size());
842 for (auto& node : (*tree_config->mutable_nodes())) {
843 tree_nodes.push_back(node);
844 node.Clear();
845 }
846 tree_config->clear_nodes();
847
848 // Prune the tree recursively starting from the root.
849 RecursivePruneTree(0, &tree_nodes);
850
851 // Rebuild compacted tree.
852 (*tree_config->add_nodes()) = tree_nodes[0];
853 std::unordered_map<size_t, size_t> nodes_map;
854 nodes_map[0] = 0;
855 for (size_t node_idx = 0; node_idx < tree_nodes.size(); ++node_idx) {
856 // Skip pruned nodes.
857 auto& original_node = tree_nodes[node_idx];
858 if (original_node.node_case() == TreeNode::NODE_NOT_SET) {
859 continue;
860 }
861
862 // Find node mapped in tree ensemble.
863 auto mapped_node_it = nodes_map.find(node_idx);
864 CHECK(mapped_node_it != nodes_map.end());
865 auto& mapped_node = (*tree_config->mutable_nodes(mapped_node_it->second));
866
867 // Get node children
868 auto children =
869 boosted_trees::trees::DecisionTree::GetChildren(original_node);
870 for (int32& child_idx : children) {
871 auto new_idx = tree_config->nodes_size();
872 (*tree_config->add_nodes()) = tree_nodes[child_idx];
873 nodes_map[child_idx] = new_idx;
874 child_idx = new_idx;
875 }
876 boosted_trees::trees::DecisionTree::LinkChildren(children, &mapped_node);
877 }
878
879 // Check if there are any nodes with gain left.
880 if (tree_config->nodes_size() == 1 &&
881 tree_config->nodes(0).node_metadata().gain() <= 0) {
882 // The whole tree should be pruned.
883 VLOG(2) << "No useful nodes left after post-pruning tree.";
884 tree_config->clear_nodes();
885 }
886 }
887
888 private:
889 boosted_trees::learner::LearnerConfig learner_config_;
890 int64 num_handlers_;
891 LearningRateDropoutDrivenConfig dropout_config_;
892 bool dropout_was_applied_;
893 bool center_bias_;
894 };
895
896 REGISTER_KERNEL_BUILDER(Name("GrowTreeEnsemble").Device(DEVICE_CPU),
897 GrowTreeEnsembleOp);
898
899 class TreeEnsembleStatsOp : public OpKernel {
900 public:
TreeEnsembleStatsOp(OpKernelConstruction * const context)901 explicit TreeEnsembleStatsOp(OpKernelConstruction* const context)
902 : OpKernel(context) {}
903
Compute(OpKernelContext * const context)904 void Compute(OpKernelContext* const context) override {
905 // Get decision tree ensemble.
906 boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
907 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
908 &ensemble_resource));
909 core::ScopedUnref unref_me(ensemble_resource);
910 tf_shared_lock l(*ensemble_resource->get_mutex());
911
912 // Get the stamp token.
913 const Tensor* stamp_token_t;
914 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
915 int64 stamp_token = stamp_token_t->scalar<int64>()();
916
917 // Only the Chief should run this Op and it is guaranteed to be in
918 // a consistent state so the stamps must always match.
919 CHECK(ensemble_resource->is_stamp_valid(stamp_token));
920 const boosted_trees::trees::DecisionTreeEnsembleConfig& ensemble_config =
921 ensemble_resource->decision_tree_ensemble();
922
923 // Set tree stats.
924 Tensor* num_trees_t = nullptr;
925 OP_REQUIRES_OK(context, context->allocate_output(
926 "num_trees", TensorShape({}), &num_trees_t));
927 Tensor* active_tree_t = nullptr;
928 OP_REQUIRES_OK(context,
929 context->allocate_output("active_tree", TensorShape({}),
930 &active_tree_t));
931 Tensor* attempted_tree_t = nullptr;
932 OP_REQUIRES_OK(context,
933 context->allocate_output("attempted_trees", TensorShape({}),
934 &attempted_tree_t));
935
936 const int num_trees = ensemble_resource->num_trees();
937 active_tree_t->scalar<int64>()() = num_trees;
938 num_trees_t->scalar<int64>()() =
939 (num_trees <= 0 ||
940 ensemble_resource->LastTreeMetadata()->is_finalized())
941 ? num_trees
942 : num_trees - 1;
943 attempted_tree_t->scalar<int64>()() =
944 ensemble_config.growing_metadata().num_trees_attempted();
945
946 // Set layer stats.
947 Tensor* num_layers_t = nullptr;
948 OP_REQUIRES_OK(context, context->allocate_output(
949 "num_layers", TensorShape({}), &num_layers_t));
950 Tensor* active_layer_t = nullptr;
951 OP_REQUIRES_OK(context,
952 context->allocate_output("active_layer", TensorShape({}),
953 &active_layer_t));
954 Tensor* attempted_layers_t = nullptr;
955 OP_REQUIRES_OK(context,
956 context->allocate_output("attempted_layers", TensorShape({}),
957 &attempted_layers_t));
958
959 int64 num_layers = 0;
960 for (const auto& tree_metadata : ensemble_config.tree_metadata()) {
961 num_layers += tree_metadata.num_layers_grown();
962 }
963 num_layers_t->scalar<int64>()() = num_layers;
964 int tree_metadata_size = ensemble_config.tree_metadata_size();
965 active_layer_t->scalar<int64>()() =
966 tree_metadata_size > 0
967 ? ensemble_config.tree_metadata(tree_metadata_size - 1)
968 .num_layers_grown()
969 : 0;
970 attempted_layers_t->scalar<int64>()() =
971 ensemble_config.growing_metadata().num_layers_attempted();
972 }
973 };
974
975 REGISTER_KERNEL_BUILDER(Name("TreeEnsembleStats").Device(DEVICE_CPU),
976 TreeEnsembleStatsOp);
977
978 } // namespace boosted_trees
979 } // namespace tensorflow
980