1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23
24 namespace tensorflow {
25
26 // Constructor.
BoostedTreesEnsembleResource()27 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
28 : tree_ensemble_(
29 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
30 &arena_)) {}
31
DebugString() const32 string BoostedTreesEnsembleResource::DebugString() const {
33 return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
34 "]");
35 }
36
InitFromSerialized(const string & serialized,const int64 stamp_token)37 bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
38 const int64 stamp_token) {
39 CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
40 if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
41 set_stamp(stamp_token);
42 return true;
43 }
44 return false;
45 }
46
SerializeAsString() const47 string BoostedTreesEnsembleResource::SerializeAsString() const {
48 return tree_ensemble_->SerializeAsString();
49 }
50
num_trees() const51 int32 BoostedTreesEnsembleResource::num_trees() const {
52 return tree_ensemble_->trees_size();
53 }
54
next_node(const int32 tree_id,const int32 node_id,const int32 index_in_batch,const std::vector<TTypes<int32>::ConstMatrix> & bucketized_features) const55 int32 BoostedTreesEnsembleResource::next_node(
56 const int32 tree_id, const int32 node_id, const int32 index_in_batch,
57 const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const {
58 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
59 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
60 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
61
62 switch (node.node_case()) {
63 case boosted_trees::Node::kBucketizedSplit: {
64 const auto& split = node.bucketized_split();
65 const auto bucketized_feature = bucketized_features[split.feature_id()];
66 return bucketized_feature(index_in_batch, split.dimension_id()) <=
67 split.threshold()
68 ? split.left_id()
69 : split.right_id();
70 }
71 case boosted_trees::Node::kCategoricalSplit: {
72 const auto& split = node.categorical_split();
73 const auto bucketized_feature = bucketized_features[split.feature_id()];
74 return bucketized_feature(index_in_batch, split.dimension_id()) ==
75 split.value()
76 ? split.left_id()
77 : split.right_id();
78 }
79 default:
80 DCHECK(false) << "Node type " << node.node_case() << " not supported.";
81 }
82 return -1;
83 }
84
node_value(const int32 tree_id,const int32 node_id) const85 std::vector<float> BoostedTreesEnsembleResource::node_value(
86 const int32 tree_id, const int32 node_id) const {
87 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
88 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
89 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
90 if (node.node_case() == boosted_trees::Node::kLeaf) {
91 if (node.leaf().has_vector()) {
92 std::vector<float> leaf_values;
93 const auto& leaf_value_vector = node.leaf().vector();
94 const int size = leaf_value_vector.value_size();
95 leaf_values.reserve(size);
96 for (int i = 0; i < size; ++i) {
97 leaf_values.push_back(leaf_value_vector.value(i));
98 }
99 return leaf_values;
100 } else {
101 return {node.leaf().scalar()};
102 }
103 } else {
104 if (node.metadata().original_leaf().has_vector()) {
105 std::vector<float> node_values;
106 const auto& leaf_value_vector = node.metadata().original_leaf().vector();
107 const int size = leaf_value_vector.value_size();
108 node_values.reserve(size);
109 for (int i = 0; i < size; ++i) {
110 node_values.push_back(leaf_value_vector.value(i));
111 }
112 return node_values;
113 } else if (node.metadata().has_original_leaf()) {
114 return {node.metadata().original_leaf().scalar()};
115 } else {
116 return {};
117 }
118 }
119 }
120
set_node_value(const int32 tree_id,const int32 node_id,const float logits)121 void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
122 const int32 node_id,
123 const float logits) {
124 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
125 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
126 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
127 DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
128 node->mutable_leaf()->set_scalar(logits);
129 }
130
GetNumLayersGrown(const int32 tree_id) const131 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
132 const int32 tree_id) const {
133 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
134 return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
135 }
136
SetNumLayersGrown(const int32 tree_id,int32 new_num_layers) const137 void BoostedTreesEnsembleResource::SetNumLayersGrown(
138 const int32 tree_id, int32 new_num_layers) const {
139 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
140 tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
141 new_num_layers);
142 }
143
UpdateLastLayerNodesRange(const int32 node_range_start,int32 node_range_end) const144 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
145 const int32 node_range_start, int32 node_range_end) const {
146 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
147 node_range_start);
148 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
149 node_range_end);
150 }
151
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const152 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
153 int32* node_range_start, int32* node_range_end) const {
154 *node_range_start =
155 tree_ensemble_->growing_metadata().last_layer_node_start();
156 *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
157 }
158
GetNumNodes(const int32 tree_id)159 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
160 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
161 return tree_ensemble_->trees(tree_id).nodes_size();
162 }
163
GetNumLayersAttempted()164 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
165 return tree_ensemble_->growing_metadata().num_layers_attempted();
166 }
167
is_leaf(const int32 tree_id,const int32 node_id) const168 bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
169 const int32 node_id) const {
170 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
171 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
172 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
173 return node.node_case() == boosted_trees::Node::kLeaf;
174 }
175
feature_id(const int32 tree_id,const int32 node_id) const176 int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
177 const int32 node_id) const {
178 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
179 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
180 return node.bucketized_split().feature_id();
181 }
182
bucket_threshold(const int32 tree_id,const int32 node_id) const183 int32 BoostedTreesEnsembleResource::bucket_threshold(
184 const int32 tree_id, const int32 node_id) const {
185 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
186 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
187 return node.bucketized_split().threshold();
188 }
189
left_id(const int32 tree_id,const int32 node_id) const190 int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
191 const int32 node_id) const {
192 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
193 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
194 return node.bucketized_split().left_id();
195 }
196
right_id(const int32 tree_id,const int32 node_id) const197 int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
198 const int32 node_id) const {
199 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
200 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
201 return node.bucketized_split().right_id();
202 }
203
GetTreeWeights() const204 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
205 return {tree_ensemble_->tree_weights().begin(),
206 tree_ensemble_->tree_weights().end()};
207 }
208
GetTreeWeight(const int32 tree_id) const209 float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
210 return tree_ensemble_->tree_weights(tree_id);
211 }
212
IsTreeFinalized(const int32 tree_id) const213 float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
214 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
215 return tree_ensemble_->tree_metadata(tree_id).is_finalized();
216 }
217
IsTreePostPruned(const int32 tree_id) const218 float BoostedTreesEnsembleResource::IsTreePostPruned(
219 const int32 tree_id) const {
220 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
221 return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
222 0;
223 }
224
SetIsFinalized(const int32 tree_id,const bool is_finalized)225 void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
226 const bool is_finalized) {
227 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
228 return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
229 is_finalized);
230 }
231
232 // Sets the weight of i'th tree.
SetTreeWeight(const int32 tree_id,const float weight)233 void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
234 const float weight) {
235 DCHECK_GE(tree_id, 0);
236 DCHECK_LT(tree_id, num_trees());
237 tree_ensemble_->set_tree_weights(tree_id, weight);
238 }
239
UpdateGrowingMetadata() const240 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
241 tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
242 tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
243
244 const int n_trees = num_trees();
245
246 if (n_trees <= 0 ||
247 // Checks if we are building the first layer of the dummy empty tree
248 ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
249 (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
250 tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
251 tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
252 }
253 }
254
255 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight,const int32 logits_dimension)256 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
257 const int32 logits_dimension) {
258 const std::vector<float> empty_leaf(logits_dimension);
259 return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
260 }
261
AddNewTreeWithLogits(const float weight,const std::vector<float> & logits,const int32 logits_dimension)262 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
263 const float weight, const std::vector<float>& logits,
264 const int32 logits_dimension) {
265 const int32 new_tree_id = tree_ensemble_->trees_size();
266 auto* node = tree_ensemble_->add_trees()->add_nodes();
267 if (logits_dimension == 1) {
268 node->mutable_leaf()->set_scalar(logits[0]);
269 } else {
270 for (int32 i = 0; i < logits_dimension; ++i) {
271 node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
272 }
273 }
274 tree_ensemble_->add_tree_weights(weight);
275 tree_ensemble_->add_tree_metadata();
276
277 return new_tree_id;
278 }
279
AddBucketizedSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)280 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
281 const int32 tree_id,
282 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
283 const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
284 const auto candidate = split_entry.second;
285 auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
286 left_node_id, right_node_id);
287 auto* new_split = node->mutable_bucketized_split();
288 new_split->set_feature_id(candidate.feature_id);
289 new_split->set_threshold(candidate.threshold);
290 new_split->set_dimension_id(candidate.dimension_id);
291 new_split->set_left_id(*left_node_id);
292 new_split->set_right_id(*right_node_id);
293
294 boosted_trees::SplitTypeWithDefault split_type_with_default;
295 bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
296 candidate.split_type, &split_type_with_default);
297 DCHECK(parsed);
298 if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
299 new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
300 } else {
301 new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
302 }
303 }
304
AddCategoricalSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)305 void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
306 const int32 tree_id,
307 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
308 const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
309 const auto candidate = split_entry.second;
310 auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
311 left_node_id, right_node_id);
312 auto* new_split = node->mutable_categorical_split();
313 new_split->set_feature_id(candidate.feature_id);
314 new_split->set_value(candidate.threshold);
315 new_split->set_dimension_id(candidate.dimension_id);
316 new_split->set_left_id(*left_node_id);
317 new_split->set_right_id(*right_node_id);
318 }
319
AddLeafNodes(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)320 boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
321 const int32 tree_id,
322 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
323 const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
324 auto* tree = tree_ensemble_->mutable_trees(tree_id);
325 const auto node_id = split_entry.first;
326 const auto candidate = split_entry.second;
327 auto* node = tree->mutable_nodes(node_id);
328 DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
329 *left_node_id = tree->nodes_size();
330 *right_node_id = *left_node_id + 1;
331 auto* left_node = tree->add_nodes();
332 auto* right_node = tree->add_nodes();
333 const bool has_leaf_value =
334 node->has_leaf() &&
335 ((logits_dimension == 1 && (node->leaf().scalar() != 0)) ||
336 node->leaf().has_vector());
337 if (node_id != 0 || has_leaf_value) {
338 // Save previous leaf value if it is not the first leaf in the tree.
339 node->mutable_metadata()->mutable_original_leaf()->Swap(
340 node->mutable_leaf());
341 }
342 node->mutable_metadata()->set_gain(candidate.gain);
343 // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
344 if (logits_dimension == 1) {
345 const float prev_logit_value = node->metadata().original_leaf().scalar();
346 left_node->mutable_leaf()->set_scalar(prev_logit_value +
347 candidate.left_node_contribs[0]);
348 right_node->mutable_leaf()->set_scalar(prev_logit_value +
349 candidate.right_node_contribs[0]);
350 } else {
351 if (has_leaf_value) {
352 DCHECK_EQ(logits_dimension,
353 node->metadata().original_leaf().vector().value_size());
354 }
355 float prev_logit_value = 0.0;
356 for (int32 i = 0; i < logits_dimension; ++i) {
357 if (has_leaf_value) {
358 prev_logit_value = node->metadata().original_leaf().vector().value(i);
359 }
360 left_node->mutable_leaf()->mutable_vector()->add_value(
361 prev_logit_value + candidate.left_node_contribs[i]);
362 right_node->mutable_leaf()->mutable_vector()->add_value(
363 prev_logit_value + candidate.right_node_contribs[i]);
364 }
365 }
366 return node;
367 }
368
Reset()369 void BoostedTreesEnsembleResource::Reset() {
370 // Reset stamp.
371 set_stamp(-1);
372
373 // Clear tree ensemle.
374 arena_.Reset();
375 CHECK_EQ(0, arena_.SpaceAllocated());
376 tree_ensemble_ =
377 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
378 }
379
PostPruneTree(const int32 current_tree,const int32 logits_dimension)380 void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree,
381 const int32 logits_dimension) {
382 // No-op if tree is empty.
383 auto* tree = tree_ensemble_->mutable_trees(current_tree);
384 int32 num_nodes = tree->nodes_size();
385 if (num_nodes == 0) {
386 return;
387 }
388
389 std::vector<int32> nodes_to_delete;
390 // If a node was pruned, we need to save the change of the prediction from
391 // this node to its parent, as well as the parent id.
392 std::vector<std::pair<int32, std::vector<float>>> nodes_changes;
393 nodes_changes.reserve(num_nodes);
394 for (int32 i = 0; i < num_nodes; ++i) {
395 std::vector<float> prune_logit_changes;
396 nodes_changes.emplace_back(i, prune_logit_changes);
397 }
398 // Prune the tree recursively starting from the root. Each node that has
399 // negative gain and only leaf children will be pruned recursively up from
400 // the bottom of the tree. This method returns the list of nodes pruned, and
401 // updates the nodes in the tree not to refer to those pruned nodes.
402 RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
403 &nodes_changes);
404
405 if (nodes_to_delete.empty()) {
406 // No pruning happened, and no post-processing needed.
407 return;
408 }
409
410 // Sort node ids so they are in asc order.
411 std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
412
413 // We need to
414 // - update split left and right children ids with new indices
415 // - actually remove the nodes that need to be removed
416 // - save the information about pruned node so we could recover the
417 // predictions from cache. Build a map for old node index=>new node index.
418 // nodes_to_delete contains nodes who's indices should be skipped, in
419 // ascending order. Save the information about new indices into meta.
420 std::map<int32, int32> old_to_new_ids;
421 int32 new_index = 0;
422 int32 index_for_deleted = 0;
423 auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
424 ->mutable_post_pruned_nodes_meta();
425
426 for (int32 i = 0; i < num_nodes; ++i) {
427 if (index_for_deleted < nodes_to_delete.size() &&
428 i == nodes_to_delete[index_for_deleted]) {
429 // Node i will get removed,
430 ++index_for_deleted;
431 // Update meta info that will allow us to use cached predictions from
432 // those nodes.
433 int32 new_id;
434 std::vector<float> logit_changes;
435 logit_changes.reserve(logits_dimension);
436 CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_changes);
437 auto* meta = post_prune_meta->Add();
438 meta->set_new_node_id(old_to_new_ids[new_id]);
439 for (int32 j = 0; j < logits_dimension; ++j) {
440 meta->add_logit_change(logit_changes[j]);
441 }
442 } else {
443 old_to_new_ids[i] = new_index++;
444 auto* meta = post_prune_meta->Add();
445 // Update meta info that will allow us to use cached predictions from
446 // those nodes.
447 meta->set_new_node_id(old_to_new_ids[i]);
448 for (int32 i = 0; i < logits_dimension; ++i) {
449 meta->add_logit_change(0.0);
450 }
451 }
452 }
453 index_for_deleted = 0;
454 int32 i = 0;
455 protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
456 new_nodes.Reserve(old_to_new_ids.size());
457 for (auto node : *(tree->mutable_nodes())) {
458 if (index_for_deleted < nodes_to_delete.size() &&
459 i == nodes_to_delete[index_for_deleted]) {
460 ++index_for_deleted;
461 ++i;
462 continue;
463 } else {
464 if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
465 node.mutable_bucketized_split()->set_left_id(
466 old_to_new_ids[node.bucketized_split().left_id()]);
467 node.mutable_bucketized_split()->set_right_id(
468 old_to_new_ids[node.bucketized_split().right_id()]);
469 }
470 *new_nodes.Add() = std::move(node);
471 }
472 ++i;
473 }
474 // Replace all the nodes in a tree with the ones we keep.
475 *tree->mutable_nodes() = std::move(new_nodes);
476
477 // Note that if the whole tree got pruned, we will end up with one node.
478 // We can't remove that tree because it will cause problems with cache.
479 }
480
GetPostPruneCorrection(const int32 tree_id,const int32 initial_node_id,int32 * current_node_id,std::vector<float> * logit_updates) const481 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
482 const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
483 std::vector<float>* logit_updates) const {
484 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
485 if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
486 DCHECK_LT(
487 initial_node_id,
488 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
489 const auto& meta =
490 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
491 initial_node_id);
492 *current_node_id = meta.new_node_id();
493 DCHECK_EQ(meta.logit_change().size(), logit_updates->size());
494 for (int32 i = 0; i < meta.logit_change().size(); ++i) {
495 logit_updates->at(i) = meta.logit_change(i);
496 }
497 }
498 }
499
IsTerminalSplitNode(const int32 tree_id,const int32 node_id) const500 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
501 const int32 tree_id, const int32 node_id) const {
502 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
503 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
504 const int32 left_id = node.bucketized_split().left_id();
505 const int32 right_id = node.bucketized_split().right_id();
506 return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
507 }
508
509 // For each pruned node, finds the leaf where it finally ended up and
510 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32 start_node_id,const std::vector<std::pair<int32,std::vector<float>>> & nodes_changes,int32 * parent_id,std::vector<float> * changes) const511 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
512 const int32 start_node_id,
513 const std::vector<std::pair<int32, std::vector<float>>>& nodes_changes,
514 int32* parent_id, std::vector<float>* changes) const {
515 const int logits_dimension = nodes_changes[start_node_id].second.size();
516 for (int i = 0; i < logits_dimension; ++i) {
517 changes->emplace_back(0.0);
518 }
519 int32 node_id = start_node_id;
520 int32 parent = nodes_changes[node_id].first;
521 while (parent != node_id) {
522 for (int i = 0; i < logits_dimension; ++i) {
523 changes->at(i) += nodes_changes[node_id].second[i];
524 }
525 node_id = parent;
526 parent = nodes_changes[node_id].first;
527 }
528 *parent_id = parent;
529 }
530
RecursivelyDoPostPrunePreparation(const int32 tree_id,const int32 node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,std::vector<float>>> * nodes_meta)531 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
532 const int32 tree_id, const int32 node_id,
533 std::vector<int32>* nodes_to_delete,
534 std::vector<std::pair<int32, std::vector<float>>>* nodes_meta) {
535 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
536 DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
537 // Base case when we reach a leaf.
538 if (node->node_case() == boosted_trees::Node::kLeaf) {
539 return;
540 }
541
542 // Traverse node children first and recursively prune their sub-trees.
543 RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
544 nodes_to_delete, nodes_meta);
545 RecursivelyDoPostPrunePreparation(tree_id,
546 node->bucketized_split().right_id(),
547 nodes_to_delete, nodes_meta);
548
549 // Two conditions must be satisfied to prune the node:
550 // 1- The split gain is negative.
551 // 2- After depth-first pruning, the node only has leaf children.
552 const auto& node_metadata = node->metadata();
553 if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
554 const int32 left_id = node->bucketized_split().left_id();
555 const int32 right_id = node->bucketized_split().right_id();
556
557 // Save children that need to be deleted.
558 nodes_to_delete->push_back(left_id);
559 nodes_to_delete->push_back(right_id);
560
561 // Change node back into leaf.
562 *node->mutable_leaf() = node_metadata.original_leaf();
563
564 // Save the old values of weights of children.
565 nodes_meta->at(left_id).first = node_id;
566 nodes_meta->at(right_id).first = node_id;
567 const auto& left_child_values = node_value(tree_id, left_id);
568 const auto& right_child_values = node_value(tree_id, right_id);
569 std::vector<float> parent_values(left_child_values.size(), 0.0);
570 if (node_metadata.has_original_leaf()) {
571 parent_values = node_value(tree_id, node_id);
572 }
573 for (int32 i = 0; i < parent_values.size(); ++i) {
574 nodes_meta->at(left_id).second.emplace_back(parent_values[i] -
575 left_child_values[i]);
576 nodes_meta->at(right_id).second.emplace_back(parent_values[i] -
577 right_child_values[i]);
578 }
579 // Clear gain for leaf node.
580 node->clear_metadata();
581 }
582 }
583
584 } // namespace tensorflow
585