• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
16 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
17 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
18 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
19 
20 namespace tensorflow {
21 namespace boosted_trees {
22 namespace models {
23 
Predict(const boosted_trees::trees::DecisionTreeEnsembleConfig & config,const std::vector<int32> & trees_to_include,const boosted_trees::utils::BatchFeatures & features,tensorflow::thread::ThreadPool * const worker_threads,tensorflow::TTypes<float>::Matrix output_predictions,Tensor * const output_leaf_index)24 void MultipleAdditiveTrees::Predict(
25     const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
26     const std::vector<int32>& trees_to_include,
27     const boosted_trees::utils::BatchFeatures& features,
28     tensorflow::thread::ThreadPool* const worker_threads,
29     tensorflow::TTypes<float>::Matrix output_predictions,
30     Tensor* const output_leaf_index) {
31   // Zero out predictions as the model is additive.
32   output_predictions.setZero();
33 
34   // Get batch size.
35   const int64 batch_size = features.batch_size();
36   if (batch_size <= 0) {
37     return;
38   }
39 
40   // Lambda for doing a block of work.
41   auto update_predictions = [&config, &features, &trees_to_include,
42                              &output_predictions,
43                              &output_leaf_index](int64 start, int64 end) {
44     auto examples_iterable = features.examples_iterable(start, end);
45     Tensor dummy_tensor(DT_INT32, TensorShape({1, 1}));
46     tensorflow::TTypes<int>::Matrix output_leaf_index_mat =
47         output_leaf_index != nullptr ? output_leaf_index->matrix<int>()
48                                      : dummy_tensor.matrix<int>();
49     for (const auto& example : examples_iterable) {
50       for (const int32 tree_idx : trees_to_include) {
51         const boosted_trees::trees::DecisionTreeConfig& tree =
52             config.trees(tree_idx);
53         const float tree_weight = config.tree_weights(tree_idx);
54         const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
55         QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
56         // Checks if output leaf tree index is required.
57         if (output_leaf_index != nullptr) {
58           output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx;
59         }
60         const auto& leaf_node = tree.nodes(leaf_idx);
61         QCHECK(leaf_node.has_leaf())
62             << "Invalid leaf node: " << leaf_node.DebugString();
63         if (leaf_node.leaf().has_sparse_vector()) {
64           const auto& leaf = leaf_node.leaf().sparse_vector();
65           QCHECK_EQ(leaf.index_size(), leaf.value_size());
66           for (size_t logit_dim = 0; logit_dim < leaf.index_size();
67                ++logit_dim) {
68             const float value = tree_weight * leaf.value(logit_dim);
69             output_predictions(example.example_idx, leaf.index(logit_dim)) +=
70                 value;
71           }
72         } else {
73           QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type";
74           const auto& leaf = leaf_node.leaf().vector();
75           for (size_t i = 0; i < leaf.value_size(); ++i) {
76             const float value = tree_weight * leaf.value(i);
77             output_predictions(example.example_idx, i) += value;
78           }
79         }
80       }
81     }
82   };
83 
84   // TODO(salehay): parallelize this for low latency in serving path where
85   // batch size tends to be small but ensemble size tends to be large.
86   boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(),
87                                     worker_threads, update_predictions);
88 }
89 
90 }  // namespace models
91 }  // namespace boosted_trees
92 }  // namespace tensorflow
93