1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to loss functions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.framework import ops 23from tensorflow.python.keras import backend as K 24from tensorflow.python.keras.engine import keras_tensor 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops.ragged import ragged_tensor 29from tensorflow.python.util.tf_export import keras_export 30 31 32@keras_export('keras.losses.Reduction', v1=[]) 33class ReductionV2(object): 34 """Types of loss reduction. 35 36 Contains the following values: 37 38 * `AUTO`: Indicates that the reduction option will be determined by the usage 39 context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When 40 used with `tf.distribute.Strategy`, outside of built-in training loops such 41 as `tf.keras` `compile` and `fit`, we expect reduction value to be 42 `SUM` or `NONE`. Using `AUTO` in that case will raise an error. 43 * `NONE`: Weighted losses with one dimension reduced (axis=-1, or axis 44 specified by loss function). When this reduction type used with built-in 45 Keras training loops like `fit`/`evaluate`, the unreduced vector loss is 46 passed to the optimizer but the reported loss will be a scalar value. 47 * `SUM`: Scalar sum of weighted losses. 48 * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. 49 This reduction type is not supported when used with 50 `tf.distribute.Strategy` outside of built-in training loops like `tf.keras` 51 `compile`/`fit`. 52 53 You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like: 54 ``` 55 with strategy.scope(): 56 loss_obj = tf.keras.losses.CategoricalCrossentropy( 57 reduction=tf.keras.losses.Reduction.NONE) 58 .... 59 loss = tf.reduce_sum(loss_obj(labels, predictions)) * 60 (1. / global_batch_size) 61 ``` 62 63 Please see the [custom training guide]( 64 https://www.tensorflow.org/tutorials/distribute/custom_training) for more 65 details on this. 66 """ 67 68 AUTO = 'auto' 69 NONE = 'none' 70 SUM = 'sum' 71 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 72 73 @classmethod 74 def all(cls): 75 return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) 76 77 @classmethod 78 def validate(cls, key): 79 if key not in cls.all(): 80 raise ValueError('Invalid Reduction Key %s.' % key) 81 82 83def remove_squeezable_dimensions( 84 labels, predictions, expected_rank_diff=0, name=None): 85 """Squeeze last dim if ranks differ from expected by exactly 1. 86 87 In the common case where we expect shapes to match, `expected_rank_diff` 88 defaults to 0, and we squeeze the last dimension of the larger rank if they 89 differ by 1. 90 91 But, for example, if `labels` contains class IDs and `predictions` contains 1 92 probability per class, we expect `predictions` to have 1 more dimension than 93 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze 94 `labels` if `rank(predictions) - rank(labels) == 0`, and 95 `predictions` if `rank(predictions) - rank(labels) == 2`. 96 97 This will use static shape if available. Otherwise, it will add graph 98 operations, which could result in a performance hit. 99 100 Args: 101 labels: Label values, a `Tensor` whose dimensions match `predictions`. 102 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 103 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. 104 name: Name of the op. 105 106 Returns: 107 Tuple of `labels` and `predictions`, possibly with last dim squeezed. 108 """ 109 with K.name_scope(name or 'remove_squeezable_dimensions'): 110 if not isinstance(predictions, ragged_tensor.RaggedTensor): 111 predictions = ops.convert_to_tensor_v2_with_dispatch(predictions) 112 if not isinstance(labels, ragged_tensor.RaggedTensor): 113 labels = ops.convert_to_tensor_v2_with_dispatch(labels) 114 predictions_shape = predictions.shape 115 predictions_rank = predictions_shape.ndims 116 labels_shape = labels.shape 117 labels_rank = labels_shape.ndims 118 if (labels_rank is not None) and (predictions_rank is not None): 119 # Use static rank. 120 rank_diff = predictions_rank - labels_rank 121 if (rank_diff == expected_rank_diff + 1 and 122 predictions_shape.dims[-1].is_compatible_with(1)): 123 predictions = array_ops.squeeze(predictions, [-1]) 124 elif (rank_diff == expected_rank_diff - 1 and 125 labels_shape.dims[-1].is_compatible_with(1)): 126 labels = array_ops.squeeze(labels, [-1]) 127 return labels, predictions 128 129 # Use dynamic rank. 130 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) 131 if (predictions_rank is None) or ( 132 predictions_shape.dims[-1].is_compatible_with(1)): 133 predictions = control_flow_ops.cond( 134 math_ops.equal(expected_rank_diff + 1, rank_diff), 135 lambda: array_ops.squeeze(predictions, [-1]), 136 lambda: predictions) 137 if (labels_rank is None) or ( 138 labels_shape.dims[-1].is_compatible_with(1)): 139 labels = control_flow_ops.cond( 140 math_ops.equal(expected_rank_diff - 1, rank_diff), 141 lambda: array_ops.squeeze(labels, [-1]), 142 lambda: labels) 143 return labels, predictions 144 145 146def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): 147 """Squeeze or expand last dimension if needed. 148 149 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 150 (using `remove_squeezable_dimensions`). 151 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 152 from the new rank of `y_pred`. 153 If `sample_weight` is scalar, it is kept scalar. 154 155 This will use static shape if available. Otherwise, it will add graph 156 operations, which could result in a performance hit. 157 158 Args: 159 y_pred: Predicted values, a `Tensor` of arbitrary dimensions. 160 y_true: Optional label `Tensor` whose dimensions match `y_pred`. 161 sample_weight: Optional weight scalar or `Tensor` whose dimensions match 162 `y_pred`. 163 164 Returns: 165 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has 166 the last dimension squeezed, 167 `sample_weight` could be extended by one dimension. 168 If `sample_weight` is None, (y_pred, y_true) is returned. 169 """ 170 y_pred_shape = y_pred.shape 171 y_pred_rank = y_pred_shape.ndims 172 if y_true is not None: 173 174 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred` 175 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)), 176 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3)) 177 # In this case, we should not try to remove squeezable dimension. 178 y_true_shape = y_true.shape 179 y_true_rank = y_true_shape.ndims 180 if (y_true_rank is not None) and (y_pred_rank is not None): 181 # Use static rank for `y_true` and `y_pred`. 182 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: 183 y_true, y_pred = remove_squeezable_dimensions( 184 y_true, y_pred) 185 else: 186 # Use dynamic rank. 187 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true) 188 squeeze_dims = lambda: remove_squeezable_dimensions( # pylint: disable=g-long-lambda 189 y_true, y_pred) 190 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1]) 191 maybe_squeeze_dims = lambda: control_flow_ops.cond( # pylint: disable=g-long-lambda 192 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred)) 193 y_true, y_pred = control_flow_ops.cond( 194 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims) 195 196 if sample_weight is None: 197 return y_pred, y_true 198 199 weights_shape = sample_weight.shape 200 weights_rank = weights_shape.ndims 201 if weights_rank == 0: # If weights is scalar, do nothing. 202 return y_pred, y_true, sample_weight 203 204 if (y_pred_rank is not None) and (weights_rank is not None): 205 # Use static rank. 206 if weights_rank - y_pred_rank == 1: 207 sample_weight = array_ops.squeeze(sample_weight, [-1]) 208 elif y_pred_rank - weights_rank == 1: 209 sample_weight = array_ops.expand_dims(sample_weight, [-1]) 210 return y_pred, y_true, sample_weight 211 212 # Use dynamic rank. 213 weights_rank_tensor = array_ops.rank(sample_weight) 214 rank_diff = weights_rank_tensor - array_ops.rank(y_pred) 215 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) 216 217 def _maybe_expand_weights(): 218 expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1]) 219 return control_flow_ops.cond( 220 math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight) 221 222 def _maybe_adjust_weights(): 223 return control_flow_ops.cond( 224 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 225 _maybe_expand_weights) 226 227 # squeeze or expand last dim of `sample_weight` if its rank differs by 1 228 # from the new rank of `y_pred`. 229 sample_weight = control_flow_ops.cond( 230 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, 231 _maybe_adjust_weights) 232 return y_pred, y_true, sample_weight 233 234 235def _safe_mean(losses, num_present): 236 """Computes a safe mean of the losses. 237 238 Args: 239 losses: `Tensor` whose elements contain individual loss measurements. 240 num_present: The number of measurable elements in `losses`. 241 242 Returns: 243 A scalar representing the mean of `losses`. If `num_present` is zero, 244 then zero is returned. 245 """ 246 total_loss = math_ops.reduce_sum(losses) 247 return math_ops.div_no_nan(total_loss, num_present, name='value') 248 249 250def _num_elements(losses): 251 """Computes the number of elements in `losses` tensor.""" 252 with K.name_scope('num_elements') as scope: 253 return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) 254 255 256def reduce_weighted_loss(weighted_losses, 257 reduction=ReductionV2.SUM_OVER_BATCH_SIZE): 258 """Reduces the individual weighted loss measurements.""" 259 if reduction == ReductionV2.NONE: 260 loss = weighted_losses 261 else: 262 loss = math_ops.reduce_sum(weighted_losses) 263 if reduction == ReductionV2.SUM_OVER_BATCH_SIZE: 264 loss = _safe_mean(loss, _num_elements(weighted_losses)) 265 return loss 266 267 268def compute_weighted_loss(losses, 269 sample_weight=None, 270 reduction=ReductionV2.SUM_OVER_BATCH_SIZE, 271 name=None): 272 """Computes the weighted loss. 273 274 Args: 275 losses: `Tensor` of shape `[batch_size, d1, ... dN]`. 276 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 277 `losses`, or be broadcastable to `losses`. 278 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. 279 Default value is `SUM_OVER_BATCH_SIZE`. 280 name: Optional name for the op. 281 282 Raises: 283 ValueError: If the shape of `sample_weight` is not compatible with `losses`. 284 285 Returns: 286 Weighted loss `Tensor` of the same type as `losses`. If `reduction` is 287 `NONE`, this has the same shape as `losses`; otherwise, it is scalar. 288 """ 289 ReductionV2.validate(reduction) 290 291 # If this function is called directly, then we just default 'AUTO' to 292 # 'SUM_OVER_BATCH_SIZE'. Eg. Canned estimator use cases. 293 if reduction == ReductionV2.AUTO: 294 reduction = ReductionV2.SUM_OVER_BATCH_SIZE 295 if sample_weight is None: 296 sample_weight = 1.0 297 with K.name_scope(name or 'weighted_loss'): 298 # Save the `reduction` argument for loss normalization when distributing 299 # to multiple replicas. Used only for estimator + v1 optimizer flow. 300 ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access 301 302 if not isinstance(losses, 303 (keras_tensor.KerasTensor, ragged_tensor.RaggedTensor)): 304 losses = ops.convert_to_tensor_v2_with_dispatch(losses) 305 input_dtype = losses.dtype 306 307 if not isinstance(sample_weight, keras_tensor.KerasTensor): 308 sample_weight = ops.convert_to_tensor_v2_with_dispatch(sample_weight) 309 310 # TODO(psv): Handle casting here in a better way, eg. if losses is float64 311 # we do not want to lose precision. 312 losses = math_ops.cast(losses, 'float32') 313 sample_weight = math_ops.cast(sample_weight, 'float32') 314 # Update dimensions of `sample_weight` to match with `losses` if possible. 315 losses, _, sample_weight = squeeze_or_expand_dimensions( # pylint: disable=unbalanced-tuple-unpacking 316 losses, None, sample_weight) 317 weighted_losses = math_ops.multiply(losses, sample_weight) 318 319 # Apply reduction function to the individual weighted losses. 320 loss = reduce_weighted_loss(weighted_losses, reduction) 321 # Convert the result back to the input type. 322 loss = math_ops.cast(loss, input_dtype) 323 return loss 324 325 326def scale_loss_for_distribution(loss_value): 327 """Scales and returns the given loss value by the number of replicas.""" 328 num_replicas = ( 329 distribution_strategy_context.get_strategy().num_replicas_in_sync) 330 if num_replicas > 1: 331 loss_value *= (1. / num_replicas) 332 return loss_value 333 334 335def cast_losses_to_common_dtype(losses): 336 """Cast a list of losses to a common dtype. 337 338 If any loss is floating-point, they will all be casted to the most-precise 339 floating-point loss. Otherwise the losses are not casted. We also skip casting 340 losses if there are any complex losses. 341 342 Args: 343 losses: A list of losses. 344 345 Returns: 346 `losses`, but they have been casted to a common dtype. 347 """ 348 highest_float = None 349 for loss in losses: 350 if loss.dtype.is_floating: 351 if highest_float is None or loss.dtype.size > highest_float.size: 352 highest_float = loss.dtype 353 elif {loss.dtype, highest_float} == {'bfloat16', 'float16'}: 354 highest_float = 'float32' 355 if loss.dtype.is_complex: 356 return losses # If we find any complex losses, do not cast any losses 357 if highest_float: 358 losses = [math_ops.cast(loss, highest_float) for loss in losses] 359 return losses 360