• Home
Name Date Size #Lines LOC

..--

python/03-May-2024-13,2499,805

BUILDD03-May-202421.5 KiB760716

README.mdD03-May-202412.3 KiB282229

__init__.pyD03-May-20241.9 KiB5127

README.md

1<!-- TODO(joelshor): Add images to the examples. -->
2<!-- TODO(joelshor): Add link to new location when b/122114187 is done. -->
3# TensorFlow-GAN (TF-GAN)
4
5TF-GAN is a lightweight library for training and evaluating Generative
6Adversarial Networks (GANs). This technique allows you to train a network
7(called the 'generator') to sample from a distribution, without having to
8explicitly model the distribution and without writing an explicit loss. For
9example, the generator could learn to draw samples from the distribution of
10natural images. For more details on this technique, see
11['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by
12Goodfellow et al. See
13[tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/)
14for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction.
15
16#### Usage
17```python
18import tensorflow as tf
19tfgan = tf.contrib.gan
20```
21
22## Why TF-GAN?
23
24* Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can
25mix TF-GAN, native TF, and other custom frameworks
26* Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc)
27* [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them
28* Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training
29* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/)
30* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model
31* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project
32* Stay up-to-date with research as we add more algorithms
33
34## What are the TF-GAN components?
35
36TF-GAN is composed of several parts which were design to exist independently.
37These include the following main pieces (explained in detail below).
38
39*   [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py):
40    provides the main infrastructure needed to train a GAN. Training occurs in
41    four phases, and each phase can be completed by custom-code or by using a
42    TF-GAN library call.
43
44*   [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/):
45    Many common GAN operations and normalization techniques are implemented for
46    you to use, such as instance normalization and conditioning.
47
48*   [losses](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/):
49    Easily experiment with already-implemented and well-tested losses and
50    penalties, such as the Wasserstein loss, gradient penalty, mutual
51    information penalty, etc
52
53*   [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/):
54    Use `Inception Score`, `Frechet Distance`, or `Kernel Distance` with a
55    pretrained Inception network to evaluate your unconditional generative
56    model. You can also use your own pretrained classifier for more specific
57    performance numbers, or use other methods for evaluating conditional
58    generative models.
59
60*   [examples](https://github.com/tensorflow/models/tree/master/research/gan/)
61    and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make
62    GAN training easier, or use the more complicated examples to jump-start your
63    own project. These include unconditional and conditional GANs, InfoGANs,
64    adversarial losses on existing networks, and image-to-image translation.
65
66## Training a GAN model
67
68Training in TF-GAN typically consists of the following steps:
69
701. Specify the input to your networks.
711. Set up your generator and discriminator using a `GANModel`.
721. Specify your loss using a `GANLoss`.
731. Create your train ops using a `GANTrainOps`.
741. Run your train ops.
75
76At each stage, you can either use TF-GAN's convenience functions, or you can
77perform the step manually for fine-grained control. We provide examples below.
78
79There are various types of GAN setups. For instance, you can train a generator
80to sample unconditionally from a learned distribution, or you can condition on
81extra information such as a class label. TF-GAN is compatible with many setups,
82and we demonstrate a few below:
83
84### Examples
85
86#### Unconditional MNIST generation
87
88This example trains a generator to produce handwritten MNIST digits. The generator maps
89random draws from a multivariate normal distribution to MNIST digit images. See
90['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by
91Goodfellow et al.
92
93```python
94# Set up the input.
95images = mnist_data_provider.provide_data(FLAGS.batch_size)
96noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
97
98# Build the generator and discriminator.
99gan_model = tfgan.gan_model(
100    generator_fn=mnist.unconditional_generator,  # you define
101    discriminator_fn=mnist.unconditional_discriminator,  # you define
102    real_data=images,
103    generator_inputs=noise)
104
105# Build the GAN loss.
106gan_loss = tfgan.gan_loss(
107    gan_model,
108    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
109    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)
110
111# Create the train ops, which calculate gradients and apply updates to weights.
112train_ops = tfgan.gan_train_ops(
113    gan_model,
114    gan_loss,
115    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
116    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))
117
118# Run the train ops in the alternating training scheme.
119tfgan.gan_train(
120    train_ops,
121    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
122    logdir=FLAGS.train_log_dir)
123```
124
125#### Conditional MNIST generation
126This example trains a generator to generate MNIST images *of a given class*.
127The generator maps random draws from a multivariate normal distribution and a
128one-hot label of the desired digit class to an MNIST digit image. See
129['Conditional Generative Adversarial Nets'](https://arxiv.org/abs/1411.1784) by
130Mirza and Osindero.
131
132```python
133# Set up the input.
134images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
135noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
136
137# Build the generator and discriminator.
138gan_model = tfgan.gan_model(
139    generator_fn=mnist.conditional_generator,  # you define
140    discriminator_fn=mnist.conditional_discriminator,  # you define
141    real_data=images,
142    generator_inputs=(noise, one_hot_labels))
143
144# The rest is the same as in the unconditional case.
145...
146```
147#### Adversarial loss
148This example combines an L1 pixel loss and an adversarial loss to learn to
149autoencode images. The bottleneck layer can be used to transmit compressed
150representations of the image. Neutral networks with pixel-wise loss only tend to
151produce blurry results, so the GAN can be used to make the reconstructions more
152plausible. See ['Full Resolution Image Compression with Recurrent Neural Networks'](https://arxiv.org/abs/1608.05148) by Toderici et al
153for an example of neural networks used for image compression, and ['Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network'](https://arxiv.org/abs/1609.04802) by Ledig et al for a more detailed description of
154how GANs can sharpen image output.
155
156```python
157# Set up the input pipeline.
158images = image_provider.provide_data(FLAGS.batch_size)
159
160# Build the generator and discriminator.
161gan_model = tfgan.gan_model(
162    generator_fn=nets.autoencoder,  # you define
163    discriminator_fn=nets.discriminator,  # you define
164    real_data=images,
165    generator_inputs=images)
166
167# Build the GAN loss and standard pixel loss.
168gan_loss = tfgan.gan_loss(
169    gan_model,
170    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
171    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
172    gradient_penalty=1.0)
173l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)
174
175# Modify the loss tuple to include the pixel loss.
176gan_loss = tfgan.losses.combine_adversarial_loss(
177    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
178
179# The rest is the same as in the unconditional case.
180...
181```
182
183#### Image-to-image translation
184This example maps images in one domain to images of the same size in a different
185dimension. For example, it can map segmentation masks to street images, or
186grayscale images to color. See ['Image-to-Image Translation with Conditional Adversarial Networks'](https://arxiv.org/abs/1611.07004) by Isola et al for more details.
187
188```python
189# Set up the input pipeline.
190input_image, target_image = data_provider.provide_data(FLAGS.batch_size)
191
192# Build the generator and discriminator.
193gan_model = tfgan.gan_model(
194    generator_fn=nets.generator,  # you define
195    discriminator_fn=nets.discriminator,  # you define
196    real_data=target_image,
197    generator_inputs=input_image)
198
199# Build the GAN loss and standard pixel loss.
200gan_loss = tfgan.gan_loss(
201    gan_model,
202    generator_loss_fn=tfgan.losses.least_squares_generator_loss,
203    discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
204l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)
205
206# Modify the loss tuple to include the pixel loss.
207gan_loss = tfgan.losses.combine_adversarial_loss(
208    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
209
210# The rest is the same as in the unconditional case.
211...
212```
213
214#### InfoGAN
215Train a generator to generate specific MNIST digit images, and control for digit style *without using any labels*. See ['InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets'](https://arxiv.org/abs/1606.03657) for more details.
216
217```python
218# Set up the input pipeline.
219images = mnist_data_provider.provide_data(FLAGS.batch_size)
220
221# Build the generator and discriminator.
222gan_model = tfgan.infogan_model(
223    generator_fn=mnist.infogan_generator,  # you define
224    discriminator_fn=mnist.infogran_discriminator,  # you define
225    real_data=images,
226    unstructured_generator_inputs=unstructured_inputs,  # you define
227    structured_generator_inputs=structured_inputs)  # you define
228
229# Build the GAN loss with mutual information penalty.
230gan_loss = tfgan.gan_loss(
231    gan_model,
232    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
233    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
234    gradient_penalty=1.0,
235    mutual_information_penalty_weight=1.0)
236
237# The rest is the same as in the unconditional case.
238...
239```
240
241#### Custom model creation
242Train an unconditional GAN to generate MNIST digits, but manually construct
243the `GANModel` tuple for more fine-grained control.
244
245```python
246# Set up the input pipeline.
247images = mnist_data_provider.provide_data(FLAGS.batch_size)
248noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
249
250# Manually build the generator and discriminator.
251with tf.variable_scope('Generator') as gen_scope:
252  generated_images = generator_fn(noise)
253with tf.variable_scope('Discriminator') as dis_scope:
254  discriminator_gen_outputs = discriminator_fn(generated_images)
255with variable_scope.variable_scope(dis_scope, reuse=True):
256  discriminator_real_outputs = discriminator_fn(images)
257generator_variables = variables_lib.get_trainable_variables(gen_scope)
258discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
259# Depending on what TF-GAN features you use, you don't always need to supply
260# every `GANModel` field. At a minimum, you need to include the discriminator
261# outputs and variables if you want to use TF-GAN to construct losses.
262gan_model = tfgan.GANModel(
263    generator_inputs,
264    generated_data,
265    generator_variables,
266    gen_scope,
267    generator_fn,
268    real_data,
269    discriminator_real_outputs,
270    discriminator_gen_outputs,
271    discriminator_variables,
272    dis_scope,
273    discriminator_fn)
274
275# The rest is the same as the unconditional case.
276...
277```
278
279
280## Authors
281Joel Shor (github: [joel-shor](https://github.com/joel-shor)) and Sergio Guadarrama (github: [sguada](https://github.com/sguada))
282