• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A deep MNIST classifier using convolutional layers.
16
17Sample usage:
18  python mnist.py --help
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import argparse
26import os
27import sys
28import time
29
30import tensorflow as tf
31
32from tensorflow.examples.tutorials.mnist import input_data
33
34layers = tf.keras.layers
35FLAGS = None
36
37
38class Discriminator(tf.keras.Model):
39  """GAN Discriminator.
40
41  A network to differentiate between generated and real handwritten digits.
42  """
43
44  def __init__(self, data_format):
45    """Creates a model for discriminating between real and generated digits.
46
47    Args:
48      data_format: Either 'channels_first' or 'channels_last'.
49        'channels_first' is typically faster on GPUs while 'channels_last' is
50        typically faster on CPUs. See
51        https://www.tensorflow.org/performance/performance_guide#data_formats
52    """
53    super(Discriminator, self).__init__(name='')
54    if data_format == 'channels_first':
55      self._input_shape = [-1, 1, 28, 28]
56    else:
57      assert data_format == 'channels_last'
58      self._input_shape = [-1, 28, 28, 1]
59    self.conv1 = layers.Conv2D(
60        64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
61    self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
62    self.conv2 = layers.Conv2D(
63        128, 5, data_format=data_format, activation=tf.tanh)
64    self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
65    self.flatten = layers.Flatten()
66    self.fc1 = layers.Dense(1024, activation=tf.tanh)
67    self.fc2 = layers.Dense(1, activation=None)
68
69  def call(self, inputs):
70    """Return two logits per image estimating input authenticity.
71
72    Users should invoke __call__ to run the network, which delegates to this
73    method (and not call this method directly).
74
75    Args:
76      inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
77        or [batch_size, 1, 28, 28]
78
79    Returns:
80      A Tensor with shape [batch_size] containing logits estimating
81      the probability that corresponding digit is real.
82    """
83    x = tf.reshape(inputs, self._input_shape)
84    x = self.conv1(x)
85    x = self.pool1(x)
86    x = self.conv2(x)
87    x = self.pool2(x)
88    x = self.flatten(x)
89    x = self.fc1(x)
90    x = self.fc2(x)
91    return x
92
93
94class Generator(tf.keras.Model):
95  """Generator of handwritten digits similar to the ones in the MNIST dataset.
96  """
97
98  def __init__(self, data_format):
99    """Creates a model for discriminating between real and generated digits.
100
101    Args:
102      data_format: Either 'channels_first' or 'channels_last'.
103        'channels_first' is typically faster on GPUs while 'channels_last' is
104        typically faster on CPUs. See
105        https://www.tensorflow.org/performance/performance_guide#data_formats
106    """
107    super(Generator, self).__init__(name='')
108    self.data_format = data_format
109    # We are using 128 6x6 channels as input to the first deconvolution layer
110    if data_format == 'channels_first':
111      self._pre_conv_shape = [-1, 128, 6, 6]
112    else:
113      assert data_format == 'channels_last'
114      self._pre_conv_shape = [-1, 6, 6, 128]
115    self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
116
117    # In call(), we reshape the output of fc1 to _pre_conv_shape
118
119    # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
120    self.conv1 = layers.Conv2DTranspose(
121        64, 4, strides=2, activation=None, data_format=data_format)
122
123    # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
124    self.conv2 = layers.Conv2DTranspose(
125        1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
126
127  def call(self, inputs):
128    """Return a batch of generated images.
129
130    Users should invoke __call__ to run the network, which delegates to this
131    method (and not call this method directly).
132
133    Args:
134      inputs: A batch of noise vectors as a Tensor with shape
135        [batch_size, length of noise vectors].
136
137    Returns:
138      A Tensor containing generated images. If data_format is 'channels_last',
139      the shape of returned images is [batch_size, 28, 28, 1], else
140      [batch_size, 1, 28, 28]
141    """
142
143    x = self.fc1(inputs)
144    x = tf.reshape(x, shape=self._pre_conv_shape)
145    x = self.conv1(x)
146    x = self.conv2(x)
147    return x
148
149
150def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
151  """Original discriminator loss for GANs, with label smoothing.
152
153  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
154  details.
155
156  Args:
157    discriminator_real_outputs: Discriminator output on real data.
158    discriminator_gen_outputs: Discriminator output on generated data. Expected
159      to be in the range of (-inf, inf).
160
161  Returns:
162    A scalar loss Tensor.
163  """
164
165  loss_on_real = tf.losses.sigmoid_cross_entropy(
166      tf.ones_like(discriminator_real_outputs),
167      discriminator_real_outputs,
168      label_smoothing=0.25)
169  loss_on_generated = tf.losses.sigmoid_cross_entropy(
170      tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
171  loss = loss_on_real + loss_on_generated
172  tf.contrib.summary.scalar('discriminator_loss', loss)
173  return loss
174
175
176def generator_loss(discriminator_gen_outputs):
177  """Original generator loss for GANs.
178
179  L = -log(sigmoid(D(G(z))))
180
181  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
182  for more details.
183
184  Args:
185    discriminator_gen_outputs: Discriminator output on generated data. Expected
186      to be in the range of (-inf, inf).
187
188  Returns:
189    A scalar loss Tensor.
190  """
191  loss = tf.losses.sigmoid_cross_entropy(
192      tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
193  tf.contrib.summary.scalar('generator_loss', loss)
194  return loss
195
196
197def train_one_epoch(generator, discriminator, generator_optimizer,
198                    discriminator_optimizer, dataset, step_counter,
199                    log_interval, noise_dim):
200  """Trains `generator` and `discriminator` models on `dataset`.
201
202  Args:
203    generator: Generator model.
204    discriminator: Discriminator model.
205    generator_optimizer: Optimizer to use for generator.
206    discriminator_optimizer: Optimizer to use for discriminator.
207    dataset: Dataset of images to train on.
208    step_counter: An integer variable, used to write summaries regularly.
209    log_interval: How many steps to wait between logging and collecting
210      summaries.
211    noise_dim: Dimension of noise vector to use.
212  """
213
214  total_generator_loss = 0.0
215  total_discriminator_loss = 0.0
216  for (batch_index, images) in enumerate(dataset):
217    with tf.device('/cpu:0'):
218      tf.assign_add(step_counter, 1)
219
220    with tf.contrib.summary.record_summaries_every_n_global_steps(
221        log_interval, global_step=step_counter):
222      current_batch_size = images.shape[0]
223      noise = tf.random_uniform(
224          shape=[current_batch_size, noise_dim],
225          minval=-1.,
226          maxval=1.,
227          seed=batch_index)
228
229      # we can use 2 tapes or a single persistent tape.
230      # Using two tapes is memory efficient since intermediate tensors can be
231      # released between the two .gradient() calls below
232      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
233        generated_images = generator(noise)
234        tf.contrib.summary.image(
235            'generated_images',
236            tf.reshape(generated_images, [-1, 28, 28, 1]),
237            max_images=10)
238
239        discriminator_gen_outputs = discriminator(generated_images)
240        discriminator_real_outputs = discriminator(images)
241        discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
242                                                    discriminator_gen_outputs)
243        total_discriminator_loss += discriminator_loss_val
244
245        generator_loss_val = generator_loss(discriminator_gen_outputs)
246        total_generator_loss += generator_loss_val
247
248      generator_grad = gen_tape.gradient(generator_loss_val,
249                                         generator.variables)
250      discriminator_grad = disc_tape.gradient(discriminator_loss_val,
251                                              discriminator.variables)
252
253      generator_optimizer.apply_gradients(
254          zip(generator_grad, generator.variables))
255      discriminator_optimizer.apply_gradients(
256          zip(discriminator_grad, discriminator.variables))
257
258      if log_interval and batch_index > 0 and batch_index % log_interval == 0:
259        print('Batch #%d\tAverage Generator Loss: %.6f\t'
260              'Average Discriminator Loss: %.6f' %
261              (batch_index, total_generator_loss / batch_index,
262               total_discriminator_loss / batch_index))
263
264
265def main(_):
266  (device, data_format) = ('/gpu:0', 'channels_first')
267  if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
268    (device, data_format) = ('/cpu:0', 'channels_last')
269  print('Using device %s, and data format %s.' % (device, data_format))
270
271  # Load the datasets
272  data = input_data.read_data_sets(FLAGS.data_dir)
273  dataset = (
274      tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000)
275      .batch(FLAGS.batch_size))
276
277  # Create the models and optimizers.
278  model_objects = {
279      'generator': Generator(data_format),
280      'discriminator': Discriminator(data_format),
281      'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
282      'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
283      'step_counter': tf.train.get_or_create_global_step(),
284  }
285
286  # Prepare summary writer and checkpoint info
287  summary_writer = tf.contrib.summary.create_summary_file_writer(
288      FLAGS.output_dir, flush_millis=1000)
289  checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
290  latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
291  if latest_cpkt:
292    print('Using latest checkpoint at ' + latest_cpkt)
293  checkpoint = tf.train.Checkpoint(**model_objects)
294  # Restore variables on creation if a checkpoint exists.
295  checkpoint.restore(latest_cpkt)
296
297  with tf.device(device):
298    for _ in range(100):
299      start = time.time()
300      with summary_writer.as_default():
301        train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
302                        noise_dim=FLAGS.noise, **model_objects)
303      end = time.time()
304      checkpoint.save(checkpoint_prefix)
305      print('\nTrain time for epoch #%d (step %d): %f' %
306            (checkpoint.save_counter.numpy(),
307             checkpoint.step_counter.numpy(),
308             end - start))
309
310
311if __name__ == '__main__':
312  tf.enable_eager_execution()
313
314  parser = argparse.ArgumentParser()
315  parser.add_argument(
316      '--data-dir',
317      type=str,
318      default='/tmp/tensorflow/mnist/input_data',
319      help=('Directory for storing input data (default '
320            '/tmp/tensorflow/mnist/input_data)'))
321  parser.add_argument(
322      '--batch-size',
323      type=int,
324      default=128,
325      metavar='N',
326      help='input batch size for training (default: 128)')
327  parser.add_argument(
328      '--log-interval',
329      type=int,
330      default=100,
331      metavar='N',
332      help=('number of batches between logging and writing summaries '
333            '(default: 100)'))
334  parser.add_argument(
335      '--output_dir',
336      type=str,
337      default=None,
338      metavar='DIR',
339      help='Directory to write TensorBoard summaries (defaults to none)')
340  parser.add_argument(
341      '--checkpoint_dir',
342      type=str,
343      default='/tmp/tensorflow/mnist/checkpoints/',
344      metavar='DIR',
345      help=('Directory to save checkpoints in (once per epoch) (default '
346            '/tmp/tensorflow/mnist/checkpoints/)'))
347  parser.add_argument(
348      '--lr',
349      type=float,
350      default=0.001,
351      metavar='LR',
352      help='learning rate (default: 0.001)')
353  parser.add_argument(
354      '--noise',
355      type=int,
356      default=100,
357      metavar='N',
358      help='Length of noise vector for generator input (default: 100)')
359  parser.add_argument(
360      '--no-gpu',
361      action='store_true',
362      default=False,
363      help='disables GPU usage even if a GPU is available')
364
365  FLAGS, unparsed = parser.parse_known_args()
366  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
367