1# TensorForest 2 3TensorForest is an implementation of random forests in TensorFlow using an 4online, [extremely randomized trees]( 5https://en.wikipedia.org/wiki/Random_forest#ExtraTrees) 6training algorithm. It supports both 7classification (binary and multiclass) and regression (scalar and vector). 8 9## Usage 10 11TensorForest is a tf.learn Estimator: 12 13```import tensorflow as tf 14 15params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( 16 num_classes=2, num_features=10, regression=False, 17 num_trees=50, max_nodes=1000) 18 19classifier = 20tf.contrib.tensor_forest.client.random_forest.TensorForestEstimator(params) 21 22classifier.fit(x=x_train, y=y_train) 23 24y_out = classifier.predict(x=x_test) 25``` 26 27TensorForest users are implored to properly shuffle their training data, 28as our training algorithm strongly assumes it is in random order. 29 30## Algorithm 31 32Each tree in the forest is trained independently in parallel. For each 33tree, we maintain the following data: 34 35* The tree structure, giving the two children of each non-leaf node and 36the *split* used to route data between them. Each split looks at a single 37input feature and compares it to a threshold value. 38 39* Leaf statistics. Each leaf needs to gather statistics, and those 40statistics have the property that at the end of training, they can be 41turned into predictions. For classification problems, the statistics are 42class counts, and for regression problems they are the vector sum of the 43values seen at the leaf, along with a count of those values. 44 45* Growing statistics. Each leaf needs to gather data that will potentially 46allow it to grow into a non-leaf parent node. That data usually consists 47of a list of potential splits, along with statistics for each of those splits. 48Split statistics in turn consist of leaf statistics for their left and 49right branches, along with some other information that allows us to assess 50the quality of the split. For classification problems, that's usually 51the [gini 52impurity](https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity) 53of the split, while for regression problems it's the mean-squared error. 54 55At the start of training, the tree structure is initialized to a root node, 56and the leaf and growing statistics for it are both empty. Then, for 57each batch `{(x_i, y_i)}` of training data, the following steps are performed: 58 591. Given the current tree structure, each `x_i` is used to find the leaf 60assignment `l_i`. 61 622. `y_i` is used to update the leaf statistics of leaf `l_i`. 63 643. If the growing statistics for the leaf `l_i` do not yet contain 65`num_splits_to_consider` splits, `x_i` is used to generate another split. 66Specifically, a random feature value is chosen, and `x_i`'s value at that 67feature is used for the split's threshold. 68 694. Otherwise, `(x_i, y_i)` is used to update the statistics of every 70split in the growing statistics of leaf `l_i`. If leaf `l_i` has now seen 71`split_after_samples` data points since creating all of its potential splits, 72the split with the best score is chosen, and the tree structure is grown. 73 74## Parameters 75 76The following ForestHParams parameters are required: 77 78* `num_classes`. The number of classes in a classification problem, or 79the number of dimensions in the output of a regression problem. 80 81* `num_features`. The number of input features. 82 83The following ForestHParams parameters are important but not required: 84 85* `regression`. True for regression problems, False for classification tasks. 86 Defaults to False (classification). 87For regression problems, TensorForests's output are the predicted regression 88values. For classification, the outputs are the per-class probabilities. 89 90* `num_trees`. The number of trees to create. Defaults to 100. There 91usually isn't any accuracy gain from using higher values. 92 93* `max_nodes`. Defaults to 10,000. No tree is allowed to grow beyond 94`max_nodes` nodes, and training stops when all trees in the forest are this 95large. 96 97The remaining ForestHParams parameters don't usually require being set by the 98user: 99 100* `num_splits_to_consider`. Defaults to `sqrt(num_features)` capped to be 101between 10 and 1000. In the extremely randomized tree training algorithm, 102only this many potential splits are evaluated for each tree node. 103 104* `split_after_samples`. Defaults to 250. In our online version of 105extremely randomized tree training, we pick a split for a node after it has 106accumulated this many training samples. 107 108* `bagging_fraction`. If less than 1.0, 109then each tree sees only a different, random sampled (without replacement), 110`bagging_fraction` sized subset of 111the training data. Defaults to 1.0 (no bagging) because it fails to give 112any accuracy improvement our experiments so far. 113 114* `feature_bagging_fraction`. If less than 1.0, then each tree sees only 115a different `feature_bagging_fraction * num_features` sized subset of the 116input features. Defaults to 1.0 (no feature bagging). 117 118* `base_random_seed`. By default (`base_random_seed = 0`), the random number 119generator for each tree is seeded by the current time (in microseconds) when 120each tree is first created. Using a non-zero value causes tree training to 121be deterministic, in that the i-th tree's random number generator is seeded 122with the value `base_random_seed + i`. 123 124## Implementation 125 126The python code in `python/tensor_forest.py` assigns default values to the 127parameters, handles both instance and feature bagging, and creates the 128TensorFlow graphs for training and inference. The graphs themselves are 129quite simple, as most of the work is done in custom ops. There is a single 130op (`model_ops.tree_predictions_v4`) that does inference for a single tree, 131and four custom ops that do training on a single tree over a single batch, 132with each op roughly corresponding to one of the four steps from the 133algorithm section above. 134 135The training data itself is stored in TensorFlow _resources_, which provide 136a means of non-tensor based persistence storage. (See 137`core/framework/resource_mgr.h` for more information about resources.) 138The tree 139structure is stored in the `DecisionTreeResource` defined in 140`kernels/v4/decision-tree-resource.h` and the leaf and growing statistics 141are stored in the `FertileStatsResource` defined in 142`kernels/v4/fertile-stats-resource.h`. 143 144## More information 145 146* [Kaggle kernel demonstrating TensorForest on Iris 147 dataset](https://www.kaggle.com/thomascolthurst/tensorforest-on-iris/notebook) 148* [TensorForest 149 paper from NIPS 2016 Workshop](https://docs.google.com/viewer?a=v&pid=sites&srcid=ZGVmYXVsdGRvbWFpbnxtbHN5c25pcHMyMDE2fGd4OjFlNTRiOWU2OGM2YzA4MjE) 150