• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17 
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 
24 namespace tensorflow {
25 
26 // Constructor.
BoostedTreesEnsembleResource()27 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
28     : tree_ensemble_(
29           protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
30               &arena_)) {}
31 
DebugString() const32 string BoostedTreesEnsembleResource::DebugString() const {
33   return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
34                          "]");
35 }
36 
InitFromSerialized(const string & serialized,const int64_t stamp_token)37 bool BoostedTreesEnsembleResource::InitFromSerialized(
38     const string& serialized, const int64_t stamp_token) {
39   CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
40   if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
41     set_stamp(stamp_token);
42     return true;
43   }
44   return false;
45 }
46 
SerializeAsString() const47 string BoostedTreesEnsembleResource::SerializeAsString() const {
48   return tree_ensemble_->SerializeAsString();
49 }
50 
num_trees() const51 int32 BoostedTreesEnsembleResource::num_trees() const {
52   return tree_ensemble_->trees_size();
53 }
54 
next_node(const int32_t tree_id,const int32_t node_id,const int32_t index_in_batch,const std::vector<TTypes<int32>::ConstMatrix> & bucketized_features) const55 int32 BoostedTreesEnsembleResource::next_node(
56     const int32_t tree_id, const int32_t node_id, const int32_t index_in_batch,
57     const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const {
58   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
59   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
60   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
61 
62   switch (node.node_case()) {
63     case boosted_trees::Node::kBucketizedSplit: {
64       const auto& split = node.bucketized_split();
65       const auto bucketized_feature = bucketized_features[split.feature_id()];
66       return bucketized_feature(index_in_batch, split.dimension_id()) <=
67                      split.threshold()
68                  ? split.left_id()
69                  : split.right_id();
70     }
71     case boosted_trees::Node::kCategoricalSplit: {
72       const auto& split = node.categorical_split();
73       const auto bucketized_feature = bucketized_features[split.feature_id()];
74       return bucketized_feature(index_in_batch, split.dimension_id()) ==
75                      split.value()
76                  ? split.left_id()
77                  : split.right_id();
78     }
79     default:
80       DCHECK(false) << "Node type " << node.node_case() << " not supported.";
81   }
82   return -1;
83 }
84 
node_value(const int32_t tree_id,const int32_t node_id) const85 std::vector<float> BoostedTreesEnsembleResource::node_value(
86     const int32_t tree_id, const int32_t node_id) const {
87   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
88   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
89   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
90   if (node.node_case() == boosted_trees::Node::kLeaf) {
91     if (node.leaf().has_vector()) {
92       std::vector<float> leaf_values;
93       const auto& leaf_value_vector = node.leaf().vector();
94       const int size = leaf_value_vector.value_size();
95       leaf_values.reserve(size);
96       for (int i = 0; i < size; ++i) {
97         leaf_values.push_back(leaf_value_vector.value(i));
98       }
99       return leaf_values;
100     } else {
101       return {node.leaf().scalar()};
102     }
103   } else {
104     if (node.metadata().original_leaf().has_vector()) {
105       std::vector<float> node_values;
106       const auto& leaf_value_vector = node.metadata().original_leaf().vector();
107       const int size = leaf_value_vector.value_size();
108       node_values.reserve(size);
109       for (int i = 0; i < size; ++i) {
110         node_values.push_back(leaf_value_vector.value(i));
111       }
112       return node_values;
113     } else if (node.metadata().has_original_leaf()) {
114       return {node.metadata().original_leaf().scalar()};
115     } else {
116       return {};
117     }
118   }
119 }
120 
set_node_value(const int32_t tree_id,const int32_t node_id,const float logits)121 void BoostedTreesEnsembleResource::set_node_value(const int32_t tree_id,
122                                                   const int32_t node_id,
123                                                   const float logits) {
124   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
125   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
126   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
127   DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
128   node->mutable_leaf()->set_scalar(logits);
129 }
130 
GetNumLayersGrown(const int32_t tree_id) const131 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
132     const int32_t tree_id) const {
133   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
134   return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
135 }
136 
SetNumLayersGrown(const int32_t tree_id,int32_t new_num_layers) const137 void BoostedTreesEnsembleResource::SetNumLayersGrown(
138     const int32_t tree_id, int32_t new_num_layers) const {
139   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
140   tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
141       new_num_layers);
142 }
143 
UpdateLastLayerNodesRange(const int32_t node_range_start,int32_t node_range_end) const144 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
145     const int32_t node_range_start, int32_t node_range_end) const {
146   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
147       node_range_start);
148   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
149       node_range_end);
150 }
151 
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const152 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
153     int32* node_range_start, int32* node_range_end) const {
154   *node_range_start =
155       tree_ensemble_->growing_metadata().last_layer_node_start();
156   *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
157 }
158 
GetNumNodes(const int32_t tree_id)159 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32_t tree_id) {
160   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
161   return tree_ensemble_->trees(tree_id).nodes_size();
162 }
163 
GetNumLayersAttempted()164 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
165   return tree_ensemble_->growing_metadata().num_layers_attempted();
166 }
167 
is_leaf(const int32_t tree_id,const int32_t node_id) const168 bool BoostedTreesEnsembleResource::is_leaf(const int32_t tree_id,
169                                            const int32_t node_id) const {
170   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
171   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
172   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
173   return node.node_case() == boosted_trees::Node::kLeaf;
174 }
175 
feature_id(const int32_t tree_id,const int32_t node_id) const176 int32 BoostedTreesEnsembleResource::feature_id(const int32_t tree_id,
177                                                const int32_t node_id) const {
178   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
179   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
180   return node.bucketized_split().feature_id();
181 }
182 
bucket_threshold(const int32_t tree_id,const int32_t node_id) const183 int32 BoostedTreesEnsembleResource::bucket_threshold(
184     const int32_t tree_id, const int32_t node_id) const {
185   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
186   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
187   return node.bucketized_split().threshold();
188 }
189 
left_id(const int32_t tree_id,const int32_t node_id) const190 int32 BoostedTreesEnsembleResource::left_id(const int32_t tree_id,
191                                             const int32_t node_id) const {
192   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
193   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
194   return node.bucketized_split().left_id();
195 }
196 
right_id(const int32_t tree_id,const int32_t node_id) const197 int32 BoostedTreesEnsembleResource::right_id(const int32_t tree_id,
198                                              const int32_t node_id) const {
199   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
200   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
201   return node.bucketized_split().right_id();
202 }
203 
GetTreeWeights() const204 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
205   return {tree_ensemble_->tree_weights().begin(),
206           tree_ensemble_->tree_weights().end()};
207 }
208 
GetTreeWeight(const int32_t tree_id) const209 float BoostedTreesEnsembleResource::GetTreeWeight(const int32_t tree_id) const {
210   return tree_ensemble_->tree_weights(tree_id);
211 }
212 
IsTreeFinalized(const int32_t tree_id) const213 float BoostedTreesEnsembleResource::IsTreeFinalized(
214     const int32_t tree_id) const {
215   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
216   return tree_ensemble_->tree_metadata(tree_id).is_finalized();
217 }
218 
IsTreePostPruned(const int32_t tree_id) const219 float BoostedTreesEnsembleResource::IsTreePostPruned(
220     const int32_t tree_id) const {
221   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
222   return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
223          0;
224 }
225 
SetIsFinalized(const int32_t tree_id,const bool is_finalized)226 void BoostedTreesEnsembleResource::SetIsFinalized(const int32_t tree_id,
227                                                   const bool is_finalized) {
228   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
229   return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
230       is_finalized);
231 }
232 
233 // Sets the weight of i'th tree.
SetTreeWeight(const int32_t tree_id,const float weight)234 void BoostedTreesEnsembleResource::SetTreeWeight(const int32_t tree_id,
235                                                  const float weight) {
236   DCHECK_GE(tree_id, 0);
237   DCHECK_LT(tree_id, num_trees());
238   tree_ensemble_->set_tree_weights(tree_id, weight);
239 }
240 
UpdateGrowingMetadata() const241 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
242   tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
243       tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
244 
245   const int n_trees = num_trees();
246 
247   if (n_trees <= 0 ||
248       // Checks if we are building the first layer of the dummy empty tree
249       ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
250        (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
251     tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
252         tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
253   }
254 }
255 
256 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight,const int32_t logits_dimension)257 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
258                                                const int32_t logits_dimension) {
259   const std::vector<float> empty_leaf(logits_dimension);
260   return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
261 }
262 
AddNewTreeWithLogits(const float weight,const std::vector<float> & logits,const int32_t logits_dimension)263 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
264     const float weight, const std::vector<float>& logits,
265     const int32_t logits_dimension) {
266   const int32_t new_tree_id = tree_ensemble_->trees_size();
267   auto* node = tree_ensemble_->add_trees()->add_nodes();
268   if (logits_dimension == 1) {
269     node->mutable_leaf()->set_scalar(logits[0]);
270   } else {
271     for (int32_t i = 0; i < logits_dimension; ++i) {
272       node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
273     }
274   }
275   tree_ensemble_->add_tree_weights(weight);
276   tree_ensemble_->add_tree_metadata();
277 
278   return new_tree_id;
279 }
280 
AddBucketizedSplitNode(const int32_t tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32_t logits_dimension,int32 * left_node_id,int32 * right_node_id)281 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
282     const int32_t tree_id,
283     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
284     const int32_t logits_dimension, int32* left_node_id, int32* right_node_id) {
285   const auto candidate = split_entry.second;
286   auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
287                             left_node_id, right_node_id);
288   auto* new_split = node->mutable_bucketized_split();
289   new_split->set_feature_id(candidate.feature_id);
290   new_split->set_threshold(candidate.threshold);
291   new_split->set_dimension_id(candidate.dimension_id);
292   new_split->set_left_id(*left_node_id);
293   new_split->set_right_id(*right_node_id);
294 
295   boosted_trees::SplitTypeWithDefault split_type_with_default;
296   bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
297       candidate.split_type, &split_type_with_default);
298   DCHECK(parsed);
299   if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
300     new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
301   } else {
302     new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
303   }
304 }
305 
AddCategoricalSplitNode(const int32_t tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32_t logits_dimension,int32 * left_node_id,int32 * right_node_id)306 void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
307     const int32_t tree_id,
308     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
309     const int32_t logits_dimension, int32* left_node_id, int32* right_node_id) {
310   const auto candidate = split_entry.second;
311   auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
312                             left_node_id, right_node_id);
313   auto* new_split = node->mutable_categorical_split();
314   new_split->set_feature_id(candidate.feature_id);
315   new_split->set_value(candidate.threshold);
316   new_split->set_dimension_id(candidate.dimension_id);
317   new_split->set_left_id(*left_node_id);
318   new_split->set_right_id(*right_node_id);
319 }
320 
AddLeafNodes(const int32_t tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32_t logits_dimension,int32 * left_node_id,int32 * right_node_id)321 boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
322     const int32_t tree_id,
323     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
324     const int32_t logits_dimension, int32* left_node_id, int32* right_node_id) {
325   auto* tree = tree_ensemble_->mutable_trees(tree_id);
326   const auto node_id = split_entry.first;
327   const auto candidate = split_entry.second;
328   auto* node = tree->mutable_nodes(node_id);
329   DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
330   *left_node_id = tree->nodes_size();
331   *right_node_id = *left_node_id + 1;
332   auto* left_node = tree->add_nodes();
333   auto* right_node = tree->add_nodes();
334   const bool has_leaf_value =
335       node->has_leaf() &&
336       ((logits_dimension == 1 && (node->leaf().scalar() != 0)) ||
337        node->leaf().has_vector());
338   if (node_id != 0 || has_leaf_value) {
339     // Save previous leaf value if it is not the first leaf in the tree.
340     node->mutable_metadata()->mutable_original_leaf()->Swap(
341         node->mutable_leaf());
342   }
343   node->mutable_metadata()->set_gain(candidate.gain);
344   // TODO(nponomareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
345   if (logits_dimension == 1) {
346     const float prev_logit_value = node->metadata().original_leaf().scalar();
347     left_node->mutable_leaf()->set_scalar(prev_logit_value +
348                                           candidate.left_node_contribs[0]);
349     right_node->mutable_leaf()->set_scalar(prev_logit_value +
350                                            candidate.right_node_contribs[0]);
351   } else {
352     if (has_leaf_value) {
353       DCHECK_EQ(logits_dimension,
354                 node->metadata().original_leaf().vector().value_size());
355     }
356     float prev_logit_value = 0.0;
357     for (int32_t i = 0; i < logits_dimension; ++i) {
358       if (has_leaf_value) {
359         prev_logit_value = node->metadata().original_leaf().vector().value(i);
360       }
361       left_node->mutable_leaf()->mutable_vector()->add_value(
362           prev_logit_value + candidate.left_node_contribs[i]);
363       right_node->mutable_leaf()->mutable_vector()->add_value(
364           prev_logit_value + candidate.right_node_contribs[i]);
365     }
366   }
367   return node;
368 }
369 
Reset()370 void BoostedTreesEnsembleResource::Reset() {
371   // Reset stamp.
372   set_stamp(-1);
373 
374   // Clear tree ensemle.
375   arena_.Reset();
376   tree_ensemble_ =
377       protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
378 }
379 
PostPruneTree(const int32_t current_tree,const int32_t logits_dimension)380 void BoostedTreesEnsembleResource::PostPruneTree(
381     const int32_t current_tree, const int32_t logits_dimension) {
382   // No-op if tree is empty.
383   auto* tree = tree_ensemble_->mutable_trees(current_tree);
384   int32_t num_nodes = tree->nodes_size();
385   if (num_nodes == 0) {
386     return;
387   }
388 
389   std::vector<int32> nodes_to_delete;
390   // If a node was pruned, we need to save the change of the prediction from
391   // this node to its parent, as well as the parent id.
392   std::vector<std::pair<int32, std::vector<float>>> nodes_changes;
393   nodes_changes.reserve(num_nodes);
394   for (int32_t i = 0; i < num_nodes; ++i) {
395     std::vector<float> prune_logit_changes;
396     nodes_changes.emplace_back(i, prune_logit_changes);
397   }
398   // Prune the tree recursively starting from the root. Each node that has
399   // negative gain and only leaf children will be pruned recursively up from
400   // the bottom of the tree. This method returns the list of nodes pruned, and
401   // updates the nodes in the tree not to refer to those pruned nodes.
402   RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
403                                     &nodes_changes);
404 
405   if (nodes_to_delete.empty()) {
406     // No pruning happened, and no post-processing needed.
407     return;
408   }
409 
410   // Sort node ids so they are in asc order.
411   std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
412 
413   // We need to
414   // - update split left and right children ids with new indices
415   // - actually remove the nodes that need to be removed
416   // - save the information about pruned node so we could recover the
417   // predictions from cache. Build a map for old node index=>new node index.
418   // nodes_to_delete contains nodes who's indices should be skipped, in
419   // ascending order. Save the information about new indices into meta.
420   std::map<int32, int32> old_to_new_ids;
421   int32_t new_index = 0;
422   int32_t index_for_deleted = 0;
423   auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
424                               ->mutable_post_pruned_nodes_meta();
425 
426   for (int32_t i = 0; i < num_nodes; ++i) {
427     const int64_t nodes_to_delete_size = nodes_to_delete.size();
428     if (index_for_deleted < nodes_to_delete_size &&
429         i == nodes_to_delete[index_for_deleted]) {
430       // Node i will get removed,
431       ++index_for_deleted;
432       // Update meta info that will allow us to use cached predictions from
433       // those nodes.
434       int32_t new_id;
435       std::vector<float> logit_changes;
436       logit_changes.reserve(logits_dimension);
437       CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_changes);
438       auto* meta = post_prune_meta->Add();
439       meta->set_new_node_id(old_to_new_ids[new_id]);
440       for (int32_t j = 0; j < logits_dimension; ++j) {
441         meta->add_logit_change(logit_changes[j]);
442       }
443     } else {
444       old_to_new_ids[i] = new_index++;
445       auto* meta = post_prune_meta->Add();
446       // Update meta info that will allow us to use cached predictions from
447       // those nodes.
448       meta->set_new_node_id(old_to_new_ids[i]);
449       for (int32_t i = 0; i < logits_dimension; ++i) {
450         meta->add_logit_change(0.0);
451       }
452     }
453   }
454   index_for_deleted = 0;
455   int32_t i = 0;
456   protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
457   new_nodes.Reserve(old_to_new_ids.size());
458   for (auto node : *(tree->mutable_nodes())) {
459     const int64_t nodes_to_delete_size = nodes_to_delete.size();
460     if (index_for_deleted < nodes_to_delete_size &&
461         i == nodes_to_delete[index_for_deleted]) {
462       ++index_for_deleted;
463       ++i;
464       continue;
465     } else {
466       if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
467         node.mutable_bucketized_split()->set_left_id(
468             old_to_new_ids[node.bucketized_split().left_id()]);
469         node.mutable_bucketized_split()->set_right_id(
470             old_to_new_ids[node.bucketized_split().right_id()]);
471       }
472       *new_nodes.Add() = std::move(node);
473     }
474     ++i;
475   }
476   // Replace all the nodes in a tree with the ones we keep.
477   *tree->mutable_nodes() = std::move(new_nodes);
478 
479   // Note that if the whole tree got pruned, we will end up with one node.
480   // We can't remove that tree because it will cause problems with cache.
481 }
482 
GetPostPruneCorrection(const int32_t tree_id,const int32_t initial_node_id,int32 * current_node_id,std::vector<float> * logit_updates) const483 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
484     const int32_t tree_id, const int32_t initial_node_id,
485     int32* current_node_id, std::vector<float>* logit_updates) const {
486   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
487   if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
488     DCHECK_LT(
489         initial_node_id,
490         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
491     const auto& meta =
492         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
493             initial_node_id);
494     *current_node_id = meta.new_node_id();
495     DCHECK_EQ(meta.logit_change().size(), logit_updates->size());
496     for (int32_t i = 0; i < meta.logit_change().size(); ++i) {
497       logit_updates->at(i) = meta.logit_change(i);
498     }
499   }
500 }
501 
IsTerminalSplitNode(const int32_t tree_id,const int32_t node_id) const502 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
503     const int32_t tree_id, const int32_t node_id) const {
504   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
505   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
506   const int32_t left_id = node.bucketized_split().left_id();
507   const int32_t right_id = node.bucketized_split().right_id();
508   return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
509 }
510 
511 // For each pruned node, finds the leaf where it finally ended up and
512 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32_t start_node_id,const std::vector<std::pair<int32,std::vector<float>>> & nodes_changes,int32 * parent_id,std::vector<float> * changes) const513 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
514     const int32_t start_node_id,
515     const std::vector<std::pair<int32, std::vector<float>>>& nodes_changes,
516     int32* parent_id, std::vector<float>* changes) const {
517   const int logits_dimension = nodes_changes[start_node_id].second.size();
518   for (int i = 0; i < logits_dimension; ++i) {
519     changes->emplace_back(0.0);
520   }
521   int32_t node_id = start_node_id;
522   int32_t parent = nodes_changes[node_id].first;
523   while (parent != node_id) {
524     for (int i = 0; i < logits_dimension; ++i) {
525       changes->at(i) += nodes_changes[node_id].second[i];
526     }
527     node_id = parent;
528     parent = nodes_changes[node_id].first;
529   }
530   *parent_id = parent;
531 }
532 
RecursivelyDoPostPrunePreparation(const int32_t tree_id,const int32_t node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,std::vector<float>>> * nodes_meta)533 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
534     const int32_t tree_id, const int32_t node_id,
535     std::vector<int32>* nodes_to_delete,
536     std::vector<std::pair<int32, std::vector<float>>>* nodes_meta) {
537   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
538   DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
539   // Base case when we reach a leaf.
540   if (node->node_case() == boosted_trees::Node::kLeaf) {
541     return;
542   }
543 
544   // Traverse node children first and recursively prune their sub-trees.
545   RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
546                                     nodes_to_delete, nodes_meta);
547   RecursivelyDoPostPrunePreparation(tree_id,
548                                     node->bucketized_split().right_id(),
549                                     nodes_to_delete, nodes_meta);
550 
551   // Two conditions must be satisfied to prune the node:
552   // 1- The split gain is negative.
553   // 2- After depth-first pruning, the node only has leaf children.
554   const auto& node_metadata = node->metadata();
555   if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
556     const int32_t left_id = node->bucketized_split().left_id();
557     const int32_t right_id = node->bucketized_split().right_id();
558 
559     // Save children that need to be deleted.
560     nodes_to_delete->push_back(left_id);
561     nodes_to_delete->push_back(right_id);
562 
563     // Change node back into leaf.
564     *node->mutable_leaf() = node_metadata.original_leaf();
565 
566     // Save the old values of weights of children.
567     nodes_meta->at(left_id).first = node_id;
568     nodes_meta->at(right_id).first = node_id;
569     const auto& left_child_values = node_value(tree_id, left_id);
570     const auto& right_child_values = node_value(tree_id, right_id);
571     std::vector<float> parent_values(left_child_values.size(), 0.0);
572     if (node_metadata.has_original_leaf()) {
573       parent_values = node_value(tree_id, node_id);
574     }
575     for (int32_t i = 0, end = parent_values.size(); i < end; ++i) {
576       nodes_meta->at(left_id).second.emplace_back(parent_values[i] -
577                                                   left_child_values[i]);
578       nodes_meta->at(right_id).second.emplace_back(parent_values[i] -
579                                                    right_child_values[i]);
580     }
581     // Clear gain for leaf node.
582     node->clear_metadata();
583   }
584 }
585 
586 }  // namespace tensorflow
587