1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for manipulating the loss collections.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export(v1=["losses.add_loss"]) 29def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): 30 """Adds a externally defined loss to the collection of losses. 31 32 Args: 33 loss: A loss `Tensor`. 34 loss_collection: Optional collection to add the loss to. 35 """ 36 # Since we have no way of figuring out when a training iteration starts or 37 # ends, holding on to a loss when executing eagerly is indistingishable from 38 # leaking memory. We instead leave the collection empty. 39 if loss_collection and not context.executing_eagerly(): 40 ops.add_to_collection(loss_collection, loss) 41 42 43@tf_export(v1=["losses.get_losses"]) 44def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): 45 """Gets the list of losses from the loss_collection. 46 47 Args: 48 scope: An optional scope name for filtering the losses to return. 49 loss_collection: Optional losses collection. 50 51 Returns: 52 a list of loss tensors. 53 """ 54 return ops.get_collection(loss_collection, scope) 55 56 57@tf_export(v1=["losses.get_regularization_losses"]) 58def get_regularization_losses(scope=None): 59 """Gets the list of regularization losses. 60 61 Args: 62 scope: An optional scope name for filtering the losses to return. 63 64 Returns: 65 A list of regularization losses as Tensors. 66 """ 67 return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) 68 69 70@tf_export(v1=["losses.get_regularization_loss"]) 71def get_regularization_loss(scope=None, name="total_regularization_loss"): 72 """Gets the total regularization loss. 73 74 Args: 75 scope: An optional scope name for filtering the losses to return. 76 name: The name of the returned tensor. 77 78 Returns: 79 A scalar regularization loss. 80 """ 81 losses = get_regularization_losses(scope) 82 if losses: 83 return math_ops.add_n(losses, name=name) 84 else: 85 return constant_op.constant(0.0) 86 87 88@tf_export(v1=["losses.get_total_loss"]) 89def get_total_loss(add_regularization_losses=True, name="total_loss"): 90 """Returns a tensor whose value represents the total loss. 91 92 In particular, this adds any losses you have added with `tf.add_loss()` to 93 any regularization losses that have been added by regularization parameters 94 on layers constructors e.g. `tf.layers`. Be very sure to use this if you 95 are constructing a loss_op manually. Otherwise regularization arguments 96 on `tf.layers` methods will not function. 97 98 Args: 99 add_regularization_losses: A boolean indicating whether or not to use the 100 regularization losses in the sum. 101 name: The name of the returned tensor. 102 103 Returns: 104 A `Tensor` whose value represents the total loss. 105 106 Raises: 107 ValueError: if `losses` is not iterable. 108 """ 109 losses = get_losses() 110 if add_regularization_losses: 111 losses += get_regularization_losses() 112 return math_ops.add_n(losses, name=name) 113