• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
17 #include "tensorflow/core/framework/tensor_testutil.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 #include "tensorflow/core/platform/test.h"
20 
21 namespace tensorflow {
22 namespace boosted_trees {
23 namespace trees {
24 namespace {
25 
26 class DecisionTreeTest : public ::testing::Test {
27  protected:
DecisionTreeTest()28   DecisionTreeTest() : batch_features_(2) {
29     // Create a batch of two examples having one dense float, two sparse float
30     // and one sparse int features, and one sparse multi-column float feature
31     // (SparseFM).
32     // The first example is missing the second sparse feature column and the
33     // second example is missing the first sparse feature column.
34     // This looks like the following:
35     // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseFM (3 cols)
36     // 0        |   7     |   -3     |          |    3     | 3.0 |   | 1.0
37     // 1        |  -2     |          |   4      |          | 1.5 |3.5|
38     auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
39     auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
40     auto sparse_float_values1 = test::AsTensor<float>({-3.0f});
41     auto sparse_float_shape1 = test::AsTensor<int64>({2, 1});
42     auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2});
43     auto sparse_float_values2 = test::AsTensor<float>({4.0f});
44     auto sparse_float_shape2 = test::AsTensor<int64>({2, 1});
45     auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
46     auto sparse_int_values1 = test::AsTensor<int64>({3});
47     auto sparse_int_shape1 = test::AsTensor<int64>({2, 1});
48 
49     // Multivalent sparse feature.
50     auto multi_sparse_float_indices =
51         test::AsTensor<int64>({0, 0, 0, 2, 1, 0, 1, 1}, {4, 2});
52     auto multi_sparse_float_values =
53         test::AsTensor<float>({3.0f, 1.0f, 1.5f, 3.5f});
54     auto multi_sparse_float_shape = test::AsTensor<int64>({2, 3});
55 
56     TF_EXPECT_OK(batch_features_.Initialize(
57         {dense_float_matrix},
58         {sparse_float_indices1, sparse_float_indices2,
59          multi_sparse_float_indices},
60         {sparse_float_values1, sparse_float_values2, multi_sparse_float_values},
61         {sparse_float_shape1, sparse_float_shape2, multi_sparse_float_shape},
62         {sparse_int_indices1}, {sparse_int_values1}, {sparse_int_shape1}));
63   }
64 
65   template <typename SplitType>
TestLinkChildrenBinary(TreeNode * node,SplitType * split)66   void TestLinkChildrenBinary(TreeNode* node, SplitType* split) {
67     // Verify children were linked.
68     DecisionTree::LinkChildren({3, 8}, node);
69     EXPECT_EQ(3, split->left_id());
70     EXPECT_EQ(8, split->right_id());
71 
72     // Invalid cases.
73     EXPECT_DEATH(DecisionTree::LinkChildren({}, node),
74                  "A binary split node must have exactly two children.");
75     EXPECT_DEATH(DecisionTree::LinkChildren({3}, node),
76                  "A binary split node must have exactly two children.");
77     EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node),
78                  "A binary split node must have exactly two children.");
79   }
80 
TestGetChildren(const TreeNode & node,const std::vector<uint32> & expected_children)81   void TestGetChildren(const TreeNode& node,
82                        const std::vector<uint32>& expected_children) {
83     // Verify children were linked.
84     auto children = DecisionTree::GetChildren(node);
85     EXPECT_EQ(children.size(), expected_children.size());
86     for (size_t idx = 0; idx < children.size(); ++idx) {
87       EXPECT_EQ(children[idx], expected_children[idx]);
88     }
89   }
90 
91   utils::BatchFeatures batch_features_;
92 };
93 
TEST_F(DecisionTreeTest,TraverseEmpty)94 TEST_F(DecisionTreeTest, TraverseEmpty) {
95   DecisionTreeConfig tree_config;
96   auto example = (*batch_features_.examples_iterable(0, 1).begin());
97   EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example));
98 }
99 
TEST_F(DecisionTreeTest,TraverseBias)100 TEST_F(DecisionTreeTest, TraverseBias) {
101   DecisionTreeConfig tree_config;
102   tree_config.add_nodes()->mutable_leaf();
103   auto example = (*batch_features_.examples_iterable(0, 1).begin());
104   EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example));
105 }
106 
TEST_F(DecisionTreeTest,TraverseInvalidSubRoot)107 TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) {
108   DecisionTreeConfig tree_config;
109   tree_config.add_nodes()->mutable_leaf();
110   auto example = (*batch_features_.examples_iterable(0, 1).begin());
111   EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example));
112 }
113 
TEST_F(DecisionTreeTest,TraverseDenseBinarySplit)114 TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) {
115   DecisionTreeConfig tree_config;
116   auto* split_node =
117       tree_config.add_nodes()->mutable_dense_float_binary_split();
118   split_node->set_feature_column(0);
119   split_node->set_threshold(0.0f);
120   split_node->set_left_id(1);
121   split_node->set_right_id(2);
122   tree_config.add_nodes()->mutable_leaf();
123   tree_config.add_nodes()->mutable_leaf();
124   auto example_iterable = batch_features_.examples_iterable(0, 2);
125 
126   // Expect right child to be picked as !(7 <= 0);
127   auto example_it = example_iterable.begin();
128   EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
129 
130   // Expect left child to be picked as (-2 <= 0);
131   EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
132 }
133 
TEST_F(DecisionTreeTest,TraverseSparseBinarySplit)134 TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) {
135   auto example_iterable = batch_features_.examples_iterable(0, 2);
136   // Split on SparseF1.
137   // Test first sparse feature which is missing for the second example.
138   {
139     DecisionTreeConfig tree_config;
140     auto* split_node = tree_config.add_nodes()
141                            ->mutable_sparse_float_binary_split_default_left()
142                            ->mutable_split();
143     split_node->set_feature_column(0);
144     split_node->set_threshold(-20.0f);
145     split_node->set_left_id(1);
146     split_node->set_right_id(2);
147     tree_config.add_nodes()->mutable_leaf();
148     tree_config.add_nodes()->mutable_leaf();
149 
150     // Expect right child to be picked as !(-3 <= -20).
151     auto example_it = example_iterable.begin();
152     EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
153 
154     // Expect left child to be picked as default direction.
155     EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
156   }
157   // Split on SparseF2.
158   // Test second sparse feature which is missing for the first example.
159   {
160     DecisionTreeConfig tree_config;
161     auto* split_node = tree_config.add_nodes()
162                            ->mutable_sparse_float_binary_split_default_right()
163                            ->mutable_split();
164     split_node->set_feature_column(1);
165     split_node->set_threshold(4.0f);
166     split_node->set_left_id(1);
167     split_node->set_right_id(2);
168     tree_config.add_nodes()->mutable_leaf();
169     tree_config.add_nodes()->mutable_leaf();
170 
171     // Expect right child to be picked as default direction.
172     auto example_it = example_iterable.begin();
173     EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
174 
175     // Expect left child to be picked as (4 <= 4).
176     EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
177   }
178   // Split on SparseFM.
179   // Test second sparse feature which is missing for the first example.
180   {
181     DecisionTreeConfig tree_config;
182     auto* split_node = tree_config.add_nodes()
183                            ->mutable_sparse_float_binary_split_default_right()
184                            ->mutable_split();
185     split_node->set_feature_column(2);
186 
187     split_node->set_left_id(1);
188     split_node->set_right_id(2);
189     tree_config.add_nodes()->mutable_leaf();
190     tree_config.add_nodes()->mutable_leaf();
191 
192     // Split on first column
193     split_node->set_dimension_id(0);
194     split_node->set_threshold(2.0f);
195 
196     // Both instances have this feature value.
197     auto example_it = example_iterable.begin();
198     EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
199     EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
200 
201     // Split on second column
202     split_node->set_dimension_id(1);
203     split_node->set_threshold(5.0f);
204 
205     // First instance does not have it (default right), second does have it.
206     example_it = example_iterable.begin();
207     EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
208     EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
209 
210     // Split on third column
211     split_node->set_dimension_id(2);
212     split_node->set_threshold(3.0f);
213     example_it = example_iterable.begin();
214 
215     // First instance has it, second does not (default right).
216     EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
217     EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
218   }
219 }
220 
TEST_F(DecisionTreeTest,TraverseCategoricalIdBinarySplit)221 TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) {
222   DecisionTreeConfig tree_config;
223   auto* split_node =
224       tree_config.add_nodes()->mutable_categorical_id_binary_split();
225   split_node->set_feature_column(0);
226   split_node->set_feature_id(3);
227   split_node->set_left_id(1);
228   split_node->set_right_id(2);
229   tree_config.add_nodes()->mutable_leaf();
230   tree_config.add_nodes()->mutable_leaf();
231   auto example_iterable = batch_features_.examples_iterable(0, 2);
232 
233   // Expect left child to be picked as 3 == 3;
234   auto example_it = example_iterable.begin();
235   EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
236 
237   // Expect right child to be picked as the feature is missing;
238   EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
239 }
240 
TEST_F(DecisionTreeTest,TraverseCategoricalIdSetMembershipBinarySplit)241 TEST_F(DecisionTreeTest, TraverseCategoricalIdSetMembershipBinarySplit) {
242   DecisionTreeConfig tree_config;
243   auto* split_node = tree_config.add_nodes()
244                          ->mutable_categorical_id_set_membership_binary_split();
245   split_node->set_feature_column(0);
246   split_node->add_feature_ids(3);
247   split_node->set_left_id(1);
248   split_node->set_right_id(2);
249   tree_config.add_nodes()->mutable_leaf();
250   tree_config.add_nodes()->mutable_leaf();
251   auto example_iterable = batch_features_.examples_iterable(0, 2);
252 
253   // Expect left child to be picked as 3 in {3};
254   auto example_it = example_iterable.begin();
255   EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
256 
257   // Expect right child to be picked as the feature is missing;
258   EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
259 }
260 
TEST_F(DecisionTreeTest,TraverseHybridSplits)261 TEST_F(DecisionTreeTest, TraverseHybridSplits) {
262   DecisionTreeConfig tree_config;
263   auto* split_node1 =
264       tree_config.add_nodes()->mutable_dense_float_binary_split();
265   split_node1->set_feature_column(0);
266   split_node1->set_threshold(9.0f);
267   split_node1->set_left_id(1);   // sparse split.
268   split_node1->set_right_id(2);  // leaf
269   auto* split_node2 = tree_config.add_nodes()
270                           ->mutable_sparse_float_binary_split_default_left()
271                           ->mutable_split();
272   tree_config.add_nodes()->mutable_leaf();
273   split_node2->set_feature_column(0);
274   split_node2->set_threshold(-20.0f);
275   split_node2->set_left_id(3);
276   split_node2->set_right_id(4);
277   auto* split_node3 =
278       tree_config.add_nodes()->mutable_categorical_id_binary_split();
279   split_node3->set_feature_column(0);
280   split_node3->set_feature_id(2);
281   split_node3->set_left_id(5);
282   split_node3->set_right_id(6);
283   tree_config.add_nodes()->mutable_leaf();
284   tree_config.add_nodes()->mutable_leaf();
285   tree_config.add_nodes()->mutable_leaf();
286   auto example_iterable = batch_features_.examples_iterable(0, 2);
287 
288   // Expect will go left through the first dense split as (7.0f <= 9.0f),
289   // then will go right through the sparse split as !(-3 <= -20).
290   auto example_it = example_iterable.begin();
291   EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it));
292 
293   // Expect will go left through the first dense split as (-2.0f <= 9.0f),
294   // then will go left the default direction as the sparse feature is missing,
295   // then will go right as 2 != 3 on the categorical split.
296   EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it));
297 }
298 
TEST_F(DecisionTreeTest,LinkChildrenLeaf)299 TEST_F(DecisionTreeTest, LinkChildrenLeaf) {
300   // Create leaf node.
301   TreeNode node;
302   node.mutable_leaf();
303 
304   // No-op.
305   DecisionTree::LinkChildren({}, &node);
306 
307   // Invalid case.
308   EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
309                "A leaf node cannot have children.");
310 }
311 
TEST_F(DecisionTreeTest,LinkChildrenDenseFloatBinarySplit)312 TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) {
313   TreeNode node;
314   auto* split = node.mutable_dense_float_binary_split();
315   split->set_left_id(-1);
316   split->set_right_id(-1);
317   TestLinkChildrenBinary(&node, split);
318 }
319 
TEST_F(DecisionTreeTest,LinkChildrenSparseFloatBinarySplitDefaultLeft)320 TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) {
321   TreeNode node;
322   auto* split =
323       node.mutable_sparse_float_binary_split_default_left()->mutable_split();
324   split->set_left_id(-1);
325   split->set_right_id(-1);
326   TestLinkChildrenBinary(&node, split);
327 }
328 
TEST_F(DecisionTreeTest,LinkChildrenSparseFloatBinarySplitDefaultRight)329 TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) {
330   TreeNode node;
331   auto* split =
332       node.mutable_sparse_float_binary_split_default_right()->mutable_split();
333   split->set_left_id(-1);
334   split->set_right_id(-1);
335   TestLinkChildrenBinary(&node, split);
336 }
337 
TEST_F(DecisionTreeTest,LinkChildrenCategoricalSingleIdBinarySplit)338 TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) {
339   TreeNode node;
340   auto* split = node.mutable_categorical_id_binary_split();
341   split->set_left_id(-1);
342   split->set_right_id(-1);
343   TestLinkChildrenBinary(&node, split);
344 }
345 
TEST_F(DecisionTreeTest,LinkChildrenNodeNotSet)346 TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) {
347   // Create unset node.
348   TreeNode node;
349 
350   // Invalid case.
351   EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
352                "A non-set node cannot have children.");
353 }
354 
TEST_F(DecisionTreeTest,GetChildrenLeaf)355 TEST_F(DecisionTreeTest, GetChildrenLeaf) {
356   TreeNode node;
357   node.mutable_leaf();
358   TestGetChildren(node, {});
359 }
360 
TEST_F(DecisionTreeTest,GetChildrenDenseFloatBinarySplit)361 TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) {
362   TreeNode node;
363   auto* split = node.mutable_dense_float_binary_split();
364   split->set_left_id(23);
365   split->set_right_id(24);
366   TestGetChildren(node, {23, 24});
367 }
368 
TEST_F(DecisionTreeTest,GetChildrenSparseFloatBinarySplitDefaultLeft)369 TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) {
370   TreeNode node;
371   auto* split =
372       node.mutable_sparse_float_binary_split_default_left()->mutable_split();
373   split->set_left_id(12);
374   split->set_right_id(13);
375   TestGetChildren(node, {12, 13});
376 }
377 
TEST_F(DecisionTreeTest,GetChildrenSparseFloatBinarySplitDefaultRight)378 TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) {
379   TreeNode node;
380   auto* split =
381       node.mutable_sparse_float_binary_split_default_right()->mutable_split();
382   split->set_left_id(1);
383   split->set_right_id(2);
384   TestGetChildren(node, {1, 2});
385 }
386 
TEST_F(DecisionTreeTest,GetChildrenCategoricalSingleIdBinarySplit)387 TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) {
388   TreeNode node;
389   auto* split = node.mutable_categorical_id_binary_split();
390   split->set_left_id(7);
391   split->set_right_id(8);
392   TestGetChildren(node, {7, 8});
393 }
394 
TEST_F(DecisionTreeTest,GetChildrenNodeNotSet)395 TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) {
396   TreeNode node;
397   TestGetChildren(node, {});
398 }
399 
400 }  // namespace
401 }  // namespace trees
402 }  // namespace boosted_trees
403 }  // namespace tensorflow
404