1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to loss functions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.keras import backend as K 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import confusion_matrix 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import weights_broadcast_ops 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export('keras.losses.Reduction', v1=[]) 34class ReductionV2(object): 35 """Types of loss reduction. 36 37 Contains the following values: 38 39 * `NONE`: Un-reduced weighted losses with the same shape as input. 40 * `SUM`: Scalar sum of weighted losses. 41 * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. 42 Note that when using `tf.distribute.Strategy`, this is the global batch 43 size across all the replicas that are contributing to a single step. 44 """ 45 46 NONE = 'none' 47 SUM = 'sum' 48 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 49 50 @classmethod 51 def all(cls): 52 return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) 53 54 @classmethod 55 def validate(cls, key): 56 if key not in cls.all(): 57 raise ValueError('Invalid Reduction Key %s.' % key) 58 59 60def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): 61 """Squeeze or expand last dimension if needed. 62 63 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 64 (using `confusion_matrix.remove_squeezable_dimensions`). 65 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 66 from the new rank of `y_pred`. 67 If `sample_weight` is scalar, it is kept scalar. 68 69 This will use static shape if available. Otherwise, it will add graph 70 operations, which could result in a performance hit. 71 72 Args: 73 y_pred: Predicted values, a `Tensor` of arbitrary dimensions. 74 y_true: Optional label `Tensor` whose dimensions match `y_pred`. 75 sample_weight: Optional weight scalar or `Tensor` whose dimensions match 76 `y_pred`. 77 78 Returns: 79 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has 80 the last dimension squeezed, 81 `sample_weight` could be extended by one dimension. 82 """ 83 y_pred_shape = y_pred.get_shape() 84 y_pred_rank = y_pred_shape.ndims 85 if y_true is not None: 86 87 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred` 88 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)), 89 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3)) 90 # In this case, we should not try to remove squeezable dimension. 91 y_true_shape = y_true.get_shape() 92 y_true_rank = y_true_shape.ndims 93 if (y_true_rank is not None) and (y_pred_rank is not None): 94 # Use static rank for `y_true` and `y_pred`. 95 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: 96 y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( 97 y_true, y_pred) 98 else: 99 # Use dynamic rank. 100 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true) 101 squeeze_dims = lambda: confusion_matrix.remove_squeezable_dimensions( # pylint: disable=g-long-lambda 102 y_true, y_pred) 103 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1]) 104 maybe_squeeze_dims = lambda: control_flow_ops.cond( # pylint: disable=g-long-lambda 105 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred)) 106 y_true, y_pred = control_flow_ops.cond( 107 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims) 108 109 if sample_weight is None: 110 return y_pred, y_true, None 111 112 sample_weight = ops.convert_to_tensor(sample_weight) 113 weights_shape = sample_weight.get_shape() 114 weights_rank = weights_shape.ndims 115 if weights_rank == 0: # If weights is scalar, do nothing. 116 return y_pred, y_true, sample_weight 117 118 if (y_pred_rank is not None) and (weights_rank is not None): 119 # Use static rank. 120 if weights_rank - y_pred_rank == 1: 121 sample_weight = array_ops.squeeze(sample_weight, [-1]) 122 elif y_pred_rank - weights_rank == 1: 123 sample_weight = array_ops.expand_dims(sample_weight, [-1]) 124 return y_pred, y_true, sample_weight 125 126 # Use dynamic rank. 127 weights_rank_tensor = array_ops.rank(sample_weight) 128 rank_diff = weights_rank_tensor - array_ops.rank(y_pred) 129 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) 130 131 def _maybe_expand_weights(): 132 return control_flow_ops.cond( 133 math_ops.equal(rank_diff, 134 -1), lambda: array_ops.expand_dims(sample_weight, [-1]), 135 lambda: sample_weight) 136 137 def _maybe_adjust_weights(): 138 return control_flow_ops.cond( 139 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 140 _maybe_expand_weights) 141 142 # squeeze or expand last dim of `sample_weight` if its rank differs by 1 143 # from the new rank of `y_pred`. 144 sample_weight = control_flow_ops.cond( 145 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, 146 _maybe_adjust_weights) 147 return y_pred, y_true, sample_weight 148 149 150def _safe_mean(losses, num_present): 151 """Computes a safe mean of the losses. 152 153 Args: 154 losses: `Tensor` whose elements contain individual loss measurements. 155 num_present: The number of measurable elements in `losses`. 156 157 Returns: 158 A scalar representing the mean of `losses`. If `num_present` is zero, 159 then zero is returned. 160 """ 161 total_loss = math_ops.reduce_sum(losses) 162 return math_ops.div_no_nan(total_loss, num_present, name='value') 163 164 165def _num_elements(losses): 166 """Computes the number of elements in `losses` tensor.""" 167 with ops.name_scope(None, 'num_elements', values=[losses]) as scope: 168 return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) 169 170 171def reduce_weighted_loss(weighted_losses, 172 reduction=ReductionV2.SUM_OVER_BATCH_SIZE): 173 """Reduces the individual weighted loss measurements.""" 174 if reduction == ReductionV2.NONE: 175 loss = weighted_losses 176 else: 177 loss = math_ops.reduce_sum(weighted_losses) 178 if reduction == ReductionV2.SUM_OVER_BATCH_SIZE: 179 num_replicas = ( # Used to convert from local to global batch size. 180 distribution_strategy_context.get_strategy().num_replicas_in_sync) 181 loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses)) 182 return loss 183 184 185def compute_weighted_loss(losses, 186 sample_weight=None, 187 reduction=ReductionV2.SUM_OVER_BATCH_SIZE, 188 name=None): 189 """Computes the weighted loss. 190 191 Args: 192 losses: `Tensor` of shape `[batch_size, d1, ... dN]`. 193 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 194 `losses`, or be broadcastable to `losses`. 195 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. 196 Default value is `SUM_OVER_BATCH_SIZE`. 197 name: Optional name for the op. 198 199 Raises: 200 ValueError: If the shape of `sample_weight` is not compatible with `losses`. 201 202 Returns: 203 Weighted loss `Tensor` of the same type as `losses`. If `reduction` is 204 `NONE`, this has the same shape as `losses`; otherwise, it is scalar. 205 """ 206 ReductionV2.validate(reduction) 207 if sample_weight is None: 208 sample_weight = 1.0 209 with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)): 210 # Update dimensions of `sample_weight` to match with `losses` if possible. 211 losses, _, sample_weight = squeeze_or_expand_dimensions( 212 losses, None, sample_weight) 213 losses = ops.convert_to_tensor(losses) 214 input_dtype = losses.dtype 215 losses = math_ops.cast(losses, dtypes.float32) 216 sample_weight = math_ops.cast(sample_weight, dtypes.float32) 217 218 try: 219 # Broadcast weights if possible. 220 sample_weight = weights_broadcast_ops.broadcast_weights( 221 sample_weight, losses) 222 except ValueError: 223 # Reduce values to same ndim as weight array. 224 ndim = K.ndim(losses) 225 weight_ndim = K.ndim(sample_weight) 226 losses = K.mean(losses, axis=list(range(weight_ndim, ndim))) 227 228 sample_weight.get_shape().assert_is_compatible_with(losses.get_shape()) 229 weighted_losses = math_ops.multiply(losses, sample_weight) 230 # Apply reduction function to the individual weighted losses. 231 loss = reduce_weighted_loss(weighted_losses, reduction) 232 # Convert the result back to the input type. 233 loss = math_ops.cast(loss, input_dtype) 234 return loss 235 236 237def scale_loss_for_distribution(loss_value): 238 """Scales and returns the given loss value by the number of replicas.""" 239 num_replicas = ( 240 distribution_strategy_context.get_strategy().num_replicas_in_sync) 241 if num_replicas > 1: 242 loss_value *= (1. / num_replicas) 243 return loss_value 244