1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/boosted_trees/resources.h"
17 #include "tensorflow/core/framework/resource_mgr.h"
18 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
19 #include "tensorflow/core/platform/mutex.h"
20 #include "tensorflow/core/platform/protobuf.h"
21
22 namespace tensorflow {
23
24 // Constructor.
BoostedTreesEnsembleResource()25 BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
26 : tree_ensemble_(
27 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
28 &arena_)) {}
29
DebugString() const30 string BoostedTreesEnsembleResource::DebugString() const {
31 return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
32 "]");
33 }
34
InitFromSerialized(const string & serialized,const int64 stamp_token)35 bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
36 const int64 stamp_token) {
37 CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
38 if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
39 set_stamp(stamp_token);
40 return true;
41 }
42 return false;
43 }
44
SerializeAsString() const45 string BoostedTreesEnsembleResource::SerializeAsString() const {
46 return tree_ensemble_->SerializeAsString();
47 }
48
num_trees() const49 int32 BoostedTreesEnsembleResource::num_trees() const {
50 return tree_ensemble_->trees_size();
51 }
52
next_node(const int32 tree_id,const int32 node_id,const int32 index_in_batch,const std::vector<TTypes<int32>::ConstVec> & bucketized_features) const53 int32 BoostedTreesEnsembleResource::next_node(
54 const int32 tree_id, const int32 node_id, const int32 index_in_batch,
55 const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
56 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
57 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
58 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
59
60 switch (node.node_case()) {
61 case boosted_trees::Node::kBucketizedSplit: {
62 const auto& split = node.bucketized_split();
63 return (bucketized_features[split.feature_id()](index_in_batch) <=
64 split.threshold())
65 ? split.left_id()
66 : split.right_id();
67 }
68 case boosted_trees::Node::kCategoricalSplit: {
69 const auto& split = node.categorical_split();
70 return (bucketized_features[split.feature_id()](index_in_batch) ==
71 split.value())
72 ? split.left_id()
73 : split.right_id();
74 }
75 default:
76 DCHECK(false) << "Node type " << node.node_case() << " not supported.";
77 }
78 return -1;
79 }
80
node_value(const int32 tree_id,const int32 node_id) const81 std::vector<float> BoostedTreesEnsembleResource::node_value(
82 const int32 tree_id, const int32 node_id) const {
83 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
84 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
85 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
86 if (node.node_case() == boosted_trees::Node::kLeaf) {
87 // TODO(crawles): only use vector leaf even if # logits=1.
88 if (node.leaf().has_vector()) {
89 std::vector<float> leaf_values;
90 const auto& leaf_value_vector = node.leaf().vector();
91 const int size = leaf_value_vector.value_size();
92 leaf_values.reserve(size);
93 for (int i = 0; i < size; ++i) {
94 leaf_values.push_back(leaf_value_vector.value(i));
95 }
96 return leaf_values;
97 } else {
98 return {node.leaf().scalar()};
99 }
100 } else {
101 if (node.metadata().original_leaf().has_vector()) {
102 std::vector<float> node_values;
103 const auto& leaf_value_vector = node.metadata().original_leaf().vector();
104 const int size = leaf_value_vector.value_size();
105 node_values.reserve(size);
106 for (int i = 0; i < size; ++i) {
107 node_values.push_back(leaf_value_vector.value(i));
108 }
109 return node_values;
110 } else {
111 return {node.metadata().original_leaf().scalar()};
112 }
113 }
114 }
115
set_node_value(const int32 tree_id,const int32 node_id,const float logits)116 void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
117 const int32 node_id,
118 const float logits) {
119 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
120 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
121 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
122 DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
123 node->mutable_leaf()->set_scalar(logits);
124 }
125
GetNumLayersGrown(const int32 tree_id) const126 int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
127 const int32 tree_id) const {
128 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
129 return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
130 }
131
SetNumLayersGrown(const int32 tree_id,int32 new_num_layers) const132 void BoostedTreesEnsembleResource::SetNumLayersGrown(
133 const int32 tree_id, int32 new_num_layers) const {
134 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
135 tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
136 new_num_layers);
137 }
138
UpdateLastLayerNodesRange(const int32 node_range_start,int32 node_range_end) const139 void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
140 const int32 node_range_start, int32 node_range_end) const {
141 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
142 node_range_start);
143 tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
144 node_range_end);
145 }
146
GetLastLayerNodesRange(int32 * node_range_start,int32 * node_range_end) const147 void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
148 int32* node_range_start, int32* node_range_end) const {
149 *node_range_start =
150 tree_ensemble_->growing_metadata().last_layer_node_start();
151 *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
152 }
153
GetNumNodes(const int32 tree_id)154 int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
155 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
156 return tree_ensemble_->trees(tree_id).nodes_size();
157 }
158
GetNumLayersAttempted()159 int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
160 return tree_ensemble_->growing_metadata().num_layers_attempted();
161 }
162
is_leaf(const int32 tree_id,const int32 node_id) const163 bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
164 const int32 node_id) const {
165 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
166 DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
167 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
168 return node.node_case() == boosted_trees::Node::kLeaf;
169 }
170
feature_id(const int32 tree_id,const int32 node_id) const171 int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
172 const int32 node_id) const {
173 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
174 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
175 return node.bucketized_split().feature_id();
176 }
177
bucket_threshold(const int32 tree_id,const int32 node_id) const178 int32 BoostedTreesEnsembleResource::bucket_threshold(
179 const int32 tree_id, const int32 node_id) const {
180 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
181 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
182 return node.bucketized_split().threshold();
183 }
184
left_id(const int32 tree_id,const int32 node_id) const185 int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
186 const int32 node_id) const {
187 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
188 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
189 return node.bucketized_split().left_id();
190 }
191
right_id(const int32 tree_id,const int32 node_id) const192 int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
193 const int32 node_id) const {
194 const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
195 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
196 return node.bucketized_split().right_id();
197 }
198
GetTreeWeights() const199 std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
200 return {tree_ensemble_->tree_weights().begin(),
201 tree_ensemble_->tree_weights().end()};
202 }
203
GetTreeWeight(const int32 tree_id) const204 float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
205 return tree_ensemble_->tree_weights(tree_id);
206 }
207
IsTreeFinalized(const int32 tree_id) const208 float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
209 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
210 return tree_ensemble_->tree_metadata(tree_id).is_finalized();
211 }
212
IsTreePostPruned(const int32 tree_id) const213 float BoostedTreesEnsembleResource::IsTreePostPruned(
214 const int32 tree_id) const {
215 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
216 return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
217 0;
218 }
219
SetIsFinalized(const int32 tree_id,const bool is_finalized)220 void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
221 const bool is_finalized) {
222 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
223 return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
224 is_finalized);
225 }
226
227 // Sets the weight of i'th tree.
SetTreeWeight(const int32 tree_id,const float weight)228 void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
229 const float weight) {
230 DCHECK_GE(tree_id, 0);
231 DCHECK_LT(tree_id, num_trees());
232 tree_ensemble_->set_tree_weights(tree_id, weight);
233 }
234
UpdateGrowingMetadata() const235 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
236 tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
237 tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
238
239 const int n_trees = num_trees();
240
241 if (n_trees <= 0 ||
242 // Checks if we are building the first layer of the dummy empty tree
243 ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
244 (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
245 tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
246 tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
247 }
248 }
249
250 // Add a tree to the ensemble and returns a new tree_id.
AddNewTree(const float weight)251 int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
252 return AddNewTreeWithLogits(weight, 0.0);
253 }
254
AddNewTreeWithLogits(const float weight,const float logits)255 int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
256 const float logits) {
257 const int32 new_tree_id = tree_ensemble_->trees_size();
258 auto* node = tree_ensemble_->add_trees()->add_nodes();
259 node->mutable_leaf()->set_scalar(logits);
260 tree_ensemble_->add_tree_weights(weight);
261 tree_ensemble_->add_tree_metadata();
262
263 return new_tree_id;
264 }
265
AddBucketizedSplitNode(const int32 tree_id,const int32 node_id,const int32 feature_id,const int32 threshold,const float gain,const float left_contrib,const float right_contrib,int32 * left_node_id,int32 * right_node_id)266 void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
267 const int32 tree_id, const int32 node_id, const int32 feature_id,
268 const int32 threshold, const float gain, const float left_contrib,
269 const float right_contrib, int32* left_node_id, int32* right_node_id) {
270 auto* tree = tree_ensemble_->mutable_trees(tree_id);
271 auto* node = tree->mutable_nodes(node_id);
272 DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
273 float prev_node_value = node->leaf().scalar();
274 *left_node_id = tree->nodes_size();
275 *right_node_id = *left_node_id + 1;
276 auto* left_node = tree->add_nodes();
277 auto* right_node = tree->add_nodes();
278 if (node_id != 0 || (node->has_leaf() && node->leaf().scalar() != 0)) {
279 // Save previous leaf value if it is not the first leaf in the tree.
280 node->mutable_metadata()->mutable_original_leaf()->Swap(
281 node->mutable_leaf());
282 }
283 node->mutable_metadata()->set_gain(gain);
284 auto* new_split = node->mutable_bucketized_split();
285 new_split->set_feature_id(feature_id);
286 new_split->set_threshold(threshold);
287 new_split->set_left_id(*left_node_id);
288 new_split->set_right_id(*right_node_id);
289 // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
290 left_node->mutable_leaf()->set_scalar(prev_node_value + left_contrib);
291 right_node->mutable_leaf()->set_scalar(prev_node_value + right_contrib);
292 }
293
Reset()294 void BoostedTreesEnsembleResource::Reset() {
295 // Reset stamp.
296 set_stamp(-1);
297
298 // Clear tree ensemle.
299 arena_.Reset();
300 CHECK_EQ(0, arena_.SpaceAllocated());
301 tree_ensemble_ =
302 protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
303 }
304
PostPruneTree(const int32 current_tree)305 void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree) {
306 // No-op if tree is empty.
307 auto* tree = tree_ensemble_->mutable_trees(current_tree);
308 int32 num_nodes = tree->nodes_size();
309 if (num_nodes == 0) {
310 return;
311 }
312
313 std::vector<int32> nodes_to_delete;
314 // If a node was pruned, we need to save the change of the prediction from
315 // this node to its parent, as well as the parent id.
316 std::vector<std::pair<int32, float>> nodes_changes;
317 nodes_changes.reserve(num_nodes);
318 for (int32 i = 0; i < num_nodes; ++i) {
319 nodes_changes.emplace_back(i, 0.0);
320 }
321 // Prune the tree recursively starting from the root. Each node that has
322 // negative gain and only leaf children will be pruned recursively up from
323 // the bottom of the tree. This method returns the list of nodes pruned, and
324 // updates the nodes in the tree not to refer to those pruned nodes.
325 RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
326 &nodes_changes);
327
328 if (nodes_to_delete.empty()) {
329 // No pruning happened, and no post-processing needed.
330 return;
331 }
332
333 // Sort node ids so they are in asc order.
334 std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
335
336 // We need to
337 // - update split left and right children ids with new indices
338 // - actually remove the nodes that need to be removed
339 // - save the information about pruned node so we could recover the
340 // predictions from cache. Build a map for old node index=>new node index.
341 // nodes_to_delete contains nodes who's indices should be skipped, in
342 // ascending order. Save the information about new indices into meta.
343 std::map<int32, int32> old_to_new_ids;
344 int32 new_index = 0;
345 int32 index_for_deleted = 0;
346 auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
347 ->mutable_post_pruned_nodes_meta();
348
349 for (int32 i = 0; i < num_nodes; ++i) {
350 if (index_for_deleted < nodes_to_delete.size() &&
351 i == nodes_to_delete[index_for_deleted]) {
352 // Node i will get removed,
353 ++index_for_deleted;
354 // Update meta info that will allow us to use cached predictions from
355 // those nodes.
356 int32 new_id;
357 float logit_change;
358 CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_change);
359 auto* meta = post_prune_meta->Add();
360 meta->set_new_node_id(old_to_new_ids[new_id]);
361 meta->set_logit_change(logit_change);
362 } else {
363 old_to_new_ids[i] = new_index++;
364 auto* meta = post_prune_meta->Add();
365 // Update meta info that will allow us to use cached predictions from
366 // those nodes.
367 meta->set_new_node_id(old_to_new_ids[i]);
368 meta->set_logit_change(0.0);
369 }
370 }
371 index_for_deleted = 0;
372 int32 i = 0;
373 protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
374 new_nodes.Reserve(old_to_new_ids.size());
375 for (auto node : *(tree->mutable_nodes())) {
376 if (index_for_deleted < nodes_to_delete.size() &&
377 i == nodes_to_delete[index_for_deleted]) {
378 ++index_for_deleted;
379 ++i;
380 continue;
381 } else {
382 if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
383 node.mutable_bucketized_split()->set_left_id(
384 old_to_new_ids[node.bucketized_split().left_id()]);
385 node.mutable_bucketized_split()->set_right_id(
386 old_to_new_ids[node.bucketized_split().right_id()]);
387 }
388 *new_nodes.Add() = std::move(node);
389 }
390 ++i;
391 }
392 // Replace all the nodes in a tree with the ones we keep.
393 *tree->mutable_nodes() = std::move(new_nodes);
394
395 // Note that if the whole tree got pruned, we will end up with one node.
396 // We can't remove that tree because it will cause problems with cache.
397 }
398
GetPostPruneCorrection(const int32 tree_id,const int32 initial_node_id,int32 * current_node_id,float * logit_update) const399 void BoostedTreesEnsembleResource::GetPostPruneCorrection(
400 const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
401 float* logit_update) const {
402 DCHECK_LT(tree_id, tree_ensemble_->trees_size());
403 if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
404 DCHECK_LT(
405 initial_node_id,
406 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
407 const auto& meta =
408 tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
409 initial_node_id);
410 *current_node_id = meta.new_node_id();
411 *logit_update += meta.logit_change();
412 }
413 }
414
IsTerminalSplitNode(const int32 tree_id,const int32 node_id) const415 bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
416 const int32 tree_id, const int32 node_id) const {
417 const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
418 DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
419 const int32 left_id = node.bucketized_split().left_id();
420 const int32 right_id = node.bucketized_split().right_id();
421 return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
422 }
423
424 // For each pruned node, finds the leaf where it finally ended up and
425 // calculates the total update from that pruned node prediction.
CalculateParentAndLogitUpdate(const int32 start_node_id,const std::vector<std::pair<int32,float>> & nodes_change,int32 * parent_id,float * change) const426 void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
427 const int32 start_node_id,
428 const std::vector<std::pair<int32, float>>& nodes_change, int32* parent_id,
429 float* change) const {
430 *change = 0.0;
431 int32 node_id = start_node_id;
432 int32 parent = nodes_change[node_id].first;
433
434 while (parent != node_id) {
435 (*change) += nodes_change[node_id].second;
436 node_id = parent;
437 parent = nodes_change[node_id].first;
438 }
439 *parent_id = parent;
440 }
441
RecursivelyDoPostPrunePreparation(const int32 tree_id,const int32 node_id,std::vector<int32> * nodes_to_delete,std::vector<std::pair<int32,float>> * nodes_meta)442 void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
443 const int32 tree_id, const int32 node_id,
444 std::vector<int32>* nodes_to_delete,
445 std::vector<std::pair<int32, float>>* nodes_meta) {
446 auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
447 DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
448 // Base case when we reach a leaf.
449 if (node->node_case() == boosted_trees::Node::kLeaf) {
450 return;
451 }
452
453 // Traverse node children first and recursively prune their sub-trees.
454 RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
455 nodes_to_delete, nodes_meta);
456 RecursivelyDoPostPrunePreparation(tree_id,
457 node->bucketized_split().right_id(),
458 nodes_to_delete, nodes_meta);
459
460 // Two conditions must be satisfied to prune the node:
461 // 1- The split gain is negative.
462 // 2- After depth-first pruning, the node only has leaf children.
463 const auto& node_metadata = node->metadata();
464 if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
465 const int32 left_id = node->bucketized_split().left_id();
466 const int32 right_id = node->bucketized_split().right_id();
467
468 // Save children that need to be deleted.
469 nodes_to_delete->push_back(left_id);
470 nodes_to_delete->push_back(right_id);
471
472 // Change node back into leaf.
473 *node->mutable_leaf() = node_metadata.original_leaf();
474 const auto& parent_values = node_value(tree_id, node_id);
475 DCHECK_EQ(parent_values.size(), 1);
476 const float parent_value = parent_values[0];
477
478 // Save the old values of weights of children.
479 (*nodes_meta)[left_id].first = node_id;
480 (*nodes_meta)[left_id].second =
481 parent_value - node_value(tree_id, left_id)[0];
482
483 (*nodes_meta)[right_id].first = node_id;
484 (*nodes_meta)[right_id].second =
485 parent_value - node_value(tree_id, right_id)[0];
486
487 // Clear gain for leaf node.
488 node->clear_metadata();
489 }
490 }
491
492 } // namespace tensorflow
493