• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Losses that are useful for training GANs.
16
17The losses belong to two main groups, but there are others that do not:
181) xxxxx_generator_loss
192) xxxxx_discriminator_loss
20
21Example:
221) wasserstein_generator_loss
232) wasserstein_discriminator_loss
24
25Other example:
26wasserstein_gradient_penalty
27
28All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
29patchGAN style losses (https://arxiv.org/abs/1611.07004).
30
31To make these losses usable in the TF-GAN framework, please create a tuple
32version of the losses with `losses_utils.py`.
33"""
34
35from __future__ import absolute_import
36from __future__ import division
37from __future__ import print_function
38
39
40from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import clip_ops
46from tensorflow.python.ops import gradients_impl
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import random_ops
49from tensorflow.python.ops import variable_scope
50from tensorflow.python.ops.losses import losses
51from tensorflow.python.ops.losses import util
52from tensorflow.python.summary import summary
53
54
55__all__ = [
56    'acgan_discriminator_loss',
57    'acgan_generator_loss',
58    'least_squares_discriminator_loss',
59    'least_squares_generator_loss',
60    'modified_discriminator_loss',
61    'modified_generator_loss',
62    'minimax_discriminator_loss',
63    'minimax_generator_loss',
64    'wasserstein_discriminator_loss',
65    'wasserstein_generator_loss',
66    'wasserstein_gradient_penalty',
67    'mutual_information_penalty',
68    'combine_adversarial_loss',
69    'cycle_consistency_loss',
70]
71
72
73def _to_float(tensor):
74  return math_ops.cast(tensor, dtypes.float32)
75
76
77# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
78def wasserstein_generator_loss(
79    discriminator_gen_outputs,
80    weights=1.0,
81    scope=None,
82    loss_collection=ops.GraphKeys.LOSSES,
83    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
84    add_summaries=False):
85  """Wasserstein generator loss for GANs.
86
87  See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
88
89  Args:
90    discriminator_gen_outputs: Discriminator output on generated data. Expected
91      to be in the range of (-inf, inf).
92    weights: Optional `Tensor` whose rank is either 0, or the same rank as
93      `discriminator_gen_outputs`, and must be broadcastable to
94      `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or
95      the same as the corresponding dimension).
96    scope: The scope for the operations performed in computing the loss.
97    loss_collection: collection to which this loss will be added.
98    reduction: A `tf.losses.Reduction` to apply to loss.
99    add_summaries: Whether or not to add detailed summaries for the loss.
100
101  Returns:
102    A loss Tensor. The shape depends on `reduction`.
103  """
104  with ops.name_scope(scope, 'generator_wasserstein_loss', (
105      discriminator_gen_outputs, weights)) as scope:
106    discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
107
108    loss = - discriminator_gen_outputs
109    loss = losses.compute_weighted_loss(
110        loss, weights, scope, loss_collection, reduction)
111
112    if add_summaries:
113      summary.scalar('generator_wass_loss', loss)
114
115  return loss
116
117
118def wasserstein_discriminator_loss(
119    discriminator_real_outputs,
120    discriminator_gen_outputs,
121    real_weights=1.0,
122    generated_weights=1.0,
123    scope=None,
124    loss_collection=ops.GraphKeys.LOSSES,
125    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
126    add_summaries=False):
127  """Wasserstein discriminator loss for GANs.
128
129  See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
130
131  Args:
132    discriminator_real_outputs: Discriminator output on real data.
133    discriminator_gen_outputs: Discriminator output on generated data. Expected
134      to be in the range of (-inf, inf).
135    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
136      `discriminator_real_outputs`, and must be broadcastable to
137      `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or
138      the same as the corresponding dimension).
139    generated_weights: Same as `real_weights`, but for
140      `discriminator_gen_outputs`.
141    scope: The scope for the operations performed in computing the loss.
142    loss_collection: collection to which this loss will be added.
143    reduction: A `tf.losses.Reduction` to apply to loss.
144    add_summaries: Whether or not to add summaries for the loss.
145
146  Returns:
147    A loss Tensor. The shape depends on `reduction`.
148  """
149  with ops.name_scope(scope, 'discriminator_wasserstein_loss', (
150      discriminator_real_outputs, discriminator_gen_outputs, real_weights,
151      generated_weights)) as scope:
152    discriminator_real_outputs = _to_float(discriminator_real_outputs)
153    discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
154    discriminator_real_outputs.shape.assert_is_compatible_with(
155        discriminator_gen_outputs.shape)
156
157    loss_on_generated = losses.compute_weighted_loss(
158        discriminator_gen_outputs, generated_weights, scope,
159        loss_collection=None, reduction=reduction)
160    loss_on_real = losses.compute_weighted_loss(
161        discriminator_real_outputs, real_weights, scope, loss_collection=None,
162        reduction=reduction)
163    loss = loss_on_generated - loss_on_real
164    util.add_loss(loss, loss_collection)
165
166    if add_summaries:
167      summary.scalar('discriminator_gen_wass_loss', loss_on_generated)
168      summary.scalar('discriminator_real_wass_loss', loss_on_real)
169      summary.scalar('discriminator_wass_loss', loss)
170
171  return loss
172
173
174# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs`
175# (https://arxiv.org/abs/1610.09585).
176def acgan_discriminator_loss(
177    discriminator_real_classification_logits,
178    discriminator_gen_classification_logits,
179    one_hot_labels,
180    label_smoothing=0.0,
181    real_weights=1.0,
182    generated_weights=1.0,
183    scope=None,
184    loss_collection=ops.GraphKeys.LOSSES,
185    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
186    add_summaries=False):
187  """ACGAN loss for the discriminator.
188
189  The ACGAN loss adds a classification loss to the conditional discriminator.
190  Therefore, the discriminator must output a tuple consisting of
191    (1) the real/fake prediction and
192    (2) the logits for the classification (usually the last conv layer,
193        flattened).
194
195  For more details:
196    ACGAN: https://arxiv.org/abs/1610.09585
197
198  Args:
199    discriminator_real_classification_logits: Classification logits for real
200      data.
201    discriminator_gen_classification_logits: Classification logits for generated
202      data.
203    one_hot_labels: A Tensor holding one-hot labels for the batch.
204    label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for
205      "discriminator on real data" as suggested in
206      https://arxiv.org/pdf/1701.00160
207    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
208      `discriminator_real_outputs`, and must be broadcastable to
209      `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or
210      the same as the corresponding dimension).
211    generated_weights: Same as `real_weights`, but for
212      `discriminator_gen_classification_logits`.
213    scope: The scope for the operations performed in computing the loss.
214    loss_collection: collection to which this loss will be added.
215    reduction: A `tf.losses.Reduction` to apply to loss.
216    add_summaries: Whether or not to add summaries for the loss.
217
218  Returns:
219    A loss Tensor. Shape depends on `reduction`.
220
221  Raises:
222    TypeError: If the discriminator does not output a tuple.
223  """
224  with ops.name_scope(
225      scope, 'acgan_discriminator_loss',
226      (discriminator_real_classification_logits,
227       discriminator_gen_classification_logits, one_hot_labels)) as scope:
228    loss_on_generated = losses.softmax_cross_entropy(
229        one_hot_labels, discriminator_gen_classification_logits,
230        weights=generated_weights, scope=scope, loss_collection=None,
231        reduction=reduction)
232    loss_on_real = losses.softmax_cross_entropy(
233        one_hot_labels, discriminator_real_classification_logits,
234        weights=real_weights, label_smoothing=label_smoothing, scope=scope,
235        loss_collection=None, reduction=reduction)
236    loss = loss_on_generated + loss_on_real
237    util.add_loss(loss, loss_collection)
238
239    if add_summaries:
240      summary.scalar('discriminator_gen_ac_loss', loss_on_generated)
241      summary.scalar('discriminator_real_ac_loss', loss_on_real)
242      summary.scalar('discriminator_ac_loss', loss)
243
244  return loss
245
246
247def acgan_generator_loss(
248    discriminator_gen_classification_logits,
249    one_hot_labels,
250    weights=1.0,
251    scope=None,
252    loss_collection=ops.GraphKeys.LOSSES,
253    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
254    add_summaries=False):
255  """ACGAN loss for the generator.
256
257  The ACGAN loss adds a classification loss to the conditional discriminator.
258  Therefore, the discriminator must output a tuple consisting of
259    (1) the real/fake prediction and
260    (2) the logits for the classification (usually the last conv layer,
261        flattened).
262
263  For more details:
264    ACGAN: https://arxiv.org/abs/1610.09585
265
266  Args:
267    discriminator_gen_classification_logits: Classification logits for generated
268      data.
269    one_hot_labels: A Tensor holding one-hot labels for the batch.
270    weights: Optional `Tensor` whose rank is either 0, or the same rank as
271      `discriminator_gen_classification_logits`, and must be broadcastable to
272      `discriminator_gen_classification_logits` (i.e., all dimensions must be
273      either `1`, or the same as the corresponding dimension).
274    scope: The scope for the operations performed in computing the loss.
275    loss_collection: collection to which this loss will be added.
276    reduction: A `tf.losses.Reduction` to apply to loss.
277    add_summaries: Whether or not to add summaries for the loss.
278
279  Returns:
280    A loss Tensor. Shape depends on `reduction`.
281
282  Raises:
283    ValueError: if arg module not either `generator` or `discriminator`
284    TypeError: if the discriminator does not output a tuple.
285  """
286  with ops.name_scope(
287      scope, 'acgan_generator_loss',
288      (discriminator_gen_classification_logits, one_hot_labels)) as scope:
289    loss = losses.softmax_cross_entropy(
290        one_hot_labels, discriminator_gen_classification_logits,
291        weights=weights, scope=scope, loss_collection=loss_collection,
292        reduction=reduction)
293
294    if add_summaries:
295      summary.scalar('generator_ac_loss', loss)
296
297  return loss
298
299
300# Wasserstein Gradient Penalty losses from `Improved Training of Wasserstein
301# GANs` (https://arxiv.org/abs/1704.00028).
302
303
304def wasserstein_gradient_penalty(
305    real_data,
306    generated_data,
307    generator_inputs,
308    discriminator_fn,
309    discriminator_scope,
310    epsilon=1e-10,
311    target=1.0,
312    one_sided=False,
313    weights=1.0,
314    scope=None,
315    loss_collection=ops.GraphKeys.LOSSES,
316    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
317    add_summaries=False):
318  """The gradient penalty for the Wasserstein discriminator loss.
319
320  See `Improved Training of Wasserstein GANs`
321  (https://arxiv.org/abs/1704.00028) for more details.
322
323  Args:
324    real_data: Real data.
325    generated_data: Output of the generator.
326    generator_inputs: Exact argument to pass to the generator, which is used
327      as optional conditioning to the discriminator.
328    discriminator_fn: A discriminator function that conforms to TF-GAN API.
329    discriminator_scope: If not `None`, reuse discriminators from this scope.
330    epsilon: A small positive number added for numerical stability when
331      computing the gradient norm.
332    target: Optional Python number or `Tensor` indicating the target value of
333      gradient norm. Defaults to 1.0.
334    one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894
335      is used. Defaults to `False`.
336    weights: Optional `Tensor` whose rank is either 0, or the same rank as
337      `real_data` and `generated_data`, and must be broadcastable to
338      them (i.e., all dimensions must be either `1`, or the same as the
339      corresponding dimension).
340    scope: The scope for the operations performed in computing the loss.
341    loss_collection: collection to which this loss will be added.
342    reduction: A `tf.losses.Reduction` to apply to loss.
343    add_summaries: Whether or not to add summaries for the loss.
344
345  Returns:
346    A loss Tensor. The shape depends on `reduction`.
347
348  Raises:
349    ValueError: If the rank of data Tensors is unknown.
350  """
351  with ops.name_scope(scope, 'wasserstein_gradient_penalty',
352                      (real_data, generated_data)) as scope:
353    real_data = ops.convert_to_tensor(real_data)
354    generated_data = ops.convert_to_tensor(generated_data)
355    if real_data.shape.ndims is None:
356      raise ValueError('`real_data` can\'t have unknown rank.')
357    if generated_data.shape.ndims is None:
358      raise ValueError('`generated_data` can\'t have unknown rank.')
359
360    differences = generated_data - real_data
361    batch_size = differences.shape.dims[0].value or array_ops.shape(
362        differences)[0]
363    alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
364    alpha = random_ops.random_uniform(shape=alpha_shape)
365    interpolates = real_data + (alpha * differences)
366
367    with ops.name_scope(None):  # Clear scope so update ops are added properly.
368      # Reuse variables if variables already exists.
369      with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope',
370                                         reuse=variable_scope.AUTO_REUSE):
371        disc_interpolates = discriminator_fn(interpolates, generator_inputs)
372
373    if isinstance(disc_interpolates, tuple):
374      # ACGAN case: disc outputs more than one tensor
375      disc_interpolates = disc_interpolates[0]
376
377    gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0]
378    gradient_squares = math_ops.reduce_sum(
379        math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims)))
380    # Propagate shape information, if possible.
381    if isinstance(batch_size, int):
382      gradient_squares.set_shape([
383          batch_size] + gradient_squares.shape.as_list()[1:])
384    # For numerical stability, add epsilon to the sum before taking the square
385    # root. Note tf.norm does not add epsilon.
386    slopes = math_ops.sqrt(gradient_squares + epsilon)
387    penalties = slopes / target - 1.0
388    if one_sided:
389      penalties = math_ops.maximum(0., penalties)
390    penalties_squared = math_ops.square(penalties)
391    penalty = losses.compute_weighted_loss(
392        penalties_squared, weights, scope=scope,
393        loss_collection=loss_collection, reduction=reduction)
394
395    if add_summaries:
396      summary.scalar('gradient_penalty_loss', penalty)
397
398    return penalty
399
400
401# Original losses from `Generative Adversarial Nets`
402# (https://arxiv.org/abs/1406.2661).
403
404
405def minimax_discriminator_loss(
406    discriminator_real_outputs,
407    discriminator_gen_outputs,
408    label_smoothing=0.25,
409    real_weights=1.0,
410    generated_weights=1.0,
411    scope=None,
412    loss_collection=ops.GraphKeys.LOSSES,
413    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
414    add_summaries=False):
415  """Original minimax discriminator loss for GANs, with label smoothing.
416
417  Note that the authors don't recommend using this loss. A more practically
418  useful loss is `modified_discriminator_loss`.
419
420  L = - real_weights * log(sigmoid(D(x)))
421      - generated_weights * log(1 - sigmoid(D(G(z))))
422
423  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
424  details.
425
426  Args:
427    discriminator_real_outputs: Discriminator output on real data.
428    discriminator_gen_outputs: Discriminator output on generated data. Expected
429      to be in the range of (-inf, inf).
430    label_smoothing: The amount of smoothing for positive labels. This technique
431      is taken from `Improved Techniques for Training GANs`
432      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
433    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
434      `real_data`, and must be broadcastable to `real_data` (i.e., all
435      dimensions must be either `1`, or the same as the corresponding
436      dimension).
437    generated_weights: Same as `real_weights`, but for `generated_data`.
438    scope: The scope for the operations performed in computing the loss.
439    loss_collection: collection to which this loss will be added.
440    reduction: A `tf.losses.Reduction` to apply to loss.
441    add_summaries: Whether or not to add summaries for the loss.
442
443  Returns:
444    A loss Tensor. The shape depends on `reduction`.
445  """
446  with ops.name_scope(scope, 'discriminator_minimax_loss', (
447      discriminator_real_outputs, discriminator_gen_outputs, real_weights,
448      generated_weights, label_smoothing)) as scope:
449
450    # -log((1 - label_smoothing) - sigmoid(D(x)))
451    loss_on_real = losses.sigmoid_cross_entropy(
452        array_ops.ones_like(discriminator_real_outputs),
453        discriminator_real_outputs, real_weights, label_smoothing, scope,
454        loss_collection=None, reduction=reduction)
455    # -log(- sigmoid(D(G(x))))
456    loss_on_generated = losses.sigmoid_cross_entropy(
457        array_ops.zeros_like(discriminator_gen_outputs),
458        discriminator_gen_outputs, generated_weights, scope=scope,
459        loss_collection=None, reduction=reduction)
460
461    loss = loss_on_real + loss_on_generated
462    util.add_loss(loss, loss_collection)
463
464    if add_summaries:
465      summary.scalar('discriminator_gen_minimax_loss', loss_on_generated)
466      summary.scalar('discriminator_real_minimax_loss', loss_on_real)
467      summary.scalar('discriminator_minimax_loss', loss)
468
469  return loss
470
471
472def minimax_generator_loss(
473    discriminator_gen_outputs,
474    label_smoothing=0.0,
475    weights=1.0,
476    scope=None,
477    loss_collection=ops.GraphKeys.LOSSES,
478    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
479    add_summaries=False):
480  """Original minimax generator loss for GANs.
481
482  Note that the authors don't recommend using this loss. A more practically
483  useful loss is `modified_generator_loss`.
484
485  L = log(sigmoid(D(x))) + log(1 - sigmoid(D(G(z))))
486
487  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
488  details.
489
490  Args:
491    discriminator_gen_outputs: Discriminator output on generated data. Expected
492      to be in the range of (-inf, inf).
493    label_smoothing: The amount of smoothing for positive labels. This technique
494      is taken from `Improved Techniques for Training GANs`
495      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
496    weights: Optional `Tensor` whose rank is either 0, or the same rank as
497      `discriminator_gen_outputs`, and must be broadcastable to
498      `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or
499      the same as the corresponding dimension).
500    scope: The scope for the operations performed in computing the loss.
501    loss_collection: collection to which this loss will be added.
502    reduction: A `tf.losses.Reduction` to apply to loss.
503    add_summaries: Whether or not to add summaries for the loss.
504
505  Returns:
506    A loss Tensor. The shape depends on `reduction`.
507  """
508  with ops.name_scope(scope, 'generator_minimax_loss') as scope:
509    loss = - minimax_discriminator_loss(
510        array_ops.ones_like(discriminator_gen_outputs),
511        discriminator_gen_outputs, label_smoothing, weights, weights, scope,
512        loss_collection, reduction, add_summaries=False)
513
514  if add_summaries:
515    summary.scalar('generator_minimax_loss', loss)
516
517  return loss
518
519
520def modified_discriminator_loss(
521    discriminator_real_outputs,
522    discriminator_gen_outputs,
523    label_smoothing=0.25,
524    real_weights=1.0,
525    generated_weights=1.0,
526    scope=None,
527    loss_collection=ops.GraphKeys.LOSSES,
528    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
529    add_summaries=False):
530  """Same as minimax discriminator loss.
531
532  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
533  details.
534
535  Args:
536    discriminator_real_outputs: Discriminator output on real data.
537    discriminator_gen_outputs: Discriminator output on generated data. Expected
538      to be in the range of (-inf, inf).
539    label_smoothing: The amount of smoothing for positive labels. This technique
540      is taken from `Improved Techniques for Training GANs`
541      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
542    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
543      `discriminator_gen_outputs`, and must be broadcastable to
544      `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or
545      the same as the corresponding dimension).
546    generated_weights: Same as `real_weights`, but for
547      `discriminator_gen_outputs`.
548    scope: The scope for the operations performed in computing the loss.
549    loss_collection: collection to which this loss will be added.
550    reduction: A `tf.losses.Reduction` to apply to loss.
551    add_summaries: Whether or not to add summaries for the loss.
552
553  Returns:
554    A loss Tensor. The shape depends on `reduction`.
555  """
556  return minimax_discriminator_loss(
557      discriminator_real_outputs,
558      discriminator_gen_outputs,
559      label_smoothing,
560      real_weights,
561      generated_weights,
562      scope or 'discriminator_modified_loss',
563      loss_collection,
564      reduction,
565      add_summaries)
566
567
568def modified_generator_loss(
569    discriminator_gen_outputs,
570    label_smoothing=0.0,
571    weights=1.0,
572    scope=None,
573    loss_collection=ops.GraphKeys.LOSSES,
574    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
575    add_summaries=False):
576  """Modified generator loss for GANs.
577
578  L = -log(sigmoid(D(G(z))))
579
580  This is the trick used in the original paper to avoid vanishing gradients
581  early in training. See `Generative Adversarial Nets`
582  (https://arxiv.org/abs/1406.2661) for more details.
583
584  Args:
585    discriminator_gen_outputs: Discriminator output on generated data. Expected
586      to be in the range of (-inf, inf).
587    label_smoothing: The amount of smoothing for positive labels. This technique
588      is taken from `Improved Techniques for Training GANs`
589      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
590    weights: Optional `Tensor` whose rank is either 0, or the same rank as
591      `discriminator_gen_outputs`, and must be broadcastable to `labels` (i.e.,
592      all dimensions must be either `1`, or the same as the corresponding
593      dimension).
594    scope: The scope for the operations performed in computing the loss.
595    loss_collection: collection to which this loss will be added.
596    reduction: A `tf.losses.Reduction` to apply to loss.
597    add_summaries: Whether or not to add summaries for the loss.
598
599  Returns:
600    A loss Tensor. The shape depends on `reduction`.
601  """
602  with ops.name_scope(scope, 'generator_modified_loss',
603                      [discriminator_gen_outputs]) as scope:
604    loss = losses.sigmoid_cross_entropy(
605        array_ops.ones_like(discriminator_gen_outputs),
606        discriminator_gen_outputs, weights, label_smoothing, scope,
607        loss_collection, reduction)
608
609    if add_summaries:
610      summary.scalar('generator_modified_loss', loss)
611
612  return loss
613
614
615# Least Squares loss from `Least Squares Generative Adversarial Networks`
616# (https://arxiv.org/abs/1611.04076).
617
618
619def least_squares_generator_loss(
620    discriminator_gen_outputs,
621    real_label=1,
622    weights=1.0,
623    scope=None,
624    loss_collection=ops.GraphKeys.LOSSES,
625    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
626    add_summaries=False):
627  """Least squares generator loss.
628
629  This loss comes from `Least Squares Generative Adversarial Networks`
630  (https://arxiv.org/abs/1611.04076).
631
632  L = 1/2 * (D(G(z)) - `real_label`) ** 2
633
634  where D(y) are discriminator logits.
635
636  Args:
637    discriminator_gen_outputs: Discriminator output on generated data. Expected
638      to be in the range of (-inf, inf).
639    real_label: The value that the generator is trying to get the discriminator
640      to output on generated data.
641    weights: Optional `Tensor` whose rank is either 0, or the same rank as
642      `discriminator_gen_outputs`, and must be broadcastable to
643      `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or
644      the same as the corresponding dimension).
645    scope: The scope for the operations performed in computing the loss.
646    loss_collection: collection to which this loss will be added.
647    reduction: A `tf.losses.Reduction` to apply to loss.
648    add_summaries: Whether or not to add summaries for the loss.
649
650  Returns:
651    A loss Tensor. The shape depends on `reduction`.
652  """
653  with ops.name_scope(scope, 'lsq_generator_loss',
654                      (discriminator_gen_outputs, real_label)) as scope:
655    discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
656    loss = math_ops.squared_difference(
657        discriminator_gen_outputs, real_label) / 2.0
658    loss = losses.compute_weighted_loss(
659        loss, weights, scope, loss_collection, reduction)
660
661  if add_summaries:
662    summary.scalar('generator_lsq_loss', loss)
663
664  return loss
665
666
667def least_squares_discriminator_loss(
668    discriminator_real_outputs,
669    discriminator_gen_outputs,
670    real_label=1,
671    fake_label=0,
672    real_weights=1.0,
673    generated_weights=1.0,
674    scope=None,
675    loss_collection=ops.GraphKeys.LOSSES,
676    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
677    add_summaries=False):
678  """Least squares discriminator loss.
679
680  This loss comes from `Least Squares Generative Adversarial Networks`
681  (https://arxiv.org/abs/1611.04076).
682
683  L = 1/2 * (D(x) - `real`) ** 2 +
684      1/2 * (D(G(z)) - `fake_label`) ** 2
685
686  where D(y) are discriminator logits.
687
688  Args:
689    discriminator_real_outputs: Discriminator output on real data.
690    discriminator_gen_outputs: Discriminator output on generated data. Expected
691      to be in the range of (-inf, inf).
692    real_label: The value that the discriminator tries to output for real data.
693    fake_label: The value that the discriminator tries to output for fake data.
694    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
695      `discriminator_real_outputs`, and must be broadcastable to
696      `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or
697      the same as the corresponding dimension).
698    generated_weights: Same as `real_weights`, but for
699      `discriminator_gen_outputs`.
700    scope: The scope for the operations performed in computing the loss.
701    loss_collection: collection to which this loss will be added.
702    reduction: A `tf.losses.Reduction` to apply to loss.
703    add_summaries: Whether or not to add summaries for the loss.
704
705  Returns:
706    A loss Tensor. The shape depends on `reduction`.
707  """
708  with ops.name_scope(scope, 'lsq_discriminator_loss',
709                      (discriminator_gen_outputs, real_label)) as scope:
710    discriminator_real_outputs = _to_float(discriminator_real_outputs)
711    discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
712    discriminator_real_outputs.shape.assert_is_compatible_with(
713        discriminator_gen_outputs.shape)
714
715    real_losses = math_ops.squared_difference(
716        discriminator_real_outputs, real_label) / 2.0
717    fake_losses = math_ops.squared_difference(
718        discriminator_gen_outputs, fake_label) / 2.0
719
720    loss_on_real = losses.compute_weighted_loss(
721        real_losses, real_weights, scope, loss_collection=None,
722        reduction=reduction)
723    loss_on_generated = losses.compute_weighted_loss(
724        fake_losses, generated_weights, scope, loss_collection=None,
725        reduction=reduction)
726
727    loss = loss_on_real + loss_on_generated
728    util.add_loss(loss, loss_collection)
729
730  if add_summaries:
731    summary.scalar('discriminator_gen_lsq_loss', loss_on_generated)
732    summary.scalar('discriminator_real_lsq_loss', loss_on_real)
733    summary.scalar('discriminator_lsq_loss', loss)
734
735  return loss
736
737
738# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by
739# `Information Maximizing Generative Adversarial Nets`
740# https://arxiv.org/abs/1606.03657
741
742
743def _validate_distributions(distributions):
744  if not isinstance(distributions, (list, tuple)):
745    raise ValueError('`distributions` must be a list or tuple. Instead, '
746                     'found %s.' % type(distributions))
747  for x in distributions:
748    # We used to check with `isinstance(x, tf.distributions.Distribution)`.
749    # However, distributions have migrated to `tfp.distributions.Distribution`,
750    # which is a new code repo, so we can't check this way anymore until
751    # TF-GAN is migrated to a new repo as well.
752    # This new check is not sufficient, but is a useful heuristic for now.
753    if not callable(getattr(x, 'log_prob', None)):
754      raise ValueError('`distributions` must be a list of `Distributions`. '
755                       'Instead, found %s.' % type(x))
756
757
758def _validate_information_penalty_inputs(
759    structured_generator_inputs, predicted_distributions):
760  """Validate input to `mutual_information_penalty`."""
761  _validate_distributions(predicted_distributions)
762  if len(structured_generator_inputs) != len(predicted_distributions):
763    raise ValueError('`structured_generator_inputs` length %i must be the same '
764                     'as `predicted_distributions` length %i.' % (
765                         len(structured_generator_inputs),
766                         len(predicted_distributions)))
767
768
769def mutual_information_penalty(
770    structured_generator_inputs,
771    predicted_distributions,
772    weights=1.0,
773    scope=None,
774    loss_collection=ops.GraphKeys.LOSSES,
775    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
776    add_summaries=False):
777  """Returns a penalty on the mutual information in an InfoGAN model.
778
779  This loss comes from an InfoGAN paper https://arxiv.org/abs/1606.03657.
780
781  Args:
782    structured_generator_inputs: A list of Tensors representing the random noise
783      that must  have high mutual information with the generator output. List
784      length should match `predicted_distributions`.
785    predicted_distributions: A list of `tfp.distributions.Distribution`s.
786      Predicted by the recognizer, and used to evaluate the likelihood of the
787      structured noise. List length should match `structured_generator_inputs`.
788    weights: Optional `Tensor` whose rank is either 0, or the same dimensions as
789      `structured_generator_inputs`.
790    scope: The scope for the operations performed in computing the loss.
791    loss_collection: collection to which this loss will be added.
792    reduction: A `tf.losses.Reduction` to apply to loss.
793    add_summaries: Whether or not to add summaries for the loss.
794
795  Returns:
796    A scalar Tensor representing the mutual information loss.
797  """
798  _validate_information_penalty_inputs(
799      structured_generator_inputs, predicted_distributions)
800
801  with ops.name_scope(scope, 'mutual_information_loss') as scope:
802    # Calculate the negative log-likelihood of the reconstructed noise.
803    log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in
804                 zip(predicted_distributions, structured_generator_inputs)]
805    loss = -1 * losses.compute_weighted_loss(
806        log_probs, weights, scope, loss_collection=loss_collection,
807        reduction=reduction)
808
809    if add_summaries:
810      summary.scalar('mutual_information_penalty', loss)
811
812  return loss
813
814
815def _numerically_stable_global_norm(tensor_list):
816  """Compute the global norm of a list of Tensors, with improved stability.
817
818  The global norm computation sometimes overflows due to the intermediate L2
819  step. To avoid this, we divide by a cheap-to-compute max over the
820  matrix elements.
821
822  Args:
823    tensor_list: A list of tensors, or `None`.
824
825  Returns:
826    A scalar tensor with the global norm.
827  """
828  if all(x is None for x in tensor_list):
829    return 0.0
830
831  list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in
832                                  tensor_list if x is not None])
833  return list_max * clip_ops.global_norm([x / list_max for x in tensor_list
834                                          if x is not None])
835
836
837def _used_weight(weights_list):
838  for weight in weights_list:
839    if weight is not None:
840      return tensor_util.constant_value(ops.convert_to_tensor(weight))
841
842
843def _validate_args(losses_list, weight_factor, gradient_ratio):
844  for loss in losses_list:
845    loss.shape.assert_is_compatible_with([])
846  if weight_factor is None and gradient_ratio is None:
847    raise ValueError(
848        '`weight_factor` and `gradient_ratio` cannot both be `None.`')
849  if weight_factor is not None and gradient_ratio is not None:
850    raise ValueError(
851        '`weight_factor` and `gradient_ratio` cannot both be specified.')
852
853
854# TODO(joelshor): Add ability to pass in gradients, to avoid recomputing.
855def combine_adversarial_loss(main_loss,
856                             adversarial_loss,
857                             weight_factor=None,
858                             gradient_ratio=None,
859                             gradient_ratio_epsilon=1e-6,
860                             variables=None,
861                             scalar_summaries=True,
862                             gradient_summaries=True,
863                             scope=None):
864  """Utility to combine main and adversarial losses.
865
866  This utility combines the main and adversarial losses in one of two ways.
867  1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
868  2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
869    used to make sure both losses affect weights roughly equally, as in
870    https://arxiv.org/pdf/1705.05823.
871
872  One can optionally also visualize the scalar and gradient behavior of the
873  losses.
874
875  Args:
876    main_loss: A floating scalar Tensor indicating the main loss.
877    adversarial_loss: A floating scalar Tensor indication the adversarial loss.
878    weight_factor: If not `None`, the coefficient by which to multiply the
879      adversarial loss. Exactly one of this and `gradient_ratio` must be
880      non-None.
881    gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
882      Specifically,
883        gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
884      Exactly one of this and `weight_factor` must be non-None.
885    gradient_ratio_epsilon: An epsilon to add to the adversarial loss
886      coefficient denominator, to avoid division-by-zero.
887    variables: List of variables to calculate gradients with respect to. If not
888      present, defaults to all trainable variables.
889    scalar_summaries: Create scalar summaries of losses.
890    gradient_summaries: Create gradient summaries of losses.
891    scope: Optional name scope.
892
893  Returns:
894    A floating scalar Tensor indicating the desired combined loss.
895
896  Raises:
897    ValueError: Malformed input.
898  """
899  _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio)
900  if variables is None:
901    variables = contrib_variables_lib.get_trainable_variables()
902
903  with ops.name_scope(scope, 'adversarial_loss',
904                      values=[main_loss, adversarial_loss]):
905    # Compute gradients if we will need them.
906    if gradient_summaries or gradient_ratio is not None:
907      main_loss_grad_mag = _numerically_stable_global_norm(
908          gradients_impl.gradients(main_loss, variables))
909      adv_loss_grad_mag = _numerically_stable_global_norm(
910          gradients_impl.gradients(adversarial_loss, variables))
911
912    # Add summaries, if applicable.
913    if scalar_summaries:
914      summary.scalar('main_loss', main_loss)
915      summary.scalar('adversarial_loss', adversarial_loss)
916    if gradient_summaries:
917      summary.scalar('main_loss_gradients', main_loss_grad_mag)
918      summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag)
919
920    # Combine losses in the appropriate way.
921    # If `weight_factor` is always `0`, avoid computing the adversarial loss
922    # tensor entirely.
923    if _used_weight((weight_factor, gradient_ratio)) == 0:
924      final_loss = main_loss
925    elif weight_factor is not None:
926      final_loss = (main_loss +
927                    array_ops.stop_gradient(weight_factor) * adversarial_loss)
928    elif gradient_ratio is not None:
929      grad_mag_ratio = main_loss_grad_mag / (
930          adv_loss_grad_mag + gradient_ratio_epsilon)
931      adv_coeff = grad_mag_ratio / gradient_ratio
932      summary.scalar('adversarial_coefficient', adv_coeff)
933      final_loss = (main_loss +
934                    array_ops.stop_gradient(adv_coeff) * adversarial_loss)
935
936  return final_loss
937
938
939def cycle_consistency_loss(data_x,
940                           reconstructed_data_x,
941                           data_y,
942                           reconstructed_data_y,
943                           scope=None,
944                           add_summaries=False):
945  """Defines the cycle consistency loss.
946
947  The cyclegan model has two partial models where `model_x2y` generator F maps
948  data set X to Y, `model_y2x` generator G maps data set Y to X. For a `data_x`
949  in data set X, we could reconstruct it by
950  * reconstructed_data_x = G(F(data_x))
951  Similarly
952  * reconstructed_data_y = F(G(data_y))
953
954  The cycle consistency loss is about the difference between data and
955  reconstructed data, namely
956  * loss_x2x = |data_x - G(F(data_x))| (L1-norm)
957  * loss_y2y = |data_y - F(G(data_y))| (L1-norm)
958  * loss = (loss_x2x + loss_y2y) / 2
959  where `loss` is the final result.
960
961  For the L1-norm, we follow the original implementation:
962  https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua
963  we use L1-norm of pixel-wise error normalized by data size such that
964  `cycle_loss_weight` can be specified independent of image size.
965
966  See https://arxiv.org/abs/1703.10593 for more details.
967
968  Args:
969    data_x: A `Tensor` of data X.
970    reconstructed_data_x: A `Tensor` of reconstructed data X.
971    data_y: A `Tensor` of data Y.
972    reconstructed_data_y: A `Tensor` of reconstructed data Y.
973    scope: The scope for the operations performed in computing the loss.
974      Defaults to None.
975    add_summaries: Whether or not to add detailed summaries for the loss.
976      Defaults to False.
977
978  Returns:
979    A scalar `Tensor` of cycle consistency loss.
980  """
981
982  with ops.name_scope(
983      scope,
984      'cycle_consistency_loss',
985      values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]):
986    loss_x2x = losses.absolute_difference(data_x, reconstructed_data_x)
987    loss_y2y = losses.absolute_difference(data_y, reconstructed_data_y)
988    loss = (loss_x2x + loss_y2y) / 2.0
989    if add_summaries:
990      summary.scalar('cycle_consistency_loss_x2x', loss_x2x)
991      summary.scalar('cycle_consistency_loss_y2y', loss_y2y)
992      summary.scalar('cycle_consistency_loss', loss)
993
994  return loss
995