1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23
24 namespace tensorflow {
25
26 // Constructor.
BoostedTreesEnsembleResource()27 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
28 : tree_ensemble_(
29 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
30 &arena_)) {}
31
DebugString() const32 string BoostedTreesEnsembleResource::DebugString() const {
33 return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
34 "]");
35 }
36
InitFromSerialized(const string & serialized,const int64_t stamp_token)37 bool BoostedTreesEnsembleResource::InitFromSerialized(
38 const string& serialized, const int64_t stamp_token) {
39 CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
40 if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
41 set_stamp(stamp_token);
42 return true;
43 }
44 return false;
45 }
46
SerializeAsString() const47 string BoostedTreesEnsembleResource::SerializeAsString() const {
48 return tree_ensemble_->SerializeAsString();
49 }
50
num_trees() const51 int32 BoostedTreesEnsembleResource::num_trees() const {
52 return tree_ensemble_->trees_size();
53 }
54
next_node(const int32_t tree_id,const int32_t node_id,const int32_t index_in_batch,const std::vector<TTypes<int32>::ConstMatrix> & bucketized_features) const55 int32 BoostedTreesEnsembleResource::next_node(
56 const int32_t tree_id, const int32_t node_id, const int32_t index_in_batch,
57 const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const {
58 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
59 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
60 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
61
62 switch (node.node_case()) {
63 case boosted_trees::Node::kBucketizedSplit: {
64 const auto& split = node.bucketized_split();
65 const auto bucketized_feature = bucketized_features[split.feature_id()];
66 return bucketized_feature(index_in_batch, split.dimension_id()) <=
67 split.threshold()
68 ? split.left_id()
69 : split.right_id();
70 }
71 case boosted_trees::Node::kCategoricalSplit: {
72 const auto& split = node.categorical_split();
73 const auto bucketized_feature = bucketized_features[split.feature_id()];
74 return bucketized_feature(index_in_batch, split.dimension_id()) ==
75 split.value()
76 ? split.left_id()
77 : split.right_id();
78 }
79 default:
80 DCHECK(false) << "Node type " << node.node_case() << " not supported.";
81 }
82 return -1;
83 }
84
node_value(const int32_t tree_id,const int32_t node_id) const85 std::vector<float> BoostedTreesEnsembleResource::node_value(
86 const int32_t tree_id, const int32_t node_id) const {
87 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
88 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
89 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
90 if (node.node_case() == boosted_trees::Node::kLeaf) {
91 if (node.leaf().has_vector()) {
92 std::vector<float> leaf_values;
93 const auto& leaf_value_vector = node.leaf().vector();
94 const int size = leaf_value_vector.value_size();
95 leaf_values.reserve(size);
96 for (int i = 0; i < size; ++i) {
97 leaf_values.push_back(leaf_value_vector.value(i));
98 }
99 return leaf_values;
100 } else {
101 return {node.leaf().scalar()};
102 }
103 } else {
104 if (node.metadata().original_leaf().has_vector()) {
105 std::vector<float> node_values;
106 const auto& leaf_value_vector = node.metadata().original_leaf().vector();
107 const int size = leaf_value_vector.value_size();
108 node_values.reserve(size);
109 for (int i = 0; i < size; ++i) {
110 node_values.push_back(leaf_value_vector.value(i));
111 }
112 return node_values;
113 } else if (node.metadata().has_original_leaf()) {
114 return {node.metadata().original_leaf().scalar()};
115 } else {
116 return {};
117 }
118 }
119 }
120
set_node_value(const int32_t tree_id,const int32_t node_id,const float logits)121 void BoostedTreesEnsembleResource::set_node_value(const int32_t tree_id,
122 const int32_t node_id,
123 const float logits) {
124 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
125 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
126 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
127 DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
128 node->mutable_leaf()->set_scalar(logits);
129 }
130
GetNumLayersGrown(const int32_t tree_id) const131 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
132 const int32_t tree_id) const {
133 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
134 return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
135 }
136
SetNumLayersGrown(const int32_t tree_id,int32_t new_num_layers) const137 void BoostedTreesEnsembleResource::SetNumLayersGrown(
138 const int32_t tree_id, int32_t new_num_layers) const {
139 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
140 tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
141 new_num_layers);
142 }
143
UpdateLastLayerNodesRange(const int32_t node_range_start,int32_t node_range_end) const144 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
145 const int32_t node_range_start, int32_t node_range_end) const {
146 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
147 node_range_start);
148 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
149 node_range_end);
150 }
151
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const152 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
153 int32* node_range_start, int32* node_range_end) const {
154 *node_range_start =
155 tree_ensemble_->growing_metadata().last_layer_node_start();
156 *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
157 }
158
GetNumNodes(const int32_t tree_id)159 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32_t tree_id) {
160 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
161 return tree_ensemble_->trees(tree_id).nodes_size();
162 }
163
GetNumLayersAttempted()164 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
165 return tree_ensemble_->growing_metadata().num_layers_attempted();
166 }
167
is_leaf(const int32_t tree_id,const int32_t node_id) const168 bool BoostedTreesEnsembleResource::is_leaf(const int32_t tree_id,
169 const int32_t node_id) const {
170 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
171 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
172 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
173 return node.node_case() == boosted_trees::Node::kLeaf;
174 }
175
feature_id(const int32_t tree_id,const int32_t node_id) const176 int32 BoostedTreesEnsembleResource::feature_id(const int32_t tree_id,
177 const int32_t node_id) const {
178 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
179 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
180 return node.bucketized_split().feature_id();
181 }
182
bucket_threshold(const int32_t tree_id,const int32_t node_id) const183 int32 BoostedTreesEnsembleResource::bucket_threshold(
184 const int32_t tree_id, const int32_t node_id) const {
185 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
186 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
187 return node.bucketized_split().threshold();
188 }
189
left_id(const int32_t tree_id,const int32_t node_id) const190 int32 BoostedTreesEnsembleResource::left_id(const int32_t tree_id,
191 const int32_t node_id) const {
192 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
193 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
194 return node.bucketized_split().left_id();
195 }
196
right_id(const int32_t tree_id,const int32_t node_id) const197 int32 BoostedTreesEnsembleResource::right_id(const int32_t tree_id,
198 const int32_t node_id) const {
199 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
200 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
201 return node.bucketized_split().right_id();
202 }
203
GetTreeWeights() const204 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
205 return {tree_ensemble_->tree_weights().begin(),
206 tree_ensemble_->tree_weights().end()};
207 }
208
GetTreeWeight(const int32_t tree_id) const209 float BoostedTreesEnsembleResource::GetTreeWeight(const int32_t tree_id) const {
210 return tree_ensemble_->tree_weights(tree_id);
211 }
212
IsTreeFinalized(const int32_t tree_id) const213 float BoostedTreesEnsembleResource::IsTreeFinalized(
214 const int32_t tree_id) const {
215 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
216 return tree_ensemble_->tree_metadata(tree_id).is_finalized();
217 }
218
IsTreePostPruned(const int32_t tree_id) const219 float BoostedTreesEnsembleResource::IsTreePostPruned(
220 const int32_t tree_id) const {
221 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
222 return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
223 0;
224 }
225
SetIsFinalized(const int32_t tree_id,const bool is_finalized)226 void BoostedTreesEnsembleResource::SetIsFinalized(const int32_t tree_id,
227 const bool is_finalized) {
228 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
229 return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
230 is_finalized);
231 }
232
233 // Sets the weight of i'th tree.
SetTreeWeight(const int32_t tree_id,const float weight)234 void BoostedTreesEnsembleResource::SetTreeWeight(const int32_t tree_id,
235 const float weight) {
236 DCHECK_GE(tree_id, 0);
237 DCHECK_LT(tree_id, num_trees());
238 tree_ensemble_->set_tree_weights(tree_id, weight);
239 }
240
UpdateGrowingMetadata() const241 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
242 tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
243 tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
244
245 const int n_trees = num_trees();
246
247 if (n_trees <= 0 ||
248 // Checks if we are building the first layer of the dummy empty tree
249 ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
250 (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
251 tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
252 tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
253 }
254 }
255
256 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight,const int32_t logits_dimension)257 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
258 const int32_t logits_dimension) {
259 const std::vector<float> empty_leaf(logits_dimension);
260 return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
261 }
262
AddNewTreeWithLogits(const float weight,const std::vector<float> & logits,const int32_t logits_dimension)263 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
264 const float weight, const std::vector<float>& logits,
265 const int32_t logits_dimension) {
266 const int32_t new_tree_id = tree_ensemble_->trees_size();
267 auto* node = tree_ensemble_->add_trees()->add_nodes();
268 if (logits_dimension == 1) {
269 node->mutable_leaf()->set_scalar(logits[0]);
270 } else {
271 for (int32_t i = 0; i < logits_dimension; ++i) {
272 node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
273 }
274 }
275 tree_ensemble_->add_tree_weights(weight);
276 tree_ensemble_->add_tree_metadata();
277
278 return new_tree_id;
279 }
280
AddBucketizedSplitNode(const int32_t tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32_t logits_dimension,int32 * left_node_id,int32 * right_node_id)281 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
282 const int32_t tree_id,
283 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
284 const int32_t logits_dimension, int32* left_node_id, int32* right_node_id) {
285 const auto candidate = split_entry.second;
286 auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
287 left_node_id, right_node_id);
288 auto* new_split = node->mutable_bucketized_split();
289 new_split->set_feature_id(candidate.feature_id);
290 new_split->set_threshold(candidate.threshold);
291 new_split->set_dimension_id(candidate.dimension_id);
292 new_split->set_left_id(*left_node_id);
293 new_split->set_right_id(*right_node_id);
294
295 boosted_trees::SplitTypeWithDefault split_type_with_default;
296 bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
297 candidate.split_type, &split_type_with_default);
298 DCHECK(parsed);
299 if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
300 new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
301 } else {
302 new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
303 }
304 }
305
AddCategoricalSplitNode(const int32_t tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32_t logits_dimension,int32 * left_node_id,int32 * right_node_id)306 void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
307 const int32_t tree_id,
308 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
309 const int32_t logits_dimension, int32* left_node_id, int32* right_node_id) {
310 const auto candidate = split_entry.second;
311 auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
312 left_node_id, right_node_id);
313 auto* new_split = node->mutable_categorical_split();
314 new_split->set_feature_id(candidate.feature_id);
315 new_split->set_value(candidate.threshold);
316 new_split->set_dimension_id(candidate.dimension_id);
317 new_split->set_left_id(*left_node_id);
318 new_split->set_right_id(*right_node_id);
319 }
320
AddLeafNodes(const int32_t tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32_t logits_dimension,int32 * left_node_id,int32 * right_node_id)321 boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
322 const int32_t tree_id,
323 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
324 const int32_t logits_dimension, int32* left_node_id, int32* right_node_id) {
325 auto* tree = tree_ensemble_->mutable_trees(tree_id);
326 const auto node_id = split_entry.first;
327 const auto candidate = split_entry.second;
328 auto* node = tree->mutable_nodes(node_id);
329 DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
330 *left_node_id = tree->nodes_size();
331 *right_node_id = *left_node_id + 1;
332 auto* left_node = tree->add_nodes();
333 auto* right_node = tree->add_nodes();
334 const bool has_leaf_value =
335 node->has_leaf() &&
336 ((logits_dimension == 1 && (node->leaf().scalar() != 0)) ||
337 node->leaf().has_vector());
338 if (node_id != 0 || has_leaf_value) {
339 // Save previous leaf value if it is not the first leaf in the tree.
340 node->mutable_metadata()->mutable_original_leaf()->Swap(
341 node->mutable_leaf());
342 }
343 node->mutable_metadata()->set_gain(candidate.gain);
344 // TODO(nponomareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
345 if (logits_dimension == 1) {
346 const float prev_logit_value = node->metadata().original_leaf().scalar();
347 left_node->mutable_leaf()->set_scalar(prev_logit_value +
348 candidate.left_node_contribs[0]);
349 right_node->mutable_leaf()->set_scalar(prev_logit_value +
350 candidate.right_node_contribs[0]);
351 } else {
352 if (has_leaf_value) {
353 DCHECK_EQ(logits_dimension,
354 node->metadata().original_leaf().vector().value_size());
355 }
356 float prev_logit_value = 0.0;
357 for (int32_t i = 0; i < logits_dimension; ++i) {
358 if (has_leaf_value) {
359 prev_logit_value = node->metadata().original_leaf().vector().value(i);
360 }
361 left_node->mutable_leaf()->mutable_vector()->add_value(
362 prev_logit_value + candidate.left_node_contribs[i]);
363 right_node->mutable_leaf()->mutable_vector()->add_value(
364 prev_logit_value + candidate.right_node_contribs[i]);
365 }
366 }
367 return node;
368 }
369
Reset()370 void BoostedTreesEnsembleResource::Reset() {
371 // Reset stamp.
372 set_stamp(-1);
373
374 // Clear tree ensemle.
375 arena_.Reset();
376 tree_ensemble_ =
377 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
378 }
379
PostPruneTree(const int32_t current_tree,const int32_t logits_dimension)380 void BoostedTreesEnsembleResource::PostPruneTree(
381 const int32_t current_tree, const int32_t logits_dimension) {
382 // No-op if tree is empty.
383 auto* tree = tree_ensemble_->mutable_trees(current_tree);
384 int32_t num_nodes = tree->nodes_size();
385 if (num_nodes == 0) {
386 return;
387 }
388
389 std::vector<int32> nodes_to_delete;
390 // If a node was pruned, we need to save the change of the prediction from
391 // this node to its parent, as well as the parent id.
392 std::vector<std::pair<int32, std::vector<float>>> nodes_changes;
393 nodes_changes.reserve(num_nodes);
394 for (int32_t i = 0; i < num_nodes; ++i) {
395 std::vector<float> prune_logit_changes;
396 nodes_changes.emplace_back(i, prune_logit_changes);
397 }
398 // Prune the tree recursively starting from the root. Each node that has
399 // negative gain and only leaf children will be pruned recursively up from
400 // the bottom of the tree. This method returns the list of nodes pruned, and
401 // updates the nodes in the tree not to refer to those pruned nodes.
402 RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
403 &nodes_changes);
404
405 if (nodes_to_delete.empty()) {
406 // No pruning happened, and no post-processing needed.
407 return;
408 }
409
410 // Sort node ids so they are in asc order.
411 std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
412
413 // We need to
414 // - update split left and right children ids with new indices
415 // - actually remove the nodes that need to be removed
416 // - save the information about pruned node so we could recover the
417 // predictions from cache. Build a map for old node index=>new node index.
418 // nodes_to_delete contains nodes who's indices should be skipped, in
419 // ascending order. Save the information about new indices into meta.
420 std::map<int32, int32> old_to_new_ids;
421 int32_t new_index = 0;
422 int32_t index_for_deleted = 0;
423 auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
424 ->mutable_post_pruned_nodes_meta();
425
426 for (int32_t i = 0; i < num_nodes; ++i) {
427 const int64_t nodes_to_delete_size = nodes_to_delete.size();
428 if (index_for_deleted < nodes_to_delete_size &&
429 i == nodes_to_delete[index_for_deleted]) {
430 // Node i will get removed,
431 ++index_for_deleted;
432 // Update meta info that will allow us to use cached predictions from
433 // those nodes.
434 int32_t new_id;
435 std::vector<float> logit_changes;
436 logit_changes.reserve(logits_dimension);
437 CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_changes);
438 auto* meta = post_prune_meta->Add();
439 meta->set_new_node_id(old_to_new_ids[new_id]);
440 for (int32_t j = 0; j < logits_dimension; ++j) {
441 meta->add_logit_change(logit_changes[j]);
442 }
443 } else {
444 old_to_new_ids[i] = new_index++;
445 auto* meta = post_prune_meta->Add();
446 // Update meta info that will allow us to use cached predictions from
447 // those nodes.
448 meta->set_new_node_id(old_to_new_ids[i]);
449 for (int32_t i = 0; i < logits_dimension; ++i) {
450 meta->add_logit_change(0.0);
451 }
452 }
453 }
454 index_for_deleted = 0;
455 int32_t i = 0;
456 protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
457 new_nodes.Reserve(old_to_new_ids.size());
458 for (auto node : *(tree->mutable_nodes())) {
459 const int64_t nodes_to_delete_size = nodes_to_delete.size();
460 if (index_for_deleted < nodes_to_delete_size &&
461 i == nodes_to_delete[index_for_deleted]) {
462 ++index_for_deleted;
463 ++i;
464 continue;
465 } else {
466 if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
467 node.mutable_bucketized_split()->set_left_id(
468 old_to_new_ids[node.bucketized_split().left_id()]);
469 node.mutable_bucketized_split()->set_right_id(
470 old_to_new_ids[node.bucketized_split().right_id()]);
471 }
472 *new_nodes.Add() = std::move(node);
473 }
474 ++i;
475 }
476 // Replace all the nodes in a tree with the ones we keep.
477 *tree->mutable_nodes() = std::move(new_nodes);
478
479 // Note that if the whole tree got pruned, we will end up with one node.
480 // We can't remove that tree because it will cause problems with cache.
481 }
482
GetPostPruneCorrection(const int32_t tree_id,const int32_t initial_node_id,int32 * current_node_id,std::vector<float> * logit_updates) const483 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
484 const int32_t tree_id, const int32_t initial_node_id,
485 int32* current_node_id, std::vector<float>* logit_updates) const {
486 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
487 if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
488 DCHECK_LT(
489 initial_node_id,
490 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
491 const auto& meta =
492 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
493 initial_node_id);
494 *current_node_id = meta.new_node_id();
495 DCHECK_EQ(meta.logit_change().size(), logit_updates->size());
496 for (int32_t i = 0; i < meta.logit_change().size(); ++i) {
497 logit_updates->at(i) = meta.logit_change(i);
498 }
499 }
500 }
501
IsTerminalSplitNode(const int32_t tree_id,const int32_t node_id) const502 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
503 const int32_t tree_id, const int32_t node_id) const {
504 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
505 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
506 const int32_t left_id = node.bucketized_split().left_id();
507 const int32_t right_id = node.bucketized_split().right_id();
508 return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
509 }
510
511 // For each pruned node, finds the leaf where it finally ended up and
512 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32_t start_node_id,const std::vector<std::pair<int32,std::vector<float>>> & nodes_changes,int32 * parent_id,std::vector<float> * changes) const513 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
514 const int32_t start_node_id,
515 const std::vector<std::pair<int32, std::vector<float>>>& nodes_changes,
516 int32* parent_id, std::vector<float>* changes) const {
517 const int logits_dimension = nodes_changes[start_node_id].second.size();
518 for (int i = 0; i < logits_dimension; ++i) {
519 changes->emplace_back(0.0);
520 }
521 int32_t node_id = start_node_id;
522 int32_t parent = nodes_changes[node_id].first;
523 while (parent != node_id) {
524 for (int i = 0; i < logits_dimension; ++i) {
525 changes->at(i) += nodes_changes[node_id].second[i];
526 }
527 node_id = parent;
528 parent = nodes_changes[node_id].first;
529 }
530 *parent_id = parent;
531 }
532
RecursivelyDoPostPrunePreparation(const int32_t tree_id,const int32_t node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,std::vector<float>>> * nodes_meta)533 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
534 const int32_t tree_id, const int32_t node_id,
535 std::vector<int32>* nodes_to_delete,
536 std::vector<std::pair<int32, std::vector<float>>>* nodes_meta) {
537 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
538 DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
539 // Base case when we reach a leaf.
540 if (node->node_case() == boosted_trees::Node::kLeaf) {
541 return;
542 }
543
544 // Traverse node children first and recursively prune their sub-trees.
545 RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
546 nodes_to_delete, nodes_meta);
547 RecursivelyDoPostPrunePreparation(tree_id,
548 node->bucketized_split().right_id(),
549 nodes_to_delete, nodes_meta);
550
551 // Two conditions must be satisfied to prune the node:
552 // 1- The split gain is negative.
553 // 2- After depth-first pruning, the node only has leaf children.
554 const auto& node_metadata = node->metadata();
555 if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
556 const int32_t left_id = node->bucketized_split().left_id();
557 const int32_t right_id = node->bucketized_split().right_id();
558
559 // Save children that need to be deleted.
560 nodes_to_delete->push_back(left_id);
561 nodes_to_delete->push_back(right_id);
562
563 // Change node back into leaf.
564 *node->mutable_leaf() = node_metadata.original_leaf();
565
566 // Save the old values of weights of children.
567 nodes_meta->at(left_id).first = node_id;
568 nodes_meta->at(right_id).first = node_id;
569 const auto& left_child_values = node_value(tree_id, left_id);
570 const auto& right_child_values = node_value(tree_id, right_id);
571 std::vector<float> parent_values(left_child_values.size(), 0.0);
572 if (node_metadata.has_original_leaf()) {
573 parent_values = node_value(tree_id, node_id);
574 }
575 for (int32_t i = 0, end = parent_values.size(); i < end; ++i) {
576 nodes_meta->at(left_id).second.emplace_back(parent_values[i] -
577 left_child_values[i]);
578 nodes_meta->at(right_id).second.emplace_back(parent_values[i] -
579 right_child_values[i]);
580 }
581 // Clear gain for leaf node.
582 node->clear_metadata();
583 }
584 }
585
586 } // namespace tensorflow
587