1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
16 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
17 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
18 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
19
20 namespace tensorflow {
21 namespace boosted_trees {
22 namespace models {
23
Predict(const boosted_trees::trees::DecisionTreeEnsembleConfig & config,const std::vector<int32> & trees_to_include,const boosted_trees::utils::BatchFeatures & features,tensorflow::thread::ThreadPool * const worker_threads,tensorflow::TTypes<float>::Matrix output_predictions,Tensor * const output_leaf_index)24 void MultipleAdditiveTrees::Predict(
25 const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
26 const std::vector<int32>& trees_to_include,
27 const boosted_trees::utils::BatchFeatures& features,
28 tensorflow::thread::ThreadPool* const worker_threads,
29 tensorflow::TTypes<float>::Matrix output_predictions,
30 Tensor* const output_leaf_index) {
31 // Zero out predictions as the model is additive.
32 output_predictions.setZero();
33
34 // Get batch size.
35 const int64 batch_size = features.batch_size();
36 if (batch_size <= 0) {
37 return;
38 }
39
40 // Lambda for doing a block of work.
41 auto update_predictions = [&config, &features, &trees_to_include,
42 &output_predictions,
43 &output_leaf_index](int64 start, int64 end) {
44 auto examples_iterable = features.examples_iterable(start, end);
45 Tensor dummy_tensor(DT_INT32, TensorShape({1, 1}));
46 tensorflow::TTypes<int>::Matrix output_leaf_index_mat =
47 output_leaf_index != nullptr ? output_leaf_index->matrix<int>()
48 : dummy_tensor.matrix<int>();
49 for (const auto& example : examples_iterable) {
50 for (const int32 tree_idx : trees_to_include) {
51 const boosted_trees::trees::DecisionTreeConfig& tree =
52 config.trees(tree_idx);
53 const float tree_weight = config.tree_weights(tree_idx);
54 const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
55 QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
56 // Checks if output leaf tree index is required.
57 if (output_leaf_index != nullptr) {
58 output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx;
59 }
60 const auto& leaf_node = tree.nodes(leaf_idx);
61 QCHECK(leaf_node.has_leaf())
62 << "Invalid leaf node: " << leaf_node.DebugString();
63 if (leaf_node.leaf().has_sparse_vector()) {
64 const auto& leaf = leaf_node.leaf().sparse_vector();
65 QCHECK_EQ(leaf.index_size(), leaf.value_size());
66 for (size_t logit_dim = 0; logit_dim < leaf.index_size();
67 ++logit_dim) {
68 const float value = tree_weight * leaf.value(logit_dim);
69 output_predictions(example.example_idx, leaf.index(logit_dim)) +=
70 value;
71 }
72 } else {
73 QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type";
74 const auto& leaf = leaf_node.leaf().vector();
75 for (size_t i = 0; i < leaf.value_size(); ++i) {
76 const float value = tree_weight * leaf.value(i);
77 output_predictions(example.example_idx, i) += value;
78 }
79 }
80 }
81 }
82 };
83
84 // TODO(salehay): parallelize this for low latency in serving path where
85 // batch size tends to be small but ensemble size tends to be large.
86 boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(),
87 worker_threads, update_predictions);
88 }
89
90 } // namespace models
91 } // namespace boosted_trees
92 } // namespace tensorflow
93