• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# TensorForest
2
3TensorForest is an implementation of random forests in TensorFlow using an
4online, [extremely randomized trees](
5https://en.wikipedia.org/wiki/Random_forest#ExtraTrees)
6training algorithm.  It supports both
7classification (binary and multiclass) and regression (scalar and vector).
8
9## Usage
10
11TensorForest is a tf.learn Estimator:
12
13```import tensorflow as tf
14
15params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
16  num_classes=2, num_features=10, regression=False,
17  num_trees=50, max_nodes=1000)
18
19classifier =
20tf.contrib.tensor_forest.client.random_forest.TensorForestEstimator(params)
21
22classifier.fit(x=x_train, y=y_train)
23
24y_out = classifier.predict(x=x_test)
25```
26
27TensorForest users are implored to properly shuffle their training data,
28as our training algorithm strongly assumes it is in random order.
29
30## Algorithm
31
32Each tree in the forest is trained independently in parallel.  For each
33tree, we maintain the following data:
34
35* The tree structure, giving the two children of each non-leaf node and
36the *split* used to route data between them.  Each split looks at a single
37input feature and compares it to a threshold value.
38
39* Leaf statistics.  Each leaf needs to gather statistics, and those
40statistics have the property that at the end of training, they can be
41turned into predictions.  For classification problems, the statistics are
42class counts, and for regression problems they are the vector sum of the
43values seen at the leaf, along with a count of those values.
44
45* Growing statistics.  Each leaf needs to gather data that will potentially
46allow it to grow into a non-leaf parent node.  That data usually consists
47of a list of potential splits, along with statistics for each of those splits.
48Split statistics in turn consist of leaf statistics for their left and
49right branches, along with some other information that allows us to assess
50the quality of the split.  For classification problems, that's usually
51the [gini
52impurity](https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity)
53of the split, while for regression problems it's the mean-squared error.
54
55At the start of training, the tree structure is initialized to a root node,
56and the leaf and growing statistics for it are both empty.  Then, for
57each batch `{(x_i, y_i)}`  of training data, the following steps are performed:
58
591. Given the current tree structure, each `x_i` is used to find the leaf
60assignment `l_i`.
61
622. `y_i` is used to update the leaf statistics of leaf `l_i`.
63
643. If the growing statistics for the leaf `l_i` do not yet contain
65`num_splits_to_consider` splits, `x_i` is used to generate another split.
66Specifically, a random feature value is chosen, and `x_i`'s value at that
67feature is used for the split's threshold.
68
694. Otherwise, `(x_i, y_i)` is used to update the statistics of every
70split in the growing statistics of leaf `l_i`.  If leaf `l_i` has now seen
71`split_after_samples` data points since creating all of its potential splits,
72the split with the best score is chosen, and the tree structure is grown.
73
74## Parameters
75
76The following ForestHParams parameters are required:
77
78* `num_classes`.  The number of classes in a classification problem, or
79the number of dimensions in the output of a regression problem.
80
81* `num_features`.  The number of input features.
82
83The following ForestHParams parameters are important but not required:
84
85* `regression`.  True for regression problems, False for classification tasks.
86  Defaults to False (classification).
87For regression problems, TensorForests's output are the predicted regression
88values.  For classification, the outputs are the per-class probabilities.
89
90* `num_trees`.  The number of trees to create.  Defaults to 100.  There
91usually isn't any accuracy gain from using higher values.
92
93* `max_nodes`.  Defaults to 10,000.  No tree is allowed to grow beyond
94`max_nodes` nodes, and training stops when all trees in the forest are this
95large.
96
97The remaining ForestHParams parameters don't usually require being set by the
98user:
99
100* `num_splits_to_consider`.  Defaults to `sqrt(num_features)` capped to be
101between 10 and 1000.  In the extremely randomized tree training algorithm,
102only this many potential splits are evaluated for each tree node.
103
104* `split_after_samples`.  Defaults to 250.  In our online version of
105extremely randomized tree training, we pick a split for a node after it has
106accumulated this many training samples.
107
108* `bagging_fraction`.  If less than 1.0,
109then each tree sees only a different, random sampled (without replacement),
110`bagging_fraction` sized subset of
111the training data.  Defaults to 1.0 (no bagging) because it fails to give
112any accuracy improvement our experiments so far.
113
114* `feature_bagging_fraction`.  If less than 1.0, then each tree sees only
115a different `feature_bagging_fraction * num_features` sized subset of the
116input features.  Defaults to 1.0 (no feature bagging).
117
118* `base_random_seed`.  By default (`base_random_seed = 0`), the random number
119generator for each tree is seeded by the current time (in microseconds) when
120each tree is first created.  Using a non-zero value causes tree training to
121be deterministic, in that the i-th tree's random number generator is seeded
122with the value `base_random_seed + i`.
123
124## Implementation
125
126The python code in `python/tensor_forest.py` assigns default values to the
127parameters, handles both instance and feature bagging, and creates the
128TensorFlow graphs for training and inference.  The graphs themselves are
129quite simple, as most of the work is done in custom ops.  There is a single
130op (`model_ops.tree_predictions_v4`) that does inference for a single tree,
131and four custom ops that do training on a single tree over a single batch,
132with each op roughly corresponding to one of the four steps from the
133algorithm section above.
134
135The training data itself is stored in TensorFlow _resources_, which provide
136a means of non-tensor based persistence storage.  (See
137`core/framework/resource_mgr.h` for more information about resources.)
138The tree
139structure is stored in the `DecisionTreeResource` defined in
140`kernels/v4/decision-tree-resource.h` and the leaf and growing statistics
141are stored in the `FertileStatsResource` defined in
142`kernels/v4/fertile-stats-resource.h`.
143
144## More information
145
146* [Kaggle kernel demonstrating TensorForest on Iris
147  dataset](https://www.kaggle.com/thomascolthurst/tensorforest-on-iris/notebook)
148* [TensorForest
149  paper from NIPS 2016 Workshop](https://docs.google.com/viewer?a=v&pid=sites&srcid=ZGVmYXVsdGRvbWFpbnxtbHN5c25pcHMyMDE2fGd4OjFlNTRiOWU2OGM2YzA4MjE)
150