README.md
1<!-- TODO(joelshor): Add images to the examples. -->
2<!-- TODO(joelshor): Add link to new location when b/122114187 is done. -->
3# TensorFlow-GAN (TF-GAN)
4
5TF-GAN is a lightweight library for training and evaluating Generative
6Adversarial Networks (GANs). This technique allows you to train a network
7(called the 'generator') to sample from a distribution, without having to
8explicitly model the distribution and without writing an explicit loss. For
9example, the generator could learn to draw samples from the distribution of
10natural images. For more details on this technique, see
11['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by
12Goodfellow et al. See
13[tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/)
14for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction.
15
16#### Usage
17```python
18import tensorflow as tf
19tfgan = tf.contrib.gan
20```
21
22## Why TF-GAN?
23
24* Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can
25mix TF-GAN, native TF, and other custom frameworks
26* Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc)
27* [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them
28* Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training
29* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/)
30* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model
31* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project
32* Stay up-to-date with research as we add more algorithms
33
34## What are the TF-GAN components?
35
36TF-GAN is composed of several parts which were design to exist independently.
37These include the following main pieces (explained in detail below).
38
39* [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py):
40 provides the main infrastructure needed to train a GAN. Training occurs in
41 four phases, and each phase can be completed by custom-code or by using a
42 TF-GAN library call.
43
44* [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/):
45 Many common GAN operations and normalization techniques are implemented for
46 you to use, such as instance normalization and conditioning.
47
48* [losses](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/):
49 Easily experiment with already-implemented and well-tested losses and
50 penalties, such as the Wasserstein loss, gradient penalty, mutual
51 information penalty, etc
52
53* [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/):
54 Use `Inception Score`, `Frechet Distance`, or `Kernel Distance` with a
55 pretrained Inception network to evaluate your unconditional generative
56 model. You can also use your own pretrained classifier for more specific
57 performance numbers, or use other methods for evaluating conditional
58 generative models.
59
60* [examples](https://github.com/tensorflow/models/tree/master/research/gan/)
61 and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make
62 GAN training easier, or use the more complicated examples to jump-start your
63 own project. These include unconditional and conditional GANs, InfoGANs,
64 adversarial losses on existing networks, and image-to-image translation.
65
66## Training a GAN model
67
68Training in TF-GAN typically consists of the following steps:
69
701. Specify the input to your networks.
711. Set up your generator and discriminator using a `GANModel`.
721. Specify your loss using a `GANLoss`.
731. Create your train ops using a `GANTrainOps`.
741. Run your train ops.
75
76At each stage, you can either use TF-GAN's convenience functions, or you can
77perform the step manually for fine-grained control. We provide examples below.
78
79There are various types of GAN setups. For instance, you can train a generator
80to sample unconditionally from a learned distribution, or you can condition on
81extra information such as a class label. TF-GAN is compatible with many setups,
82and we demonstrate a few below:
83
84### Examples
85
86#### Unconditional MNIST generation
87
88This example trains a generator to produce handwritten MNIST digits. The generator maps
89random draws from a multivariate normal distribution to MNIST digit images. See
90['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by
91Goodfellow et al.
92
93```python
94# Set up the input.
95images = mnist_data_provider.provide_data(FLAGS.batch_size)
96noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
97
98# Build the generator and discriminator.
99gan_model = tfgan.gan_model(
100 generator_fn=mnist.unconditional_generator, # you define
101 discriminator_fn=mnist.unconditional_discriminator, # you define
102 real_data=images,
103 generator_inputs=noise)
104
105# Build the GAN loss.
106gan_loss = tfgan.gan_loss(
107 gan_model,
108 generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
109 discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)
110
111# Create the train ops, which calculate gradients and apply updates to weights.
112train_ops = tfgan.gan_train_ops(
113 gan_model,
114 gan_loss,
115 generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
116 discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))
117
118# Run the train ops in the alternating training scheme.
119tfgan.gan_train(
120 train_ops,
121 hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
122 logdir=FLAGS.train_log_dir)
123```
124
125#### Conditional MNIST generation
126This example trains a generator to generate MNIST images *of a given class*.
127The generator maps random draws from a multivariate normal distribution and a
128one-hot label of the desired digit class to an MNIST digit image. See
129['Conditional Generative Adversarial Nets'](https://arxiv.org/abs/1411.1784) by
130Mirza and Osindero.
131
132```python
133# Set up the input.
134images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
135noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
136
137# Build the generator and discriminator.
138gan_model = tfgan.gan_model(
139 generator_fn=mnist.conditional_generator, # you define
140 discriminator_fn=mnist.conditional_discriminator, # you define
141 real_data=images,
142 generator_inputs=(noise, one_hot_labels))
143
144# The rest is the same as in the unconditional case.
145...
146```
147#### Adversarial loss
148This example combines an L1 pixel loss and an adversarial loss to learn to
149autoencode images. The bottleneck layer can be used to transmit compressed
150representations of the image. Neutral networks with pixel-wise loss only tend to
151produce blurry results, so the GAN can be used to make the reconstructions more
152plausible. See ['Full Resolution Image Compression with Recurrent Neural Networks'](https://arxiv.org/abs/1608.05148) by Toderici et al
153for an example of neural networks used for image compression, and ['Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network'](https://arxiv.org/abs/1609.04802) by Ledig et al for a more detailed description of
154how GANs can sharpen image output.
155
156```python
157# Set up the input pipeline.
158images = image_provider.provide_data(FLAGS.batch_size)
159
160# Build the generator and discriminator.
161gan_model = tfgan.gan_model(
162 generator_fn=nets.autoencoder, # you define
163 discriminator_fn=nets.discriminator, # you define
164 real_data=images,
165 generator_inputs=images)
166
167# Build the GAN loss and standard pixel loss.
168gan_loss = tfgan.gan_loss(
169 gan_model,
170 generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
171 discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
172 gradient_penalty=1.0)
173l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)
174
175# Modify the loss tuple to include the pixel loss.
176gan_loss = tfgan.losses.combine_adversarial_loss(
177 gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
178
179# The rest is the same as in the unconditional case.
180...
181```
182
183#### Image-to-image translation
184This example maps images in one domain to images of the same size in a different
185dimension. For example, it can map segmentation masks to street images, or
186grayscale images to color. See ['Image-to-Image Translation with Conditional Adversarial Networks'](https://arxiv.org/abs/1611.07004) by Isola et al for more details.
187
188```python
189# Set up the input pipeline.
190input_image, target_image = data_provider.provide_data(FLAGS.batch_size)
191
192# Build the generator and discriminator.
193gan_model = tfgan.gan_model(
194 generator_fn=nets.generator, # you define
195 discriminator_fn=nets.discriminator, # you define
196 real_data=target_image,
197 generator_inputs=input_image)
198
199# Build the GAN loss and standard pixel loss.
200gan_loss = tfgan.gan_loss(
201 gan_model,
202 generator_loss_fn=tfgan.losses.least_squares_generator_loss,
203 discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
204l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)
205
206# Modify the loss tuple to include the pixel loss.
207gan_loss = tfgan.losses.combine_adversarial_loss(
208 gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
209
210# The rest is the same as in the unconditional case.
211...
212```
213
214#### InfoGAN
215Train a generator to generate specific MNIST digit images, and control for digit style *without using any labels*. See ['InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets'](https://arxiv.org/abs/1606.03657) for more details.
216
217```python
218# Set up the input pipeline.
219images = mnist_data_provider.provide_data(FLAGS.batch_size)
220
221# Build the generator and discriminator.
222gan_model = tfgan.infogan_model(
223 generator_fn=mnist.infogan_generator, # you define
224 discriminator_fn=mnist.infogran_discriminator, # you define
225 real_data=images,
226 unstructured_generator_inputs=unstructured_inputs, # you define
227 structured_generator_inputs=structured_inputs) # you define
228
229# Build the GAN loss with mutual information penalty.
230gan_loss = tfgan.gan_loss(
231 gan_model,
232 generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
233 discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
234 gradient_penalty=1.0,
235 mutual_information_penalty_weight=1.0)
236
237# The rest is the same as in the unconditional case.
238...
239```
240
241#### Custom model creation
242Train an unconditional GAN to generate MNIST digits, but manually construct
243the `GANModel` tuple for more fine-grained control.
244
245```python
246# Set up the input pipeline.
247images = mnist_data_provider.provide_data(FLAGS.batch_size)
248noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
249
250# Manually build the generator and discriminator.
251with tf.variable_scope('Generator') as gen_scope:
252 generated_images = generator_fn(noise)
253with tf.variable_scope('Discriminator') as dis_scope:
254 discriminator_gen_outputs = discriminator_fn(generated_images)
255with variable_scope.variable_scope(dis_scope, reuse=True):
256 discriminator_real_outputs = discriminator_fn(images)
257generator_variables = variables_lib.get_trainable_variables(gen_scope)
258discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
259# Depending on what TF-GAN features you use, you don't always need to supply
260# every `GANModel` field. At a minimum, you need to include the discriminator
261# outputs and variables if you want to use TF-GAN to construct losses.
262gan_model = tfgan.GANModel(
263 generator_inputs,
264 generated_data,
265 generator_variables,
266 gen_scope,
267 generator_fn,
268 real_data,
269 discriminator_real_outputs,
270 discriminator_gen_outputs,
271 discriminator_variables,
272 dis_scope,
273 discriminator_fn)
274
275# The rest is the same as the unconditional case.
276...
277```
278
279
280## Authors
281Joel Shor (github: [joel-shor](https://github.com/joel-shor)) and Sergio Guadarrama (github: [sguada](https://github.com/sguada))
282