• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The TF-GAN project provides a lightweight GAN training/testing framework.
16
17This file contains the core helper functions to create and train a GAN model.
18See the README or examples in `tensorflow_models` for details on how to use.
19
20TF-GAN training occurs in four steps:
211) Create a model
222) Add a loss
233) Create train ops
244) Run the train ops
25
26The functions in this file are organized around these four steps. Each function
27corresponds to one of the steps.
28"""
29
30from __future__ import absolute_import
31from __future__ import division
32from __future__ import print_function
33
34from tensorflow.contrib.framework.python.ops import variables as variables_lib
35from tensorflow.contrib.gan.python import losses as tfgan_losses
36from tensorflow.contrib.gan.python import namedtuples
37from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
38from tensorflow.contrib.slim.python.slim import learning as slim_learning
39from tensorflow.contrib.training.python.training import training
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import check_ops
44from tensorflow.python.ops import init_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.ops.losses import losses
49from tensorflow.python.summary import summary
50from tensorflow.python.training import session_run_hook
51from tensorflow.python.training import sync_replicas_optimizer
52from tensorflow.python.training import training_util
53
54__all__ = [
55    'gan_model',
56    'infogan_model',
57    'acgan_model',
58    'cyclegan_model',
59    'stargan_model',
60    'gan_loss',
61    'cyclegan_loss',
62    'stargan_loss',
63    'gan_train_ops',
64    'gan_train',
65    'get_sequential_train_hooks',
66    'get_joint_train_hooks',
67    'get_sequential_train_steps',
68    'RunTrainOpsHook',
69]
70
71
72def gan_model(
73    # Lambdas defining models.
74    generator_fn,
75    discriminator_fn,
76    # Real data and conditioning.
77    real_data,
78    generator_inputs,
79    # Optional scopes.
80    generator_scope='Generator',
81    discriminator_scope='Discriminator',
82    # Options.
83    check_shapes=True):
84  """Returns GAN model outputs and variables.
85
86  Args:
87    generator_fn: A python lambda that takes `generator_inputs` as inputs and
88      returns the outputs of the GAN generator.
89    discriminator_fn: A python lambda that takes `real_data`/`generated data`
90      and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
91    real_data: A Tensor representing the real data.
92    generator_inputs: A Tensor or list of Tensors to the generator. In the
93      vanilla GAN case, this might be a single noise Tensor. In the conditional
94      GAN case, this might be the generator's conditioning.
95    generator_scope: Optional generator variable scope. Useful if you want to
96      reuse a subgraph that has already been created.
97    discriminator_scope: Optional discriminator variable scope. Useful if you
98      want to reuse a subgraph that has already been created.
99    check_shapes: If `True`, check that generator produces Tensors that are the
100      same shape as real data. Otherwise, skip this check.
101
102  Returns:
103    A GANModel namedtuple.
104
105  Raises:
106    ValueError: If the generator outputs a Tensor that isn't the same shape as
107      `real_data`.
108  """
109  # Create models
110  with variable_scope.variable_scope(generator_scope) as gen_scope:
111    generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
112    generated_data = generator_fn(generator_inputs)
113  with variable_scope.variable_scope(discriminator_scope) as dis_scope:
114    discriminator_gen_outputs = discriminator_fn(generated_data,
115                                                 generator_inputs)
116  with variable_scope.variable_scope(dis_scope, reuse=True):
117    real_data = _convert_tensor_or_l_or_d(real_data)
118    discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
119
120  if check_shapes:
121    if not generated_data.shape.is_compatible_with(real_data.shape):
122      raise ValueError(
123          'Generator output shape (%s) must be the same shape as real data '
124          '(%s).' % (generated_data.shape, real_data.shape))
125
126  # Get model-specific variables.
127  generator_variables = variables_lib.get_trainable_variables(gen_scope)
128  discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
129
130  return namedtuples.GANModel(
131      generator_inputs, generated_data, generator_variables, gen_scope,
132      generator_fn, real_data, discriminator_real_outputs,
133      discriminator_gen_outputs, discriminator_variables, dis_scope,
134      discriminator_fn)
135
136
137def infogan_model(
138    # Lambdas defining models.
139    generator_fn,
140    discriminator_fn,
141    # Real data and conditioning.
142    real_data,
143    unstructured_generator_inputs,
144    structured_generator_inputs,
145    # Optional scopes.
146    generator_scope='Generator',
147    discriminator_scope='Discriminator'):
148  """Returns an InfoGAN model outputs and variables.
149
150  See https://arxiv.org/abs/1606.03657 for more details.
151
152  Args:
153    generator_fn: A python lambda that takes a list of Tensors as inputs and
154      returns the outputs of the GAN generator.
155    discriminator_fn: A python lambda that takes `real_data`/`generated data`
156      and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
157      `logits` are in the range [-inf, inf], and `distribution_list` is a list
158      of Tensorflow distributions representing the predicted noise distribution
159      of the ith structure noise.
160    real_data: A Tensor representing the real data.
161    unstructured_generator_inputs: A list of Tensors to the generator.
162      These tensors represent the unstructured noise or conditioning.
163    structured_generator_inputs: A list of Tensors to the generator.
164      These tensors must have high mutual information with the recognizer.
165    generator_scope: Optional generator variable scope. Useful if you want to
166      reuse a subgraph that has already been created.
167    discriminator_scope: Optional discriminator variable scope. Useful if you
168      want to reuse a subgraph that has already been created.
169
170  Returns:
171    An InfoGANModel namedtuple.
172
173  Raises:
174    ValueError: If the generator outputs a Tensor that isn't the same shape as
175      `real_data`.
176    ValueError: If the discriminator output is malformed.
177  """
178  # Create models
179  with variable_scope.variable_scope(generator_scope) as gen_scope:
180    unstructured_generator_inputs = _convert_tensor_or_l_or_d(
181        unstructured_generator_inputs)
182    structured_generator_inputs = _convert_tensor_or_l_or_d(
183        structured_generator_inputs)
184    generator_inputs = (
185        unstructured_generator_inputs + structured_generator_inputs)
186    generated_data = generator_fn(generator_inputs)
187  with variable_scope.variable_scope(discriminator_scope) as disc_scope:
188    dis_gen_outputs, predicted_distributions = discriminator_fn(
189        generated_data, generator_inputs)
190  _validate_distributions(predicted_distributions, structured_generator_inputs)
191  with variable_scope.variable_scope(disc_scope, reuse=True):
192    real_data = ops.convert_to_tensor(real_data)
193    dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
194
195  if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
196    raise ValueError(
197        'Generator output shape (%s) must be the same shape as real data '
198        '(%s).' % (generated_data.get_shape(), real_data.get_shape()))
199
200  # Get model-specific variables.
201  generator_variables = variables_lib.get_trainable_variables(gen_scope)
202  discriminator_variables = variables_lib.get_trainable_variables(disc_scope)
203
204  return namedtuples.InfoGANModel(
205      generator_inputs,
206      generated_data,
207      generator_variables,
208      gen_scope,
209      generator_fn,
210      real_data,
211      dis_real_outputs,
212      dis_gen_outputs,
213      discriminator_variables,
214      disc_scope,
215      lambda x, y: discriminator_fn(x, y)[0],  # conform to non-InfoGAN API
216      structured_generator_inputs,
217      predicted_distributions,
218      discriminator_fn)
219
220
221def acgan_model(
222    # Lambdas defining models.
223    generator_fn,
224    discriminator_fn,
225    # Real data and conditioning.
226    real_data,
227    generator_inputs,
228    one_hot_labels,
229    # Optional scopes.
230    generator_scope='Generator',
231    discriminator_scope='Discriminator',
232    # Options.
233    check_shapes=True):
234  """Returns an ACGANModel contains all the pieces needed for ACGAN training.
235
236  The `acgan_model` is the same as the `gan_model` with the only difference
237  being that the discriminator additionally outputs logits to classify the input
238  (real or generated).
239  Therefore, an explicit field holding one_hot_labels is necessary, as well as a
240  discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
241  classification.
242
243  See https://arxiv.org/abs/1610.09585 for more details.
244
245  Args:
246    generator_fn: A python lambda that takes `generator_inputs` as inputs and
247      returns the outputs of the GAN generator.
248    discriminator_fn: A python lambda that takes `real_data`/`generated data`
249      and `generator_inputs`. Outputs a tuple consisting of two Tensors:
250        (1) real/fake logits in the range [-inf, inf]
251        (2) classification logits in the range [-inf, inf]
252    real_data: A Tensor representing the real data.
253    generator_inputs: A Tensor or list of Tensors to the generator. In the
254      vanilla GAN case, this might be a single noise Tensor. In the conditional
255      GAN case, this might be the generator's conditioning.
256    one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
257      acgan_loss.
258    generator_scope: Optional generator variable scope. Useful if you want to
259      reuse a subgraph that has already been created.
260    discriminator_scope: Optional discriminator variable scope. Useful if you
261      want to reuse a subgraph that has already been created.
262    check_shapes: If `True`, check that generator produces Tensors that are the
263      same shape as real data. Otherwise, skip this check.
264
265  Returns:
266    A ACGANModel namedtuple.
267
268  Raises:
269    ValueError: If the generator outputs a Tensor that isn't the same shape as
270      `real_data`.
271    TypeError: If the discriminator does not output a tuple consisting of
272    (discrimination logits, classification logits).
273  """
274  # Create models
275  with variable_scope.variable_scope(generator_scope) as gen_scope:
276    generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
277    generated_data = generator_fn(generator_inputs)
278  with variable_scope.variable_scope(discriminator_scope) as dis_scope:
279    with ops.name_scope(dis_scope.name + '/generated/'):
280      (discriminator_gen_outputs, discriminator_gen_classification_logits
281      ) = _validate_acgan_discriminator_outputs(
282          discriminator_fn(generated_data, generator_inputs))
283  with variable_scope.variable_scope(dis_scope, reuse=True):
284    with ops.name_scope(dis_scope.name + '/real/'):
285      real_data = ops.convert_to_tensor(real_data)
286      (discriminator_real_outputs, discriminator_real_classification_logits
287      ) = _validate_acgan_discriminator_outputs(
288          discriminator_fn(real_data, generator_inputs))
289  if check_shapes:
290    if not generated_data.shape.is_compatible_with(real_data.shape):
291      raise ValueError(
292          'Generator output shape (%s) must be the same shape as real data '
293          '(%s).' % (generated_data.shape, real_data.shape))
294
295  # Get model-specific variables.
296  generator_variables = variables_lib.get_trainable_variables(gen_scope)
297  discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
298
299  return namedtuples.ACGANModel(
300      generator_inputs, generated_data, generator_variables, gen_scope,
301      generator_fn, real_data, discriminator_real_outputs,
302      discriminator_gen_outputs, discriminator_variables, dis_scope,
303      discriminator_fn, one_hot_labels,
304      discriminator_real_classification_logits,
305      discriminator_gen_classification_logits)
306
307
308def cyclegan_model(
309    # Lambdas defining models.
310    generator_fn,
311    discriminator_fn,
312    # data X and Y.
313    data_x,
314    data_y,
315    # Optional scopes.
316    generator_scope='Generator',
317    discriminator_scope='Discriminator',
318    model_x2y_scope='ModelX2Y',
319    model_y2x_scope='ModelY2X',
320    # Options.
321    check_shapes=True):
322  """Returns a CycleGAN model outputs and variables.
323
324  See https://arxiv.org/abs/1703.10593 for more details.
325
326  Args:
327    generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
328      returns the outputs of the GAN generator.
329    discriminator_fn: A python lambda that takes `real_data`/`generated data`
330      and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
331    data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
332    data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
333    generator_scope: Optional generator variable scope. Useful if you want to
334      reuse a subgraph that has already been created. Defaults to 'Generator'.
335    discriminator_scope: Optional discriminator variable scope. Useful if you
336      want to reuse a subgraph that has already been created. Defaults to
337      'Discriminator'.
338    model_x2y_scope: Optional variable scope for model x2y variables. Defaults
339      to 'ModelX2Y'.
340    model_y2x_scope: Optional variable scope for model y2x variables. Defaults
341      to 'ModelY2X'.
342    check_shapes: If `True`, check that generator produces Tensors that are the
343      same shape as `data_x` (`data_y`). Otherwise, skip this check.
344
345  Returns:
346    A `CycleGANModel` namedtuple.
347
348  Raises:
349    ValueError: If `check_shapes` is True and `data_x` or the generator output
350      does not have the same shape as `data_y`.
351  """
352
353  # Create models.
354  def _define_partial_model(input_data, output_data):
355    return gan_model(
356        generator_fn=generator_fn,
357        discriminator_fn=discriminator_fn,
358        real_data=output_data,
359        generator_inputs=input_data,
360        generator_scope=generator_scope,
361        discriminator_scope=discriminator_scope,
362        check_shapes=check_shapes)
363
364  with variable_scope.variable_scope(model_x2y_scope):
365    model_x2y = _define_partial_model(data_x, data_y)
366  with variable_scope.variable_scope(model_y2x_scope):
367    model_y2x = _define_partial_model(data_y, data_x)
368
369  with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True):
370    reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
371  with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True):
372    reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)
373
374  return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
375                                   reconstructed_y)
376
377
378def stargan_model(generator_fn,
379                  discriminator_fn,
380                  input_data,
381                  input_data_domain_label,
382                  generator_scope='Generator',
383                  discriminator_scope='Discriminator'):
384  """Returns a StarGAN model outputs and variables.
385
386  See https://arxiv.org/abs/1711.09020 for more details.
387
388  Args:
389    generator_fn: A python lambda that takes `inputs` and `targets` as inputs
390      and returns 'generated_data' as the transformed version of `input` based
391      on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
392      num_domains), and `generated_data` has the same shape as `input`.
393    discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
394      inputs and returns a tuple (`source_prediction`, `domain_prediction`).
395      `source_prediction` represents the source(real/generated) prediction by
396      the discriminator, and `domain_prediction` represents the domain
397      prediction/classification by the discriminator. `source_prediction` has
398      shape (n) and `domain_prediction` has shape (n, num_domains).
399    input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
400      the real input images.
401    input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
402      num_domains) representing the domain label associated with the real
403      images.
404    generator_scope: Optional generator variable scope. Useful if you want to
405      reuse a subgraph that has already been created.
406    discriminator_scope: Optional discriminator variable scope. Useful if you
407      want to reuse a subgraph that has already been created.
408
409  Returns:
410    StarGANModel nametuple return the tensor that are needed to compute the
411    loss.
412
413  Raises:
414    ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
415    defined in every dimensions.
416  """
417
418  # Convert to tensor.
419  input_data = _convert_tensor_or_l_or_d(input_data)
420  input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
421
422  # Convert list of tensor to a single tensor if applicable.
423  if isinstance(input_data, (list, tuple)):
424    input_data = array_ops.concat(
425        [ops.convert_to_tensor(x) for x in input_data], 0)
426  if isinstance(input_data_domain_label, (list, tuple)):
427    input_data_domain_label = array_ops.concat(
428        [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)
429
430  # Get batch_size, num_domains from the labels.
431  input_data_domain_label.shape.assert_has_rank(2)
432  input_data_domain_label.shape.assert_is_fully_defined()
433  batch_size, num_domains = input_data_domain_label.shape.as_list()
434
435  # Transform input_data to random target domains.
436  with variable_scope.variable_scope(generator_scope) as generator_scope:
437    generated_data_domain_target = _generate_stargan_random_domain_target(
438        batch_size, num_domains)
439    generated_data = generator_fn(input_data, generated_data_domain_target)
440
441  # Transform generated_data back to the original input_data domain.
442  with variable_scope.variable_scope(generator_scope, reuse=True):
443    reconstructed_data = generator_fn(generated_data, input_data_domain_label)
444
445  # Predict source and domain for the generated_data using the discriminator.
446  with variable_scope.variable_scope(
447      discriminator_scope) as discriminator_scope:
448    disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
449        generated_data, num_domains)
450
451  # Predict source and domain for the input_data using the discriminator.
452  with variable_scope.variable_scope(discriminator_scope, reuse=True):
453    disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
454        input_data, num_domains)
455
456  # Collect trainable variables from the neural networks.
457  generator_variables = variables_lib.get_trainable_variables(generator_scope)
458  discriminator_variables = variables_lib.get_trainable_variables(
459      discriminator_scope)
460
461  # Create the StarGANModel namedtuple.
462  return namedtuples.StarGANModel(
463      input_data=input_data,
464      input_data_domain_label=input_data_domain_label,
465      generated_data=generated_data,
466      generated_data_domain_target=generated_data_domain_target,
467      reconstructed_data=reconstructed_data,
468      discriminator_input_data_source_predication=disc_input_data_source_pred,
469      discriminator_generated_data_source_predication=disc_gen_data_source_pred,
470      discriminator_input_data_domain_predication=disc_input_data_domain_pred,
471      discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
472      generator_variables=generator_variables,
473      generator_scope=generator_scope,
474      generator_fn=generator_fn,
475      discriminator_variables=discriminator_variables,
476      discriminator_scope=discriminator_scope,
477      discriminator_fn=discriminator_fn)
478
479
480def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
481  if isinstance(aux_loss_weight, ops.Tensor):
482    aux_loss_weight.shape.assert_is_compatible_with([])
483    with ops.control_dependencies(
484        [check_ops.assert_greater_equal(aux_loss_weight, 0.0)]):
485      aux_loss_weight = array_ops.identity(aux_loss_weight)
486  elif aux_loss_weight is not None and aux_loss_weight < 0:
487    raise ValueError('`%s` must be greater than 0. Instead, was %s' %
488                     (name, aux_loss_weight))
489  return aux_loss_weight
490
491
492def _use_aux_loss(aux_loss_weight):
493  if aux_loss_weight is not None:
494    if not isinstance(aux_loss_weight, ops.Tensor):
495      return aux_loss_weight > 0
496    else:
497      return True
498  else:
499    return False
500
501
502def _tensor_pool_adjusted_model(model, tensor_pool_fn):
503  """Adjusts model using `tensor_pool_fn`.
504
505  Args:
506    model: A GANModel tuple.
507    tensor_pool_fn: A function that takes (generated_data, generator_inputs),
508      stores them in an internal pool and returns a previously stored
509      (generated_data, generator_inputs) with some probability. For example
510      tfgan.features.tensor_pool.
511
512  Returns:
513    A new GANModel tuple where discriminator outputs are adjusted by taking
514    pooled generator outputs as inputs. Returns the original model if
515    `tensor_pool_fn` is None.
516
517  Raises:
518    ValueError: If tensor pool does not support the `model`.
519  """
520  if isinstance(model, namedtuples.GANModel):
521    pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
522        (model.generator_inputs, model.generated_data))
523    with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
524      dis_gen_outputs = model.discriminator_fn(pooled_generated_data,
525                                               pooled_generator_inputs)
526    return model._replace(
527        generator_inputs=pooled_generator_inputs,
528        generated_data=pooled_generated_data,
529        discriminator_gen_outputs=dis_gen_outputs)
530  elif isinstance(model, namedtuples.ACGANModel):
531    pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
532        (model.generator_inputs, model.generated_data))
533    with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
534      (pooled_discriminator_gen_outputs,
535       pooled_discriminator_gen_classification_logits) = model.discriminator_fn(
536           pooled_generated_data, pooled_generator_inputs)
537    return model._replace(
538        generator_inputs=pooled_generator_inputs,
539        generated_data=pooled_generated_data,
540        discriminator_gen_outputs=pooled_discriminator_gen_outputs,
541        discriminator_gen_classification_logits=
542        pooled_discriminator_gen_classification_logits)
543  elif isinstance(model, namedtuples.InfoGANModel):
544    pooled_generator_inputs, pooled_generated_data, pooled_structured_input = (
545        tensor_pool_fn((model.generator_inputs, model.generated_data,
546                        model.structured_generator_inputs)))
547    with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
548      (pooled_discriminator_gen_outputs,
549       pooled_predicted_distributions) = model.discriminator_and_aux_fn(
550           pooled_generated_data, pooled_generator_inputs)
551    return model._replace(
552        generator_inputs=pooled_generator_inputs,
553        generated_data=pooled_generated_data,
554        structured_generator_inputs=pooled_structured_input,
555        discriminator_gen_outputs=pooled_discriminator_gen_outputs,
556        predicted_distributions=pooled_predicted_distributions)
557  else:
558    raise ValueError('Tensor pool does not support `model`: %s.' % type(model))
559
560
561def gan_loss(
562    # GANModel.
563    model,
564    # Loss functions.
565    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
566    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
567    # Auxiliary losses.
568    gradient_penalty_weight=None,
569    gradient_penalty_epsilon=1e-10,
570    gradient_penalty_target=1.0,
571    gradient_penalty_one_sided=False,
572    mutual_information_penalty_weight=None,
573    aux_cond_generator_weight=None,
574    aux_cond_discriminator_weight=None,
575    tensor_pool_fn=None,
576    # Options.
577    add_summaries=True):
578  """Returns losses necessary to train generator and discriminator.
579
580  Args:
581    model: A GANModel tuple.
582    generator_loss_fn: The loss function on the generator. Takes a GANModel
583      tuple.
584    discriminator_loss_fn: The loss function on the discriminator. Takes a
585      GANModel tuple.
586    gradient_penalty_weight: If not `None`, must be a non-negative Python number
587      or Tensor indicating how much to weight the gradient penalty. See
588      https://arxiv.org/pdf/1704.00028.pdf for more details.
589    gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
590      small positive value used by the gradient penalty function for numerical
591      stability. Note some applications will need to increase this value to
592      avoid NaNs.
593    gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
594      number or `Tensor` indicating the target value of gradient norm. See the
595      CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
596    gradient_penalty_one_sided: If `True`, penalty proposed in
597      https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
598    mutual_information_penalty_weight: If not `None`, must be a non-negative
599      Python number or Tensor indicating how much to weight the mutual
600      information penalty. See https://arxiv.org/abs/1606.03657 for more
601      details.
602    aux_cond_generator_weight: If not None: add a classification loss as in
603      https://arxiv.org/abs/1610.09585
604    aux_cond_discriminator_weight: If not None: add a classification loss as in
605      https://arxiv.org/abs/1610.09585
606    tensor_pool_fn: A function that takes (generated_data, generator_inputs),
607      stores them in an internal pool and returns previous stored
608      (generated_data, generator_inputs). For example
609      `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool).
610    add_summaries: Whether or not to add summaries for the losses.
611
612  Returns:
613    A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
614    regularization losses.
615
616  Raises:
617    ValueError: If any of the auxiliary loss weights is provided and negative.
618    ValueError: If `mutual_information_penalty_weight` is provided, but the
619      `model` isn't an `InfoGANModel`.
620  """
621  # Validate arguments.
622  gradient_penalty_weight = _validate_aux_loss_weight(
623      gradient_penalty_weight, 'gradient_penalty_weight')
624  mutual_information_penalty_weight = _validate_aux_loss_weight(
625      mutual_information_penalty_weight, 'infogan_weight')
626  aux_cond_generator_weight = _validate_aux_loss_weight(
627      aux_cond_generator_weight, 'aux_cond_generator_weight')
628  aux_cond_discriminator_weight = _validate_aux_loss_weight(
629      aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
630
631  # Verify configuration for mutual information penalty
632  if (_use_aux_loss(mutual_information_penalty_weight) and
633      not isinstance(model, namedtuples.InfoGANModel)):
634    raise ValueError(
635        'When `mutual_information_penalty_weight` is provided, `model` must be '
636        'an `InfoGANModel`. Instead, was %s.' % type(model))
637
638  # Verify configuration for mutual auxiliary condition loss (ACGAN).
639  if ((_use_aux_loss(aux_cond_generator_weight) or
640       _use_aux_loss(aux_cond_discriminator_weight)) and
641      not isinstance(model, namedtuples.ACGANModel)):
642    raise ValueError(
643        'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
644        'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
645        type(model))
646
647  # Optionally create pooled model.
648  if tensor_pool_fn:
649    pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn)
650  else:
651    pooled_model = model
652
653  # Create standard losses.
654  gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
655  dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries)
656
657  # Add optional extra losses.
658  if _use_aux_loss(gradient_penalty_weight):
659    gp_loss = tfgan_losses.wasserstein_gradient_penalty(
660        pooled_model,
661        epsilon=gradient_penalty_epsilon,
662        target=gradient_penalty_target,
663        one_sided=gradient_penalty_one_sided,
664        add_summaries=add_summaries)
665    dis_loss += gradient_penalty_weight * gp_loss
666  if _use_aux_loss(mutual_information_penalty_weight):
667    gen_info_loss = tfgan_losses.mutual_information_penalty(
668        model, add_summaries=add_summaries)
669    if tensor_pool_fn is None:
670      dis_info_loss = gen_info_loss
671    else:
672      dis_info_loss = tfgan_losses.mutual_information_penalty(
673          pooled_model, add_summaries=add_summaries)
674    gen_loss += mutual_information_penalty_weight * gen_info_loss
675    dis_loss += mutual_information_penalty_weight * dis_info_loss
676  if _use_aux_loss(aux_cond_generator_weight):
677    ac_gen_loss = tfgan_losses.acgan_generator_loss(
678        model, add_summaries=add_summaries)
679    gen_loss += aux_cond_generator_weight * ac_gen_loss
680  if _use_aux_loss(aux_cond_discriminator_weight):
681    ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
682        pooled_model, add_summaries=add_summaries)
683    dis_loss += aux_cond_discriminator_weight * ac_disc_loss
684  # Gathers auxiliary losses.
685  if model.generator_scope:
686    gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
687  else:
688    gen_reg_loss = 0
689  if model.discriminator_scope:
690    dis_reg_loss = losses.get_regularization_loss(
691        model.discriminator_scope.name)
692  else:
693    dis_reg_loss = 0
694
695  return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
696
697
698def cyclegan_loss(
699    model,
700    # Loss functions.
701    generator_loss_fn=tfgan_losses.least_squares_generator_loss,
702    discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss,
703    # Auxiliary losses.
704    cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss,
705    cycle_consistency_loss_weight=10.0,
706    # Options
707    **kwargs):
708  """Returns the losses for a `CycleGANModel`.
709
710  See https://arxiv.org/abs/1703.10593 for more details.
711
712  Args:
713    model: A `CycleGANModel` namedtuple.
714    generator_loss_fn: The loss function on the generator. Takes a `GANModel`
715      named tuple.
716    discriminator_loss_fn: The loss function on the discriminator. Takes a
717      `GANModel` namedtuple.
718    cycle_consistency_loss_fn: The cycle consistency loss function. Takes a
719      `CycleGANModel` namedtuple.
720    cycle_consistency_loss_weight: A non-negative Python number or a scalar
721      `Tensor` indicating how much to weigh the cycle consistency loss.
722    **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss
723      for each partial model of `model`.
724
725  Returns:
726    A `CycleGANLoss` namedtuple.
727
728  Raises:
729    ValueError: If `model` is not a `CycleGANModel` namedtuple.
730  """
731  # Sanity checks.
732  if not isinstance(model, namedtuples.CycleGANModel):
733    raise ValueError(
734        '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model))
735
736  # Defines cycle consistency loss.
737  cycle_consistency_loss = cycle_consistency_loss_fn(
738      model, add_summaries=kwargs.get('add_summaries', True))
739  cycle_consistency_loss_weight = _validate_aux_loss_weight(
740      cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
741  aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
742
743  # Defines losses for each partial model.
744  def _partial_loss(partial_model):
745    partial_loss = gan_loss(
746        partial_model,
747        generator_loss_fn=generator_loss_fn,
748        discriminator_loss_fn=discriminator_loss_fn,
749        **kwargs)
750    return partial_loss._replace(generator_loss=partial_loss.generator_loss +
751                                 aux_loss)
752
753  with ops.name_scope('cyclegan_loss_x2y'):
754    loss_x2y = _partial_loss(model.model_x2y)
755  with ops.name_scope('cyclegan_loss_y2x'):
756    loss_y2x = _partial_loss(model.model_y2x)
757
758  return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
759
760# Begin google-internal
761# The four major parts can be found here: http://screen/tMRMBAohDYG.
762# End google-internal
763def stargan_loss(
764    model,
765    generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(
766        tfgan_losses_impl.wasserstein_generator_loss),
767    discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(
768        tfgan_losses_impl.wasserstein_discriminator_loss),
769    gradient_penalty_weight=10.0,
770    gradient_penalty_epsilon=1e-10,
771    gradient_penalty_target=1.0,
772    gradient_penalty_one_sided=False,
773    reconstruction_loss_fn=losses.absolute_difference,
774    reconstruction_loss_weight=10.0,
775    classification_loss_fn=losses.softmax_cross_entropy,
776    classification_loss_weight=1.0,
777    classification_one_hot=True,
778    add_summaries=True):
779  """StarGAN Loss.
780
781  Args:
782    model: (StarGAN) Model output of the stargan_model() function call.
783    generator_loss_fn: The loss function on the generator. Takes a
784      `StarGANModel` named tuple.
785    discriminator_loss_fn: The loss function on the discriminator. Takes a
786      `StarGANModel` namedtuple.
787    gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
788      the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
789      turn off gradient penalty.
790    gradient_penalty_epsilon: (float) A small positive number added for
791      numerical stability when computing the gradient norm.
792    gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
793      gradient norm. Defaults to 1.0.
794    gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
795      https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
796    reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
797      and the function must conform to the `tf.losses` API.
798    reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
799    classification_loss_fn: The loss function on the discriminator's ability to
800      classify domain of the input. Default to one-hot softmax cross entropy
801      loss, and the function must conform to the `tf.losses` API.
802    classification_loss_weight: (float) Classification loss weight. Default to
803      1.0.
804    classification_one_hot: (bool) If the label is one hot representation.
805      Default to True. If False, classification classification_loss_fn need to
806      be sigmoid cross entropy loss instead.
807    add_summaries: (bool) Add the loss to the summary
808
809  Returns:
810    GANLoss namedtuple where we have generator loss and discriminator loss.
811
812  Raises:
813    ValueError: If input StarGANModel.input_data_domain_label does not have rank
814    2, or dimension 2 is not defined.
815  """
816
817  def _classification_loss_helper(true_labels, predict_logits, scope_name):
818    """Classification Loss Function Helper.
819
820    Args:
821      true_labels: Tensor of shape [batch_size, num_domains] representing the
822        label where each row is an one-hot vector.
823      predict_logits: Tensor of shape [batch_size, num_domains] representing the
824        predicted label logit, which is UNSCALED output from the NN.
825      scope_name: (string) Name scope of the loss component.
826
827    Returns:
828      Single scalar tensor representing the classification loss.
829    """
830
831    with ops.name_scope(scope_name, values=(true_labels, predict_logits)):
832
833      loss = classification_loss_fn(
834          onehot_labels=true_labels, logits=predict_logits)
835
836      if not classification_one_hot:
837        loss = math_ops.reduce_sum(loss, axis=1)
838      loss = math_ops.reduce_mean(loss)
839
840      if add_summaries:
841        summary.scalar(scope_name, loss)
842
843      return loss
844
845  # Check input shape.
846  model.input_data_domain_label.shape.assert_has_rank(2)
847  model.input_data_domain_label.shape[1:].assert_is_fully_defined()
848
849  # Adversarial Loss.
850  generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
851  discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
852
853  # Gradient Penalty.
854  if _use_aux_loss(gradient_penalty_weight):
855    gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper(
856        tfgan_losses_impl.wasserstein_gradient_penalty)
857    discriminator_loss += gradient_penalty_fn(
858        model,
859        epsilon=gradient_penalty_epsilon,
860        target=gradient_penalty_target,
861        one_sided=gradient_penalty_one_sided,
862        add_summaries=add_summaries) * gradient_penalty_weight
863
864  # Reconstruction Loss.
865  reconstruction_loss = reconstruction_loss_fn(model.input_data,
866                                               model.reconstructed_data)
867  generator_loss += reconstruction_loss * reconstruction_loss_weight
868  if add_summaries:
869    summary.scalar('reconstruction_loss', reconstruction_loss)
870
871  # Classification Loss.
872  generator_loss += _classification_loss_helper(
873      true_labels=model.generated_data_domain_target,
874      predict_logits=model.discriminator_generated_data_domain_predication,
875      scope_name='generator_classification_loss') * classification_loss_weight
876  discriminator_loss += _classification_loss_helper(
877      true_labels=model.input_data_domain_label,
878      predict_logits=model.discriminator_input_data_domain_predication,
879      scope_name='discriminator_classification_loss'
880  ) * classification_loss_weight
881
882  return namedtuples.GANLoss(generator_loss, discriminator_loss)
883
884
885def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
886  """Gets generator and discriminator update ops.
887
888  Args:
889    kwargs: A dictionary of kwargs to be passed to `create_train_op`.
890      `update_ops` is removed, if present.
891    gen_scope: A scope for the generator.
892    dis_scope: A scope for the discriminator.
893    check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
894      unused update ops.
895
896  Returns:
897    A 2-tuple of (generator update ops, discriminator train ops).
898
899  Raises:
900    ValueError: If there are update ops outside of the generator or
901      discriminator scopes.
902  """
903  if 'update_ops' in kwargs:
904    update_ops = set(kwargs['update_ops'])
905    del kwargs['update_ops']
906  else:
907    update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
908
909  all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope))
910  all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope))
911
912  if check_for_unused_ops:
913    unused_ops = update_ops - all_gen_ops - all_dis_ops
914    if unused_ops:
915      raise ValueError('There are unused update ops: %s' % unused_ops)
916
917  gen_update_ops = list(all_gen_ops & update_ops)
918  dis_update_ops = list(all_dis_ops & update_ops)
919
920  return gen_update_ops, dis_update_ops
921
922
923def gan_train_ops(
924    model,
925    loss,
926    generator_optimizer,
927    discriminator_optimizer,
928    check_for_unused_update_ops=True,
929    is_chief=True,
930    # Optional args to pass directly to the `create_train_op`.
931    **kwargs):
932  """Returns GAN train ops.
933
934  The highest-level call in TF-GAN. It is composed of functions that can also
935  be called, should a user require more control over some part of the GAN
936  training process.
937
938  Args:
939    model: A GANModel.
940    loss: A GANLoss.
941    generator_optimizer: The optimizer for generator updates.
942    discriminator_optimizer: The optimizer for the discriminator updates.
943    check_for_unused_update_ops: If `True`, throws an exception if there are
944      update ops outside of the generator or discriminator scopes.
945    is_chief: Specifies whether or not the training is being run by the primary
946      replica during replica training.
947    **kwargs: Keyword args to pass directly to
948      `training.create_train_op` for both the generator and
949      discriminator train op.
950
951  Returns:
952    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
953    be used to train a generator/discriminator pair.
954  """
955  if isinstance(model, namedtuples.CycleGANModel):
956    # Get and store all arguments other than model and loss from locals.
957    # Contents of locals should not be modified, may not affect values. So make
958    # a copy. https://docs.python.org/2/library/functions.html#locals.
959    saved_params = dict(locals())
960    saved_params.pop('model', None)
961    saved_params.pop('loss', None)
962    kwargs = saved_params.pop('kwargs', {})
963    saved_params.update(kwargs)
964    with ops.name_scope('cyclegan_x2y_train'):
965      train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
966                                    **saved_params)
967    with ops.name_scope('cyclegan_y2x_train'):
968      train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
969                                    **saved_params)
970    return namedtuples.GANTrainOps(
971        (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
972        (train_ops_x2y.discriminator_train_op,
973         train_ops_y2x.discriminator_train_op),
974        training_util.get_or_create_global_step().assign_add(1))
975
976  # Create global step increment op.
977  global_step = training_util.get_or_create_global_step()
978  global_step_inc = global_step.assign_add(1)
979
980  # Get generator and discriminator update ops. We split them so that update
981  # ops aren't accidentally run multiple times. For now, throw an error if
982  # there are update ops that aren't associated with either the generator or
983  # the discriminator. Might modify the `kwargs` dictionary.
984  gen_update_ops, dis_update_ops = _get_update_ops(
985      kwargs, model.generator_scope.name, model.discriminator_scope.name,
986      check_for_unused_update_ops)
987
988  # Get the sync hooks if these are needed.
989  sync_hooks = []
990
991  generator_global_step = None
992  if isinstance(generator_optimizer,
993                sync_replicas_optimizer.SyncReplicasOptimizer):
994    # TODO(joelshor): Figure out a way to get this work without including the
995    # dummy global step in the checkpoint.
996    # WARNING: Making this variable a local variable causes sync replicas to
997    # hang forever.
998    generator_global_step = variable_scope.get_variable(
999        'dummy_global_step_generator',
1000        shape=[],
1001        dtype=global_step.dtype.base_dtype,
1002        initializer=init_ops.zeros_initializer(),
1003        trainable=False,
1004        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
1005    gen_update_ops += [generator_global_step.assign(global_step)]
1006    sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief))
1007  with ops.name_scope('generator_train'):
1008    gen_train_op = training.create_train_op(
1009        total_loss=loss.generator_loss,
1010        optimizer=generator_optimizer,
1011        variables_to_train=model.generator_variables,
1012        global_step=generator_global_step,
1013        update_ops=gen_update_ops,
1014        **kwargs)
1015
1016  discriminator_global_step = None
1017  if isinstance(discriminator_optimizer,
1018                sync_replicas_optimizer.SyncReplicasOptimizer):
1019    # See comment above `generator_global_step`.
1020    discriminator_global_step = variable_scope.get_variable(
1021        'dummy_global_step_discriminator',
1022        shape=[],
1023        dtype=global_step.dtype.base_dtype,
1024        initializer=init_ops.zeros_initializer(),
1025        trainable=False,
1026        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
1027    dis_update_ops += [discriminator_global_step.assign(global_step)]
1028    sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief))
1029  with ops.name_scope('discriminator_train'):
1030    disc_train_op = training.create_train_op(
1031        total_loss=loss.discriminator_loss,
1032        optimizer=discriminator_optimizer,
1033        variables_to_train=model.discriminator_variables,
1034        global_step=discriminator_global_step,
1035        update_ops=dis_update_ops,
1036        **kwargs)
1037
1038  return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc,
1039                                 sync_hooks)
1040
1041
1042# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
1043# Image Compression` (https://arxiv.org/abs/1705.05823)
1044class RunTrainOpsHook(session_run_hook.SessionRunHook):
1045  """A hook to run train ops a fixed number of times."""
1046
1047  def __init__(self, train_ops, train_steps):
1048    """Run train ops a certain number of times.
1049
1050    Args:
1051      train_ops: A train op or iterable of train ops to run.
1052      train_steps: The number of times to run the op(s).
1053    """
1054    if not isinstance(train_ops, (list, tuple)):
1055      train_ops = [train_ops]
1056    self._train_ops = train_ops
1057    self._train_steps = train_steps
1058
1059  def before_run(self, run_context):
1060    for _ in range(self._train_steps):
1061      run_context.session.run(self._train_ops)
1062
1063
1064def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
1065  """Returns a hooks function for sequential GAN training.
1066
1067  Args:
1068    train_steps: A `GANTrainSteps` tuple that determines how many generator
1069      and discriminator training steps to take.
1070
1071  Returns:
1072    A function that takes a GANTrainOps tuple and returns a list of hooks.
1073  """
1074
1075  def get_hooks(train_ops):
1076    generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
1077                                     train_steps.generator_train_steps)
1078    discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
1079                                         train_steps.discriminator_train_steps)
1080    return [generator_hook, discriminator_hook] + list(train_ops.train_hooks)
1081
1082  return get_hooks
1083
1084
1085def _num_joint_steps(train_steps):
1086  g_steps = train_steps.generator_train_steps
1087  d_steps = train_steps.discriminator_train_steps
1088  # Get the number of each type of step that should be run.
1089  num_d_and_g_steps = min(g_steps, d_steps)
1090  num_g_steps = g_steps - num_d_and_g_steps
1091  num_d_steps = d_steps - num_d_and_g_steps
1092
1093  return num_d_and_g_steps, num_g_steps, num_d_steps
1094
1095
1096def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
1097  """Returns a hooks function for joint GAN training.
1098
1099  When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
1100  ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
1101
1102  The order of steps taken is:
1103  1) Combined generator and discriminator steps
1104  2) Generator only steps, if any remain
1105  3) Discriminator only steps, if any remain
1106
1107  **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
1108  for the generator and discriminator simultaneously whenever possible. This
1109  reduces the number of `tf.Session` calls, and can also change the training
1110  semantics.
1111
1112  To illustrate the difference look at the following example:
1113
1114  `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
1115  `get_sequential_train_hooks` to make 8 session calls:
1116    1) 3 generator steps
1117    2) 5 discriminator steps
1118
1119  In contrast, `get_joint_train_steps` will make 5 session calls:
1120  1) 3 generator + discriminator steps
1121  2) 2 discriminator steps
1122
1123  Args:
1124    train_steps: A `GANTrainSteps` tuple that determines how many generator
1125      and discriminator training steps to take.
1126
1127  Returns:
1128    A function that takes a GANTrainOps tuple and returns a list of hooks.
1129  """
1130  num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps)
1131
1132  def get_hooks(train_ops):
1133    g_op = train_ops.generator_train_op
1134    d_op = train_ops.discriminator_train_op
1135
1136    joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
1137    g_hook = RunTrainOpsHook(g_op, num_g_steps)
1138    d_hook = RunTrainOpsHook(d_op, num_d_steps)
1139
1140    return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks)
1141
1142  return get_hooks
1143
1144
1145# TODO(joelshor): This function currently returns the global step. Find a
1146# good way for it to return the generator, discriminator, and final losses.
1147def gan_train(train_ops,
1148              logdir,
1149              get_hooks_fn=get_sequential_train_hooks(),
1150              master='',
1151              is_chief=True,
1152              scaffold=None,
1153              hooks=None,
1154              chief_only_hooks=None,
1155              save_checkpoint_secs=600,
1156              save_summaries_steps=100,
1157              config=None):
1158  """A wrapper around `contrib.training.train` that uses GAN hooks.
1159
1160  Args:
1161    train_ops: A GANTrainOps named tuple.
1162    logdir: The directory where the graph and checkpoints are saved.
1163    get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
1164      of hooks.
1165    master: The URL of the master.
1166    is_chief: Specifies whether or not the training is being run by the primary
1167      replica during replica training.
1168    scaffold: An tf.train.Scaffold instance.
1169    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
1170      training loop.
1171    chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
1172      inside the training loop for the chief trainer only.
1173    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
1174      using a default checkpoint saver. If `save_checkpoint_secs` is set to
1175      `None`, then the default checkpoint saver isn't used.
1176    save_summaries_steps: The frequency, in number of global steps, that the
1177      summaries are written to disk using a default summary saver. If
1178      `save_summaries_steps` is set to `None`, then the default summary saver
1179      isn't used.
1180    config: An instance of `tf.ConfigProto`.
1181
1182  Returns:
1183    Output of the call to `training.train`.
1184  """
1185  new_hooks = get_hooks_fn(train_ops)
1186  if hooks is not None:
1187    hooks = list(hooks) + list(new_hooks)
1188  else:
1189    hooks = new_hooks
1190  return training.train(
1191      train_ops.global_step_inc_op,
1192      logdir,
1193      master=master,
1194      is_chief=is_chief,
1195      scaffold=scaffold,
1196      hooks=hooks,
1197      chief_only_hooks=chief_only_hooks,
1198      save_checkpoint_secs=save_checkpoint_secs,
1199      save_summaries_steps=save_summaries_steps,
1200      config=config)
1201
1202
1203def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)):
1204  """Returns a thin wrapper around slim.learning.train_step, for GANs.
1205
1206  This function is to provide support for the Supervisor. For new code, please
1207  use `MonitoredSession` and `get_sequential_train_hooks`.
1208
1209  Args:
1210    train_steps: A `GANTrainSteps` tuple that determines how many generator
1211      and discriminator training steps to take.
1212
1213  Returns:
1214    A function that can be used for `train_step_fn` for GANs.
1215  """
1216
1217  def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
1218    """A thin wrapper around slim.learning.train_step, for GANs.
1219
1220    Args:
1221      sess: A Tensorflow session.
1222      train_ops: A GANTrainOps tuple of train ops to run.
1223      global_step: The global step.
1224      train_step_kwargs: Dictionary controlling `train_step` behavior.
1225
1226    Returns:
1227      A scalar final loss and a bool whether or not the train loop should stop.
1228    """
1229    # Only run `should_stop` at the end, if required. Make a local copy of
1230    # `train_step_kwargs`, if necessary, so as not to modify the caller's
1231    # dictionary.
1232    should_stop_op, train_kwargs = None, train_step_kwargs
1233    if 'should_stop' in train_step_kwargs:
1234      should_stop_op = train_step_kwargs['should_stop']
1235      train_kwargs = train_step_kwargs.copy()
1236      del train_kwargs['should_stop']
1237
1238    # Run generator training steps.
1239    gen_loss = 0
1240    for _ in range(train_steps.generator_train_steps):
1241      cur_gen_loss, _ = slim_learning.train_step(
1242          sess, train_ops.generator_train_op, global_step, train_kwargs)
1243      gen_loss += cur_gen_loss
1244
1245    # Run discriminator training steps.
1246    dis_loss = 0
1247    for _ in range(train_steps.discriminator_train_steps):
1248      cur_dis_loss, _ = slim_learning.train_step(
1249          sess, train_ops.discriminator_train_op, global_step, train_kwargs)
1250      dis_loss += cur_dis_loss
1251
1252    sess.run(train_ops.global_step_inc_op)
1253
1254    # Run the `should_stop` op after the global step has been incremented, so
1255    # that the `should_stop` aligns with the proper `global_step` count.
1256    if should_stop_op is not None:
1257      should_stop = sess.run(should_stop_op)
1258    else:
1259      should_stop = False
1260
1261    return gen_loss + dis_loss, should_stop
1262
1263  return sequential_train_steps
1264
1265
1266# Helpers
1267
1268
1269def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
1270  """Convert input, list of inputs, or dictionary of inputs to Tensors."""
1271  if isinstance(tensor_or_l_or_d, (list, tuple)):
1272    return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
1273  elif isinstance(tensor_or_l_or_d, dict):
1274    return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
1275  else:
1276    return ops.convert_to_tensor(tensor_or_l_or_d)
1277
1278
1279def _validate_distributions(distributions_l, noise_l):
1280  if not isinstance(distributions_l, (tuple, list)):
1281    raise ValueError('`predicted_distributions` must be a list. Instead, found '
1282                     '%s.' % type(distributions_l))
1283  if len(distributions_l) != len(noise_l):
1284    raise ValueError('Length of `predicted_distributions` %i must be the same '
1285                     'as the length of structured noise %i.' %
1286                     (len(distributions_l), len(noise_l)))
1287
1288
1289def _validate_acgan_discriminator_outputs(discriminator_output):
1290  try:
1291    a, b = discriminator_output
1292  except (TypeError, ValueError):
1293    raise TypeError(
1294        'A discriminator function for ACGAN must output a tuple '
1295        'consisting of (discrimination logits, classification logits).')
1296  return a, b
1297
1298
1299def _generate_stargan_random_domain_target(batch_size, num_domains):
1300  """Generate random domain label.
1301
1302  Args:
1303    batch_size: (int) Number of random domain label.
1304    num_domains: (int) Number of domains representing with the label.
1305
1306  Returns:
1307    Tensor of shape (batch_size, num_domains) representing random label.
1308  """
1309  domain_idx = random_ops.random_uniform(
1310      [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32)
1311
1312  return array_ops.one_hot(domain_idx, num_domains)
1313