• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17 #include "tensorflow/core/framework/resource_mgr.h"
18 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
19 #include "tensorflow/core/platform/mutex.h"
20 #include "tensorflow/core/platform/protobuf.h"
21 
22 namespace tensorflow {
23 
24 // Constructor.
BoostedTreesEnsembleResource()25 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
26     : tree_ensemble_(
27           protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
28               &arena_)) {}
29 
DebugString() const30 string BoostedTreesEnsembleResource::DebugString() const {
31   return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
32                          "]");
33 }
34 
InitFromSerialized(const string & serialized,const int64 stamp_token)35 bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
36                                                       const int64 stamp_token) {
37   CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
38   if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
39     set_stamp(stamp_token);
40     return true;
41   }
42   return false;
43 }
44 
SerializeAsString() const45 string BoostedTreesEnsembleResource::SerializeAsString() const {
46   return tree_ensemble_->SerializeAsString();
47 }
48 
num_trees() const49 int32 BoostedTreesEnsembleResource::num_trees() const {
50   return tree_ensemble_->trees_size();
51 }
52 
next_node(const int32 tree_id,const int32 node_id,const int32 index_in_batch,const std::vector<TTypes<int32>::ConstVec> & bucketized_features) const53 int32 BoostedTreesEnsembleResource::next_node(
54     const int32 tree_id, const int32 node_id, const int32 index_in_batch,
55     const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
56   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
57   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
58   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
59 
60   switch (node.node_case()) {
61     case boosted_trees::Node::kBucketizedSplit: {
62       const auto& split = node.bucketized_split();
63       return (bucketized_features[split.feature_id()](index_in_batch) <=
64               split.threshold())
65                  ? split.left_id()
66                  : split.right_id();
67     }
68     case boosted_trees::Node::kCategoricalSplit: {
69       const auto& split = node.categorical_split();
70       return (bucketized_features[split.feature_id()](index_in_batch) ==
71               split.value())
72                  ? split.left_id()
73                  : split.right_id();
74     }
75     default:
76       DCHECK(false) << "Node type " << node.node_case() << " not supported.";
77   }
78   return -1;
79 }
80 
node_value(const int32 tree_id,const int32 node_id) const81 std::vector<float> BoostedTreesEnsembleResource::node_value(
82     const int32 tree_id, const int32 node_id) const {
83   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
84   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
85   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
86   if (node.node_case() == boosted_trees::Node::kLeaf) {
87     // TODO(crawles): only use vector leaf even if # logits=1.
88     if (node.leaf().has_vector()) {
89       std::vector<float> leaf_values;
90       const auto& leaf_value_vector = node.leaf().vector();
91       const int size = leaf_value_vector.value_size();
92       leaf_values.reserve(size);
93       for (int i = 0; i < size; ++i) {
94         leaf_values.push_back(leaf_value_vector.value(i));
95       }
96       return leaf_values;
97     } else {
98       return {node.leaf().scalar()};
99     }
100   } else {
101     if (node.metadata().original_leaf().has_vector()) {
102       std::vector<float> node_values;
103       const auto& leaf_value_vector = node.metadata().original_leaf().vector();
104       const int size = leaf_value_vector.value_size();
105       node_values.reserve(size);
106       for (int i = 0; i < size; ++i) {
107         node_values.push_back(leaf_value_vector.value(i));
108       }
109       return node_values;
110     } else {
111       return {node.metadata().original_leaf().scalar()};
112     }
113   }
114 }
115 
set_node_value(const int32 tree_id,const int32 node_id,const float logits)116 void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
117                                                   const int32 node_id,
118                                                   const float logits) {
119   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
120   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
121   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
122   DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
123   node->mutable_leaf()->set_scalar(logits);
124 }
125 
GetNumLayersGrown(const int32 tree_id) const126 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
127     const int32 tree_id) const {
128   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
129   return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
130 }
131 
SetNumLayersGrown(const int32 tree_id,int32 new_num_layers) const132 void BoostedTreesEnsembleResource::SetNumLayersGrown(
133     const int32 tree_id, int32 new_num_layers) const {
134   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
135   tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
136       new_num_layers);
137 }
138 
UpdateLastLayerNodesRange(const int32 node_range_start,int32 node_range_end) const139 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
140     const int32 node_range_start, int32 node_range_end) const {
141   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
142       node_range_start);
143   tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
144       node_range_end);
145 }
146 
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const147 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
148     int32* node_range_start, int32* node_range_end) const {
149   *node_range_start =
150       tree_ensemble_->growing_metadata().last_layer_node_start();
151   *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
152 }
153 
GetNumNodes(const int32 tree_id)154 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
155   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
156   return tree_ensemble_->trees(tree_id).nodes_size();
157 }
158 
GetNumLayersAttempted()159 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
160   return tree_ensemble_->growing_metadata().num_layers_attempted();
161 }
162 
is_leaf(const int32 tree_id,const int32 node_id) const163 bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
164                                            const int32 node_id) const {
165   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
166   DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
167   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
168   return node.node_case() == boosted_trees::Node::kLeaf;
169 }
170 
feature_id(const int32 tree_id,const int32 node_id) const171 int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
172                                                const int32 node_id) const {
173   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
174   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
175   return node.bucketized_split().feature_id();
176 }
177 
bucket_threshold(const int32 tree_id,const int32 node_id) const178 int32 BoostedTreesEnsembleResource::bucket_threshold(
179     const int32 tree_id, const int32 node_id) const {
180   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
181   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
182   return node.bucketized_split().threshold();
183 }
184 
left_id(const int32 tree_id,const int32 node_id) const185 int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
186                                             const int32 node_id) const {
187   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
188   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
189   return node.bucketized_split().left_id();
190 }
191 
right_id(const int32 tree_id,const int32 node_id) const192 int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
193                                              const int32 node_id) const {
194   const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
195   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
196   return node.bucketized_split().right_id();
197 }
198 
GetTreeWeights() const199 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
200   return {tree_ensemble_->tree_weights().begin(),
201           tree_ensemble_->tree_weights().end()};
202 }
203 
GetTreeWeight(const int32 tree_id) const204 float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
205   return tree_ensemble_->tree_weights(tree_id);
206 }
207 
IsTreeFinalized(const int32 tree_id) const208 float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
209   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
210   return tree_ensemble_->tree_metadata(tree_id).is_finalized();
211 }
212 
IsTreePostPruned(const int32 tree_id) const213 float BoostedTreesEnsembleResource::IsTreePostPruned(
214     const int32 tree_id) const {
215   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
216   return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
217          0;
218 }
219 
SetIsFinalized(const int32 tree_id,const bool is_finalized)220 void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
221                                                   const bool is_finalized) {
222   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
223   return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
224       is_finalized);
225 }
226 
227 // Sets the weight of i'th tree.
SetTreeWeight(const int32 tree_id,const float weight)228 void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
229                                                  const float weight) {
230   DCHECK_GE(tree_id, 0);
231   DCHECK_LT(tree_id, num_trees());
232   tree_ensemble_->set_tree_weights(tree_id, weight);
233 }
234 
UpdateGrowingMetadata() const235 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
236   tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
237       tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
238 
239   const int n_trees = num_trees();
240 
241   if (n_trees <= 0 ||
242       // Checks if we are building the first layer of the dummy empty tree
243       ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
244        (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
245     tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
246         tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
247   }
248 }
249 
250 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight)251 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
252   return AddNewTreeWithLogits(weight, 0.0);
253 }
254 
AddNewTreeWithLogits(const float weight,const float logits)255 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
256                                                          const float logits) {
257   const int32 new_tree_id = tree_ensemble_->trees_size();
258   auto* node = tree_ensemble_->add_trees()->add_nodes();
259   node->mutable_leaf()->set_scalar(logits);
260   tree_ensemble_->add_tree_weights(weight);
261   tree_ensemble_->add_tree_metadata();
262 
263   return new_tree_id;
264 }
265 
AddBucketizedSplitNode(const int32 tree_id,const int32 node_id,const int32 feature_id,const int32 threshold,const float gain,const float left_contrib,const float right_contrib,int32 * left_node_id,int32 * right_node_id)266 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
267     const int32 tree_id, const int32 node_id, const int32 feature_id,
268     const int32 threshold, const float gain, const float left_contrib,
269     const float right_contrib, int32* left_node_id, int32* right_node_id) {
270   auto* tree = tree_ensemble_->mutable_trees(tree_id);
271   auto* node = tree->mutable_nodes(node_id);
272   DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
273   float prev_node_value = node->leaf().scalar();
274   *left_node_id = tree->nodes_size();
275   *right_node_id = *left_node_id + 1;
276   auto* left_node = tree->add_nodes();
277   auto* right_node = tree->add_nodes();
278   if (node_id != 0 || (node->has_leaf() && node->leaf().scalar() != 0)) {
279     // Save previous leaf value if it is not the first leaf in the tree.
280     node->mutable_metadata()->mutable_original_leaf()->Swap(
281         node->mutable_leaf());
282   }
283   node->mutable_metadata()->set_gain(gain);
284   auto* new_split = node->mutable_bucketized_split();
285   new_split->set_feature_id(feature_id);
286   new_split->set_threshold(threshold);
287   new_split->set_left_id(*left_node_id);
288   new_split->set_right_id(*right_node_id);
289   // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
290   left_node->mutable_leaf()->set_scalar(prev_node_value + left_contrib);
291   right_node->mutable_leaf()->set_scalar(prev_node_value + right_contrib);
292 }
293 
Reset()294 void BoostedTreesEnsembleResource::Reset() {
295   // Reset stamp.
296   set_stamp(-1);
297 
298   // Clear tree ensemle.
299   arena_.Reset();
300   CHECK_EQ(0, arena_.SpaceAllocated());
301   tree_ensemble_ =
302       protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
303 }
304 
PostPruneTree(const int32 current_tree)305 void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree) {
306   // No-op if tree is empty.
307   auto* tree = tree_ensemble_->mutable_trees(current_tree);
308   int32 num_nodes = tree->nodes_size();
309   if (num_nodes == 0) {
310     return;
311   }
312 
313   std::vector<int32> nodes_to_delete;
314   // If a node was pruned, we need to save the change of the prediction from
315   // this node to its parent, as well as the parent id.
316   std::vector<std::pair<int32, float>> nodes_changes;
317   nodes_changes.reserve(num_nodes);
318   for (int32 i = 0; i < num_nodes; ++i) {
319     nodes_changes.emplace_back(i, 0.0);
320   }
321   // Prune the tree recursively starting from the root. Each node that has
322   // negative gain and only leaf children will be pruned recursively up from
323   // the bottom of the tree. This method returns the list of nodes pruned, and
324   // updates the nodes in the tree not to refer to those pruned nodes.
325   RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
326                                     &nodes_changes);
327 
328   if (nodes_to_delete.empty()) {
329     // No pruning happened, and no post-processing needed.
330     return;
331   }
332 
333   // Sort node ids so they are in asc order.
334   std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
335 
336   // We need to
337   // - update split left and right children ids with new indices
338   // - actually remove the nodes that need to be removed
339   // - save the information about pruned node so we could recover the
340   // predictions from cache. Build a map for old node index=>new node index.
341   // nodes_to_delete contains nodes who's indices should be skipped, in
342   // ascending order. Save the information about new indices into meta.
343   std::map<int32, int32> old_to_new_ids;
344   int32 new_index = 0;
345   int32 index_for_deleted = 0;
346   auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
347                               ->mutable_post_pruned_nodes_meta();
348 
349   for (int32 i = 0; i < num_nodes; ++i) {
350     if (index_for_deleted < nodes_to_delete.size() &&
351         i == nodes_to_delete[index_for_deleted]) {
352       // Node i will get removed,
353       ++index_for_deleted;
354       // Update meta info that will allow us to use cached predictions from
355       // those nodes.
356       int32 new_id;
357       float logit_change;
358       CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_change);
359       auto* meta = post_prune_meta->Add();
360       meta->set_new_node_id(old_to_new_ids[new_id]);
361       meta->set_logit_change(logit_change);
362     } else {
363       old_to_new_ids[i] = new_index++;
364       auto* meta = post_prune_meta->Add();
365       // Update meta info that will allow us to use cached predictions from
366       // those nodes.
367       meta->set_new_node_id(old_to_new_ids[i]);
368       meta->set_logit_change(0.0);
369     }
370   }
371   index_for_deleted = 0;
372   int32 i = 0;
373   protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
374   new_nodes.Reserve(old_to_new_ids.size());
375   for (auto node : *(tree->mutable_nodes())) {
376     if (index_for_deleted < nodes_to_delete.size() &&
377         i == nodes_to_delete[index_for_deleted]) {
378       ++index_for_deleted;
379       ++i;
380       continue;
381     } else {
382       if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
383         node.mutable_bucketized_split()->set_left_id(
384             old_to_new_ids[node.bucketized_split().left_id()]);
385         node.mutable_bucketized_split()->set_right_id(
386             old_to_new_ids[node.bucketized_split().right_id()]);
387       }
388       *new_nodes.Add() = std::move(node);
389     }
390     ++i;
391   }
392   // Replace all the nodes in a tree with the ones we keep.
393   *tree->mutable_nodes() = std::move(new_nodes);
394 
395   // Note that if the whole tree got pruned, we will end up with one node.
396   // We can't remove that tree because it will cause problems with cache.
397 }
398 
GetPostPruneCorrection(const int32 tree_id,const int32 initial_node_id,int32 * current_node_id,float * logit_update) const399 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
400     const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
401     float* logit_update) const {
402   DCHECK_LT(tree_id, tree_ensemble_->trees_size());
403   if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
404     DCHECK_LT(
405         initial_node_id,
406         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
407     const auto& meta =
408         tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
409             initial_node_id);
410     *current_node_id = meta.new_node_id();
411     *logit_update += meta.logit_change();
412   }
413 }
414 
IsTerminalSplitNode(const int32 tree_id,const int32 node_id) const415 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
416     const int32 tree_id, const int32 node_id) const {
417   const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
418   DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
419   const int32 left_id = node.bucketized_split().left_id();
420   const int32 right_id = node.bucketized_split().right_id();
421   return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
422 }
423 
424 // For each pruned node, finds the leaf where it finally ended up and
425 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32 start_node_id,const std::vector<std::pair<int32,float>> & nodes_change,int32 * parent_id,float * change) const426 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
427     const int32 start_node_id,
428     const std::vector<std::pair<int32, float>>& nodes_change, int32* parent_id,
429     float* change) const {
430   *change = 0.0;
431   int32 node_id = start_node_id;
432   int32 parent = nodes_change[node_id].first;
433 
434   while (parent != node_id) {
435     (*change) += nodes_change[node_id].second;
436     node_id = parent;
437     parent = nodes_change[node_id].first;
438   }
439   *parent_id = parent;
440 }
441 
RecursivelyDoPostPrunePreparation(const int32 tree_id,const int32 node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,float>> * nodes_meta)442 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
443     const int32 tree_id, const int32 node_id,
444     std::vector<int32>* nodes_to_delete,
445     std::vector<std::pair<int32, float>>* nodes_meta) {
446   auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
447   DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
448   // Base case when we reach a leaf.
449   if (node->node_case() == boosted_trees::Node::kLeaf) {
450     return;
451   }
452 
453   // Traverse node children first and recursively prune their sub-trees.
454   RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
455                                     nodes_to_delete, nodes_meta);
456   RecursivelyDoPostPrunePreparation(tree_id,
457                                     node->bucketized_split().right_id(),
458                                     nodes_to_delete, nodes_meta);
459 
460   // Two conditions must be satisfied to prune the node:
461   // 1- The split gain is negative.
462   // 2- After depth-first pruning, the node only has leaf children.
463   const auto& node_metadata = node->metadata();
464   if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
465     const int32 left_id = node->bucketized_split().left_id();
466     const int32 right_id = node->bucketized_split().right_id();
467 
468     // Save children that need to be deleted.
469     nodes_to_delete->push_back(left_id);
470     nodes_to_delete->push_back(right_id);
471 
472     // Change node back into leaf.
473     *node->mutable_leaf() = node_metadata.original_leaf();
474     const auto& parent_values = node_value(tree_id, node_id);
475     DCHECK_EQ(parent_values.size(), 1);
476     const float parent_value = parent_values[0];
477 
478     // Save the old values of weights of children.
479     (*nodes_meta)[left_id].first = node_id;
480     (*nodes_meta)[left_id].second =
481         parent_value - node_value(tree_id, left_id)[0];
482 
483     (*nodes_meta)[right_id].first = node_id;
484     (*nodes_meta)[right_id].second =
485         parent_value - node_value(tree_id, right_id)[0];
486 
487     // Clear gain for leaf node.
488     node->clear_metadata();
489   }
490 }
491 
492 }  // namespace tensorflow
493