1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Demonstrates multiclass MNIST TF Boosted trees example. 16 17 This example demonstrates how to run experiments with TF Boosted Trees on 18 a MNIST dataset. We are using layer by layer boosting with diagonal hessian 19 strategy for multiclass handling, and cross entropy loss. 20 21 Example Usage: 22 python tensorflow/contrib/boosted_trees/examples/mnist.py \ 23 --output_dir="/tmp/mnist" --depth=4 --learning_rate=0.3 --batch_size=60000 \ 24 --examples_per_layer=60000 --eval_batch_size=10000 --num_eval_steps=1 \ 25 --num_trees=10 --l2=1 --vmodule=training_ops=1 26 27 When training is done, accuracy on eval data is reported. Point tensorboard 28 to the directory for the run to see how the training progresses: 29 30 tensorboard --logdir=/tmp/mnist 31 32""" 33from __future__ import absolute_import 34from __future__ import division 35from __future__ import print_function 36 37import argparse 38import sys 39 40import numpy as np 41import tensorflow as tf 42from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier 43from tensorflow.contrib.boosted_trees.proto import learner_pb2 44from tensorflow.contrib.learn import learn_runner 45 46 47def get_input_fn(dataset_split, 48 batch_size, 49 capacity=10000, 50 min_after_dequeue=3000): 51 """Input function over MNIST data.""" 52 53 def _input_fn(): 54 """Prepare features and labels.""" 55 images_batch, labels_batch = tf.train.shuffle_batch( 56 tensors=[dataset_split.images, 57 dataset_split.labels.astype(np.int32)], 58 batch_size=batch_size, 59 capacity=capacity, 60 min_after_dequeue=min_after_dequeue, 61 enqueue_many=True, 62 num_threads=4) 63 features_map = {"images": images_batch} 64 return features_map, labels_batch 65 66 return _input_fn 67 68 69# Main config - creates a TF Boosted Trees Estimator based on flags. 70def _get_tfbt(output_dir): 71 """Configures TF Boosted Trees estimator based on flags.""" 72 learner_config = learner_pb2.LearnerConfig() 73 74 num_classes = 10 75 76 learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate 77 learner_config.num_classes = num_classes 78 learner_config.regularization.l1 = 0.0 79 learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer 80 learner_config.constraints.max_tree_depth = FLAGS.depth 81 82 growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER 83 learner_config.growing_mode = growing_mode 84 run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) 85 86 learner_config.multi_class_strategy = ( 87 learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) 88 89 # Create a TF Boosted trees estimator that can take in custom loss. 90 estimator = GradientBoostedDecisionTreeClassifier( 91 learner_config=learner_config, 92 n_classes=num_classes, 93 examples_per_layer=FLAGS.examples_per_layer, 94 model_dir=output_dir, 95 num_trees=FLAGS.num_trees, 96 center_bias=False, 97 config=run_config) 98 return estimator 99 100 101def _make_experiment_fn(output_dir): 102 """Creates experiment for gradient boosted decision trees.""" 103 data = tf.contrib.learn.datasets.mnist.load_mnist() 104 train_input_fn = get_input_fn(data.train, FLAGS.batch_size) 105 eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) 106 107 return tf.contrib.learn.Experiment( 108 estimator=_get_tfbt(output_dir), 109 train_input_fn=train_input_fn, 110 eval_input_fn=eval_input_fn, 111 train_steps=None, 112 eval_steps=FLAGS.num_eval_steps, 113 eval_metrics=None) 114 115 116def main(unused_argv): 117 learn_runner.run( 118 experiment_fn=_make_experiment_fn, 119 output_dir=FLAGS.output_dir, 120 schedule="train_and_evaluate") 121 122 123if __name__ == "__main__": 124 tf.logging.set_verbosity(tf.logging.INFO) 125 parser = argparse.ArgumentParser() 126 # Define the list of flags that users can change. 127 parser.add_argument( 128 "--output_dir", 129 type=str, 130 required=True, 131 help="Choose the dir for the output.") 132 parser.add_argument( 133 "--batch_size", 134 type=int, 135 default=1000, 136 help="The batch size for reading data.") 137 parser.add_argument( 138 "--eval_batch_size", 139 type=int, 140 default=1000, 141 help="Size of the batch for eval.") 142 parser.add_argument( 143 "--num_eval_steps", 144 type=int, 145 default=1, 146 help="The number of steps to run evaluation for.") 147 # Flags for gradient boosted trees config. 148 parser.add_argument( 149 "--depth", type=int, default=4, help="Maximum depth of weak learners.") 150 parser.add_argument( 151 "--l2", type=float, default=1.0, help="l2 regularization per batch.") 152 parser.add_argument( 153 "--learning_rate", 154 type=float, 155 default=0.1, 156 help="Learning rate (shrinkage weight) with which each new tree is added." 157 ) 158 parser.add_argument( 159 "--examples_per_layer", 160 type=int, 161 default=1000, 162 help="Number of examples to accumulate stats for per layer.") 163 parser.add_argument( 164 "--num_trees", 165 type=int, 166 default=None, 167 required=True, 168 help="Number of trees to grow before stopping.") 169 170 FLAGS, unparsed = parser.parse_known_args() 171 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 172