• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17 
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 
24 namespace tensorflow {
25 
26 // Constructor.
BoostedTreesEnsembleResource()27 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
28     : tree_ensemble_(
29           protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
30               &arena_)) {}
31 
DebugString() const32 string BoostedTreesEnsembleResource::DebugString() const {
33   return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
34                          "]");
35 }
36 
InitFromSerialized(const string & serialized,const int64 stamp_token)37 bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
38                                                       const int64 stamp_token) {
39   CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
40   if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
41     set_stamp(stamp_token);
42     return true;
43   }
44   return false;
45 }
46 
SerializeAsString() const47 string BoostedTreesEnsembleResource::SerializeAsString() const {
48   return tree_ensemble_->SerializeAsString();
49 }
50 
num_trees() const51 int32 BoostedTreesEnsembleResource::num_trees() const {
52   return tree_ensemble_->trees_size();
53 }
54 
next_node(const int32 tree_id,const int32 node_id,const int32 index_in_batch,const std::vector<TTypes<int32>::ConstMatrix> & bucketized_features) const55 int32 BoostedTreesEnsembleResource::next_node(
56     const int32 tree_id, const int32 node_id, const int32 index_in_batch,
57     const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const {
58   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
59   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
60   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
61 
62   switch (node.node_case()) {
63     case boosted_trees::Node::kBucketizedSplit: {
64       const auto& split = node.bucketized_split();
65       const auto bucketized_feature = bucketized_features[split.feature_id()];
66       return bucketized_feature(index_in_batch, split.dimension_id()) <=
67                      split.threshold()
68                  ? split.left_id()
69                  : split.right_id();
70     }
71     case boosted_trees::Node::kCategoricalSplit: {
72       const auto& split = node.categorical_split();
73       const auto bucketized_feature = bucketized_features[split.feature_id()];
74       return bucketized_feature(index_in_batch, split.dimension_id()) ==
75                      split.value()
76                  ? split.left_id()
77                  : split.right_id();
78     }
79     default:
80       DCHECK(false) << "Node type " << node.node_case() << " not supported.";
81   }
82   return -1;
83 }
84 
node_value(const int32 tree_id,const int32 node_id) const85 std::vector<float> BoostedTreesEnsembleResource::node_value(
86     const int32 tree_id, const int32 node_id) const {
87   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
88   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
89   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
90   if (node.node_case() == boosted_trees::Node::kLeaf) {
91     if (node.leaf().has_vector()) {
92       std::vector<float> leaf_values;
93       const auto& leaf_value_vector = node.leaf().vector();
94       const int size = leaf_value_vector.value_size();
95       leaf_values.reserve(size);
96       for (int i = 0; i < size; ++i) {
97         leaf_values.push_back(leaf_value_vector.value(i));
98       }
99       return leaf_values;
100     } else {
101       return {node.leaf().scalar()};
102     }
103   } else {
104     if (node.metadata().original_leaf().has_vector()) {
105       std::vector<float> node_values;
106       const auto& leaf_value_vector = node.metadata().original_leaf().vector();
107       const int size = leaf_value_vector.value_size();
108       node_values.reserve(size);
109       for (int i = 0; i < size; ++i) {
110         node_values.push_back(leaf_value_vector.value(i));
111       }
112       return node_values;
113     } else if (node.metadata().has_original_leaf()) {
114       return {node.metadata().original_leaf().scalar()};
115     } else {
116       return {};
117     }
118   }
119 }
120 
set_node_value(const int32 tree_id,const int32 node_id,const float logits)121 void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
122                                                   const int32 node_id,
123                                                   const float logits) {
124   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
125   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
126   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
127   DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
128   node->mutable_leaf()->set_scalar(logits);
129 }
130 
GetNumLayersGrown(const int32 tree_id) const131 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
132     const int32 tree_id) const {
133   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
134   return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
135 }
136 
SetNumLayersGrown(const int32 tree_id,int32 new_num_layers) const137 void BoostedTreesEnsembleResource::SetNumLayersGrown(
138     const int32 tree_id, int32 new_num_layers) const {
139   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
140   tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
141       new_num_layers);
142 }
143 
UpdateLastLayerNodesRange(const int32 node_range_start,int32 node_range_end) const144 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
145     const int32 node_range_start, int32 node_range_end) const {
146   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
147       node_range_start);
148   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
149       node_range_end);
150 }
151 
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const152 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
153     int32* node_range_start, int32* node_range_end) const {
154   *node_range_start =
155       tree_ensemble_->growing_metadata().last_layer_node_start();
156   *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
157 }
158 
GetNumNodes(const int32 tree_id)159 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
160   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
161   return tree_ensemble_->trees(tree_id).nodes_size();
162 }
163 
GetNumLayersAttempted()164 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
165   return tree_ensemble_->growing_metadata().num_layers_attempted();
166 }
167 
is_leaf(const int32 tree_id,const int32 node_id) const168 bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
169                                            const int32 node_id) const {
170   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
171   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
172   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
173   return node.node_case() == boosted_trees::Node::kLeaf;
174 }
175 
feature_id(const int32 tree_id,const int32 node_id) const176 int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
177                                                const int32 node_id) const {
178   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
179   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
180   return node.bucketized_split().feature_id();
181 }
182 
bucket_threshold(const int32 tree_id,const int32 node_id) const183 int32 BoostedTreesEnsembleResource::bucket_threshold(
184     const int32 tree_id, const int32 node_id) const {
185   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
186   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
187   return node.bucketized_split().threshold();
188 }
189 
left_id(const int32 tree_id,const int32 node_id) const190 int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
191                                             const int32 node_id) const {
192   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
193   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
194   return node.bucketized_split().left_id();
195 }
196 
right_id(const int32 tree_id,const int32 node_id) const197 int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
198                                              const int32 node_id) const {
199   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
200   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
201   return node.bucketized_split().right_id();
202 }
203 
GetTreeWeights() const204 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
205   return {tree_ensemble_->tree_weights().begin(),
206           tree_ensemble_->tree_weights().end()};
207 }
208 
GetTreeWeight(const int32 tree_id) const209 float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
210   return tree_ensemble_->tree_weights(tree_id);
211 }
212 
IsTreeFinalized(const int32 tree_id) const213 float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
214   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
215   return tree_ensemble_->tree_metadata(tree_id).is_finalized();
216 }
217 
IsTreePostPruned(const int32 tree_id) const218 float BoostedTreesEnsembleResource::IsTreePostPruned(
219     const int32 tree_id) const {
220   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
221   return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
222          0;
223 }
224 
SetIsFinalized(const int32 tree_id,const bool is_finalized)225 void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
226                                                   const bool is_finalized) {
227   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
228   return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
229       is_finalized);
230 }
231 
232 // Sets the weight of i'th tree.
SetTreeWeight(const int32 tree_id,const float weight)233 void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
234                                                  const float weight) {
235   DCHECK_GE(tree_id, 0);
236   DCHECK_LT(tree_id, num_trees());
237   tree_ensemble_->set_tree_weights(tree_id, weight);
238 }
239 
UpdateGrowingMetadata() const240 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
241   tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
242       tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
243 
244   const int n_trees = num_trees();
245 
246   if (n_trees <= 0 ||
247       // Checks if we are building the first layer of the dummy empty tree
248       ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
249        (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
250     tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
251         tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
252   }
253 }
254 
255 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight,const int32 logits_dimension)256 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
257                                                const int32 logits_dimension) {
258   const std::vector<float> empty_leaf(logits_dimension);
259   return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
260 }
261 
AddNewTreeWithLogits(const float weight,const std::vector<float> & logits,const int32 logits_dimension)262 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
263     const float weight, const std::vector<float>& logits,
264     const int32 logits_dimension) {
265   const int32 new_tree_id = tree_ensemble_->trees_size();
266   auto* node = tree_ensemble_->add_trees()->add_nodes();
267   if (logits_dimension == 1) {
268     node->mutable_leaf()->set_scalar(logits[0]);
269   } else {
270     for (int32 i = 0; i < logits_dimension; ++i) {
271       node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
272     }
273   }
274   tree_ensemble_->add_tree_weights(weight);
275   tree_ensemble_->add_tree_metadata();
276 
277   return new_tree_id;
278 }
279 
AddBucketizedSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)280 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
281     const int32 tree_id,
282     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
283     const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
284   const auto candidate = split_entry.second;
285   auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
286                             left_node_id, right_node_id);
287   auto* new_split = node->mutable_bucketized_split();
288   new_split->set_feature_id(candidate.feature_id);
289   new_split->set_threshold(candidate.threshold);
290   new_split->set_dimension_id(candidate.dimension_id);
291   new_split->set_left_id(*left_node_id);
292   new_split->set_right_id(*right_node_id);
293 
294   boosted_trees::SplitTypeWithDefault split_type_with_default;
295   bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
296       candidate.split_type, &split_type_with_default);
297   DCHECK(parsed);
298   if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
299     new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
300   } else {
301     new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
302   }
303 }
304 
AddCategoricalSplitNode(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)305 void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
306     const int32 tree_id,
307     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
308     const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
309   const auto candidate = split_entry.second;
310   auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
311                             left_node_id, right_node_id);
312   auto* new_split = node->mutable_categorical_split();
313   new_split->set_feature_id(candidate.feature_id);
314   new_split->set_value(candidate.threshold);
315   new_split->set_dimension_id(candidate.dimension_id);
316   new_split->set_left_id(*left_node_id);
317   new_split->set_right_id(*right_node_id);
318 }
319 
AddLeafNodes(const int32 tree_id,const std::pair<int32,boosted_trees::SplitCandidate> & split_entry,const int32 logits_dimension,int32 * left_node_id,int32 * right_node_id)320 boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
321     const int32 tree_id,
322     const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
323     const int32 logits_dimension, int32* left_node_id, int32* right_node_id) {
324   auto* tree = tree_ensemble_->mutable_trees(tree_id);
325   const auto node_id = split_entry.first;
326   const auto candidate = split_entry.second;
327   auto* node = tree->mutable_nodes(node_id);
328   DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
329   *left_node_id = tree->nodes_size();
330   *right_node_id = *left_node_id + 1;
331   auto* left_node = tree->add_nodes();
332   auto* right_node = tree->add_nodes();
333   const bool has_leaf_value =
334       node->has_leaf() &&
335       ((logits_dimension == 1 && (node->leaf().scalar() != 0)) ||
336        node->leaf().has_vector());
337   if (node_id != 0 || has_leaf_value) {
338     // Save previous leaf value if it is not the first leaf in the tree.
339     node->mutable_metadata()->mutable_original_leaf()->Swap(
340         node->mutable_leaf());
341   }
342   node->mutable_metadata()->set_gain(candidate.gain);
343   // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
344   if (logits_dimension == 1) {
345     const float prev_logit_value = node->metadata().original_leaf().scalar();
346     left_node->mutable_leaf()->set_scalar(prev_logit_value +
347                                           candidate.left_node_contribs[0]);
348     right_node->mutable_leaf()->set_scalar(prev_logit_value +
349                                            candidate.right_node_contribs[0]);
350   } else {
351     if (has_leaf_value) {
352       DCHECK_EQ(logits_dimension,
353                 node->metadata().original_leaf().vector().value_size());
354     }
355     float prev_logit_value = 0.0;
356     for (int32 i = 0; i < logits_dimension; ++i) {
357       if (has_leaf_value) {
358         prev_logit_value = node->metadata().original_leaf().vector().value(i);
359       }
360       left_node->mutable_leaf()->mutable_vector()->add_value(
361           prev_logit_value + candidate.left_node_contribs[i]);
362       right_node->mutable_leaf()->mutable_vector()->add_value(
363           prev_logit_value + candidate.right_node_contribs[i]);
364     }
365   }
366   return node;
367 }
368 
Reset()369 void BoostedTreesEnsembleResource::Reset() {
370   // Reset stamp.
371   set_stamp(-1);
372 
373   // Clear tree ensemle.
374   arena_.Reset();
375   CHECK_EQ(0, arena_.SpaceAllocated());
376   tree_ensemble_ =
377       protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
378 }
379 
PostPruneTree(const int32 current_tree,const int32 logits_dimension)380 void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree,
381                                                  const int32 logits_dimension) {
382   // No-op if tree is empty.
383   auto* tree = tree_ensemble_->mutable_trees(current_tree);
384   int32 num_nodes = tree->nodes_size();
385   if (num_nodes == 0) {
386     return;
387   }
388 
389   std::vector<int32> nodes_to_delete;
390   // If a node was pruned, we need to save the change of the prediction from
391   // this node to its parent, as well as the parent id.
392   std::vector<std::pair<int32, std::vector<float>>> nodes_changes;
393   nodes_changes.reserve(num_nodes);
394   for (int32 i = 0; i < num_nodes; ++i) {
395     std::vector<float> prune_logit_changes;
396     nodes_changes.emplace_back(i, prune_logit_changes);
397   }
398   // Prune the tree recursively starting from the root. Each node that has
399   // negative gain and only leaf children will be pruned recursively up from
400   // the bottom of the tree. This method returns the list of nodes pruned, and
401   // updates the nodes in the tree not to refer to those pruned nodes.
402   RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
403                                     &nodes_changes);
404 
405   if (nodes_to_delete.empty()) {
406     // No pruning happened, and no post-processing needed.
407     return;
408   }
409 
410   // Sort node ids so they are in asc order.
411   std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
412 
413   // We need to
414   // - update split left and right children ids with new indices
415   // - actually remove the nodes that need to be removed
416   // - save the information about pruned node so we could recover the
417   // predictions from cache. Build a map for old node index=>new node index.
418   // nodes_to_delete contains nodes who's indices should be skipped, in
419   // ascending order. Save the information about new indices into meta.
420   std::map<int32, int32> old_to_new_ids;
421   int32 new_index = 0;
422   int32 index_for_deleted = 0;
423   auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
424                               ->mutable_post_pruned_nodes_meta();
425 
426   for (int32 i = 0; i < num_nodes; ++i) {
427     if (index_for_deleted < nodes_to_delete.size() &&
428         i == nodes_to_delete[index_for_deleted]) {
429       // Node i will get removed,
430       ++index_for_deleted;
431       // Update meta info that will allow us to use cached predictions from
432       // those nodes.
433       int32 new_id;
434       std::vector<float> logit_changes;
435       logit_changes.reserve(logits_dimension);
436       CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_changes);
437       auto* meta = post_prune_meta->Add();
438       meta->set_new_node_id(old_to_new_ids[new_id]);
439       for (int32 j = 0; j < logits_dimension; ++j) {
440         meta->add_logit_change(logit_changes[j]);
441       }
442     } else {
443       old_to_new_ids[i] = new_index++;
444       auto* meta = post_prune_meta->Add();
445       // Update meta info that will allow us to use cached predictions from
446       // those nodes.
447       meta->set_new_node_id(old_to_new_ids[i]);
448       for (int32 i = 0; i < logits_dimension; ++i) {
449         meta->add_logit_change(0.0);
450       }
451     }
452   }
453   index_for_deleted = 0;
454   int32 i = 0;
455   protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
456   new_nodes.Reserve(old_to_new_ids.size());
457   for (auto node : *(tree->mutable_nodes())) {
458     if (index_for_deleted < nodes_to_delete.size() &&
459         i == nodes_to_delete[index_for_deleted]) {
460       ++index_for_deleted;
461       ++i;
462       continue;
463     } else {
464       if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
465         node.mutable_bucketized_split()->set_left_id(
466             old_to_new_ids[node.bucketized_split().left_id()]);
467         node.mutable_bucketized_split()->set_right_id(
468             old_to_new_ids[node.bucketized_split().right_id()]);
469       }
470       *new_nodes.Add() = std::move(node);
471     }
472     ++i;
473   }
474   // Replace all the nodes in a tree with the ones we keep.
475   *tree->mutable_nodes() = std::move(new_nodes);
476 
477   // Note that if the whole tree got pruned, we will end up with one node.
478   // We can't remove that tree because it will cause problems with cache.
479 }
480 
GetPostPruneCorrection(const int32 tree_id,const int32 initial_node_id,int32 * current_node_id,std::vector<float> * logit_updates) const481 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
482     const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
483     std::vector<float>* logit_updates) const {
484   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
485   if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
486     DCHECK_LT(
487         initial_node_id,
488         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
489     const auto& meta =
490         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
491             initial_node_id);
492     *current_node_id = meta.new_node_id();
493     DCHECK_EQ(meta.logit_change().size(), logit_updates->size());
494     for (int32 i = 0; i < meta.logit_change().size(); ++i) {
495       logit_updates->at(i) = meta.logit_change(i);
496     }
497   }
498 }
499 
IsTerminalSplitNode(const int32 tree_id,const int32 node_id) const500 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
501     const int32 tree_id, const int32 node_id) const {
502   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
503   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
504   const int32 left_id = node.bucketized_split().left_id();
505   const int32 right_id = node.bucketized_split().right_id();
506   return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
507 }
508 
509 // For each pruned node, finds the leaf where it finally ended up and
510 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32 start_node_id,const std::vector<std::pair<int32,std::vector<float>>> & nodes_changes,int32 * parent_id,std::vector<float> * changes) const511 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
512     const int32 start_node_id,
513     const std::vector<std::pair<int32, std::vector<float>>>& nodes_changes,
514     int32* parent_id, std::vector<float>* changes) const {
515   const int logits_dimension = nodes_changes[start_node_id].second.size();
516   for (int i = 0; i < logits_dimension; ++i) {
517     changes->emplace_back(0.0);
518   }
519   int32 node_id = start_node_id;
520   int32 parent = nodes_changes[node_id].first;
521   while (parent != node_id) {
522     for (int i = 0; i < logits_dimension; ++i) {
523       changes->at(i) += nodes_changes[node_id].second[i];
524     }
525     node_id = parent;
526     parent = nodes_changes[node_id].first;
527   }
528   *parent_id = parent;
529 }
530 
RecursivelyDoPostPrunePreparation(const int32 tree_id,const int32 node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,std::vector<float>>> * nodes_meta)531 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
532     const int32 tree_id, const int32 node_id,
533     std::vector<int32>* nodes_to_delete,
534     std::vector<std::pair<int32, std::vector<float>>>* nodes_meta) {
535   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
536   DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
537   // Base case when we reach a leaf.
538   if (node->node_case() == boosted_trees::Node::kLeaf) {
539     return;
540   }
541 
542   // Traverse node children first and recursively prune their sub-trees.
543   RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
544                                     nodes_to_delete, nodes_meta);
545   RecursivelyDoPostPrunePreparation(tree_id,
546                                     node->bucketized_split().right_id(),
547                                     nodes_to_delete, nodes_meta);
548 
549   // Two conditions must be satisfied to prune the node:
550   // 1- The split gain is negative.
551   // 2- After depth-first pruning, the node only has leaf children.
552   const auto& node_metadata = node->metadata();
553   if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
554     const int32 left_id = node->bucketized_split().left_id();
555     const int32 right_id = node->bucketized_split().right_id();
556 
557     // Save children that need to be deleted.
558     nodes_to_delete->push_back(left_id);
559     nodes_to_delete->push_back(right_id);
560 
561     // Change node back into leaf.
562     *node->mutable_leaf() = node_metadata.original_leaf();
563 
564     // Save the old values of weights of children.
565     nodes_meta->at(left_id).first = node_id;
566     nodes_meta->at(right_id).first = node_id;
567     const auto& left_child_values = node_value(tree_id, left_id);
568     const auto& right_child_values = node_value(tree_id, right_id);
569     std::vector<float> parent_values(left_child_values.size(), 0.0);
570     if (node_metadata.has_original_leaf()) {
571       parent_values = node_value(tree_id, node_id);
572     }
573     for (int32 i = 0; i < parent_values.size(); ++i) {
574       nodes_meta->at(left_id).second.emplace_back(parent_values[i] -
575                                                   left_child_values[i]);
576       nodes_meta->at(right_id).second.emplace_back(parent_values[i] -
577                                                    right_child_values[i]);
578     }
579     // Clear gain for leaf node.
580     node->clear_metadata();
581   }
582 }
583 
584 }  // namespace tensorflow
585