• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A binary to train pruned CIFAR-10 using a single GPU.
16
17Accuracy:
18cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
19data) as judged by cifar10_eval.py when target sparsity in
20cifar10_pruning_spec.pbtxt is set to zero
21
22Results:
23Sparsity | Accuracy after 150K steps
24-------- | -------------------------
250%       | 86%
2650%      | 86%
2775%      | TODO(suyoggupta)
2890%      | TODO(suyoggupta)
2995%      | 77%
30
31Usage:
32Please see the tutorial and website for how to download the CIFAR-10
33data set, compile the program and train the model.
34
35
36"""
37from __future__ import absolute_import
38from __future__ import division
39from __future__ import print_function
40
41import argparse
42import datetime
43import sys
44import time
45
46
47import tensorflow as tf
48
49from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
50from tensorflow.contrib.model_pruning.python import pruning
51
52FLAGS = None
53
54
55def train():
56  """Train CIFAR-10 for a number of steps."""
57  with tf.Graph().as_default():
58    global_step = tf.contrib.framework.get_or_create_global_step()
59
60    # Get images and labels for CIFAR-10.
61    images, labels = cifar10.distorted_inputs()
62
63    # Build a Graph that computes the logits predictions from the
64    # inference model.
65    logits = cifar10.inference(images)
66
67    # Calculate loss.
68    loss = cifar10.loss(logits, labels)
69
70    # Build a Graph that trains the model with one batch of examples and
71    # updates the model parameters.
72    train_op = cifar10.train(loss, global_step)
73
74    # Parse pruning hyperparameters
75    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
76
77    # Create a pruning object using the pruning hyperparameters
78    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
79
80    # Use the pruning_obj to add ops to the training graph to update the masks
81    # The conditional_mask_update_op will update the masks only when the
82    # training step is in [begin_pruning_step, end_pruning_step] specified in
83    # the pruning spec proto
84    mask_update_op = pruning_obj.conditional_mask_update_op()
85
86    # Use the pruning_obj to add summaries to the graph to track the sparsity
87    # of each of the layers
88    pruning_obj.add_pruning_summaries()
89
90    class _LoggerHook(tf.train.SessionRunHook):
91      """Logs loss and runtime."""
92
93      def begin(self):
94        self._step = -1
95
96      def before_run(self, run_context):
97        self._step += 1
98        self._start_time = time.time()
99        return tf.train.SessionRunArgs(loss)  # Asks for loss value.
100
101      def after_run(self, run_context, run_values):
102        duration = time.time() - self._start_time
103        loss_value = run_values.results
104        if self._step % 10 == 0:
105          num_examples_per_step = 128
106          examples_per_sec = num_examples_per_step / duration
107          sec_per_batch = float(duration)
108
109          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
110                        'sec/batch)')
111          print(format_str % (datetime.datetime.now(), self._step, loss_value,
112                              examples_per_sec, sec_per_batch))
113
114    with tf.train.MonitoredTrainingSession(
115        checkpoint_dir=FLAGS.train_dir,
116        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
117               tf.train.NanTensorHook(loss),
118               _LoggerHook()],
119        config=tf.ConfigProto(
120            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
121      while not mon_sess.should_stop():
122        mon_sess.run(train_op)
123        # Update the masks
124        mon_sess.run(mask_update_op)
125
126
127def main(argv=None):  # pylint: disable=unused-argument
128  cifar10.maybe_download_and_extract()
129  if tf.gfile.Exists(FLAGS.train_dir):
130    tf.gfile.DeleteRecursively(FLAGS.train_dir)
131  tf.gfile.MakeDirs(FLAGS.train_dir)
132  train()
133
134
135if __name__ == '__main__':
136  parser = argparse.ArgumentParser()
137  parser.add_argument(
138      '--train_dir',
139      type=str,
140      default='/tmp/cifar10_train',
141      help='Directory where to write event logs and checkpoint.')
142  parser.add_argument(
143      '--pruning_hparams',
144      type=str,
145      default='',
146      help="""Comma separated list of pruning-related hyperparameters""")
147  parser.add_argument(
148      '--max_steps',
149      type=int,
150      default=1000000,
151      help='Number of batches to run.')
152  parser.add_argument(
153      '--log_device_placement',
154      type=bool,
155      default=False,
156      help='Whether to log device placement.')
157
158  FLAGS, unparsed = parser.parse_known_args()
159  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
160