1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
16 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
17 #include "tensorflow/core/framework/tensor_testutil.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 #include "tensorflow/core/platform/test.h"
20
21 namespace tensorflow {
22 namespace boosted_trees {
23 namespace trees {
24 namespace {
25
26 class DecisionTreeTest : public ::testing::Test {
27 protected:
DecisionTreeTest()28 DecisionTreeTest() : batch_features_(2) {
29 // Create a batch of two examples having one dense float, two sparse float
30 // and one sparse int features, and one sparse multi-column float feature
31 // (SparseFM).
32 // The first example is missing the second sparse feature column and the
33 // second example is missing the first sparse feature column.
34 // This looks like the following:
35 // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseFM (3 cols)
36 // 0 | 7 | -3 | | 3 | 3.0 | | 1.0
37 // 1 | -2 | | 4 | | 1.5 |3.5|
38 auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
39 auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
40 auto sparse_float_values1 = test::AsTensor<float>({-3.0f});
41 auto sparse_float_shape1 = test::AsTensor<int64>({2, 1});
42 auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2});
43 auto sparse_float_values2 = test::AsTensor<float>({4.0f});
44 auto sparse_float_shape2 = test::AsTensor<int64>({2, 1});
45 auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
46 auto sparse_int_values1 = test::AsTensor<int64>({3});
47 auto sparse_int_shape1 = test::AsTensor<int64>({2, 1});
48
49 // Multivalent sparse feature.
50 auto multi_sparse_float_indices =
51 test::AsTensor<int64>({0, 0, 0, 2, 1, 0, 1, 1}, {4, 2});
52 auto multi_sparse_float_values =
53 test::AsTensor<float>({3.0f, 1.0f, 1.5f, 3.5f});
54 auto multi_sparse_float_shape = test::AsTensor<int64>({2, 3});
55
56 TF_EXPECT_OK(batch_features_.Initialize(
57 {dense_float_matrix},
58 {sparse_float_indices1, sparse_float_indices2,
59 multi_sparse_float_indices},
60 {sparse_float_values1, sparse_float_values2, multi_sparse_float_values},
61 {sparse_float_shape1, sparse_float_shape2, multi_sparse_float_shape},
62 {sparse_int_indices1}, {sparse_int_values1}, {sparse_int_shape1}));
63 }
64
65 template <typename SplitType>
TestLinkChildrenBinary(TreeNode * node,SplitType * split)66 void TestLinkChildrenBinary(TreeNode* node, SplitType* split) {
67 // Verify children were linked.
68 DecisionTree::LinkChildren({3, 8}, node);
69 EXPECT_EQ(3, split->left_id());
70 EXPECT_EQ(8, split->right_id());
71
72 // Invalid cases.
73 EXPECT_DEATH(DecisionTree::LinkChildren({}, node),
74 "A binary split node must have exactly two children.");
75 EXPECT_DEATH(DecisionTree::LinkChildren({3}, node),
76 "A binary split node must have exactly two children.");
77 EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node),
78 "A binary split node must have exactly two children.");
79 }
80
TestGetChildren(const TreeNode & node,const std::vector<uint32> & expected_children)81 void TestGetChildren(const TreeNode& node,
82 const std::vector<uint32>& expected_children) {
83 // Verify children were linked.
84 auto children = DecisionTree::GetChildren(node);
85 EXPECT_EQ(children.size(), expected_children.size());
86 for (size_t idx = 0; idx < children.size(); ++idx) {
87 EXPECT_EQ(children[idx], expected_children[idx]);
88 }
89 }
90
91 utils::BatchFeatures batch_features_;
92 };
93
TEST_F(DecisionTreeTest,TraverseEmpty)94 TEST_F(DecisionTreeTest, TraverseEmpty) {
95 DecisionTreeConfig tree_config;
96 auto example = (*batch_features_.examples_iterable(0, 1).begin());
97 EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example));
98 }
99
TEST_F(DecisionTreeTest,TraverseBias)100 TEST_F(DecisionTreeTest, TraverseBias) {
101 DecisionTreeConfig tree_config;
102 tree_config.add_nodes()->mutable_leaf();
103 auto example = (*batch_features_.examples_iterable(0, 1).begin());
104 EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example));
105 }
106
TEST_F(DecisionTreeTest,TraverseInvalidSubRoot)107 TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) {
108 DecisionTreeConfig tree_config;
109 tree_config.add_nodes()->mutable_leaf();
110 auto example = (*batch_features_.examples_iterable(0, 1).begin());
111 EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example));
112 }
113
TEST_F(DecisionTreeTest,TraverseDenseBinarySplit)114 TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) {
115 DecisionTreeConfig tree_config;
116 auto* split_node =
117 tree_config.add_nodes()->mutable_dense_float_binary_split();
118 split_node->set_feature_column(0);
119 split_node->set_threshold(0.0f);
120 split_node->set_left_id(1);
121 split_node->set_right_id(2);
122 tree_config.add_nodes()->mutable_leaf();
123 tree_config.add_nodes()->mutable_leaf();
124 auto example_iterable = batch_features_.examples_iterable(0, 2);
125
126 // Expect right child to be picked as !(7 <= 0);
127 auto example_it = example_iterable.begin();
128 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
129
130 // Expect left child to be picked as (-2 <= 0);
131 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
132 }
133
TEST_F(DecisionTreeTest,TraverseSparseBinarySplit)134 TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) {
135 auto example_iterable = batch_features_.examples_iterable(0, 2);
136 // Split on SparseF1.
137 // Test first sparse feature which is missing for the second example.
138 {
139 DecisionTreeConfig tree_config;
140 auto* split_node = tree_config.add_nodes()
141 ->mutable_sparse_float_binary_split_default_left()
142 ->mutable_split();
143 split_node->set_feature_column(0);
144 split_node->set_threshold(-20.0f);
145 split_node->set_left_id(1);
146 split_node->set_right_id(2);
147 tree_config.add_nodes()->mutable_leaf();
148 tree_config.add_nodes()->mutable_leaf();
149
150 // Expect right child to be picked as !(-3 <= -20).
151 auto example_it = example_iterable.begin();
152 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
153
154 // Expect left child to be picked as default direction.
155 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
156 }
157 // Split on SparseF2.
158 // Test second sparse feature which is missing for the first example.
159 {
160 DecisionTreeConfig tree_config;
161 auto* split_node = tree_config.add_nodes()
162 ->mutable_sparse_float_binary_split_default_right()
163 ->mutable_split();
164 split_node->set_feature_column(1);
165 split_node->set_threshold(4.0f);
166 split_node->set_left_id(1);
167 split_node->set_right_id(2);
168 tree_config.add_nodes()->mutable_leaf();
169 tree_config.add_nodes()->mutable_leaf();
170
171 // Expect right child to be picked as default direction.
172 auto example_it = example_iterable.begin();
173 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
174
175 // Expect left child to be picked as (4 <= 4).
176 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
177 }
178 // Split on SparseFM.
179 // Test second sparse feature which is missing for the first example.
180 {
181 DecisionTreeConfig tree_config;
182 auto* split_node = tree_config.add_nodes()
183 ->mutable_sparse_float_binary_split_default_right()
184 ->mutable_split();
185 split_node->set_feature_column(2);
186
187 split_node->set_left_id(1);
188 split_node->set_right_id(2);
189 tree_config.add_nodes()->mutable_leaf();
190 tree_config.add_nodes()->mutable_leaf();
191
192 // Split on first column
193 split_node->set_dimension_id(0);
194 split_node->set_threshold(2.0f);
195
196 // Both instances have this feature value.
197 auto example_it = example_iterable.begin();
198 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
199 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
200
201 // Split on second column
202 split_node->set_dimension_id(1);
203 split_node->set_threshold(5.0f);
204
205 // First instance does not have it (default right), second does have it.
206 example_it = example_iterable.begin();
207 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
208 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
209
210 // Split on third column
211 split_node->set_dimension_id(2);
212 split_node->set_threshold(3.0f);
213 example_it = example_iterable.begin();
214
215 // First instance has it, second does not (default right).
216 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
217 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
218 }
219 }
220
TEST_F(DecisionTreeTest,TraverseCategoricalIdBinarySplit)221 TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) {
222 DecisionTreeConfig tree_config;
223 auto* split_node =
224 tree_config.add_nodes()->mutable_categorical_id_binary_split();
225 split_node->set_feature_column(0);
226 split_node->set_feature_id(3);
227 split_node->set_left_id(1);
228 split_node->set_right_id(2);
229 tree_config.add_nodes()->mutable_leaf();
230 tree_config.add_nodes()->mutable_leaf();
231 auto example_iterable = batch_features_.examples_iterable(0, 2);
232
233 // Expect left child to be picked as 3 == 3;
234 auto example_it = example_iterable.begin();
235 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
236
237 // Expect right child to be picked as the feature is missing;
238 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
239 }
240
TEST_F(DecisionTreeTest,TraverseCategoricalIdSetMembershipBinarySplit)241 TEST_F(DecisionTreeTest, TraverseCategoricalIdSetMembershipBinarySplit) {
242 DecisionTreeConfig tree_config;
243 auto* split_node = tree_config.add_nodes()
244 ->mutable_categorical_id_set_membership_binary_split();
245 split_node->set_feature_column(0);
246 split_node->add_feature_ids(3);
247 split_node->set_left_id(1);
248 split_node->set_right_id(2);
249 tree_config.add_nodes()->mutable_leaf();
250 tree_config.add_nodes()->mutable_leaf();
251 auto example_iterable = batch_features_.examples_iterable(0, 2);
252
253 // Expect left child to be picked as 3 in {3};
254 auto example_it = example_iterable.begin();
255 EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
256
257 // Expect right child to be picked as the feature is missing;
258 EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
259 }
260
TEST_F(DecisionTreeTest,TraverseHybridSplits)261 TEST_F(DecisionTreeTest, TraverseHybridSplits) {
262 DecisionTreeConfig tree_config;
263 auto* split_node1 =
264 tree_config.add_nodes()->mutable_dense_float_binary_split();
265 split_node1->set_feature_column(0);
266 split_node1->set_threshold(9.0f);
267 split_node1->set_left_id(1); // sparse split.
268 split_node1->set_right_id(2); // leaf
269 auto* split_node2 = tree_config.add_nodes()
270 ->mutable_sparse_float_binary_split_default_left()
271 ->mutable_split();
272 tree_config.add_nodes()->mutable_leaf();
273 split_node2->set_feature_column(0);
274 split_node2->set_threshold(-20.0f);
275 split_node2->set_left_id(3);
276 split_node2->set_right_id(4);
277 auto* split_node3 =
278 tree_config.add_nodes()->mutable_categorical_id_binary_split();
279 split_node3->set_feature_column(0);
280 split_node3->set_feature_id(2);
281 split_node3->set_left_id(5);
282 split_node3->set_right_id(6);
283 tree_config.add_nodes()->mutable_leaf();
284 tree_config.add_nodes()->mutable_leaf();
285 tree_config.add_nodes()->mutable_leaf();
286 auto example_iterable = batch_features_.examples_iterable(0, 2);
287
288 // Expect will go left through the first dense split as (7.0f <= 9.0f),
289 // then will go right through the sparse split as !(-3 <= -20).
290 auto example_it = example_iterable.begin();
291 EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it));
292
293 // Expect will go left through the first dense split as (-2.0f <= 9.0f),
294 // then will go left the default direction as the sparse feature is missing,
295 // then will go right as 2 != 3 on the categorical split.
296 EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it));
297 }
298
TEST_F(DecisionTreeTest,LinkChildrenLeaf)299 TEST_F(DecisionTreeTest, LinkChildrenLeaf) {
300 // Create leaf node.
301 TreeNode node;
302 node.mutable_leaf();
303
304 // No-op.
305 DecisionTree::LinkChildren({}, &node);
306
307 // Invalid case.
308 EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
309 "A leaf node cannot have children.");
310 }
311
TEST_F(DecisionTreeTest,LinkChildrenDenseFloatBinarySplit)312 TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) {
313 TreeNode node;
314 auto* split = node.mutable_dense_float_binary_split();
315 split->set_left_id(-1);
316 split->set_right_id(-1);
317 TestLinkChildrenBinary(&node, split);
318 }
319
TEST_F(DecisionTreeTest,LinkChildrenSparseFloatBinarySplitDefaultLeft)320 TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) {
321 TreeNode node;
322 auto* split =
323 node.mutable_sparse_float_binary_split_default_left()->mutable_split();
324 split->set_left_id(-1);
325 split->set_right_id(-1);
326 TestLinkChildrenBinary(&node, split);
327 }
328
TEST_F(DecisionTreeTest,LinkChildrenSparseFloatBinarySplitDefaultRight)329 TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) {
330 TreeNode node;
331 auto* split =
332 node.mutable_sparse_float_binary_split_default_right()->mutable_split();
333 split->set_left_id(-1);
334 split->set_right_id(-1);
335 TestLinkChildrenBinary(&node, split);
336 }
337
TEST_F(DecisionTreeTest,LinkChildrenCategoricalSingleIdBinarySplit)338 TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) {
339 TreeNode node;
340 auto* split = node.mutable_categorical_id_binary_split();
341 split->set_left_id(-1);
342 split->set_right_id(-1);
343 TestLinkChildrenBinary(&node, split);
344 }
345
TEST_F(DecisionTreeTest,LinkChildrenNodeNotSet)346 TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) {
347 // Create unset node.
348 TreeNode node;
349
350 // Invalid case.
351 EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
352 "A non-set node cannot have children.");
353 }
354
TEST_F(DecisionTreeTest,GetChildrenLeaf)355 TEST_F(DecisionTreeTest, GetChildrenLeaf) {
356 TreeNode node;
357 node.mutable_leaf();
358 TestGetChildren(node, {});
359 }
360
TEST_F(DecisionTreeTest,GetChildrenDenseFloatBinarySplit)361 TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) {
362 TreeNode node;
363 auto* split = node.mutable_dense_float_binary_split();
364 split->set_left_id(23);
365 split->set_right_id(24);
366 TestGetChildren(node, {23, 24});
367 }
368
TEST_F(DecisionTreeTest,GetChildrenSparseFloatBinarySplitDefaultLeft)369 TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) {
370 TreeNode node;
371 auto* split =
372 node.mutable_sparse_float_binary_split_default_left()->mutable_split();
373 split->set_left_id(12);
374 split->set_right_id(13);
375 TestGetChildren(node, {12, 13});
376 }
377
TEST_F(DecisionTreeTest,GetChildrenSparseFloatBinarySplitDefaultRight)378 TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) {
379 TreeNode node;
380 auto* split =
381 node.mutable_sparse_float_binary_split_default_right()->mutable_split();
382 split->set_left_id(1);
383 split->set_right_id(2);
384 TestGetChildren(node, {1, 2});
385 }
386
TEST_F(DecisionTreeTest,GetChildrenCategoricalSingleIdBinarySplit)387 TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) {
388 TreeNode node;
389 auto* split = node.mutable_categorical_id_binary_split();
390 split->set_left_id(7);
391 split->set_right_id(8);
392 TestGetChildren(node, {7, 8});
393 }
394
TEST_F(DecisionTreeTest,GetChildrenNodeNotSet)395 TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) {
396 TreeNode node;
397 TestGetChildren(node, {});
398 }
399
400 } // namespace
401 } // namespace trees
402 } // namespace boosted_trees
403 } // namespace tensorflow
404