1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Named tuples for TF-GAN. 16 17TF-GAN training occurs in four steps, and each step communicates with the next 18step via one of these named tuples. At each step, you can either use a TF-GAN 19helper function in `train.py`, or you can manually construct a tuple. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27 28__all__ = [ 29 'GANModel', 30 'InfoGANModel', 31 'ACGANModel', 32 'CycleGANModel', 33 'StarGANModel', 34 'GANLoss', 35 'CycleGANLoss', 36 'GANTrainOps', 37 'GANTrainSteps', 38] 39 40 41class GANModel( 42 collections.namedtuple('GANModel', ( 43 'generator_inputs', 44 'generated_data', 45 'generator_variables', 46 'generator_scope', 47 'generator_fn', 48 'real_data', 49 'discriminator_real_outputs', 50 'discriminator_gen_outputs', 51 'discriminator_variables', 52 'discriminator_scope', 53 'discriminator_fn', 54 ))): 55 """A GANModel contains all the pieces needed for GAN training. 56 57 Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt 58 to create an implicit generative model of data by solving a two agent game. 59 The generator generates candidate examples that are supposed to match the 60 data distribution, and the discriminator aims to tell the real examples 61 apart from the generated samples. 62 63 Args: 64 generator_inputs: The random noise source that acts as input to the 65 generator. 66 generated_data: The generated output data of the GAN. 67 generator_variables: A list of all generator variables. 68 generator_scope: Variable scope all generator variables live in. 69 generator_fn: The generator function. 70 real_data: A tensor or real data. 71 discriminator_real_outputs: The discriminator's output on real data. 72 discriminator_gen_outputs: The discriminator's output on generated data. 73 discriminator_variables: A list of all discriminator variables. 74 discriminator_scope: Variable scope all discriminator variables live in. 75 discriminator_fn: The discriminator function. 76 """ 77 78 79# TODO(joelshor): Have this class inherit from `GANModel`. 80class InfoGANModel( 81 collections.namedtuple('InfoGANModel', GANModel._fields + ( 82 'structured_generator_inputs', 83 'predicted_distributions', 84 'discriminator_and_aux_fn', 85 ))): 86 """An InfoGANModel contains all the pieces needed for InfoGAN training. 87 88 See https://arxiv.org/abs/1606.03657 for more details. 89 90 Args: 91 structured_generator_inputs: A list of Tensors representing the random noise 92 that must have high mutual information with the generator output. List 93 length should match `predicted_distributions`. 94 predicted_distributions: A list of `tfp.distributions.Distribution`s. 95 Predicted by the recognizer, and used to evaluate the likelihood of the 96 structured noise. List length should match `structured_generator_inputs`. 97 discriminator_and_aux_fn: The original discriminator function that returns 98 a tuple of (logits, `predicted_distributions`). 99 """ 100 101 102class ACGANModel( 103 collections.namedtuple('ACGANModel', GANModel._fields + 104 ('one_hot_labels', 105 'discriminator_real_classification_logits', 106 'discriminator_gen_classification_logits',))): 107 """An ACGANModel contains all the pieces needed for ACGAN training. 108 109 See https://arxiv.org/abs/1610.09585 for more details. 110 111 Args: 112 one_hot_labels: A Tensor holding one-hot-labels for the batch. 113 discriminator_real_classification_logits: Classification logits for real 114 data. 115 discriminator_gen_classification_logits: Classification logits for generated 116 data. 117 """ 118 119 120class CycleGANModel( 121 collections.namedtuple( 122 'CycleGANModel', 123 ('model_x2y', 'model_y2x', 'reconstructed_x', 'reconstructed_y'))): 124 """An CycleGANModel contains all the pieces needed for CycleGAN training. 125 126 The model `model_x2y` generator F maps data set X to Y, while the model 127 `model_y2x` generator G maps data set Y to X. 128 129 See https://arxiv.org/abs/1703.10593 for more details. 130 131 Args: 132 model_x2y: A `GANModel` namedtuple whose generator maps data set X to Y. 133 model_y2x: A `GANModel` namedtuple whose generator maps data set Y to X. 134 reconstructed_x: A `Tensor` of reconstructed data X which is G(F(X)). 135 reconstructed_y: A `Tensor` of reconstructed data Y which is F(G(Y)). 136 """ 137 138 139class StarGANModel( 140 collections.namedtuple('StarGANModel', ( 141 'input_data', 142 'input_data_domain_label', 143 'generated_data', 144 'generated_data_domain_target', 145 'reconstructed_data', 146 'discriminator_input_data_source_predication', 147 'discriminator_generated_data_source_predication', 148 'discriminator_input_data_domain_predication', 149 'discriminator_generated_data_domain_predication', 150 'generator_variables', 151 'generator_scope', 152 'generator_fn', 153 'discriminator_variables', 154 'discriminator_scope', 155 'discriminator_fn', 156 ))): 157 """A StarGANModel contains all the pieces needed for StarGAN training. 158 159 Args: 160 input_data: The real images that need to be transferred by the generator. 161 input_data_domain_label: The real domain labels associated with the real 162 images. 163 generated_data: The generated images produced by the generator. It has the 164 same shape as the input_data. 165 generated_data_domain_target: The target domain that the generated images 166 belong to. It has the same shape as the input_data_domain_label. 167 reconstructed_data: The reconstructed images produced by the G(enerator). 168 reconstructed_data = G(G(input_data, generated_data_domain_target), 169 input_data_domain_label). 170 discriminator_input_data_source: The discriminator's output for predicting 171 the source (real/generated) of input_data. 172 discriminator_generated_data_source: The discriminator's output for 173 predicting the source (real/generated) of generated_data. 174 discriminator_input_data_domain_predication: The discriminator's output for 175 predicting the domain_label for the input_data. 176 discriminator_generated_data_domain_predication: The discriminatorr's output 177 for predicting the domain_target for the generated_data. 178 generator_variables: A list of all generator variables. 179 generator_scope: Variable scope all generator variables live in. 180 generator_fn: The generator function. 181 discriminator_variables: A list of all discriminator variables. 182 discriminator_scope: Variable scope all discriminator variables live in. 183 discriminator_fn: The discriminator function. 184 """ 185 186 187class GANLoss( 188 collections.namedtuple('GANLoss', ( 189 'generator_loss', 190 'discriminator_loss' 191 ))): 192 """GANLoss contains the generator and discriminator losses. 193 194 Args: 195 generator_loss: A tensor for the generator loss. 196 discriminator_loss: A tensor for the discriminator loss. 197 """ 198 199 200class CycleGANLoss( 201 collections.namedtuple('CycleGANLoss', ('loss_x2y', 'loss_y2x'))): 202 """CycleGANLoss contains the losses for `CycleGANModel`. 203 204 See https://arxiv.org/abs/1703.10593 for more details. 205 206 Args: 207 loss_x2y: A `GANLoss` namedtuple representing the loss of `model_x2y`. 208 loss_y2x: A `GANLoss` namedtuple representing the loss of `model_y2x`. 209 """ 210 211 212class GANTrainOps( 213 collections.namedtuple('GANTrainOps', ( 214 'generator_train_op', 215 'discriminator_train_op', 216 'global_step_inc_op', 217 'train_hooks' 218 ))): 219 """GANTrainOps contains the training ops. 220 221 Args: 222 generator_train_op: Op that performs a generator update step. 223 discriminator_train_op: Op that performs a discriminator update step. 224 global_step_inc_op: Op that increments the shared global step. 225 train_hooks: a list or tuple containing hooks related to training that need 226 to be populated when training ops are instantiated. Used primarily for 227 sync hooks. 228 """ 229 230 def __new__(cls, generator_train_op, discriminator_train_op, 231 global_step_inc_op, train_hooks=()): 232 return super(GANTrainOps, cls).__new__(cls, generator_train_op, 233 discriminator_train_op, 234 global_step_inc_op, train_hooks) 235 236 237class GANTrainSteps( 238 collections.namedtuple('GANTrainSteps', ( 239 'generator_train_steps', 240 'discriminator_train_steps' 241 ))): 242 """Contains configuration for the GAN Training. 243 244 Args: 245 generator_train_steps: Number of generator steps to take in each GAN step. 246 discriminator_train_steps: Number of discriminator steps to take in each GAN 247 step. 248 """ 249