1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Losses that are useful for training GANs. 16 17The losses belong to two main groups, but there are others that do not: 181) xxxxx_generator_loss 192) xxxxx_discriminator_loss 20 21Example: 221) wasserstein_generator_loss 232) wasserstein_discriminator_loss 24 25Other example: 26wasserstein_gradient_penalty 27 28All losses must be able to accept 1D or 2D Tensors, so as to be compatible with 29patchGAN style losses (https://arxiv.org/abs/1611.07004). 30 31To make these losses usable in the TF-GAN framework, please create a tuple 32version of the losses with `losses_utils.py`. 33""" 34 35from __future__ import absolute_import 36from __future__ import division 37from __future__ import print_function 38 39 40from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import clip_ops 46from tensorflow.python.ops import gradients_impl 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import random_ops 49from tensorflow.python.ops import variable_scope 50from tensorflow.python.ops.losses import losses 51from tensorflow.python.ops.losses import util 52from tensorflow.python.summary import summary 53 54 55__all__ = [ 56 'acgan_discriminator_loss', 57 'acgan_generator_loss', 58 'least_squares_discriminator_loss', 59 'least_squares_generator_loss', 60 'modified_discriminator_loss', 61 'modified_generator_loss', 62 'minimax_discriminator_loss', 63 'minimax_generator_loss', 64 'wasserstein_discriminator_loss', 65 'wasserstein_generator_loss', 66 'wasserstein_gradient_penalty', 67 'mutual_information_penalty', 68 'combine_adversarial_loss', 69 'cycle_consistency_loss', 70] 71 72 73def _to_float(tensor): 74 return math_ops.cast(tensor, dtypes.float32) 75 76 77# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). 78def wasserstein_generator_loss( 79 discriminator_gen_outputs, 80 weights=1.0, 81 scope=None, 82 loss_collection=ops.GraphKeys.LOSSES, 83 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 84 add_summaries=False): 85 """Wasserstein generator loss for GANs. 86 87 See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details. 88 89 Args: 90 discriminator_gen_outputs: Discriminator output on generated data. Expected 91 to be in the range of (-inf, inf). 92 weights: Optional `Tensor` whose rank is either 0, or the same rank as 93 `discriminator_gen_outputs`, and must be broadcastable to 94 `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or 95 the same as the corresponding dimension). 96 scope: The scope for the operations performed in computing the loss. 97 loss_collection: collection to which this loss will be added. 98 reduction: A `tf.losses.Reduction` to apply to loss. 99 add_summaries: Whether or not to add detailed summaries for the loss. 100 101 Returns: 102 A loss Tensor. The shape depends on `reduction`. 103 """ 104 with ops.name_scope(scope, 'generator_wasserstein_loss', ( 105 discriminator_gen_outputs, weights)) as scope: 106 discriminator_gen_outputs = _to_float(discriminator_gen_outputs) 107 108 loss = - discriminator_gen_outputs 109 loss = losses.compute_weighted_loss( 110 loss, weights, scope, loss_collection, reduction) 111 112 if add_summaries: 113 summary.scalar('generator_wass_loss', loss) 114 115 return loss 116 117 118def wasserstein_discriminator_loss( 119 discriminator_real_outputs, 120 discriminator_gen_outputs, 121 real_weights=1.0, 122 generated_weights=1.0, 123 scope=None, 124 loss_collection=ops.GraphKeys.LOSSES, 125 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 126 add_summaries=False): 127 """Wasserstein discriminator loss for GANs. 128 129 See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details. 130 131 Args: 132 discriminator_real_outputs: Discriminator output on real data. 133 discriminator_gen_outputs: Discriminator output on generated data. Expected 134 to be in the range of (-inf, inf). 135 real_weights: Optional `Tensor` whose rank is either 0, or the same rank as 136 `discriminator_real_outputs`, and must be broadcastable to 137 `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or 138 the same as the corresponding dimension). 139 generated_weights: Same as `real_weights`, but for 140 `discriminator_gen_outputs`. 141 scope: The scope for the operations performed in computing the loss. 142 loss_collection: collection to which this loss will be added. 143 reduction: A `tf.losses.Reduction` to apply to loss. 144 add_summaries: Whether or not to add summaries for the loss. 145 146 Returns: 147 A loss Tensor. The shape depends on `reduction`. 148 """ 149 with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( 150 discriminator_real_outputs, discriminator_gen_outputs, real_weights, 151 generated_weights)) as scope: 152 discriminator_real_outputs = _to_float(discriminator_real_outputs) 153 discriminator_gen_outputs = _to_float(discriminator_gen_outputs) 154 discriminator_real_outputs.shape.assert_is_compatible_with( 155 discriminator_gen_outputs.shape) 156 157 loss_on_generated = losses.compute_weighted_loss( 158 discriminator_gen_outputs, generated_weights, scope, 159 loss_collection=None, reduction=reduction) 160 loss_on_real = losses.compute_weighted_loss( 161 discriminator_real_outputs, real_weights, scope, loss_collection=None, 162 reduction=reduction) 163 loss = loss_on_generated - loss_on_real 164 util.add_loss(loss, loss_collection) 165 166 if add_summaries: 167 summary.scalar('discriminator_gen_wass_loss', loss_on_generated) 168 summary.scalar('discriminator_real_wass_loss', loss_on_real) 169 summary.scalar('discriminator_wass_loss', loss) 170 171 return loss 172 173 174# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` 175# (https://arxiv.org/abs/1610.09585). 176def acgan_discriminator_loss( 177 discriminator_real_classification_logits, 178 discriminator_gen_classification_logits, 179 one_hot_labels, 180 label_smoothing=0.0, 181 real_weights=1.0, 182 generated_weights=1.0, 183 scope=None, 184 loss_collection=ops.GraphKeys.LOSSES, 185 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 186 add_summaries=False): 187 """ACGAN loss for the discriminator. 188 189 The ACGAN loss adds a classification loss to the conditional discriminator. 190 Therefore, the discriminator must output a tuple consisting of 191 (1) the real/fake prediction and 192 (2) the logits for the classification (usually the last conv layer, 193 flattened). 194 195 For more details: 196 ACGAN: https://arxiv.org/abs/1610.09585 197 198 Args: 199 discriminator_real_classification_logits: Classification logits for real 200 data. 201 discriminator_gen_classification_logits: Classification logits for generated 202 data. 203 one_hot_labels: A Tensor holding one-hot labels for the batch. 204 label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for 205 "discriminator on real data" as suggested in 206 https://arxiv.org/pdf/1701.00160 207 real_weights: Optional `Tensor` whose rank is either 0, or the same rank as 208 `discriminator_real_outputs`, and must be broadcastable to 209 `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or 210 the same as the corresponding dimension). 211 generated_weights: Same as `real_weights`, but for 212 `discriminator_gen_classification_logits`. 213 scope: The scope for the operations performed in computing the loss. 214 loss_collection: collection to which this loss will be added. 215 reduction: A `tf.losses.Reduction` to apply to loss. 216 add_summaries: Whether or not to add summaries for the loss. 217 218 Returns: 219 A loss Tensor. Shape depends on `reduction`. 220 221 Raises: 222 TypeError: If the discriminator does not output a tuple. 223 """ 224 with ops.name_scope( 225 scope, 'acgan_discriminator_loss', 226 (discriminator_real_classification_logits, 227 discriminator_gen_classification_logits, one_hot_labels)) as scope: 228 loss_on_generated = losses.softmax_cross_entropy( 229 one_hot_labels, discriminator_gen_classification_logits, 230 weights=generated_weights, scope=scope, loss_collection=None, 231 reduction=reduction) 232 loss_on_real = losses.softmax_cross_entropy( 233 one_hot_labels, discriminator_real_classification_logits, 234 weights=real_weights, label_smoothing=label_smoothing, scope=scope, 235 loss_collection=None, reduction=reduction) 236 loss = loss_on_generated + loss_on_real 237 util.add_loss(loss, loss_collection) 238 239 if add_summaries: 240 summary.scalar('discriminator_gen_ac_loss', loss_on_generated) 241 summary.scalar('discriminator_real_ac_loss', loss_on_real) 242 summary.scalar('discriminator_ac_loss', loss) 243 244 return loss 245 246 247def acgan_generator_loss( 248 discriminator_gen_classification_logits, 249 one_hot_labels, 250 weights=1.0, 251 scope=None, 252 loss_collection=ops.GraphKeys.LOSSES, 253 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 254 add_summaries=False): 255 """ACGAN loss for the generator. 256 257 The ACGAN loss adds a classification loss to the conditional discriminator. 258 Therefore, the discriminator must output a tuple consisting of 259 (1) the real/fake prediction and 260 (2) the logits for the classification (usually the last conv layer, 261 flattened). 262 263 For more details: 264 ACGAN: https://arxiv.org/abs/1610.09585 265 266 Args: 267 discriminator_gen_classification_logits: Classification logits for generated 268 data. 269 one_hot_labels: A Tensor holding one-hot labels for the batch. 270 weights: Optional `Tensor` whose rank is either 0, or the same rank as 271 `discriminator_gen_classification_logits`, and must be broadcastable to 272 `discriminator_gen_classification_logits` (i.e., all dimensions must be 273 either `1`, or the same as the corresponding dimension). 274 scope: The scope for the operations performed in computing the loss. 275 loss_collection: collection to which this loss will be added. 276 reduction: A `tf.losses.Reduction` to apply to loss. 277 add_summaries: Whether or not to add summaries for the loss. 278 279 Returns: 280 A loss Tensor. Shape depends on `reduction`. 281 282 Raises: 283 ValueError: if arg module not either `generator` or `discriminator` 284 TypeError: if the discriminator does not output a tuple. 285 """ 286 with ops.name_scope( 287 scope, 'acgan_generator_loss', 288 (discriminator_gen_classification_logits, one_hot_labels)) as scope: 289 loss = losses.softmax_cross_entropy( 290 one_hot_labels, discriminator_gen_classification_logits, 291 weights=weights, scope=scope, loss_collection=loss_collection, 292 reduction=reduction) 293 294 if add_summaries: 295 summary.scalar('generator_ac_loss', loss) 296 297 return loss 298 299 300# Wasserstein Gradient Penalty losses from `Improved Training of Wasserstein 301# GANs` (https://arxiv.org/abs/1704.00028). 302 303 304def wasserstein_gradient_penalty( 305 real_data, 306 generated_data, 307 generator_inputs, 308 discriminator_fn, 309 discriminator_scope, 310 epsilon=1e-10, 311 target=1.0, 312 one_sided=False, 313 weights=1.0, 314 scope=None, 315 loss_collection=ops.GraphKeys.LOSSES, 316 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 317 add_summaries=False): 318 """The gradient penalty for the Wasserstein discriminator loss. 319 320 See `Improved Training of Wasserstein GANs` 321 (https://arxiv.org/abs/1704.00028) for more details. 322 323 Args: 324 real_data: Real data. 325 generated_data: Output of the generator. 326 generator_inputs: Exact argument to pass to the generator, which is used 327 as optional conditioning to the discriminator. 328 discriminator_fn: A discriminator function that conforms to TF-GAN API. 329 discriminator_scope: If not `None`, reuse discriminators from this scope. 330 epsilon: A small positive number added for numerical stability when 331 computing the gradient norm. 332 target: Optional Python number or `Tensor` indicating the target value of 333 gradient norm. Defaults to 1.0. 334 one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 335 is used. Defaults to `False`. 336 weights: Optional `Tensor` whose rank is either 0, or the same rank as 337 `real_data` and `generated_data`, and must be broadcastable to 338 them (i.e., all dimensions must be either `1`, or the same as the 339 corresponding dimension). 340 scope: The scope for the operations performed in computing the loss. 341 loss_collection: collection to which this loss will be added. 342 reduction: A `tf.losses.Reduction` to apply to loss. 343 add_summaries: Whether or not to add summaries for the loss. 344 345 Returns: 346 A loss Tensor. The shape depends on `reduction`. 347 348 Raises: 349 ValueError: If the rank of data Tensors is unknown. 350 """ 351 with ops.name_scope(scope, 'wasserstein_gradient_penalty', 352 (real_data, generated_data)) as scope: 353 real_data = ops.convert_to_tensor(real_data) 354 generated_data = ops.convert_to_tensor(generated_data) 355 if real_data.shape.ndims is None: 356 raise ValueError('`real_data` can\'t have unknown rank.') 357 if generated_data.shape.ndims is None: 358 raise ValueError('`generated_data` can\'t have unknown rank.') 359 360 differences = generated_data - real_data 361 batch_size = differences.shape.dims[0].value or array_ops.shape( 362 differences)[0] 363 alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) 364 alpha = random_ops.random_uniform(shape=alpha_shape) 365 interpolates = real_data + (alpha * differences) 366 367 with ops.name_scope(None): # Clear scope so update ops are added properly. 368 # Reuse variables if variables already exists. 369 with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', 370 reuse=variable_scope.AUTO_REUSE): 371 disc_interpolates = discriminator_fn(interpolates, generator_inputs) 372 373 if isinstance(disc_interpolates, tuple): 374 # ACGAN case: disc outputs more than one tensor 375 disc_interpolates = disc_interpolates[0] 376 377 gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] 378 gradient_squares = math_ops.reduce_sum( 379 math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) 380 # Propagate shape information, if possible. 381 if isinstance(batch_size, int): 382 gradient_squares.set_shape([ 383 batch_size] + gradient_squares.shape.as_list()[1:]) 384 # For numerical stability, add epsilon to the sum before taking the square 385 # root. Note tf.norm does not add epsilon. 386 slopes = math_ops.sqrt(gradient_squares + epsilon) 387 penalties = slopes / target - 1.0 388 if one_sided: 389 penalties = math_ops.maximum(0., penalties) 390 penalties_squared = math_ops.square(penalties) 391 penalty = losses.compute_weighted_loss( 392 penalties_squared, weights, scope=scope, 393 loss_collection=loss_collection, reduction=reduction) 394 395 if add_summaries: 396 summary.scalar('gradient_penalty_loss', penalty) 397 398 return penalty 399 400 401# Original losses from `Generative Adversarial Nets` 402# (https://arxiv.org/abs/1406.2661). 403 404 405def minimax_discriminator_loss( 406 discriminator_real_outputs, 407 discriminator_gen_outputs, 408 label_smoothing=0.25, 409 real_weights=1.0, 410 generated_weights=1.0, 411 scope=None, 412 loss_collection=ops.GraphKeys.LOSSES, 413 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 414 add_summaries=False): 415 """Original minimax discriminator loss for GANs, with label smoothing. 416 417 Note that the authors don't recommend using this loss. A more practically 418 useful loss is `modified_discriminator_loss`. 419 420 L = - real_weights * log(sigmoid(D(x))) 421 - generated_weights * log(1 - sigmoid(D(G(z)))) 422 423 See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more 424 details. 425 426 Args: 427 discriminator_real_outputs: Discriminator output on real data. 428 discriminator_gen_outputs: Discriminator output on generated data. Expected 429 to be in the range of (-inf, inf). 430 label_smoothing: The amount of smoothing for positive labels. This technique 431 is taken from `Improved Techniques for Training GANs` 432 (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. 433 real_weights: Optional `Tensor` whose rank is either 0, or the same rank as 434 `real_data`, and must be broadcastable to `real_data` (i.e., all 435 dimensions must be either `1`, or the same as the corresponding 436 dimension). 437 generated_weights: Same as `real_weights`, but for `generated_data`. 438 scope: The scope for the operations performed in computing the loss. 439 loss_collection: collection to which this loss will be added. 440 reduction: A `tf.losses.Reduction` to apply to loss. 441 add_summaries: Whether or not to add summaries for the loss. 442 443 Returns: 444 A loss Tensor. The shape depends on `reduction`. 445 """ 446 with ops.name_scope(scope, 'discriminator_minimax_loss', ( 447 discriminator_real_outputs, discriminator_gen_outputs, real_weights, 448 generated_weights, label_smoothing)) as scope: 449 450 # -log((1 - label_smoothing) - sigmoid(D(x))) 451 loss_on_real = losses.sigmoid_cross_entropy( 452 array_ops.ones_like(discriminator_real_outputs), 453 discriminator_real_outputs, real_weights, label_smoothing, scope, 454 loss_collection=None, reduction=reduction) 455 # -log(- sigmoid(D(G(x)))) 456 loss_on_generated = losses.sigmoid_cross_entropy( 457 array_ops.zeros_like(discriminator_gen_outputs), 458 discriminator_gen_outputs, generated_weights, scope=scope, 459 loss_collection=None, reduction=reduction) 460 461 loss = loss_on_real + loss_on_generated 462 util.add_loss(loss, loss_collection) 463 464 if add_summaries: 465 summary.scalar('discriminator_gen_minimax_loss', loss_on_generated) 466 summary.scalar('discriminator_real_minimax_loss', loss_on_real) 467 summary.scalar('discriminator_minimax_loss', loss) 468 469 return loss 470 471 472def minimax_generator_loss( 473 discriminator_gen_outputs, 474 label_smoothing=0.0, 475 weights=1.0, 476 scope=None, 477 loss_collection=ops.GraphKeys.LOSSES, 478 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 479 add_summaries=False): 480 """Original minimax generator loss for GANs. 481 482 Note that the authors don't recommend using this loss. A more practically 483 useful loss is `modified_generator_loss`. 484 485 L = log(sigmoid(D(x))) + log(1 - sigmoid(D(G(z)))) 486 487 See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more 488 details. 489 490 Args: 491 discriminator_gen_outputs: Discriminator output on generated data. Expected 492 to be in the range of (-inf, inf). 493 label_smoothing: The amount of smoothing for positive labels. This technique 494 is taken from `Improved Techniques for Training GANs` 495 (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. 496 weights: Optional `Tensor` whose rank is either 0, or the same rank as 497 `discriminator_gen_outputs`, and must be broadcastable to 498 `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or 499 the same as the corresponding dimension). 500 scope: The scope for the operations performed in computing the loss. 501 loss_collection: collection to which this loss will be added. 502 reduction: A `tf.losses.Reduction` to apply to loss. 503 add_summaries: Whether or not to add summaries for the loss. 504 505 Returns: 506 A loss Tensor. The shape depends on `reduction`. 507 """ 508 with ops.name_scope(scope, 'generator_minimax_loss') as scope: 509 loss = - minimax_discriminator_loss( 510 array_ops.ones_like(discriminator_gen_outputs), 511 discriminator_gen_outputs, label_smoothing, weights, weights, scope, 512 loss_collection, reduction, add_summaries=False) 513 514 if add_summaries: 515 summary.scalar('generator_minimax_loss', loss) 516 517 return loss 518 519 520def modified_discriminator_loss( 521 discriminator_real_outputs, 522 discriminator_gen_outputs, 523 label_smoothing=0.25, 524 real_weights=1.0, 525 generated_weights=1.0, 526 scope=None, 527 loss_collection=ops.GraphKeys.LOSSES, 528 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 529 add_summaries=False): 530 """Same as minimax discriminator loss. 531 532 See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more 533 details. 534 535 Args: 536 discriminator_real_outputs: Discriminator output on real data. 537 discriminator_gen_outputs: Discriminator output on generated data. Expected 538 to be in the range of (-inf, inf). 539 label_smoothing: The amount of smoothing for positive labels. This technique 540 is taken from `Improved Techniques for Training GANs` 541 (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. 542 real_weights: Optional `Tensor` whose rank is either 0, or the same rank as 543 `discriminator_gen_outputs`, and must be broadcastable to 544 `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or 545 the same as the corresponding dimension). 546 generated_weights: Same as `real_weights`, but for 547 `discriminator_gen_outputs`. 548 scope: The scope for the operations performed in computing the loss. 549 loss_collection: collection to which this loss will be added. 550 reduction: A `tf.losses.Reduction` to apply to loss. 551 add_summaries: Whether or not to add summaries for the loss. 552 553 Returns: 554 A loss Tensor. The shape depends on `reduction`. 555 """ 556 return minimax_discriminator_loss( 557 discriminator_real_outputs, 558 discriminator_gen_outputs, 559 label_smoothing, 560 real_weights, 561 generated_weights, 562 scope or 'discriminator_modified_loss', 563 loss_collection, 564 reduction, 565 add_summaries) 566 567 568def modified_generator_loss( 569 discriminator_gen_outputs, 570 label_smoothing=0.0, 571 weights=1.0, 572 scope=None, 573 loss_collection=ops.GraphKeys.LOSSES, 574 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 575 add_summaries=False): 576 """Modified generator loss for GANs. 577 578 L = -log(sigmoid(D(G(z)))) 579 580 This is the trick used in the original paper to avoid vanishing gradients 581 early in training. See `Generative Adversarial Nets` 582 (https://arxiv.org/abs/1406.2661) for more details. 583 584 Args: 585 discriminator_gen_outputs: Discriminator output on generated data. Expected 586 to be in the range of (-inf, inf). 587 label_smoothing: The amount of smoothing for positive labels. This technique 588 is taken from `Improved Techniques for Training GANs` 589 (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. 590 weights: Optional `Tensor` whose rank is either 0, or the same rank as 591 `discriminator_gen_outputs`, and must be broadcastable to `labels` (i.e., 592 all dimensions must be either `1`, or the same as the corresponding 593 dimension). 594 scope: The scope for the operations performed in computing the loss. 595 loss_collection: collection to which this loss will be added. 596 reduction: A `tf.losses.Reduction` to apply to loss. 597 add_summaries: Whether or not to add summaries for the loss. 598 599 Returns: 600 A loss Tensor. The shape depends on `reduction`. 601 """ 602 with ops.name_scope(scope, 'generator_modified_loss', 603 [discriminator_gen_outputs]) as scope: 604 loss = losses.sigmoid_cross_entropy( 605 array_ops.ones_like(discriminator_gen_outputs), 606 discriminator_gen_outputs, weights, label_smoothing, scope, 607 loss_collection, reduction) 608 609 if add_summaries: 610 summary.scalar('generator_modified_loss', loss) 611 612 return loss 613 614 615# Least Squares loss from `Least Squares Generative Adversarial Networks` 616# (https://arxiv.org/abs/1611.04076). 617 618 619def least_squares_generator_loss( 620 discriminator_gen_outputs, 621 real_label=1, 622 weights=1.0, 623 scope=None, 624 loss_collection=ops.GraphKeys.LOSSES, 625 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 626 add_summaries=False): 627 """Least squares generator loss. 628 629 This loss comes from `Least Squares Generative Adversarial Networks` 630 (https://arxiv.org/abs/1611.04076). 631 632 L = 1/2 * (D(G(z)) - `real_label`) ** 2 633 634 where D(y) are discriminator logits. 635 636 Args: 637 discriminator_gen_outputs: Discriminator output on generated data. Expected 638 to be in the range of (-inf, inf). 639 real_label: The value that the generator is trying to get the discriminator 640 to output on generated data. 641 weights: Optional `Tensor` whose rank is either 0, or the same rank as 642 `discriminator_gen_outputs`, and must be broadcastable to 643 `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or 644 the same as the corresponding dimension). 645 scope: The scope for the operations performed in computing the loss. 646 loss_collection: collection to which this loss will be added. 647 reduction: A `tf.losses.Reduction` to apply to loss. 648 add_summaries: Whether or not to add summaries for the loss. 649 650 Returns: 651 A loss Tensor. The shape depends on `reduction`. 652 """ 653 with ops.name_scope(scope, 'lsq_generator_loss', 654 (discriminator_gen_outputs, real_label)) as scope: 655 discriminator_gen_outputs = _to_float(discriminator_gen_outputs) 656 loss = math_ops.squared_difference( 657 discriminator_gen_outputs, real_label) / 2.0 658 loss = losses.compute_weighted_loss( 659 loss, weights, scope, loss_collection, reduction) 660 661 if add_summaries: 662 summary.scalar('generator_lsq_loss', loss) 663 664 return loss 665 666 667def least_squares_discriminator_loss( 668 discriminator_real_outputs, 669 discriminator_gen_outputs, 670 real_label=1, 671 fake_label=0, 672 real_weights=1.0, 673 generated_weights=1.0, 674 scope=None, 675 loss_collection=ops.GraphKeys.LOSSES, 676 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 677 add_summaries=False): 678 """Least squares discriminator loss. 679 680 This loss comes from `Least Squares Generative Adversarial Networks` 681 (https://arxiv.org/abs/1611.04076). 682 683 L = 1/2 * (D(x) - `real`) ** 2 + 684 1/2 * (D(G(z)) - `fake_label`) ** 2 685 686 where D(y) are discriminator logits. 687 688 Args: 689 discriminator_real_outputs: Discriminator output on real data. 690 discriminator_gen_outputs: Discriminator output on generated data. Expected 691 to be in the range of (-inf, inf). 692 real_label: The value that the discriminator tries to output for real data. 693 fake_label: The value that the discriminator tries to output for fake data. 694 real_weights: Optional `Tensor` whose rank is either 0, or the same rank as 695 `discriminator_real_outputs`, and must be broadcastable to 696 `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or 697 the same as the corresponding dimension). 698 generated_weights: Same as `real_weights`, but for 699 `discriminator_gen_outputs`. 700 scope: The scope for the operations performed in computing the loss. 701 loss_collection: collection to which this loss will be added. 702 reduction: A `tf.losses.Reduction` to apply to loss. 703 add_summaries: Whether or not to add summaries for the loss. 704 705 Returns: 706 A loss Tensor. The shape depends on `reduction`. 707 """ 708 with ops.name_scope(scope, 'lsq_discriminator_loss', 709 (discriminator_gen_outputs, real_label)) as scope: 710 discriminator_real_outputs = _to_float(discriminator_real_outputs) 711 discriminator_gen_outputs = _to_float(discriminator_gen_outputs) 712 discriminator_real_outputs.shape.assert_is_compatible_with( 713 discriminator_gen_outputs.shape) 714 715 real_losses = math_ops.squared_difference( 716 discriminator_real_outputs, real_label) / 2.0 717 fake_losses = math_ops.squared_difference( 718 discriminator_gen_outputs, fake_label) / 2.0 719 720 loss_on_real = losses.compute_weighted_loss( 721 real_losses, real_weights, scope, loss_collection=None, 722 reduction=reduction) 723 loss_on_generated = losses.compute_weighted_loss( 724 fake_losses, generated_weights, scope, loss_collection=None, 725 reduction=reduction) 726 727 loss = loss_on_real + loss_on_generated 728 util.add_loss(loss, loss_collection) 729 730 if add_summaries: 731 summary.scalar('discriminator_gen_lsq_loss', loss_on_generated) 732 summary.scalar('discriminator_real_lsq_loss', loss_on_real) 733 summary.scalar('discriminator_lsq_loss', loss) 734 735 return loss 736 737 738# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by 739# `Information Maximizing Generative Adversarial Nets` 740# https://arxiv.org/abs/1606.03657 741 742 743def _validate_distributions(distributions): 744 if not isinstance(distributions, (list, tuple)): 745 raise ValueError('`distributions` must be a list or tuple. Instead, ' 746 'found %s.' % type(distributions)) 747 for x in distributions: 748 # We used to check with `isinstance(x, tf.distributions.Distribution)`. 749 # However, distributions have migrated to `tfp.distributions.Distribution`, 750 # which is a new code repo, so we can't check this way anymore until 751 # TF-GAN is migrated to a new repo as well. 752 # This new check is not sufficient, but is a useful heuristic for now. 753 if not callable(getattr(x, 'log_prob', None)): 754 raise ValueError('`distributions` must be a list of `Distributions`. ' 755 'Instead, found %s.' % type(x)) 756 757 758def _validate_information_penalty_inputs( 759 structured_generator_inputs, predicted_distributions): 760 """Validate input to `mutual_information_penalty`.""" 761 _validate_distributions(predicted_distributions) 762 if len(structured_generator_inputs) != len(predicted_distributions): 763 raise ValueError('`structured_generator_inputs` length %i must be the same ' 764 'as `predicted_distributions` length %i.' % ( 765 len(structured_generator_inputs), 766 len(predicted_distributions))) 767 768 769def mutual_information_penalty( 770 structured_generator_inputs, 771 predicted_distributions, 772 weights=1.0, 773 scope=None, 774 loss_collection=ops.GraphKeys.LOSSES, 775 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, 776 add_summaries=False): 777 """Returns a penalty on the mutual information in an InfoGAN model. 778 779 This loss comes from an InfoGAN paper https://arxiv.org/abs/1606.03657. 780 781 Args: 782 structured_generator_inputs: A list of Tensors representing the random noise 783 that must have high mutual information with the generator output. List 784 length should match `predicted_distributions`. 785 predicted_distributions: A list of `tfp.distributions.Distribution`s. 786 Predicted by the recognizer, and used to evaluate the likelihood of the 787 structured noise. List length should match `structured_generator_inputs`. 788 weights: Optional `Tensor` whose rank is either 0, or the same dimensions as 789 `structured_generator_inputs`. 790 scope: The scope for the operations performed in computing the loss. 791 loss_collection: collection to which this loss will be added. 792 reduction: A `tf.losses.Reduction` to apply to loss. 793 add_summaries: Whether or not to add summaries for the loss. 794 795 Returns: 796 A scalar Tensor representing the mutual information loss. 797 """ 798 _validate_information_penalty_inputs( 799 structured_generator_inputs, predicted_distributions) 800 801 with ops.name_scope(scope, 'mutual_information_loss') as scope: 802 # Calculate the negative log-likelihood of the reconstructed noise. 803 log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in 804 zip(predicted_distributions, structured_generator_inputs)] 805 loss = -1 * losses.compute_weighted_loss( 806 log_probs, weights, scope, loss_collection=loss_collection, 807 reduction=reduction) 808 809 if add_summaries: 810 summary.scalar('mutual_information_penalty', loss) 811 812 return loss 813 814 815def _numerically_stable_global_norm(tensor_list): 816 """Compute the global norm of a list of Tensors, with improved stability. 817 818 The global norm computation sometimes overflows due to the intermediate L2 819 step. To avoid this, we divide by a cheap-to-compute max over the 820 matrix elements. 821 822 Args: 823 tensor_list: A list of tensors, or `None`. 824 825 Returns: 826 A scalar tensor with the global norm. 827 """ 828 if all(x is None for x in tensor_list): 829 return 0.0 830 831 list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in 832 tensor_list if x is not None]) 833 return list_max * clip_ops.global_norm([x / list_max for x in tensor_list 834 if x is not None]) 835 836 837def _used_weight(weights_list): 838 for weight in weights_list: 839 if weight is not None: 840 return tensor_util.constant_value(ops.convert_to_tensor(weight)) 841 842 843def _validate_args(losses_list, weight_factor, gradient_ratio): 844 for loss in losses_list: 845 loss.shape.assert_is_compatible_with([]) 846 if weight_factor is None and gradient_ratio is None: 847 raise ValueError( 848 '`weight_factor` and `gradient_ratio` cannot both be `None.`') 849 if weight_factor is not None and gradient_ratio is not None: 850 raise ValueError( 851 '`weight_factor` and `gradient_ratio` cannot both be specified.') 852 853 854# TODO(joelshor): Add ability to pass in gradients, to avoid recomputing. 855def combine_adversarial_loss(main_loss, 856 adversarial_loss, 857 weight_factor=None, 858 gradient_ratio=None, 859 gradient_ratio_epsilon=1e-6, 860 variables=None, 861 scalar_summaries=True, 862 gradient_summaries=True, 863 scope=None): 864 """Utility to combine main and adversarial losses. 865 866 This utility combines the main and adversarial losses in one of two ways. 867 1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case. 868 2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often 869 used to make sure both losses affect weights roughly equally, as in 870 https://arxiv.org/pdf/1705.05823. 871 872 One can optionally also visualize the scalar and gradient behavior of the 873 losses. 874 875 Args: 876 main_loss: A floating scalar Tensor indicating the main loss. 877 adversarial_loss: A floating scalar Tensor indication the adversarial loss. 878 weight_factor: If not `None`, the coefficient by which to multiply the 879 adversarial loss. Exactly one of this and `gradient_ratio` must be 880 non-None. 881 gradient_ratio: If not `None`, the ratio of the magnitude of the gradients. 882 Specifically, 883 gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss) 884 Exactly one of this and `weight_factor` must be non-None. 885 gradient_ratio_epsilon: An epsilon to add to the adversarial loss 886 coefficient denominator, to avoid division-by-zero. 887 variables: List of variables to calculate gradients with respect to. If not 888 present, defaults to all trainable variables. 889 scalar_summaries: Create scalar summaries of losses. 890 gradient_summaries: Create gradient summaries of losses. 891 scope: Optional name scope. 892 893 Returns: 894 A floating scalar Tensor indicating the desired combined loss. 895 896 Raises: 897 ValueError: Malformed input. 898 """ 899 _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio) 900 if variables is None: 901 variables = contrib_variables_lib.get_trainable_variables() 902 903 with ops.name_scope(scope, 'adversarial_loss', 904 values=[main_loss, adversarial_loss]): 905 # Compute gradients if we will need them. 906 if gradient_summaries or gradient_ratio is not None: 907 main_loss_grad_mag = _numerically_stable_global_norm( 908 gradients_impl.gradients(main_loss, variables)) 909 adv_loss_grad_mag = _numerically_stable_global_norm( 910 gradients_impl.gradients(adversarial_loss, variables)) 911 912 # Add summaries, if applicable. 913 if scalar_summaries: 914 summary.scalar('main_loss', main_loss) 915 summary.scalar('adversarial_loss', adversarial_loss) 916 if gradient_summaries: 917 summary.scalar('main_loss_gradients', main_loss_grad_mag) 918 summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag) 919 920 # Combine losses in the appropriate way. 921 # If `weight_factor` is always `0`, avoid computing the adversarial loss 922 # tensor entirely. 923 if _used_weight((weight_factor, gradient_ratio)) == 0: 924 final_loss = main_loss 925 elif weight_factor is not None: 926 final_loss = (main_loss + 927 array_ops.stop_gradient(weight_factor) * adversarial_loss) 928 elif gradient_ratio is not None: 929 grad_mag_ratio = main_loss_grad_mag / ( 930 adv_loss_grad_mag + gradient_ratio_epsilon) 931 adv_coeff = grad_mag_ratio / gradient_ratio 932 summary.scalar('adversarial_coefficient', adv_coeff) 933 final_loss = (main_loss + 934 array_ops.stop_gradient(adv_coeff) * adversarial_loss) 935 936 return final_loss 937 938 939def cycle_consistency_loss(data_x, 940 reconstructed_data_x, 941 data_y, 942 reconstructed_data_y, 943 scope=None, 944 add_summaries=False): 945 """Defines the cycle consistency loss. 946 947 The cyclegan model has two partial models where `model_x2y` generator F maps 948 data set X to Y, `model_y2x` generator G maps data set Y to X. For a `data_x` 949 in data set X, we could reconstruct it by 950 * reconstructed_data_x = G(F(data_x)) 951 Similarly 952 * reconstructed_data_y = F(G(data_y)) 953 954 The cycle consistency loss is about the difference between data and 955 reconstructed data, namely 956 * loss_x2x = |data_x - G(F(data_x))| (L1-norm) 957 * loss_y2y = |data_y - F(G(data_y))| (L1-norm) 958 * loss = (loss_x2x + loss_y2y) / 2 959 where `loss` is the final result. 960 961 For the L1-norm, we follow the original implementation: 962 https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua 963 we use L1-norm of pixel-wise error normalized by data size such that 964 `cycle_loss_weight` can be specified independent of image size. 965 966 See https://arxiv.org/abs/1703.10593 for more details. 967 968 Args: 969 data_x: A `Tensor` of data X. 970 reconstructed_data_x: A `Tensor` of reconstructed data X. 971 data_y: A `Tensor` of data Y. 972 reconstructed_data_y: A `Tensor` of reconstructed data Y. 973 scope: The scope for the operations performed in computing the loss. 974 Defaults to None. 975 add_summaries: Whether or not to add detailed summaries for the loss. 976 Defaults to False. 977 978 Returns: 979 A scalar `Tensor` of cycle consistency loss. 980 """ 981 982 with ops.name_scope( 983 scope, 984 'cycle_consistency_loss', 985 values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]): 986 loss_x2x = losses.absolute_difference(data_x, reconstructed_data_x) 987 loss_y2y = losses.absolute_difference(data_y, reconstructed_data_y) 988 loss = (loss_x2x + loss_y2y) / 2.0 989 if add_summaries: 990 summary.scalar('cycle_consistency_loss_x2x', loss_x2x) 991 summary.scalar('cycle_consistency_loss_y2y', loss_y2y) 992 summary.scalar('cycle_consistency_loss', loss) 993 994 return loss 995