• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Named tuples for TF-GAN.
16
17TF-GAN training occurs in four steps, and each step communicates with the next
18step via one of these named tuples. At each step, you can either use a TF-GAN
19helper function in `train.py`, or you can manually construct a tuple.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27
28__all__ = [
29    'GANModel',
30    'InfoGANModel',
31    'ACGANModel',
32    'CycleGANModel',
33    'StarGANModel',
34    'GANLoss',
35    'CycleGANLoss',
36    'GANTrainOps',
37    'GANTrainSteps',
38]
39
40
41class GANModel(
42    collections.namedtuple('GANModel', (
43        'generator_inputs',
44        'generated_data',
45        'generator_variables',
46        'generator_scope',
47        'generator_fn',
48        'real_data',
49        'discriminator_real_outputs',
50        'discriminator_gen_outputs',
51        'discriminator_variables',
52        'discriminator_scope',
53        'discriminator_fn',
54    ))):
55  """A GANModel contains all the pieces needed for GAN training.
56
57  Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt
58  to create an implicit generative model of data by solving a two agent game.
59  The generator generates candidate examples that are supposed to match the
60  data distribution, and the discriminator aims to tell the real examples
61  apart from the generated samples.
62
63  Args:
64    generator_inputs: The random noise source that acts as input to the
65      generator.
66    generated_data: The generated output data of the GAN.
67    generator_variables: A list of all generator variables.
68    generator_scope: Variable scope all generator variables live in.
69    generator_fn: The generator function.
70    real_data: A tensor or real data.
71    discriminator_real_outputs: The discriminator's output on real data.
72    discriminator_gen_outputs: The discriminator's output on generated data.
73    discriminator_variables: A list of all discriminator variables.
74    discriminator_scope: Variable scope all discriminator variables live in.
75    discriminator_fn: The discriminator function.
76  """
77
78
79# TODO(joelshor): Have this class inherit from `GANModel`.
80class InfoGANModel(
81    collections.namedtuple('InfoGANModel', GANModel._fields + (
82        'structured_generator_inputs',
83        'predicted_distributions',
84        'discriminator_and_aux_fn',
85    ))):
86  """An InfoGANModel contains all the pieces needed for InfoGAN training.
87
88  See https://arxiv.org/abs/1606.03657 for more details.
89
90  Args:
91    structured_generator_inputs: A list of Tensors representing the random noise
92      that must  have high mutual information with the generator output. List
93      length should match `predicted_distributions`.
94    predicted_distributions: A list of `tfp.distributions.Distribution`s.
95      Predicted by the recognizer, and used to evaluate the likelihood of the
96      structured noise. List length should match `structured_generator_inputs`.
97    discriminator_and_aux_fn: The original discriminator function that returns
98      a tuple of (logits, `predicted_distributions`).
99  """
100
101
102class ACGANModel(
103    collections.namedtuple('ACGANModel', GANModel._fields +
104                           ('one_hot_labels',
105                            'discriminator_real_classification_logits',
106                            'discriminator_gen_classification_logits',))):
107  """An ACGANModel contains all the pieces needed for ACGAN training.
108
109  See https://arxiv.org/abs/1610.09585 for more details.
110
111  Args:
112    one_hot_labels: A Tensor holding one-hot-labels for the batch.
113    discriminator_real_classification_logits: Classification logits for real
114      data.
115    discriminator_gen_classification_logits: Classification logits for generated
116      data.
117  """
118
119
120class CycleGANModel(
121    collections.namedtuple(
122        'CycleGANModel',
123        ('model_x2y', 'model_y2x', 'reconstructed_x', 'reconstructed_y'))):
124  """An CycleGANModel contains all the pieces needed for CycleGAN training.
125
126  The model `model_x2y` generator F maps data set X to Y, while the model
127  `model_y2x` generator G maps data set Y to X.
128
129  See https://arxiv.org/abs/1703.10593 for more details.
130
131  Args:
132    model_x2y: A `GANModel` namedtuple whose generator maps data set X to Y.
133    model_y2x: A `GANModel` namedtuple whose generator maps data set Y to X.
134    reconstructed_x: A `Tensor` of reconstructed data X which is G(F(X)).
135    reconstructed_y: A `Tensor` of reconstructed data Y which is F(G(Y)).
136  """
137
138
139class StarGANModel(
140    collections.namedtuple('StarGANModel', (
141        'input_data',
142        'input_data_domain_label',
143        'generated_data',
144        'generated_data_domain_target',
145        'reconstructed_data',
146        'discriminator_input_data_source_predication',
147        'discriminator_generated_data_source_predication',
148        'discriminator_input_data_domain_predication',
149        'discriminator_generated_data_domain_predication',
150        'generator_variables',
151        'generator_scope',
152        'generator_fn',
153        'discriminator_variables',
154        'discriminator_scope',
155        'discriminator_fn',
156    ))):
157  """A StarGANModel contains all the pieces needed for StarGAN training.
158
159  Args:
160    input_data: The real images that need to be transferred by the generator.
161    input_data_domain_label: The real domain labels associated with the real
162      images.
163    generated_data: The generated images produced by the generator. It has the
164      same shape as the input_data.
165    generated_data_domain_target: The target domain that the generated images
166      belong to. It has the same shape as the input_data_domain_label.
167    reconstructed_data: The reconstructed images produced by the G(enerator).
168      reconstructed_data = G(G(input_data, generated_data_domain_target),
169      input_data_domain_label).
170    discriminator_input_data_source: The discriminator's output for predicting
171      the source (real/generated) of input_data.
172    discriminator_generated_data_source: The discriminator's output for
173      predicting the source (real/generated) of  generated_data.
174    discriminator_input_data_domain_predication: The discriminator's output for
175      predicting the domain_label for the input_data.
176    discriminator_generated_data_domain_predication: The discriminatorr's output
177      for predicting the domain_target for the generated_data.
178    generator_variables: A list of all generator variables.
179    generator_scope: Variable scope all generator variables live in.
180    generator_fn: The generator function.
181    discriminator_variables: A list of all discriminator variables.
182    discriminator_scope: Variable scope all discriminator variables live in.
183    discriminator_fn: The discriminator function.
184  """
185
186
187class GANLoss(
188    collections.namedtuple('GANLoss', (
189        'generator_loss',
190        'discriminator_loss'
191    ))):
192  """GANLoss contains the generator and discriminator losses.
193
194  Args:
195    generator_loss: A tensor for the generator loss.
196    discriminator_loss: A tensor for the discriminator loss.
197  """
198
199
200class CycleGANLoss(
201    collections.namedtuple('CycleGANLoss', ('loss_x2y', 'loss_y2x'))):
202  """CycleGANLoss contains the losses for `CycleGANModel`.
203
204  See https://arxiv.org/abs/1703.10593 for more details.
205
206  Args:
207    loss_x2y: A `GANLoss` namedtuple representing the loss of `model_x2y`.
208    loss_y2x: A `GANLoss` namedtuple representing the loss of `model_y2x`.
209  """
210
211
212class GANTrainOps(
213    collections.namedtuple('GANTrainOps', (
214        'generator_train_op',
215        'discriminator_train_op',
216        'global_step_inc_op',
217        'train_hooks'
218    ))):
219  """GANTrainOps contains the training ops.
220
221  Args:
222    generator_train_op: Op that performs a generator update step.
223    discriminator_train_op: Op that performs a discriminator update step.
224    global_step_inc_op: Op that increments the shared global step.
225    train_hooks: a list or tuple containing hooks related to training that need
226      to be populated when training ops are instantiated. Used primarily for
227      sync hooks.
228  """
229
230  def __new__(cls, generator_train_op, discriminator_train_op,
231              global_step_inc_op, train_hooks=()):
232    return super(GANTrainOps, cls).__new__(cls, generator_train_op,
233                                           discriminator_train_op,
234                                           global_step_inc_op, train_hooks)
235
236
237class GANTrainSteps(
238    collections.namedtuple('GANTrainSteps', (
239        'generator_train_steps',
240        'discriminator_train_steps'
241    ))):
242  """Contains configuration for the GAN Training.
243
244  Args:
245    generator_train_steps: Number of generator steps to take in each GAN step.
246    discriminator_train_steps: Number of discriminator steps to take in each GAN
247      step.
248  """
249