1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for manipulating the loss collections.""" 16 17from tensorflow.python.eager import context 18from tensorflow.python.framework import constant_op 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import confusion_matrix 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.util import tf_contextlib 27from tensorflow.python.util.tf_export import tf_export 28 29 30def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): 31 """Squeeze or expand last dimension if needed. 32 33 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 34 (using `confusion_matrix.remove_squeezable_dimensions`). 35 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 36 from the new rank of `y_pred`. 37 If `sample_weight` is scalar, it is kept scalar. 38 39 This will use static shape if available. Otherwise, it will add graph 40 operations, which could result in a performance hit. 41 42 Args: 43 y_pred: Predicted values, a `Tensor` of arbitrary dimensions. 44 y_true: Optional label `Tensor` whose dimensions match `y_pred`. 45 sample_weight: Optional weight scalar or `Tensor` whose dimensions match 46 `y_pred`. 47 48 Returns: 49 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has 50 the last dimension squeezed, 51 `sample_weight` could be extended by one dimension. 52 If `sample_weight` is None, (y_pred, y_true) is returned. 53 """ 54 y_pred_shape = y_pred.shape 55 y_pred_rank = y_pred_shape.ndims 56 if y_true is not None: 57 58 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred` 59 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)), 60 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3)) 61 # In this case, we should not try to remove squeezable dimension. 62 y_true_shape = y_true.shape 63 y_true_rank = y_true_shape.ndims 64 if (y_true_rank is not None) and (y_pred_rank is not None): 65 # Use static rank for `y_true` and `y_pred`. 66 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: 67 y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( 68 y_true, y_pred) 69 else: 70 # Use dynamic rank. 71 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true) 72 squeeze_dims = lambda: confusion_matrix.remove_squeezable_dimensions( # pylint: disable=g-long-lambda 73 y_true, y_pred) 74 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1]) 75 maybe_squeeze_dims = lambda: control_flow_ops.cond( # pylint: disable=g-long-lambda 76 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred)) 77 y_true, y_pred = control_flow_ops.cond( 78 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims) 79 80 if sample_weight is None: 81 return y_pred, y_true 82 83 weights_shape = sample_weight.shape 84 weights_rank = weights_shape.ndims 85 if weights_rank == 0: # If weights is scalar, do nothing. 86 return y_pred, y_true, sample_weight 87 88 if (y_pred_rank is not None) and (weights_rank is not None): 89 # Use static rank. 90 if weights_rank - y_pred_rank == 1: 91 sample_weight = array_ops.squeeze(sample_weight, [-1]) 92 elif y_pred_rank - weights_rank == 1: 93 sample_weight = array_ops.expand_dims(sample_weight, [-1]) 94 return y_pred, y_true, sample_weight 95 96 # Use dynamic rank. 97 weights_rank_tensor = array_ops.rank(sample_weight) 98 rank_diff = weights_rank_tensor - array_ops.rank(y_pred) 99 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) 100 101 def _maybe_expand_weights(): 102 expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1]) 103 return control_flow_ops.cond( 104 math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight) 105 106 def _maybe_adjust_weights(): 107 return control_flow_ops.cond( 108 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 109 _maybe_expand_weights) 110 111 # squeeze or expand last dim of `sample_weight` if its rank differs by 1 112 # from the new rank of `y_pred`. 113 sample_weight = control_flow_ops.cond( 114 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, 115 _maybe_adjust_weights) 116 return y_pred, y_true, sample_weight 117 118 119def scale_losses_by_sample_weight(losses, sample_weight): 120 """Scales loss values by the given sample weights. 121 122 `sample_weight` dimensions are updated to match with the dimension of `losses` 123 if possible by using squeeze/expand/broadcast. 124 125 Args: 126 losses: Loss tensor. 127 sample_weight: Sample weights tensor. 128 129 Returns: 130 `losses` scaled by `sample_weight` with dtype float32. 131 """ 132 # TODO(psv): Handle the casting here in a better way, eg. if losses is float64 133 # we do not want to lose precision. 134 losses = math_ops.cast(losses, dtypes.float32) 135 sample_weight = math_ops.cast(sample_weight, dtypes.float32) 136 137 # Update dimensions of `sample_weight` to match with `losses` if possible. 138 losses, _, sample_weight = squeeze_or_expand_dimensions( 139 losses, None, sample_weight) 140 return math_ops.multiply(losses, sample_weight) 141 142 143@tf_contextlib.contextmanager 144def check_per_example_loss_rank(per_example_loss): 145 """Context manager that checks that the rank of per_example_loss is at least 1. 146 147 Args: 148 per_example_loss: Per example loss tensor. 149 150 Yields: 151 A context manager. 152 """ 153 loss_rank = per_example_loss.shape.rank 154 if loss_rank is not None: 155 # Handle static rank. 156 if loss_rank == 0: 157 raise ValueError( 158 "Invalid value passed for `per_example_loss`. Expected a tensor with " 159 f"at least rank 1. Received per_example_loss={per_example_loss} with " 160 f"rank {loss_rank}") 161 yield 162 else: 163 # Handle dynamic rank. 164 with ops.control_dependencies([ 165 check_ops.assert_greater_equal( 166 array_ops.rank(per_example_loss), 167 math_ops.cast(1, dtype=dtypes.int32), 168 message="Invalid value passed for `per_example_loss`. Expected a " 169 "tensor with at least rank 1.") 170 ]): 171 yield 172 173 174@tf_export(v1=["losses.add_loss"]) 175def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): 176 """Adds a externally defined loss to the collection of losses. 177 178 Args: 179 loss: A loss `Tensor`. 180 loss_collection: Optional collection to add the loss to. 181 """ 182 # Since we have no way of figuring out when a training iteration starts or 183 # ends, holding on to a loss when executing eagerly is indistinguishable from 184 # leaking memory. We instead leave the collection empty. 185 if loss_collection and not context.executing_eagerly(): 186 ops.add_to_collection(loss_collection, loss) 187 188 189@tf_export(v1=["losses.get_losses"]) 190def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): 191 """Gets the list of losses from the loss_collection. 192 193 Args: 194 scope: An optional scope name for filtering the losses to return. 195 loss_collection: Optional losses collection. 196 197 Returns: 198 a list of loss tensors. 199 """ 200 return ops.get_collection(loss_collection, scope) 201 202 203@tf_export(v1=["losses.get_regularization_losses"]) 204def get_regularization_losses(scope=None): 205 """Gets the list of regularization losses. 206 207 Args: 208 scope: An optional scope name for filtering the losses to return. 209 210 Returns: 211 A list of regularization losses as Tensors. 212 """ 213 return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) 214 215 216@tf_export(v1=["losses.get_regularization_loss"]) 217def get_regularization_loss(scope=None, name="total_regularization_loss"): 218 """Gets the total regularization loss. 219 220 Args: 221 scope: An optional scope name for filtering the losses to return. 222 name: The name of the returned tensor. 223 224 Returns: 225 A scalar regularization loss. 226 """ 227 losses = get_regularization_losses(scope) 228 if losses: 229 return math_ops.add_n(losses, name=name) 230 else: 231 return constant_op.constant(0.0) 232 233 234@tf_export(v1=["losses.get_total_loss"]) 235def get_total_loss(add_regularization_losses=True, 236 name="total_loss", 237 scope=None): 238 """Returns a tensor whose value represents the total loss. 239 240 In particular, this adds any losses you have added with `tf.add_loss()` to 241 any regularization losses that have been added by regularization parameters 242 on layers constructors e.g. `tf.layers`. Be very sure to use this if you 243 are constructing a loss_op manually. Otherwise regularization arguments 244 on `tf.layers` methods will not function. 245 246 Args: 247 add_regularization_losses: A boolean indicating whether or not to use the 248 regularization losses in the sum. 249 name: The name of the returned tensor. 250 scope: An optional scope name for filtering the losses to return. Note that 251 this filters the losses added with `tf.add_loss()` as well as the 252 regularization losses to that scope. 253 254 Returns: 255 A `Tensor` whose value represents the total loss. 256 257 Raises: 258 ValueError: if `losses` is not iterable. 259 """ 260 losses = get_losses(scope=scope) 261 if add_regularization_losses: 262 losses += get_regularization_losses(scope=scope) 263 return math_ops.add_n(losses, name=name) 264