• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17 
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 
24 namespace tensorflow {
25 
26 // Constructor.
BoostedTreesEnsembleResource()27 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
28     : tree_ensemble_(
29           protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
30               &arena_)) {}
31 
DebugString() const32 string BoostedTreesEnsembleResource::DebugString() const {
33   return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
34                          "]");
35 }
36 
InitFromSerialized(const string & serialized,const int64 stamp_token)37 bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
38                                                       const int64 stamp_token) {
39   CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
40   if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
41     set_stamp(stamp_token);
42     return true;
43   }
44   return false;
45 }
46 
SerializeAsString() const47 string BoostedTreesEnsembleResource::SerializeAsString() const {
48   return tree_ensemble_->SerializeAsString();
49 }
50 
num_trees() const51 int32 BoostedTreesEnsembleResource::num_trees() const {
52   return tree_ensemble_->trees_size();
53 }
54 
next_node(const int32 tree_id,const int32 node_id,const int32 index_in_batch,const std::vector<TTypes<int32>::ConstMatrix> & bucketized_features) const55 int32 BoostedTreesEnsembleResource::next_node(
56     const int32 tree_id, const int32 node_id, const int32 index_in_batch,
57     const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const {
58   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
59   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
60   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
61 
62   switch (node.node_case()) {
63     case boosted_trees::Node::kBucketizedSplit: {
64       const auto& split = node.bucketized_split();
65       const auto bucketized_feature = bucketized_features[split.feature_id()];
66       return bucketized_feature(index_in_batch, split.dimension_id()) <=
67                      split.threshold()
68                  ? split.left_id()
69                  : split.right_id();
70     }
71     case boosted_trees::Node::kCategoricalSplit: {
72       const auto& split = node.categorical_split();
73       const auto bucketized_feature = bucketized_features[split.feature_id()];
74       return bucketized_feature(index_in_batch, split.dimension_id()) ==
75                      split.value()
76                  ? split.left_id()
77                  : split.right_id();
78     }
79     default:
80       DCHECK(false) << "Node type " << node.node_case() << " not supported.";
81   }
82   return -1;
83 }
84 
node_value(const int32 tree_id,const int32 node_id) const85 std::vector<float> BoostedTreesEnsembleResource::node_value(
86     const int32 tree_id, const int32 node_id) const {
87   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
88   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
89   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
90   if (node.node_case() == boosted_trees::Node::kLeaf) {
91     if (node.leaf().has_vector()) {
92       std::vector<float> leaf_values;
93       const auto& leaf_value_vector = node.leaf().vector();
94       const int size = leaf_value_vector.value_size();
95       leaf_values.reserve(size);
96       for (int i = 0; i < size; ++i) {
97         leaf_values.push_back(leaf_value_vector.value(i));
98       }
99       return leaf_values;
100     } else {
101       return {node.leaf().scalar()};
102     }
103   } else {
104     if (node.metadata().original_leaf().has_vector()) {
105       std::vector<float> node_values;
106       const auto& leaf_value_vector = node.metadata().original_leaf().vector();
107       const int size = leaf_value_vector.value_size();
108       node_values.reserve(size);
109       for (int i = 0; i < size; ++i) {
110         node_values.push_back(leaf_value_vector.value(i));
111       }
112       return node_values;
113     } else if (node.metadata().has_original_leaf()) {
114       return {node.metadata().original_leaf().scalar()};
115     } else {
116       return {};
117     }
118   }
119 }
120 
set_node_value(const int32 tree_id,const int32 node_id,const float logits)121 void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
122                                                   const int32 node_id,
123                                                   const float logits) {
124   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
125   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
126   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
127   DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
128   node->mutable_leaf()->set_scalar(logits);
129 }
130 
GetNumLayersGrown(const int32 tree_id) const131 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
132     const int32 tree_id) const {
133   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
134   return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
135 }
136 
SetNumLayersGrown(const int32 tree_id,int32 new_num_layers) const137 void BoostedTreesEnsembleResource::SetNumLayersGrown(
138     const int32 tree_id, int32 new_num_layers) const {
139   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
140   tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
141       new_num_layers);
142 }
143 
UpdateLastLayerNodesRange(const int32 node_range_start,int32 node_range_end) const144 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
145     const int32 node_range_start, int32 node_range_end) const {
146   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
147       node_range_start);
148   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
149       node_range_end);
150 }
151 
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const152 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
153     int32* node_range_start, int32* node_range_end) const {
154   *node_range_start =
155       tree_ensemble_->growing_metadata().last_layer_node_start();
156   *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
157 }
158 
GetNumNodes(const int32 tree_id)159 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
160   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
161   return tree_ensemble_->trees(tree_id).nodes_size();
162 }
163 
GetNumLayersAttempted()164 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
165   return tree_ensemble_->growing_metadata().num_layers_attempted();
166 }
167 
is_leaf(const int32 tree_id,const int32 node_id) const168 bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
169                                            const int32 node_id) const {
170   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
171   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
172   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
173   return node.node_case() == boosted_trees::Node::kLeaf;
174 }
175 
feature_id(const int32 tree_id,const int32 node_id) const176 int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
177                                                const int32 node_id) const {
178   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
179   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
180   return node.bucketized_split().feature_id();
181 }
182 
bucket_threshold(const int32 tree_id,const int32 node_id) const183 int32 BoostedTreesEnsembleResource::bucket_threshold(
184     const int32 tree_id, const int32 node_id) const {
185   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
186   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
187   return node.bucketized_split().threshold();
188 }
189 
left_id(const int32 tree_id,const int32 node_id) const190 int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
191                                             const int32 node_id) const {
192   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
193   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
194   return node.bucketized_split().left_id();
195 }
196 
right_id(const int32 tree_id,const int32 node_id) const197 int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
198                                              const int32 node_id) const {
199   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
200   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
201   return node.bucketized_split().right_id();
202 }
203 
GetTreeWeights() const204 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
205   return {tree_ensemble_->tree_weights().begin(),
206           tree_ensemble_->tree_weights().end()};
207 }
208 
GetTreeWeight(const int32 tree_id) const209 float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
210   return tree_ensemble_->tree_weights(tree_id);
211 }
212 
IsTreeFinalized(const int32 tree_id) const213 float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
214   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
215   return tree_ensemble_->tree_metadata(tree_id).is_finalized();
216 }
217 
IsTreePostPruned(const int32 tree_id) const218 float BoostedTreesEnsembleResource::IsTreePostPruned(
219     const int32 tree_id) const {
220   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
221   return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
222          0;
223 }
224 
SetIsFinalized(const int32 tree_id,const bool is_finalized)225 void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
226                                                   const bool is_finalized) {
227   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
228   return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
229       is_finalized);
230 }
231 
232 // Sets the weight of i'th tree.
SetTreeWeight(const int32 tree_id,const float weight)233 void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
234                                                  const float weight) {
235   DCHECK_GE(tree_id, 0);
236   DCHECK_LT(tree_id, num_trees());
237   tree_ensemble_->set_tree_weights(tree_id, weight);
238 }
239 
UpdateGrowingMetadata() const240 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
241   tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
242       tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
243 
244   const int n_trees = num_trees();
245 
246   if (n_trees <= 0 ||
247       // Checks if we are building the first layer of the dummy empty tree
248       ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
249        (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
250     tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
251         tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
252   }
253 }
254 
255 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight,const int32 logits_dimension)256 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
257                                                const int32 logits_dimension) {
258   const std::vector<float> empty_leaf(logits_dimension);
259   return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
260 }
261 
AddNewTreeWithLogits(const float weight,const std::vector<float> & logits,const int32 logits_dimension)262 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
263     const float weight, const std::vector<float>& logits,
264     const int32 logits_dimension) {
265   const int32 new_tree_id = tree_ensemble_->trees_size();
266   auto* node = tree_ensemble_->add_trees()->add_nodes();
267   if (logits_dimension == 1) {
268     node->mutable_leaf()->set_scalar(logits[0]);
269   } else {
270     for (int32 i = 0; i < logits_dimension; ++i) {
271       node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
272     }
273   }
274   tree_ensemble_->add_tree_weights(weight);
275   tree_ensemble_->add_tree_metadata();
276 
277   return new_tree_id;
278 }
279 
AddBucketizedSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)280 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
281     const int32 tree_id,
282     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
283     const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
284   const auto candidate = split_entry.second;
285   auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
286                             left_node_id, right_node_id);
287   auto* new_split = node->mutable_bucketized_split();
288   new_split->set_feature_id(candidate.feature_id);
289   new_split->set_threshold(candidate.threshold);
290   new_split->set_dimension_id(candidate.dimension_id);
291   new_split->set_left_id(*left_node_id);
292   new_split->set_right_id(*right_node_id);
293 
294   boosted_trees::SplitTypeWithDefault split_type_with_default;
295   bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
296       candidate.split_type, &split_type_with_default);
297   DCHECK(parsed);
298   if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
299     new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
300   } else {
301     new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
302   }
303 }
304 
AddCategoricalSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)305 void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
306     const int32 tree_id,
307     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
308     const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
309   const auto candidate = split_entry.second;
310   auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
311                             left_node_id, right_node_id);
312   auto* new_split = node->mutable_categorical_split();
313   new_split->set_feature_id(candidate.feature_id);
314   new_split->set_value(candidate.threshold);
315   new_split->set_dimension_id(candidate.dimension_id);
316   new_split->set_left_id(*left_node_id);
317   new_split->set_right_id(*right_node_id);
318 }
319 
AddLeafNodes(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)320 boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
321     const int32 tree_id,
322     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
323     const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
324   auto* tree = tree_ensemble_->mutable_trees(tree_id);
325   const auto node_id = split_entry.first;
326   const auto candidate = split_entry.second;
327   auto* node = tree->mutable_nodes(node_id);
328   DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
329   *left_node_id = tree->nodes_size();
330   *right_node_id = *left_node_id + 1;
331   auto* left_node = tree->add_nodes();
332   auto* right_node = tree->add_nodes();
333   const bool has_leaf_value =
334       node->has_leaf() &&
335       ((logits_dimension == 1 && (node->leaf().scalar() != 0)) ||
336        node->leaf().has_vector());
337   if (node_id != 0 || has_leaf_value) {
338     // Save previous leaf value if it is not the first leaf in the tree.
339     node->mutable_metadata()->mutable_original_leaf()->Swap(
340         node->mutable_leaf());
341   }
342   node->mutable_metadata()->set_gain(candidate.gain);
343   // TODO(nponomareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
344   if (logits_dimension == 1) {
345     const float prev_logit_value = node->metadata().original_leaf().scalar();
346     left_node->mutable_leaf()->set_scalar(prev_logit_value +
347                                           candidate.left_node_contribs[0]);
348     right_node->mutable_leaf()->set_scalar(prev_logit_value +
349                                            candidate.right_node_contribs[0]);
350   } else {
351     if (has_leaf_value) {
352       DCHECK_EQ(logits_dimension,
353                 node->metadata().original_leaf().vector().value_size());
354     }
355     float prev_logit_value = 0.0;
356     for (int32 i = 0; i < logits_dimension; ++i) {
357       if (has_leaf_value) {
358         prev_logit_value = node->metadata().original_leaf().vector().value(i);
359       }
360       left_node->mutable_leaf()->mutable_vector()->add_value(
361           prev_logit_value + candidate.left_node_contribs[i]);
362       right_node->mutable_leaf()->mutable_vector()->add_value(
363           prev_logit_value + candidate.right_node_contribs[i]);
364     }
365   }
366   return node;
367 }
368 
Reset()369 void BoostedTreesEnsembleResource::Reset() {
370   // Reset stamp.
371   set_stamp(-1);
372 
373   // Clear tree ensemle.
374   arena_.Reset();
375   tree_ensemble_ =
376       protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
377 }
378 
PostPruneTree(const int32 current_tree,const int32 logits_dimension)379 void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree,
380                                                  const int32 logits_dimension) {
381   // No-op if tree is empty.
382   auto* tree = tree_ensemble_->mutable_trees(current_tree);
383   int32 num_nodes = tree->nodes_size();
384   if (num_nodes == 0) {
385     return;
386   }
387 
388   std::vector<int32> nodes_to_delete;
389   // If a node was pruned, we need to save the change of the prediction from
390   // this node to its parent, as well as the parent id.
391   std::vector<std::pair<int32, std::vector<float>>> nodes_changes;
392   nodes_changes.reserve(num_nodes);
393   for (int32 i = 0; i < num_nodes; ++i) {
394     std::vector<float> prune_logit_changes;
395     nodes_changes.emplace_back(i, prune_logit_changes);
396   }
397   // Prune the tree recursively starting from the root. Each node that has
398   // negative gain and only leaf children will be pruned recursively up from
399   // the bottom of the tree. This method returns the list of nodes pruned, and
400   // updates the nodes in the tree not to refer to those pruned nodes.
401   RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
402                                     &nodes_changes);
403 
404   if (nodes_to_delete.empty()) {
405     // No pruning happened, and no post-processing needed.
406     return;
407   }
408 
409   // Sort node ids so they are in asc order.
410   std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
411 
412   // We need to
413   // - update split left and right children ids with new indices
414   // - actually remove the nodes that need to be removed
415   // - save the information about pruned node so we could recover the
416   // predictions from cache. Build a map for old node index=>new node index.
417   // nodes_to_delete contains nodes who's indices should be skipped, in
418   // ascending order. Save the information about new indices into meta.
419   std::map<int32, int32> old_to_new_ids;
420   int32 new_index = 0;
421   int32 index_for_deleted = 0;
422   auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
423                               ->mutable_post_pruned_nodes_meta();
424 
425   for (int32 i = 0; i < num_nodes; ++i) {
426     const int64 nodes_to_delete_size = nodes_to_delete.size();
427     if (index_for_deleted < nodes_to_delete_size &&
428         i == nodes_to_delete[index_for_deleted]) {
429       // Node i will get removed,
430       ++index_for_deleted;
431       // Update meta info that will allow us to use cached predictions from
432       // those nodes.
433       int32 new_id;
434       std::vector<float> logit_changes;
435       logit_changes.reserve(logits_dimension);
436       CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_changes);
437       auto* meta = post_prune_meta->Add();
438       meta->set_new_node_id(old_to_new_ids[new_id]);
439       for (int32 j = 0; j < logits_dimension; ++j) {
440         meta->add_logit_change(logit_changes[j]);
441       }
442     } else {
443       old_to_new_ids[i] = new_index++;
444       auto* meta = post_prune_meta->Add();
445       // Update meta info that will allow us to use cached predictions from
446       // those nodes.
447       meta->set_new_node_id(old_to_new_ids[i]);
448       for (int32 i = 0; i < logits_dimension; ++i) {
449         meta->add_logit_change(0.0);
450       }
451     }
452   }
453   index_for_deleted = 0;
454   int32 i = 0;
455   protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
456   new_nodes.Reserve(old_to_new_ids.size());
457   for (auto node : *(tree->mutable_nodes())) {
458     const int64 nodes_to_delete_size = nodes_to_delete.size();
459     if (index_for_deleted < nodes_to_delete_size &&
460         i == nodes_to_delete[index_for_deleted]) {
461       ++index_for_deleted;
462       ++i;
463       continue;
464     } else {
465       if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
466         node.mutable_bucketized_split()->set_left_id(
467             old_to_new_ids[node.bucketized_split().left_id()]);
468         node.mutable_bucketized_split()->set_right_id(
469             old_to_new_ids[node.bucketized_split().right_id()]);
470       }
471       *new_nodes.Add() = std::move(node);
472     }
473     ++i;
474   }
475   // Replace all the nodes in a tree with the ones we keep.
476   *tree->mutable_nodes() = std::move(new_nodes);
477 
478   // Note that if the whole tree got pruned, we will end up with one node.
479   // We can't remove that tree because it will cause problems with cache.
480 }
481 
GetPostPruneCorrection(const int32 tree_id,const int32 initial_node_id,int32 * current_node_id,std::vector<float> * logit_updates) const482 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
483     const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
484     std::vector<float>* logit_updates) const {
485   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
486   if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
487     DCHECK_LT(
488         initial_node_id,
489         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
490     const auto& meta =
491         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
492             initial_node_id);
493     *current_node_id = meta.new_node_id();
494     DCHECK_EQ(meta.logit_change().size(), logit_updates->size());
495     for (int32 i = 0; i < meta.logit_change().size(); ++i) {
496       logit_updates->at(i) = meta.logit_change(i);
497     }
498   }
499 }
500 
IsTerminalSplitNode(const int32 tree_id,const int32 node_id) const501 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
502     const int32 tree_id, const int32 node_id) const {
503   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
504   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
505   const int32 left_id = node.bucketized_split().left_id();
506   const int32 right_id = node.bucketized_split().right_id();
507   return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
508 }
509 
510 // For each pruned node, finds the leaf where it finally ended up and
511 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32 start_node_id,const std::vector<std::pair<int32,std::vector<float>>> & nodes_changes,int32 * parent_id,std::vector<float> * changes) const512 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
513     const int32 start_node_id,
514     const std::vector<std::pair<int32, std::vector<float>>>& nodes_changes,
515     int32* parent_id, std::vector<float>* changes) const {
516   const int logits_dimension = nodes_changes[start_node_id].second.size();
517   for (int i = 0; i < logits_dimension; ++i) {
518     changes->emplace_back(0.0);
519   }
520   int32 node_id = start_node_id;
521   int32 parent = nodes_changes[node_id].first;
522   while (parent != node_id) {
523     for (int i = 0; i < logits_dimension; ++i) {
524       changes->at(i) += nodes_changes[node_id].second[i];
525     }
526     node_id = parent;
527     parent = nodes_changes[node_id].first;
528   }
529   *parent_id = parent;
530 }
531 
RecursivelyDoPostPrunePreparation(const int32 tree_id,const int32 node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,std::vector<float>>> * nodes_meta)532 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
533     const int32 tree_id, const int32 node_id,
534     std::vector<int32>* nodes_to_delete,
535     std::vector<std::pair<int32, std::vector<float>>>* nodes_meta) {
536   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
537   DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
538   // Base case when we reach a leaf.
539   if (node->node_case() == boosted_trees::Node::kLeaf) {
540     return;
541   }
542 
543   // Traverse node children first and recursively prune their sub-trees.
544   RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
545                                     nodes_to_delete, nodes_meta);
546   RecursivelyDoPostPrunePreparation(tree_id,
547                                     node->bucketized_split().right_id(),
548                                     nodes_to_delete, nodes_meta);
549 
550   // Two conditions must be satisfied to prune the node:
551   // 1- The split gain is negative.
552   // 2- After depth-first pruning, the node only has leaf children.
553   const auto& node_metadata = node->metadata();
554   if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
555     const int32 left_id = node->bucketized_split().left_id();
556     const int32 right_id = node->bucketized_split().right_id();
557 
558     // Save children that need to be deleted.
559     nodes_to_delete->push_back(left_id);
560     nodes_to_delete->push_back(right_id);
561 
562     // Change node back into leaf.
563     *node->mutable_leaf() = node_metadata.original_leaf();
564 
565     // Save the old values of weights of children.
566     nodes_meta->at(left_id).first = node_id;
567     nodes_meta->at(right_id).first = node_id;
568     const auto& left_child_values = node_value(tree_id, left_id);
569     const auto& right_child_values = node_value(tree_id, right_id);
570     std::vector<float> parent_values(left_child_values.size(), 0.0);
571     if (node_metadata.has_original_leaf()) {
572       parent_values = node_value(tree_id, node_id);
573     }
574     for (int32 i = 0, end = parent_values.size(); i < end; ++i) {
575       nodes_meta->at(left_id).second.emplace_back(parent_values[i] -
576                                                   left_child_values[i]);
577       nodes_meta->at(right_id).second.emplace_back(parent_values[i] -
578                                                    right_child_values[i]);
579     }
580     // Clear gain for leaf node.
581     node->clear_metadata();
582   }
583 }
584 
585 }  // namespace tensorflow
586