1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A deep MNIST classifier using convolutional layers. 16 17Sample usage: 18 python mnist.py --help 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import argparse 26import os 27import sys 28import time 29 30import tensorflow as tf 31 32from tensorflow.examples.tutorials.mnist import input_data 33 34layers = tf.keras.layers 35FLAGS = None 36 37 38class Discriminator(tf.keras.Model): 39 """GAN Discriminator. 40 41 A network to differentiate between generated and real handwritten digits. 42 """ 43 44 def __init__(self, data_format): 45 """Creates a model for discriminating between real and generated digits. 46 47 Args: 48 data_format: Either 'channels_first' or 'channels_last'. 49 'channels_first' is typically faster on GPUs while 'channels_last' is 50 typically faster on CPUs. See 51 https://www.tensorflow.org/performance/performance_guide#data_formats 52 """ 53 super(Discriminator, self).__init__(name='') 54 if data_format == 'channels_first': 55 self._input_shape = [-1, 1, 28, 28] 56 else: 57 assert data_format == 'channels_last' 58 self._input_shape = [-1, 28, 28, 1] 59 self.conv1 = layers.Conv2D( 60 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) 61 self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) 62 self.conv2 = layers.Conv2D( 63 128, 5, data_format=data_format, activation=tf.tanh) 64 self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) 65 self.flatten = layers.Flatten() 66 self.fc1 = layers.Dense(1024, activation=tf.tanh) 67 self.fc2 = layers.Dense(1, activation=None) 68 69 def call(self, inputs): 70 """Return two logits per image estimating input authenticity. 71 72 Users should invoke __call__ to run the network, which delegates to this 73 method (and not call this method directly). 74 75 Args: 76 inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1] 77 or [batch_size, 1, 28, 28] 78 79 Returns: 80 A Tensor with shape [batch_size] containing logits estimating 81 the probability that corresponding digit is real. 82 """ 83 x = tf.reshape(inputs, self._input_shape) 84 x = self.conv1(x) 85 x = self.pool1(x) 86 x = self.conv2(x) 87 x = self.pool2(x) 88 x = self.flatten(x) 89 x = self.fc1(x) 90 x = self.fc2(x) 91 return x 92 93 94class Generator(tf.keras.Model): 95 """Generator of handwritten digits similar to the ones in the MNIST dataset. 96 """ 97 98 def __init__(self, data_format): 99 """Creates a model for discriminating between real and generated digits. 100 101 Args: 102 data_format: Either 'channels_first' or 'channels_last'. 103 'channels_first' is typically faster on GPUs while 'channels_last' is 104 typically faster on CPUs. See 105 https://www.tensorflow.org/performance/performance_guide#data_formats 106 """ 107 super(Generator, self).__init__(name='') 108 self.data_format = data_format 109 # We are using 128 6x6 channels as input to the first deconvolution layer 110 if data_format == 'channels_first': 111 self._pre_conv_shape = [-1, 128, 6, 6] 112 else: 113 assert data_format == 'channels_last' 114 self._pre_conv_shape = [-1, 6, 6, 128] 115 self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) 116 117 # In call(), we reshape the output of fc1 to _pre_conv_shape 118 119 # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) 120 self.conv1 = layers.Conv2DTranspose( 121 64, 4, strides=2, activation=None, data_format=data_format) 122 123 # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) 124 self.conv2 = layers.Conv2DTranspose( 125 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) 126 127 def call(self, inputs): 128 """Return a batch of generated images. 129 130 Users should invoke __call__ to run the network, which delegates to this 131 method (and not call this method directly). 132 133 Args: 134 inputs: A batch of noise vectors as a Tensor with shape 135 [batch_size, length of noise vectors]. 136 137 Returns: 138 A Tensor containing generated images. If data_format is 'channels_last', 139 the shape of returned images is [batch_size, 28, 28, 1], else 140 [batch_size, 1, 28, 28] 141 """ 142 143 x = self.fc1(inputs) 144 x = tf.reshape(x, shape=self._pre_conv_shape) 145 x = self.conv1(x) 146 x = self.conv2(x) 147 return x 148 149 150def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): 151 """Original discriminator loss for GANs, with label smoothing. 152 153 See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more 154 details. 155 156 Args: 157 discriminator_real_outputs: Discriminator output on real data. 158 discriminator_gen_outputs: Discriminator output on generated data. Expected 159 to be in the range of (-inf, inf). 160 161 Returns: 162 A scalar loss Tensor. 163 """ 164 165 loss_on_real = tf.losses.sigmoid_cross_entropy( 166 tf.ones_like(discriminator_real_outputs), 167 discriminator_real_outputs, 168 label_smoothing=0.25) 169 loss_on_generated = tf.losses.sigmoid_cross_entropy( 170 tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) 171 loss = loss_on_real + loss_on_generated 172 tf.contrib.summary.scalar('discriminator_loss', loss) 173 return loss 174 175 176def generator_loss(discriminator_gen_outputs): 177 """Original generator loss for GANs. 178 179 L = -log(sigmoid(D(G(z)))) 180 181 See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) 182 for more details. 183 184 Args: 185 discriminator_gen_outputs: Discriminator output on generated data. Expected 186 to be in the range of (-inf, inf). 187 188 Returns: 189 A scalar loss Tensor. 190 """ 191 loss = tf.losses.sigmoid_cross_entropy( 192 tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) 193 tf.contrib.summary.scalar('generator_loss', loss) 194 return loss 195 196 197def train_one_epoch(generator, discriminator, generator_optimizer, 198 discriminator_optimizer, dataset, step_counter, 199 log_interval, noise_dim): 200 """Trains `generator` and `discriminator` models on `dataset`. 201 202 Args: 203 generator: Generator model. 204 discriminator: Discriminator model. 205 generator_optimizer: Optimizer to use for generator. 206 discriminator_optimizer: Optimizer to use for discriminator. 207 dataset: Dataset of images to train on. 208 step_counter: An integer variable, used to write summaries regularly. 209 log_interval: How many steps to wait between logging and collecting 210 summaries. 211 noise_dim: Dimension of noise vector to use. 212 """ 213 214 total_generator_loss = 0.0 215 total_discriminator_loss = 0.0 216 for (batch_index, images) in enumerate(dataset): 217 with tf.device('/cpu:0'): 218 tf.assign_add(step_counter, 1) 219 220 with tf.contrib.summary.record_summaries_every_n_global_steps( 221 log_interval, global_step=step_counter): 222 current_batch_size = images.shape[0] 223 noise = tf.random_uniform( 224 shape=[current_batch_size, noise_dim], 225 minval=-1., 226 maxval=1., 227 seed=batch_index) 228 229 # we can use 2 tapes or a single persistent tape. 230 # Using two tapes is memory efficient since intermediate tensors can be 231 # released between the two .gradient() calls below 232 with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 233 generated_images = generator(noise) 234 tf.contrib.summary.image( 235 'generated_images', 236 tf.reshape(generated_images, [-1, 28, 28, 1]), 237 max_images=10) 238 239 discriminator_gen_outputs = discriminator(generated_images) 240 discriminator_real_outputs = discriminator(images) 241 discriminator_loss_val = discriminator_loss(discriminator_real_outputs, 242 discriminator_gen_outputs) 243 total_discriminator_loss += discriminator_loss_val 244 245 generator_loss_val = generator_loss(discriminator_gen_outputs) 246 total_generator_loss += generator_loss_val 247 248 generator_grad = gen_tape.gradient(generator_loss_val, 249 generator.variables) 250 discriminator_grad = disc_tape.gradient(discriminator_loss_val, 251 discriminator.variables) 252 253 generator_optimizer.apply_gradients( 254 zip(generator_grad, generator.variables)) 255 discriminator_optimizer.apply_gradients( 256 zip(discriminator_grad, discriminator.variables)) 257 258 if log_interval and batch_index > 0 and batch_index % log_interval == 0: 259 print('Batch #%d\tAverage Generator Loss: %.6f\t' 260 'Average Discriminator Loss: %.6f' % 261 (batch_index, total_generator_loss / batch_index, 262 total_discriminator_loss / batch_index)) 263 264 265def main(_): 266 (device, data_format) = ('/gpu:0', 'channels_first') 267 if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0: 268 (device, data_format) = ('/cpu:0', 'channels_last') 269 print('Using device %s, and data format %s.' % (device, data_format)) 270 271 # Load the datasets 272 data = input_data.read_data_sets(FLAGS.data_dir) 273 dataset = ( 274 tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000) 275 .batch(FLAGS.batch_size)) 276 277 # Create the models and optimizers. 278 model_objects = { 279 'generator': Generator(data_format), 280 'discriminator': Discriminator(data_format), 281 'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), 282 'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), 283 'step_counter': tf.train.get_or_create_global_step(), 284 } 285 286 # Prepare summary writer and checkpoint info 287 summary_writer = tf.contrib.summary.create_summary_file_writer( 288 FLAGS.output_dir, flush_millis=1000) 289 checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') 290 latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 291 if latest_cpkt: 292 print('Using latest checkpoint at ' + latest_cpkt) 293 checkpoint = tf.train.Checkpoint(**model_objects) 294 # Restore variables on creation if a checkpoint exists. 295 checkpoint.restore(latest_cpkt) 296 297 with tf.device(device): 298 for _ in range(100): 299 start = time.time() 300 with summary_writer.as_default(): 301 train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval, 302 noise_dim=FLAGS.noise, **model_objects) 303 end = time.time() 304 checkpoint.save(checkpoint_prefix) 305 print('\nTrain time for epoch #%d (step %d): %f' % 306 (checkpoint.save_counter.numpy(), 307 checkpoint.step_counter.numpy(), 308 end - start)) 309 310 311if __name__ == '__main__': 312 tf.enable_eager_execution() 313 314 parser = argparse.ArgumentParser() 315 parser.add_argument( 316 '--data-dir', 317 type=str, 318 default='/tmp/tensorflow/mnist/input_data', 319 help=('Directory for storing input data (default ' 320 '/tmp/tensorflow/mnist/input_data)')) 321 parser.add_argument( 322 '--batch-size', 323 type=int, 324 default=128, 325 metavar='N', 326 help='input batch size for training (default: 128)') 327 parser.add_argument( 328 '--log-interval', 329 type=int, 330 default=100, 331 metavar='N', 332 help=('number of batches between logging and writing summaries ' 333 '(default: 100)')) 334 parser.add_argument( 335 '--output_dir', 336 type=str, 337 default=None, 338 metavar='DIR', 339 help='Directory to write TensorBoard summaries (defaults to none)') 340 parser.add_argument( 341 '--checkpoint_dir', 342 type=str, 343 default='/tmp/tensorflow/mnist/checkpoints/', 344 metavar='DIR', 345 help=('Directory to save checkpoints in (once per epoch) (default ' 346 '/tmp/tensorflow/mnist/checkpoints/)')) 347 parser.add_argument( 348 '--lr', 349 type=float, 350 default=0.001, 351 metavar='LR', 352 help='learning rate (default: 0.001)') 353 parser.add_argument( 354 '--noise', 355 type=int, 356 default=100, 357 metavar='N', 358 help='Length of noise vector for generator input (default: 100)') 359 parser.add_argument( 360 '--no-gpu', 361 action='store_true', 362 default=False, 363 help='disables GPU usage even if a GPU is available') 364 365 FLAGS, unparsed = parser.parse_known_args() 366 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 367