1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23
24 namespace tensorflow {
25
26 // Constructor.
BoostedTreesEnsembleResource()27 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
28 : tree_ensemble_(
29 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
30 &arena_)) {}
31
DebugString() const32 string BoostedTreesEnsembleResource::DebugString() const {
33 return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
34 "]");
35 }
36
InitFromSerialized(const string & serialized,const int64 stamp_token)37 bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
38 const int64 stamp_token) {
39 CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
40 if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
41 set_stamp(stamp_token);
42 return true;
43 }
44 return false;
45 }
46
SerializeAsString() const47 string BoostedTreesEnsembleResource::SerializeAsString() const {
48 return tree_ensemble_->SerializeAsString();
49 }
50
num_trees() const51 int32 BoostedTreesEnsembleResource::num_trees() const {
52 return tree_ensemble_->trees_size();
53 }
54
next_node(const int32 tree_id,const int32 node_id,const int32 index_in_batch,const std::vector<TTypes<int32>::ConstMatrix> & bucketized_features) const55 int32 BoostedTreesEnsembleResource::next_node(
56 const int32 tree_id, const int32 node_id, const int32 index_in_batch,
57 const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const {
58 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
59 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
60 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
61
62 switch (node.node_case()) {
63 case boosted_trees::Node::kBucketizedSplit: {
64 const auto& split = node.bucketized_split();
65 const auto bucketized_feature = bucketized_features[split.feature_id()];
66 return bucketized_feature(index_in_batch, split.dimension_id()) <=
67 split.threshold()
68 ? split.left_id()
69 : split.right_id();
70 }
71 case boosted_trees::Node::kCategoricalSplit: {
72 const auto& split = node.categorical_split();
73 const auto bucketized_feature = bucketized_features[split.feature_id()];
74 return bucketized_feature(index_in_batch, split.dimension_id()) ==
75 split.value()
76 ? split.left_id()
77 : split.right_id();
78 }
79 default:
80 DCHECK(false) << "Node type " << node.node_case() << " not supported.";
81 }
82 return -1;
83 }
84
node_value(const int32 tree_id,const int32 node_id) const85 std::vector<float> BoostedTreesEnsembleResource::node_value(
86 const int32 tree_id, const int32 node_id) const {
87 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
88 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
89 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
90 if (node.node_case() == boosted_trees::Node::kLeaf) {
91 if (node.leaf().has_vector()) {
92 std::vector<float> leaf_values;
93 const auto& leaf_value_vector = node.leaf().vector();
94 const int size = leaf_value_vector.value_size();
95 leaf_values.reserve(size);
96 for (int i = 0; i < size; ++i) {
97 leaf_values.push_back(leaf_value_vector.value(i));
98 }
99 return leaf_values;
100 } else {
101 return {node.leaf().scalar()};
102 }
103 } else {
104 if (node.metadata().original_leaf().has_vector()) {
105 std::vector<float> node_values;
106 const auto& leaf_value_vector = node.metadata().original_leaf().vector();
107 const int size = leaf_value_vector.value_size();
108 node_values.reserve(size);
109 for (int i = 0; i < size; ++i) {
110 node_values.push_back(leaf_value_vector.value(i));
111 }
112 return node_values;
113 } else if (node.metadata().has_original_leaf()) {
114 return {node.metadata().original_leaf().scalar()};
115 } else {
116 return {};
117 }
118 }
119 }
120
set_node_value(const int32 tree_id,const int32 node_id,const float logits)121 void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
122 const int32 node_id,
123 const float logits) {
124 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
125 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
126 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
127 DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
128 node->mutable_leaf()->set_scalar(logits);
129 }
130
GetNumLayersGrown(const int32 tree_id) const131 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
132 const int32 tree_id) const {
133 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
134 return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
135 }
136
SetNumLayersGrown(const int32 tree_id,int32 new_num_layers) const137 void BoostedTreesEnsembleResource::SetNumLayersGrown(
138 const int32 tree_id, int32 new_num_layers) const {
139 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
140 tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
141 new_num_layers);
142 }
143
UpdateLastLayerNodesRange(const int32 node_range_start,int32 node_range_end) const144 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
145 const int32 node_range_start, int32 node_range_end) const {
146 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
147 node_range_start);
148 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
149 node_range_end);
150 }
151
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const152 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
153 int32* node_range_start, int32* node_range_end) const {
154 *node_range_start =
155 tree_ensemble_->growing_metadata().last_layer_node_start();
156 *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
157 }
158
GetNumNodes(const int32 tree_id)159 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
160 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
161 return tree_ensemble_->trees(tree_id).nodes_size();
162 }
163
GetNumLayersAttempted()164 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
165 return tree_ensemble_->growing_metadata().num_layers_attempted();
166 }
167
is_leaf(const int32 tree_id,const int32 node_id) const168 bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
169 const int32 node_id) const {
170 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
171 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
172 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
173 return node.node_case() == boosted_trees::Node::kLeaf;
174 }
175
feature_id(const int32 tree_id,const int32 node_id) const176 int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
177 const int32 node_id) const {
178 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
179 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
180 return node.bucketized_split().feature_id();
181 }
182
bucket_threshold(const int32 tree_id,const int32 node_id) const183 int32 BoostedTreesEnsembleResource::bucket_threshold(
184 const int32 tree_id, const int32 node_id) const {
185 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
186 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
187 return node.bucketized_split().threshold();
188 }
189
left_id(const int32 tree_id,const int32 node_id) const190 int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
191 const int32 node_id) const {
192 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
193 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
194 return node.bucketized_split().left_id();
195 }
196
right_id(const int32 tree_id,const int32 node_id) const197 int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
198 const int32 node_id) const {
199 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
200 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
201 return node.bucketized_split().right_id();
202 }
203
GetTreeWeights() const204 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
205 return {tree_ensemble_->tree_weights().begin(),
206 tree_ensemble_->tree_weights().end()};
207 }
208
GetTreeWeight(const int32 tree_id) const209 float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
210 return tree_ensemble_->tree_weights(tree_id);
211 }
212
IsTreeFinalized(const int32 tree_id) const213 float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
214 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
215 return tree_ensemble_->tree_metadata(tree_id).is_finalized();
216 }
217
IsTreePostPruned(const int32 tree_id) const218 float BoostedTreesEnsembleResource::IsTreePostPruned(
219 const int32 tree_id) const {
220 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
221 return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
222 0;
223 }
224
SetIsFinalized(const int32 tree_id,const bool is_finalized)225 void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
226 const bool is_finalized) {
227 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
228 return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
229 is_finalized);
230 }
231
232 // Sets the weight of i'th tree.
SetTreeWeight(const int32 tree_id,const float weight)233 void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
234 const float weight) {
235 DCHECK_GE(tree_id, 0);
236 DCHECK_LT(tree_id, num_trees());
237 tree_ensemble_->set_tree_weights(tree_id, weight);
238 }
239
UpdateGrowingMetadata() const240 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
241 tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
242 tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
243
244 const int n_trees = num_trees();
245
246 if (n_trees <= 0 ||
247 // Checks if we are building the first layer of the dummy empty tree
248 ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
249 (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
250 tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
251 tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
252 }
253 }
254
255 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight,const int32 logits_dimension)256 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
257 const int32 logits_dimension) {
258 const std::vector<float> empty_leaf(logits_dimension);
259 return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
260 }
261
AddNewTreeWithLogits(const float weight,const std::vector<float> & logits,const int32 logits_dimension)262 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
263 const float weight, const std::vector<float>& logits,
264 const int32 logits_dimension) {
265 const int32 new_tree_id = tree_ensemble_->trees_size();
266 auto* node = tree_ensemble_->add_trees()->add_nodes();
267 if (logits_dimension == 1) {
268 node->mutable_leaf()->set_scalar(logits[0]);
269 } else {
270 for (int32 i = 0; i < logits_dimension; ++i) {
271 node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
272 }
273 }
274 tree_ensemble_->add_tree_weights(weight);
275 tree_ensemble_->add_tree_metadata();
276
277 return new_tree_id;
278 }
279
AddBucketizedSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)280 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
281 const int32 tree_id,
282 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
283 const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
284 const auto candidate = split_entry.second;
285 auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
286 left_node_id, right_node_id);
287 auto* new_split = node->mutable_bucketized_split();
288 new_split->set_feature_id(candidate.feature_id);
289 new_split->set_threshold(candidate.threshold);
290 new_split->set_dimension_id(candidate.dimension_id);
291 new_split->set_left_id(*left_node_id);
292 new_split->set_right_id(*right_node_id);
293
294 boosted_trees::SplitTypeWithDefault split_type_with_default;
295 bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
296 candidate.split_type, &split_type_with_default);
297 DCHECK(parsed);
298 if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
299 new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
300 } else {
301 new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
302 }
303 }
304
AddCategoricalSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)305 void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
306 const int32 tree_id,
307 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
308 const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
309 const auto candidate = split_entry.second;
310 auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
311 left_node_id, right_node_id);
312 auto* new_split = node->mutable_categorical_split();
313 new_split->set_feature_id(candidate.feature_id);
314 new_split->set_value(candidate.threshold);
315 new_split->set_dimension_id(candidate.dimension_id);
316 new_split->set_left_id(*left_node_id);
317 new_split->set_right_id(*right_node_id);
318 }
319
AddLeafNodes(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)320 boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
321 const int32 tree_id,
322 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
323 const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
324 auto* tree = tree_ensemble_->mutable_trees(tree_id);
325 const auto node_id = split_entry.first;
326 const auto candidate = split_entry.second;
327 auto* node = tree->mutable_nodes(node_id);
328 DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
329 *left_node_id = tree->nodes_size();
330 *right_node_id = *left_node_id + 1;
331 auto* left_node = tree->add_nodes();
332 auto* right_node = tree->add_nodes();
333 const bool has_leaf_value =
334 node->has_leaf() &&
335 ((logits_dimension == 1 && (node->leaf().scalar() != 0)) ||
336 node->leaf().has_vector());
337 if (node_id != 0 || has_leaf_value) {
338 // Save previous leaf value if it is not the first leaf in the tree.
339 node->mutable_metadata()->mutable_original_leaf()->Swap(
340 node->mutable_leaf());
341 }
342 node->mutable_metadata()->set_gain(candidate.gain);
343 // TODO(nponomareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
344 if (logits_dimension == 1) {
345 const float prev_logit_value = node->metadata().original_leaf().scalar();
346 left_node->mutable_leaf()->set_scalar(prev_logit_value +
347 candidate.left_node_contribs[0]);
348 right_node->mutable_leaf()->set_scalar(prev_logit_value +
349 candidate.right_node_contribs[0]);
350 } else {
351 if (has_leaf_value) {
352 DCHECK_EQ(logits_dimension,
353 node->metadata().original_leaf().vector().value_size());
354 }
355 float prev_logit_value = 0.0;
356 for (int32 i = 0; i < logits_dimension; ++i) {
357 if (has_leaf_value) {
358 prev_logit_value = node->metadata().original_leaf().vector().value(i);
359 }
360 left_node->mutable_leaf()->mutable_vector()->add_value(
361 prev_logit_value + candidate.left_node_contribs[i]);
362 right_node->mutable_leaf()->mutable_vector()->add_value(
363 prev_logit_value + candidate.right_node_contribs[i]);
364 }
365 }
366 return node;
367 }
368
Reset()369 void BoostedTreesEnsembleResource::Reset() {
370 // Reset stamp.
371 set_stamp(-1);
372
373 // Clear tree ensemle.
374 arena_.Reset();
375 tree_ensemble_ =
376 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
377 }
378
PostPruneTree(const int32 current_tree,const int32 logits_dimension)379 void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree,
380 const int32 logits_dimension) {
381 // No-op if tree is empty.
382 auto* tree = tree_ensemble_->mutable_trees(current_tree);
383 int32 num_nodes = tree->nodes_size();
384 if (num_nodes == 0) {
385 return;
386 }
387
388 std::vector<int32> nodes_to_delete;
389 // If a node was pruned, we need to save the change of the prediction from
390 // this node to its parent, as well as the parent id.
391 std::vector<std::pair<int32, std::vector<float>>> nodes_changes;
392 nodes_changes.reserve(num_nodes);
393 for (int32 i = 0; i < num_nodes; ++i) {
394 std::vector<float> prune_logit_changes;
395 nodes_changes.emplace_back(i, prune_logit_changes);
396 }
397 // Prune the tree recursively starting from the root. Each node that has
398 // negative gain and only leaf children will be pruned recursively up from
399 // the bottom of the tree. This method returns the list of nodes pruned, and
400 // updates the nodes in the tree not to refer to those pruned nodes.
401 RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
402 &nodes_changes);
403
404 if (nodes_to_delete.empty()) {
405 // No pruning happened, and no post-processing needed.
406 return;
407 }
408
409 // Sort node ids so they are in asc order.
410 std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
411
412 // We need to
413 // - update split left and right children ids with new indices
414 // - actually remove the nodes that need to be removed
415 // - save the information about pruned node so we could recover the
416 // predictions from cache. Build a map for old node index=>new node index.
417 // nodes_to_delete contains nodes who's indices should be skipped, in
418 // ascending order. Save the information about new indices into meta.
419 std::map<int32, int32> old_to_new_ids;
420 int32 new_index = 0;
421 int32 index_for_deleted = 0;
422 auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
423 ->mutable_post_pruned_nodes_meta();
424
425 for (int32 i = 0; i < num_nodes; ++i) {
426 const int64 nodes_to_delete_size = nodes_to_delete.size();
427 if (index_for_deleted < nodes_to_delete_size &&
428 i == nodes_to_delete[index_for_deleted]) {
429 // Node i will get removed,
430 ++index_for_deleted;
431 // Update meta info that will allow us to use cached predictions from
432 // those nodes.
433 int32 new_id;
434 std::vector<float> logit_changes;
435 logit_changes.reserve(logits_dimension);
436 CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_changes);
437 auto* meta = post_prune_meta->Add();
438 meta->set_new_node_id(old_to_new_ids[new_id]);
439 for (int32 j = 0; j < logits_dimension; ++j) {
440 meta->add_logit_change(logit_changes[j]);
441 }
442 } else {
443 old_to_new_ids[i] = new_index++;
444 auto* meta = post_prune_meta->Add();
445 // Update meta info that will allow us to use cached predictions from
446 // those nodes.
447 meta->set_new_node_id(old_to_new_ids[i]);
448 for (int32 i = 0; i < logits_dimension; ++i) {
449 meta->add_logit_change(0.0);
450 }
451 }
452 }
453 index_for_deleted = 0;
454 int32 i = 0;
455 protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
456 new_nodes.Reserve(old_to_new_ids.size());
457 for (auto node : *(tree->mutable_nodes())) {
458 const int64 nodes_to_delete_size = nodes_to_delete.size();
459 if (index_for_deleted < nodes_to_delete_size &&
460 i == nodes_to_delete[index_for_deleted]) {
461 ++index_for_deleted;
462 ++i;
463 continue;
464 } else {
465 if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
466 node.mutable_bucketized_split()->set_left_id(
467 old_to_new_ids[node.bucketized_split().left_id()]);
468 node.mutable_bucketized_split()->set_right_id(
469 old_to_new_ids[node.bucketized_split().right_id()]);
470 }
471 *new_nodes.Add() = std::move(node);
472 }
473 ++i;
474 }
475 // Replace all the nodes in a tree with the ones we keep.
476 *tree->mutable_nodes() = std::move(new_nodes);
477
478 // Note that if the whole tree got pruned, we will end up with one node.
479 // We can't remove that tree because it will cause problems with cache.
480 }
481
GetPostPruneCorrection(const int32 tree_id,const int32 initial_node_id,int32 * current_node_id,std::vector<float> * logit_updates) const482 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
483 const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
484 std::vector<float>* logit_updates) const {
485 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
486 if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
487 DCHECK_LT(
488 initial_node_id,
489 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
490 const auto& meta =
491 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
492 initial_node_id);
493 *current_node_id = meta.new_node_id();
494 DCHECK_EQ(meta.logit_change().size(), logit_updates->size());
495 for (int32 i = 0; i < meta.logit_change().size(); ++i) {
496 logit_updates->at(i) = meta.logit_change(i);
497 }
498 }
499 }
500
IsTerminalSplitNode(const int32 tree_id,const int32 node_id) const501 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
502 const int32 tree_id, const int32 node_id) const {
503 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
504 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
505 const int32 left_id = node.bucketized_split().left_id();
506 const int32 right_id = node.bucketized_split().right_id();
507 return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
508 }
509
510 // For each pruned node, finds the leaf where it finally ended up and
511 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32 start_node_id,const std::vector<std::pair<int32,std::vector<float>>> & nodes_changes,int32 * parent_id,std::vector<float> * changes) const512 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
513 const int32 start_node_id,
514 const std::vector<std::pair<int32, std::vector<float>>>& nodes_changes,
515 int32* parent_id, std::vector<float>* changes) const {
516 const int logits_dimension = nodes_changes[start_node_id].second.size();
517 for (int i = 0; i < logits_dimension; ++i) {
518 changes->emplace_back(0.0);
519 }
520 int32 node_id = start_node_id;
521 int32 parent = nodes_changes[node_id].first;
522 while (parent != node_id) {
523 for (int i = 0; i < logits_dimension; ++i) {
524 changes->at(i) += nodes_changes[node_id].second[i];
525 }
526 node_id = parent;
527 parent = nodes_changes[node_id].first;
528 }
529 *parent_id = parent;
530 }
531
RecursivelyDoPostPrunePreparation(const int32 tree_id,const int32 node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,std::vector<float>>> * nodes_meta)532 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
533 const int32 tree_id, const int32 node_id,
534 std::vector<int32>* nodes_to_delete,
535 std::vector<std::pair<int32, std::vector<float>>>* nodes_meta) {
536 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
537 DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
538 // Base case when we reach a leaf.
539 if (node->node_case() == boosted_trees::Node::kLeaf) {
540 return;
541 }
542
543 // Traverse node children first and recursively prune their sub-trees.
544 RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
545 nodes_to_delete, nodes_meta);
546 RecursivelyDoPostPrunePreparation(tree_id,
547 node->bucketized_split().right_id(),
548 nodes_to_delete, nodes_meta);
549
550 // Two conditions must be satisfied to prune the node:
551 // 1- The split gain is negative.
552 // 2- After depth-first pruning, the node only has leaf children.
553 const auto& node_metadata = node->metadata();
554 if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
555 const int32 left_id = node->bucketized_split().left_id();
556 const int32 right_id = node->bucketized_split().right_id();
557
558 // Save children that need to be deleted.
559 nodes_to_delete->push_back(left_id);
560 nodes_to_delete->push_back(right_id);
561
562 // Change node back into leaf.
563 *node->mutable_leaf() = node_metadata.original_leaf();
564
565 // Save the old values of weights of children.
566 nodes_meta->at(left_id).first = node_id;
567 nodes_meta->at(right_id).first = node_id;
568 const auto& left_child_values = node_value(tree_id, left_id);
569 const auto& right_child_values = node_value(tree_id, right_id);
570 std::vector<float> parent_values(left_child_values.size(), 0.0);
571 if (node_metadata.has_original_leaf()) {
572 parent_values = node_value(tree_id, node_id);
573 }
574 for (int32 i = 0, end = parent_values.size(); i < end; ++i) {
575 nodes_meta->at(left_id).second.emplace_back(parent_values[i] -
576 left_child_values[i]);
577 nodes_meta->at(right_id).second.emplace_back(parent_values[i] -
578 right_child_values[i]);
579 }
580 // Clear gain for leaf node.
581 node->clear_metadata();
582 }
583 }
584
585 } // namespace tensorflow
586