• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Builds the CIFAR-10 network with additional variables to support pruning.
16
17Summary of available functions:
18
19 # Compute input images and labels for training. If you would like to run
20 # evaluations, use inputs() instead.
21 inputs, labels = distorted_inputs()
22
23 # Compute inference on the model inputs to make a prediction.
24 predictions = inference(inputs)
25
26 # Compute the total loss of the prediction with respect to the labels.
27 loss = loss(predictions, labels)
28
29 # Create a graph to run one step of training with respect to the loss.
30 train_op = train(loss, global_step)
31"""
32# pylint: disable=missing-docstring
33from __future__ import absolute_import
34from __future__ import division
35from __future__ import print_function
36
37import os
38import re
39import sys
40import tarfile
41
42from six.moves import urllib
43import tensorflow as tf
44
45from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input
46from tensorflow.contrib.model_pruning.python import pruning
47
48# Global constants describing the CIFAR-10 data set.
49IMAGE_SIZE = cifar10_input.IMAGE_SIZE
50NUM_CLASSES = cifar10_input.NUM_CLASSES
51NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN  # pylint: disable=line-too-long
52NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
53BATCH_SIZE = 128
54DATA_DIR = '/tmp/cifar10_data'
55
56# Constants describing the training process.
57MOVING_AVERAGE_DECAY = 0.9999  # The decay to use for the moving average.
58NUM_EPOCHS_PER_DECAY = 350.0  # Epochs after which learning rate decays.
59LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.
60INITIAL_LEARNING_RATE = 0.1  # Initial learning rate.
61
62# If a model is trained with multiple GPUs, prefix all Op names with tower_name
63# to differentiate the operations. Note that this prefix is removed from the
64# names of the summaries when visualizing a model.
65TOWER_NAME = 'tower'
66
67DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
68
69
70def _activation_summary(x):
71  """Helper to create summaries for activations.
72
73  Creates a summary that provides a histogram of activations.
74  Creates a summary that measures the sparsity of activations.
75
76  Args:
77    x: Tensor
78  Returns:
79    nothing
80  """
81  # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
82  # session. This helps the clarity of presentation on tensorboard.
83  tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
84  tf.summary.histogram(tensor_name + '/activations', x)
85  tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
86
87
88def _variable_on_cpu(name, shape, initializer):
89  """Helper to create a Variable stored on CPU memory.
90
91  Args:
92    name: name of the variable
93    shape: list of ints
94    initializer: initializer for Variable
95
96  Returns:
97    Variable Tensor
98  """
99  with tf.device('/cpu:0'):
100    dtype = tf.float32
101    var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
102  return var
103
104
105def _variable_with_weight_decay(name, shape, stddev, wd):
106  """Helper to create an initialized Variable with weight decay.
107
108  Note that the Variable is initialized with a truncated normal distribution.
109  A weight decay is added only if one is specified.
110
111  Args:
112    name: name of the variable
113    shape: list of ints
114    stddev: standard deviation of a truncated Gaussian
115    wd: add L2Loss weight decay multiplied by this float. If None, weight
116        decay is not added for this Variable.
117
118  Returns:
119    Variable Tensor
120  """
121  dtype = tf.float32
122  var = _variable_on_cpu(name, shape,
123                         tf.truncated_normal_initializer(
124                             stddev=stddev, dtype=dtype))
125  if wd is not None:
126    weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
127    tf.add_to_collection('losses', weight_decay)
128  return var
129
130
131def distorted_inputs():
132  """Construct distorted input for CIFAR training using the Reader ops.
133
134  Returns:
135    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
136    labels: Labels. 1D tensor of [batch_size] size.
137
138  Raises:
139    ValueError: If no data_dir
140  """
141  if not DATA_DIR:
142    raise ValueError('Please supply a data_dir')
143  data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
144  images, labels = cifar10_input.distorted_inputs(
145      data_dir=data_dir, batch_size=BATCH_SIZE)
146  return images, labels
147
148
149def inputs(eval_data):
150  """Construct input for CIFAR evaluation using the Reader ops.
151
152  Args:
153    eval_data: bool, indicating if one should use the train or eval data set.
154
155  Returns:
156    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
157    labels: Labels. 1D tensor of [batch_size] size.
158
159  Raises:
160    ValueError: If no data_dir
161  """
162  if not DATA_DIR:
163    raise ValueError('Please supply a data_dir')
164  data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
165  images, labels = cifar10_input.inputs(
166      eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE)
167  return images, labels
168
169
170def inference(images):
171  """Build the CIFAR-10 model.
172
173  Args:
174    images: Images returned from distorted_inputs() or inputs().
175
176  Returns:
177    Logits.
178  """
179  # We instantiate all variables using tf.get_variable() instead of
180  # tf.Variable() in order to share variables across multiple GPU training runs.
181  # If we only ran this model on a single GPU, we could simplify this function
182  # by replacing all instances of tf.get_variable() with tf.Variable().
183  #
184  # While instantiating conv and local layers, we add mask and threshold
185  # variables to the layer by calling the pruning.apply_mask() function.
186  # Note that the masks are applied only to the weight tensors
187  # conv1
188  with tf.variable_scope('conv1') as scope:
189    kernel = _variable_with_weight_decay(
190        'weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)
191
192    conv = tf.nn.conv2d(
193        images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
194    biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
195    pre_activation = tf.nn.bias_add(conv, biases)
196    conv1 = tf.nn.relu(pre_activation, name=scope.name)
197    _activation_summary(conv1)
198
199  # pool1
200  pool1 = tf.nn.max_pool(
201      conv1,
202      ksize=[1, 3, 3, 1],
203      strides=[1, 2, 2, 1],
204      padding='SAME',
205      name='pool1')
206  # norm1
207  norm1 = tf.nn.lrn(
208      pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
209
210  # conv2
211  with tf.variable_scope('conv2') as scope:
212    kernel = _variable_with_weight_decay(
213        'weights', shape=[5, 5, 64, 64], stddev=5e-2, wd=0.0)
214    conv = tf.nn.conv2d(
215        norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
216    biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
217    pre_activation = tf.nn.bias_add(conv, biases)
218    conv2 = tf.nn.relu(pre_activation, name=scope.name)
219    _activation_summary(conv2)
220
221  # norm2
222  norm2 = tf.nn.lrn(
223      conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
224  # pool2
225  pool2 = tf.nn.max_pool(
226      norm2,
227      ksize=[1, 3, 3, 1],
228      strides=[1, 2, 2, 1],
229      padding='SAME',
230      name='pool2')
231
232  # local3
233  with tf.variable_scope('local3') as scope:
234    # Move everything into depth so we can perform a single matrix multiply.
235    reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
236    dim = reshape.get_shape()[1].value
237    weights = _variable_with_weight_decay(
238        'weights', shape=[dim, 384], stddev=0.04, wd=0.004)
239    biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
240    local3 = tf.nn.relu(
241        tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
242        name=scope.name)
243    _activation_summary(local3)
244
245  # local4
246  with tf.variable_scope('local4') as scope:
247    weights = _variable_with_weight_decay(
248        'weights', shape=[384, 192], stddev=0.04, wd=0.004)
249    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
250    local4 = tf.nn.relu(
251        tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
252        name=scope.name)
253    _activation_summary(local4)
254
255  # linear layer(WX + b),
256  # We don't apply softmax here because
257  # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
258  # and performs the softmax internally for efficiency.
259  with tf.variable_scope('softmax_linear') as scope:
260    weights = _variable_with_weight_decay(
261        'weights', [192, NUM_CLASSES], stddev=1 / 192.0, wd=0.0)
262    biases = _variable_on_cpu('biases', [NUM_CLASSES],
263                              tf.constant_initializer(0.0))
264    softmax_linear = tf.add(
265        tf.matmul(local4, pruning.apply_mask(weights, scope)),
266        biases,
267        name=scope.name)
268    _activation_summary(softmax_linear)
269
270  return softmax_linear
271
272
273def loss(logits, labels):
274  """Add L2Loss to all the trainable variables.
275
276  Add summary for "Loss" and "Loss/avg".
277  Args:
278    logits: Logits from inference().
279    labels: Labels from distorted_inputs or inputs(). 1-D tensor
280            of shape [batch_size]
281
282  Returns:
283    Loss tensor of type float.
284  """
285  # Calculate the average cross entropy loss across the batch.
286  labels = tf.cast(labels, tf.int64)
287  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
288      labels=labels, logits=logits, name='cross_entropy_per_example')
289  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
290  tf.add_to_collection('losses', cross_entropy_mean)
291
292  # The total loss is defined as the cross entropy loss plus all of the weight
293  # decay terms (L2 loss).
294  return tf.add_n(tf.get_collection('losses'), name='total_loss')
295
296
297def _add_loss_summaries(total_loss):
298  """Add summaries for losses in CIFAR-10 model.
299
300  Generates moving average for all losses and associated summaries for
301  visualizing the performance of the network.
302
303  Args:
304    total_loss: Total loss from loss().
305  Returns:
306    loss_averages_op: op for generating moving averages of losses.
307  """
308  # Compute the moving average of all individual losses and the total loss.
309  loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
310  losses = tf.get_collection('losses')
311  loss_averages_op = loss_averages.apply(losses + [total_loss])
312
313  # Attach a scalar summary to all individual losses and the total loss; do the
314  # same for the averaged version of the losses.
315  for l in losses + [total_loss]:
316    # Name each loss as '(raw)' and name the moving average version of the loss
317    # as the original loss name.
318    tf.summary.scalar(l.op.name + ' (raw)', l)
319    tf.summary.scalar(l.op.name, loss_averages.average(l))
320
321  return loss_averages_op
322
323
324def train(total_loss, global_step):
325  """Train CIFAR-10 model.
326
327  Create an optimizer and apply to all trainable variables. Add moving
328  average for all trainable variables.
329
330  Args:
331    total_loss: Total loss from loss().
332    global_step: Integer Variable counting the number of training steps
333      processed.
334  Returns:
335    train_op: op for training.
336  """
337  # Variables that affect learning rate.
338  num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
339  decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
340
341  # Decay the learning rate exponentially based on the number of steps.
342  lr = tf.train.exponential_decay(
343      INITIAL_LEARNING_RATE,
344      global_step,
345      decay_steps,
346      LEARNING_RATE_DECAY_FACTOR,
347      staircase=True)
348  tf.summary.scalar('learning_rate', lr)
349
350  # Generate moving averages of all losses and associated summaries.
351  loss_averages_op = _add_loss_summaries(total_loss)
352
353  # Compute gradients.
354  with tf.control_dependencies([loss_averages_op]):
355    opt = tf.train.GradientDescentOptimizer(lr)
356    grads = opt.compute_gradients(total_loss)
357
358  # Apply gradients.
359  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
360
361  # Add histograms for trainable variables.
362  for var in tf.trainable_variables():
363    tf.summary.histogram(var.op.name, var)
364
365  # Add histograms for gradients.
366  for grad, var in grads:
367    if grad is not None:
368      tf.summary.histogram(var.op.name + '/gradients', grad)
369
370  # Track the moving averages of all trainable variables.
371  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,
372                                                        global_step)
373  variables_averages_op = variable_averages.apply(tf.trainable_variables())
374
375  with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
376    train_op = tf.no_op(name='train')
377
378  return train_op
379
380
381def maybe_download_and_extract():
382  """Download and extract the tarball from Alex's website."""
383  dest_directory = DATA_DIR
384  if not os.path.exists(dest_directory):
385    os.makedirs(dest_directory)
386  filename = DATA_URL.split('/')[-1]
387  filepath = os.path.join(dest_directory, filename)
388  if not os.path.exists(filepath):
389
390    def _progress(count, block_size, total_size):
391      sys.stdout.write('\r>> Downloading %s %.1f%%' %
392                       (filename,
393                        float(count * block_size) / float(total_size) * 100.0))
394      sys.stdout.flush()
395
396    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
397    print()
398    statinfo = os.stat(filepath)
399    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
400
401  tarfile.open(filepath, 'r:gz').extractall(dest_directory)
402