1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Built-in loss functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import functools 22 23import six 24 25from tensorflow.python.autograph.core import ag_ctx 26from tensorflow.python.autograph.impl import api as autograph 27from tensorflow.python.distribute import distribution_strategy_context 28from tensorflow.python.eager import context 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import smart_cond 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.keras import backend as K 35from tensorflow.python.keras.utils import losses_utils 36from tensorflow.python.keras.utils import tf_utils 37from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 38from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import nn 43from tensorflow.python.ops.losses import losses_impl 44from tensorflow.python.ops.ragged import ragged_map_ops 45from tensorflow.python.ops.ragged import ragged_tensor 46from tensorflow.python.ops.ragged import ragged_util 47from tensorflow.python.util import dispatch 48from tensorflow.python.util.tf_export import keras_export 49from tensorflow.tools.docs import doc_controls 50 51 52@keras_export('keras.losses.Loss') 53class Loss(object): 54 """Loss base class. 55 56 To be implemented by subclasses: 57 * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`. 58 59 Example subclass implementation: 60 61 ```python 62 class MeanSquaredError(Loss): 63 64 def call(self, y_true, y_pred): 65 y_pred = tf.convert_to_tensor_v2(y_pred) 66 y_true = tf.cast(y_true, y_pred.dtype) 67 return tf.reduce_mean(math_ops.square(y_pred - y_true), axis=-1) 68 ``` 69 70 When used with `tf.distribute.Strategy`, outside of built-in training loops 71 such as `tf.keras` `compile` and `fit`, please use 'SUM' or 'NONE' reduction 72 types, and reduce losses explicitly in your training loop. Using 'AUTO' or 73 'SUM_OVER_BATCH_SIZE' will raise an error. 74 75 Please see this custom training [tutorial]( 76 https://www.tensorflow.org/tutorials/distribute/custom_training) for more 77 details on this. 78 79 You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like: 80 ```python 81 with strategy.scope(): 82 loss_obj = tf.keras.losses.CategoricalCrossentropy( 83 reduction=tf.keras.losses.Reduction.NONE) 84 .... 85 loss = (tf.reduce_sum(loss_obj(labels, predictions)) * 86 (1. / global_batch_size)) 87 ``` 88 """ 89 90 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name=None): 91 """Initializes `Loss` class. 92 93 Args: 94 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 95 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 96 option will be determined by the usage context. For almost all cases 97 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 98 `tf.distribute.Strategy`, outside of built-in training loops such as 99 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 100 will raise an error. Please see this custom training [tutorial]( 101 https://www.tensorflow.org/tutorials/distribute/custom_training) for 102 more details. 103 name: Optional name for the op. 104 """ 105 losses_utils.ReductionV2.validate(reduction) 106 self.reduction = reduction 107 self.name = name 108 # SUM_OVER_BATCH is only allowed in losses managed by `fit` or 109 # CannedEstimators. 110 self._allow_sum_over_batch_size = False 111 self._set_name_scope() 112 113 def _set_name_scope(self): 114 """Creates a valid `name_scope` name.""" 115 if self.name is None: 116 self._name_scope = self.__class__.__name__ 117 elif self.name == '<lambda>': 118 self._name_scope = 'lambda' 119 else: 120 # E.g. '_my_loss' => 'my_loss' 121 self._name_scope = self.name.strip('_') 122 123 def __call__(self, y_true, y_pred, sample_weight=None): 124 """Invokes the `Loss` instance. 125 126 Args: 127 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`, except 128 sparse loss functions such as sparse categorical crossentropy where 129 shape = `[batch_size, d0, .. dN-1]` 130 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]` 131 sample_weight: Optional `sample_weight` acts as a coefficient for the 132 loss. If a scalar is provided, then the loss is simply scaled by the 133 given value. If `sample_weight` is a tensor of size `[batch_size]`, then 134 the total loss for each sample of the batch is rescaled by the 135 corresponding element in the `sample_weight` vector. If the shape of 136 `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to 137 this shape), then each loss element of `y_pred` is scaled 138 by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss 139 functions reduce by 1 dimension, usually axis=-1.) 140 141 Returns: 142 Weighted loss float `Tensor`. If `reduction` is `NONE`, this has 143 shape `[batch_size, d0, .. dN-1]`; otherwise, it is scalar. (Note `dN-1` 144 because all loss functions reduce by 1 dimension, usually axis=-1.) 145 146 Raises: 147 ValueError: If the shape of `sample_weight` is invalid. 148 """ 149 # If we are wrapping a lambda function strip '<>' from the name as it is not 150 # accepted in scope name. 151 graph_ctx = tf_utils.graph_context_for_symbolic_tensors( 152 y_true, y_pred, sample_weight) 153 with K.name_scope(self._name_scope), graph_ctx: 154 if context.executing_eagerly(): 155 call_fn = self.call 156 else: 157 call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 158 losses = call_fn(y_true, y_pred) 159 return losses_utils.compute_weighted_loss( 160 losses, sample_weight, reduction=self._get_reduction()) 161 162 @classmethod 163 def from_config(cls, config): 164 """Instantiates a `Loss` from its config (output of `get_config()`). 165 166 Args: 167 config: Output of `get_config()`. 168 169 Returns: 170 A `Loss` instance. 171 """ 172 return cls(**config) 173 174 def get_config(self): 175 """Returns the config dictionary for a `Loss` instance.""" 176 return {'reduction': self.reduction, 'name': self.name} 177 178 @abc.abstractmethod 179 @doc_controls.for_subclass_implementers 180 def call(self, y_true, y_pred): 181 """Invokes the `Loss` instance. 182 183 Args: 184 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`, except 185 sparse loss functions such as sparse categorical crossentropy where 186 shape = `[batch_size, d0, .. dN-1]` 187 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]` 188 189 Returns: 190 Loss values with the shape `[batch_size, d0, .. dN-1]`. 191 """ 192 raise NotImplementedError('Must be implemented in subclasses.') 193 194 def _get_reduction(self): 195 """Handles `AUTO` reduction cases and returns the reduction value.""" 196 if (not self._allow_sum_over_batch_size and 197 distribution_strategy_context.has_strategy() and 198 (self.reduction == losses_utils.ReductionV2.AUTO or 199 self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)): 200 raise ValueError( 201 'Please use `tf.keras.losses.Reduction.SUM` or ' 202 '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are ' 203 'used with `tf.distribute.Strategy` outside of the built-in training ' 204 'loops. You can implement ' 205 '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch ' 206 'size like:\n```\nwith strategy.scope():\n' 207 ' loss_obj = tf.keras.losses.CategoricalCrossentropy(' 208 'reduction=tf.keras.losses.Reduction.NONE)\n....\n' 209 ' loss = tf.reduce_sum(loss_obj(labels, predictions)) * ' 210 '(1. / global_batch_size)\n```\nPlease see ' 211 'https://www.tensorflow.org/tutorials/distribute/custom_training' 212 ' for more details.') 213 214 if self.reduction == losses_utils.ReductionV2.AUTO: 215 return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 216 return self.reduction 217 218 219class LossFunctionWrapper(Loss): 220 """Wraps a loss function in the `Loss` class.""" 221 222 def __init__(self, 223 fn, 224 reduction=losses_utils.ReductionV2.AUTO, 225 name=None, 226 **kwargs): 227 """Initializes `LossFunctionWrapper` class. 228 229 Args: 230 fn: The loss function to wrap, with signature `fn(y_true, y_pred, 231 **kwargs)`. 232 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 233 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 234 option will be determined by the usage context. For almost all cases 235 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 236 `tf.distribute.Strategy`, outside of built-in training loops such as 237 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 238 will raise an error. Please see this custom training [tutorial]( 239 https://www.tensorflow.org/tutorials/distribute/custom_training) for 240 more details. 241 name: (Optional) name for the loss. 242 **kwargs: The keyword arguments that are passed on to `fn`. 243 """ 244 super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name) 245 self.fn = fn 246 self._fn_kwargs = kwargs 247 248 def call(self, y_true, y_pred): 249 """Invokes the `LossFunctionWrapper` instance. 250 251 Args: 252 y_true: Ground truth values. 253 y_pred: The predicted values. 254 255 Returns: 256 Loss values per sample. 257 """ 258 if tensor_util.is_tf_type(y_pred) and tensor_util.is_tf_type(y_true): 259 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true) 260 261 ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx()) 262 return ag_fn(y_true, y_pred, **self._fn_kwargs) 263 264 def get_config(self): 265 config = {} 266 for k, v in six.iteritems(self._fn_kwargs): 267 config[k] = K.eval(v) if tf_utils.is_tensor_or_variable(v) else v 268 base_config = super(LossFunctionWrapper, self).get_config() 269 return dict(list(base_config.items()) + list(config.items())) 270 271 272@keras_export('keras.losses.MeanSquaredError') 273class MeanSquaredError(LossFunctionWrapper): 274 """Computes the mean of squares of errors between labels and predictions. 275 276 `loss = square(y_true - y_pred)` 277 278 Standalone usage: 279 280 >>> y_true = [[0., 1.], [0., 0.]] 281 >>> y_pred = [[1., 1.], [1., 0.]] 282 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 283 >>> mse = tf.keras.losses.MeanSquaredError() 284 >>> mse(y_true, y_pred).numpy() 285 0.5 286 287 >>> # Calling with 'sample_weight'. 288 >>> mse(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 289 0.25 290 291 >>> # Using 'sum' reduction type. 292 >>> mse = tf.keras.losses.MeanSquaredError( 293 ... reduction=tf.keras.losses.Reduction.SUM) 294 >>> mse(y_true, y_pred).numpy() 295 1.0 296 297 >>> # Using 'none' reduction type. 298 >>> mse = tf.keras.losses.MeanSquaredError( 299 ... reduction=tf.keras.losses.Reduction.NONE) 300 >>> mse(y_true, y_pred).numpy() 301 array([0.5, 0.5], dtype=float32) 302 303 Usage with the `compile()` API: 304 305 ```python 306 model.compile(optimizer='sgd', loss=tf.keras.losses.MeanSquaredError()) 307 ``` 308 """ 309 310 def __init__(self, 311 reduction=losses_utils.ReductionV2.AUTO, 312 name='mean_squared_error'): 313 """Initializes `MeanSquaredError` instance. 314 315 Args: 316 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 317 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 318 option will be determined by the usage context. For almost all cases 319 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 320 `tf.distribute.Strategy`, outside of built-in training loops such as 321 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 322 will raise an error. Please see this custom training [tutorial]( 323 https://www.tensorflow.org/tutorials/distribute/custom_training) for 324 more details. 325 name: Optional name for the op. Defaults to 'mean_squared_error'. 326 """ 327 super(MeanSquaredError, self).__init__( 328 mean_squared_error, name=name, reduction=reduction) 329 330 331@keras_export('keras.losses.MeanAbsoluteError') 332class MeanAbsoluteError(LossFunctionWrapper): 333 """Computes the mean of absolute difference between labels and predictions. 334 335 `loss = abs(y_true - y_pred)` 336 337 Standalone usage: 338 339 >>> y_true = [[0., 1.], [0., 0.]] 340 >>> y_pred = [[1., 1.], [1., 0.]] 341 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 342 >>> mae = tf.keras.losses.MeanAbsoluteError() 343 >>> mae(y_true, y_pred).numpy() 344 0.5 345 346 >>> # Calling with 'sample_weight'. 347 >>> mae(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 348 0.25 349 350 >>> # Using 'sum' reduction type. 351 >>> mae = tf.keras.losses.MeanAbsoluteError( 352 ... reduction=tf.keras.losses.Reduction.SUM) 353 >>> mae(y_true, y_pred).numpy() 354 1.0 355 356 >>> # Using 'none' reduction type. 357 >>> mae = tf.keras.losses.MeanAbsoluteError( 358 ... reduction=tf.keras.losses.Reduction.NONE) 359 >>> mae(y_true, y_pred).numpy() 360 array([0.5, 0.5], dtype=float32) 361 362 Usage with the `compile()` API: 363 364 ```python 365 model.compile(optimizer='sgd', loss=tf.keras.losses.MeanAbsoluteError()) 366 ``` 367 """ 368 369 def __init__(self, 370 reduction=losses_utils.ReductionV2.AUTO, 371 name='mean_absolute_error'): 372 """Initializes `MeanAbsoluteError` instance. 373 374 Args: 375 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 376 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 377 option will be determined by the usage context. For almost all cases 378 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 379 `tf.distribute.Strategy`, outside of built-in training loops such as 380 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 381 will raise an error. Please see this custom training [tutorial]( 382 https://www.tensorflow.org/tutorials/distribute/custom_training) for 383 more details. 384 name: Optional name for the op. Defaults to 'mean_absolute_error'. 385 """ 386 super(MeanAbsoluteError, self).__init__( 387 mean_absolute_error, name=name, reduction=reduction) 388 389 390@keras_export('keras.losses.MeanAbsolutePercentageError') 391class MeanAbsolutePercentageError(LossFunctionWrapper): 392 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 393 394 `loss = 100 * abs(y_true - y_pred) / y_true` 395 396 Standalone usage: 397 398 >>> y_true = [[2., 1.], [2., 3.]] 399 >>> y_pred = [[1., 1.], [1., 0.]] 400 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 401 >>> mape = tf.keras.losses.MeanAbsolutePercentageError() 402 >>> mape(y_true, y_pred).numpy() 403 50. 404 405 >>> # Calling with 'sample_weight'. 406 >>> mape(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 407 20. 408 409 >>> # Using 'sum' reduction type. 410 >>> mape = tf.keras.losses.MeanAbsolutePercentageError( 411 ... reduction=tf.keras.losses.Reduction.SUM) 412 >>> mape(y_true, y_pred).numpy() 413 100. 414 415 >>> # Using 'none' reduction type. 416 >>> mape = tf.keras.losses.MeanAbsolutePercentageError( 417 ... reduction=tf.keras.losses.Reduction.NONE) 418 >>> mape(y_true, y_pred).numpy() 419 array([25., 75.], dtype=float32) 420 421 Usage with the `compile()` API: 422 423 ```python 424 model.compile(optimizer='sgd', 425 loss=tf.keras.losses.MeanAbsolutePercentageError()) 426 ``` 427 """ 428 429 def __init__(self, 430 reduction=losses_utils.ReductionV2.AUTO, 431 name='mean_absolute_percentage_error'): 432 """Initializes `MeanAbsolutePercentageError` instance. 433 434 Args: 435 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 436 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 437 option will be determined by the usage context. For almost all cases 438 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 439 `tf.distribute.Strategy`, outside of built-in training loops such as 440 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 441 will raise an error. Please see this custom training [tutorial]( 442 https://www.tensorflow.org/tutorials/distribute/custom_training) for 443 more details. 444 name: Optional name for the op. Defaults to 445 'mean_absolute_percentage_error'. 446 """ 447 super(MeanAbsolutePercentageError, self).__init__( 448 mean_absolute_percentage_error, name=name, reduction=reduction) 449 450 451@keras_export('keras.losses.MeanSquaredLogarithmicError') 452class MeanSquaredLogarithmicError(LossFunctionWrapper): 453 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 454 455 `loss = square(log(y_true + 1.) - log(y_pred + 1.))` 456 457 Standalone usage: 458 459 >>> y_true = [[0., 1.], [0., 0.]] 460 >>> y_pred = [[1., 1.], [1., 0.]] 461 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 462 >>> msle = tf.keras.losses.MeanSquaredLogarithmicError() 463 >>> msle(y_true, y_pred).numpy() 464 0.240 465 466 >>> # Calling with 'sample_weight'. 467 >>> msle(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 468 0.120 469 470 >>> # Using 'sum' reduction type. 471 >>> msle = tf.keras.losses.MeanSquaredLogarithmicError( 472 ... reduction=tf.keras.losses.Reduction.SUM) 473 >>> msle(y_true, y_pred).numpy() 474 0.480 475 476 >>> # Using 'none' reduction type. 477 >>> msle = tf.keras.losses.MeanSquaredLogarithmicError( 478 ... reduction=tf.keras.losses.Reduction.NONE) 479 >>> msle(y_true, y_pred).numpy() 480 array([0.240, 0.240], dtype=float32) 481 482 Usage with the `compile()` API: 483 484 ```python 485 model.compile(optimizer='sgd', 486 loss=tf.keras.losses.MeanSquaredLogarithmicError()) 487 ``` 488 """ 489 490 def __init__(self, 491 reduction=losses_utils.ReductionV2.AUTO, 492 name='mean_squared_logarithmic_error'): 493 """Initializes `MeanSquaredLogarithmicError` instance. 494 495 Args: 496 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 497 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 498 option will be determined by the usage context. For almost all cases 499 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 500 `tf.distribute.Strategy`, outside of built-in training loops such as 501 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 502 will raise an error. Please see this custom training [tutorial]( 503 https://www.tensorflow.org/tutorials/distribute/custom_training) for 504 more details. 505 name: Optional name for the op. Defaults to 506 'mean_squared_logarithmic_error'. 507 """ 508 super(MeanSquaredLogarithmicError, self).__init__( 509 mean_squared_logarithmic_error, name=name, reduction=reduction) 510 511 512@keras_export('keras.losses.BinaryCrossentropy') 513class BinaryCrossentropy(LossFunctionWrapper): 514 """Computes the cross-entropy loss between true labels and predicted labels. 515 516 Use this cross-entropy loss for binary (0 or 1) classification applications. 517 The loss function requires the following inputs: 518 519 - `y_true` (true label): This is either 0 or 1. 520 - `y_pred` (predicted value): This is the model's prediction, i.e, a single 521 floating-point value which either represents a 522 [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] 523 when `from_logits=True`) or a probability (i.e, value in [0., 1.] when 524 `from_logits=False`). 525 526 **Recommended Usage:** (set `from_logits=True`) 527 528 With `tf.keras` API: 529 530 ```python 531 model.compile( 532 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 533 .... 534 ) 535 ``` 536 537 As a standalone function: 538 539 >>> # Example 1: (batch_size = 1, number of samples = 4) 540 >>> y_true = [0, 1, 0, 0] 541 >>> y_pred = [-18.6, 0.51, 2.94, -12.8] 542 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) 543 >>> bce(y_true, y_pred).numpy() 544 0.865 545 546 >>> # Example 2: (batch_size = 2, number of samples = 4) 547 >>> y_true = [[0, 1], [0, 0]] 548 >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]] 549 >>> # Using default 'auto'/'sum_over_batch_size' reduction type. 550 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) 551 >>> bce(y_true, y_pred).numpy() 552 0.865 553 >>> # Using 'sample_weight' attribute 554 >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 555 0.243 556 >>> # Using 'sum' reduction` type. 557 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True, 558 ... reduction=tf.keras.losses.Reduction.SUM) 559 >>> bce(y_true, y_pred).numpy() 560 1.730 561 >>> # Using 'none' reduction type. 562 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True, 563 ... reduction=tf.keras.losses.Reduction.NONE) 564 >>> bce(y_true, y_pred).numpy() 565 array([0.235, 1.496], dtype=float32) 566 567 **Default Usage:** (set `from_logits=False`) 568 569 >>> # Make the following updates to the above "Recommended Usage" section 570 >>> # 1. Set `from_logits=False` 571 >>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False') 572 >>> # 2. Update `y_pred` to use probabilities instead of logits 573 >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]] 574 """ 575 576 def __init__(self, 577 from_logits=False, 578 label_smoothing=0, 579 reduction=losses_utils.ReductionV2.AUTO, 580 name='binary_crossentropy'): 581 """Initializes `BinaryCrossentropy` instance. 582 583 Args: 584 from_logits: Whether to interpret `y_pred` as a tensor of 585 [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we 586 assume that `y_pred` contains probabilities (i.e., values in [0, 1]). 587 **Note - Using from_logits=True may be more numerically stable. 588 label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > 0, 589 we compute the loss between the predicted labels and a smoothed version 590 of the true labels, where the smoothing squeezes the labels towards 0.5. 591 Larger values of `label_smoothing` correspond to heavier smoothing. 592 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 593 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 594 option will be determined by the usage context. For almost all cases 595 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 596 `tf.distribute.Strategy`, outside of built-in training loops such as 597 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 598 will raise an error. Please see this custom training [tutorial]( 599 https://www.tensorflow.org/tutorials/distribute/custom_training) for 600 more details. 601 name: (Optional) Name for the op. Defaults to 'binary_crossentropy'. 602 """ 603 super(BinaryCrossentropy, self).__init__( 604 binary_crossentropy, 605 name=name, 606 reduction=reduction, 607 from_logits=from_logits, 608 label_smoothing=label_smoothing) 609 self.from_logits = from_logits 610 611 612@keras_export('keras.losses.CategoricalCrossentropy') 613class CategoricalCrossentropy(LossFunctionWrapper): 614 """Computes the crossentropy loss between the labels and predictions. 615 616 Use this crossentropy loss function when there are two or more label classes. 617 We expect labels to be provided in a `one_hot` representation. If you want to 618 provide labels as integers, please use `SparseCategoricalCrossentropy` loss. 619 There should be `# classes` floating point values per feature. 620 621 In the snippet below, there is `# classes` floating pointing values per 622 example. The shape of both `y_pred` and `y_true` are 623 `[batch_size, num_classes]`. 624 625 Standalone usage: 626 627 >>> y_true = [[0, 1, 0], [0, 0, 1]] 628 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 629 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 630 >>> cce = tf.keras.losses.CategoricalCrossentropy() 631 >>> cce(y_true, y_pred).numpy() 632 1.177 633 634 >>> # Calling with 'sample_weight'. 635 >>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() 636 0.814 637 638 >>> # Using 'sum' reduction type. 639 >>> cce = tf.keras.losses.CategoricalCrossentropy( 640 ... reduction=tf.keras.losses.Reduction.SUM) 641 >>> cce(y_true, y_pred).numpy() 642 2.354 643 644 >>> # Using 'none' reduction type. 645 >>> cce = tf.keras.losses.CategoricalCrossentropy( 646 ... reduction=tf.keras.losses.Reduction.NONE) 647 >>> cce(y_true, y_pred).numpy() 648 array([0.0513, 2.303], dtype=float32) 649 650 Usage with the `compile()` API: 651 652 ```python 653 model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalCrossentropy()) 654 ``` 655 """ 656 657 def __init__(self, 658 from_logits=False, 659 label_smoothing=0, 660 reduction=losses_utils.ReductionV2.AUTO, 661 name='categorical_crossentropy'): 662 """Initializes `CategoricalCrossentropy` instance. 663 664 Args: 665 from_logits: Whether `y_pred` is expected to be a logits tensor. By 666 default, we assume that `y_pred` encodes a probability distribution. 667 **Note - Using from_logits=True is more numerically stable.** 668 label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, 669 meaning the confidence on label values are relaxed. For example, if 670 `0.1`, use `0.1 / num_classes` for non-target labels and 671 `0.9 + 0.1 / num_classes` for target labels. 672 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 673 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 674 option will be determined by the usage context. For almost all cases 675 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 676 `tf.distribute.Strategy`, outside of built-in training loops such as 677 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 678 will raise an error. Please see this custom training [tutorial]( 679 https://www.tensorflow.org/tutorials/distribute/custom_training) for 680 more details. 681 name: Optional name for the op. Defaults to 'categorical_crossentropy'. 682 """ 683 super(CategoricalCrossentropy, self).__init__( 684 categorical_crossentropy, 685 name=name, 686 reduction=reduction, 687 from_logits=from_logits, 688 label_smoothing=label_smoothing) 689 690 691@keras_export('keras.losses.SparseCategoricalCrossentropy') 692class SparseCategoricalCrossentropy(LossFunctionWrapper): 693 """Computes the crossentropy loss between the labels and predictions. 694 695 Use this crossentropy loss function when there are two or more label classes. 696 We expect labels to be provided as integers. If you want to provide labels 697 using `one-hot` representation, please use `CategoricalCrossentropy` loss. 698 There should be `# classes` floating point values per feature for `y_pred` 699 and a single floating point value per feature for `y_true`. 700 701 In the snippet below, there is a single floating point value per example for 702 `y_true` and `# classes` floating pointing values per example for `y_pred`. 703 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is 704 `[batch_size, num_classes]`. 705 706 Standalone usage: 707 708 >>> y_true = [1, 2] 709 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 710 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 711 >>> scce = tf.keras.losses.SparseCategoricalCrossentropy() 712 >>> scce(y_true, y_pred).numpy() 713 1.177 714 715 >>> # Calling with 'sample_weight'. 716 >>> scce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() 717 0.814 718 719 >>> # Using 'sum' reduction type. 720 >>> scce = tf.keras.losses.SparseCategoricalCrossentropy( 721 ... reduction=tf.keras.losses.Reduction.SUM) 722 >>> scce(y_true, y_pred).numpy() 723 2.354 724 725 >>> # Using 'none' reduction type. 726 >>> scce = tf.keras.losses.SparseCategoricalCrossentropy( 727 ... reduction=tf.keras.losses.Reduction.NONE) 728 >>> scce(y_true, y_pred).numpy() 729 array([0.0513, 2.303], dtype=float32) 730 731 Usage with the `compile()` API: 732 733 ```python 734 model.compile(optimizer='sgd', 735 loss=tf.keras.losses.SparseCategoricalCrossentropy()) 736 ``` 737 """ 738 739 def __init__(self, 740 from_logits=False, 741 reduction=losses_utils.ReductionV2.AUTO, 742 name='sparse_categorical_crossentropy'): 743 """Initializes `SparseCategoricalCrossentropy` instance. 744 745 Args: 746 from_logits: Whether `y_pred` is expected to be a logits tensor. By 747 default, we assume that `y_pred` encodes a probability distribution. 748 **Note - Using from_logits=True may be more numerically stable. 749 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 750 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 751 option will be determined by the usage context. For almost all cases 752 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 753 `tf.distribute.Strategy`, outside of built-in training loops such as 754 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 755 will raise an error. Please see this custom training [tutorial]( 756 https://www.tensorflow.org/tutorials/distribute/custom_training) for 757 more details. 758 name: Optional name for the op. Defaults to 759 'sparse_categorical_crossentropy'. 760 """ 761 super(SparseCategoricalCrossentropy, self).__init__( 762 sparse_categorical_crossentropy, 763 name=name, 764 reduction=reduction, 765 from_logits=from_logits) 766 767 768@keras_export('keras.losses.Hinge') 769class Hinge(LossFunctionWrapper): 770 """Computes the hinge loss between `y_true` and `y_pred`. 771 772 `loss = maximum(1 - y_true * y_pred, 0)` 773 774 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 775 provided we will convert them to -1 or 1. 776 777 Standalone usage: 778 779 >>> y_true = [[0., 1.], [0., 0.]] 780 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 781 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 782 >>> h = tf.keras.losses.Hinge() 783 >>> h(y_true, y_pred).numpy() 784 1.3 785 786 >>> # Calling with 'sample_weight'. 787 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 788 0.55 789 790 >>> # Using 'sum' reduction type. 791 >>> h = tf.keras.losses.Hinge( 792 ... reduction=tf.keras.losses.Reduction.SUM) 793 >>> h(y_true, y_pred).numpy() 794 2.6 795 796 >>> # Using 'none' reduction type. 797 >>> h = tf.keras.losses.Hinge( 798 ... reduction=tf.keras.losses.Reduction.NONE) 799 >>> h(y_true, y_pred).numpy() 800 array([1.1, 1.5], dtype=float32) 801 802 Usage with the `compile()` API: 803 804 ```python 805 model.compile(optimizer='sgd', loss=tf.keras.losses.Hinge()) 806 ``` 807 """ 808 809 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='hinge'): 810 """Initializes `Hinge` instance. 811 812 Args: 813 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 814 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 815 option will be determined by the usage context. For almost all cases 816 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 817 `tf.distribute.Strategy`, outside of built-in training loops such as 818 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 819 will raise an error. Please see this custom training [tutorial]( 820 https://www.tensorflow.org/tutorials/distribute/custom_training) for 821 more details. 822 name: Optional name for the op. Defaults to 'hinge'. 823 """ 824 super(Hinge, self).__init__(hinge, name=name, reduction=reduction) 825 826 827@keras_export('keras.losses.SquaredHinge') 828class SquaredHinge(LossFunctionWrapper): 829 """Computes the squared hinge loss between `y_true` and `y_pred`. 830 831 `loss = square(maximum(1 - y_true * y_pred, 0))` 832 833 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 834 provided we will convert them to -1 or 1. 835 836 Standalone usage: 837 838 >>> y_true = [[0., 1.], [0., 0.]] 839 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 840 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 841 >>> h = tf.keras.losses.SquaredHinge() 842 >>> h(y_true, y_pred).numpy() 843 1.86 844 845 >>> # Calling with 'sample_weight'. 846 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 847 0.73 848 849 >>> # Using 'sum' reduction type. 850 >>> h = tf.keras.losses.SquaredHinge( 851 ... reduction=tf.keras.losses.Reduction.SUM) 852 >>> h(y_true, y_pred).numpy() 853 3.72 854 855 >>> # Using 'none' reduction type. 856 >>> h = tf.keras.losses.SquaredHinge( 857 ... reduction=tf.keras.losses.Reduction.NONE) 858 >>> h(y_true, y_pred).numpy() 859 array([1.46, 2.26], dtype=float32) 860 861 Usage with the `compile()` API: 862 863 ```python 864 model.compile(optimizer='sgd', loss=tf.keras.losses.SquaredHinge()) 865 ``` 866 """ 867 868 def __init__(self, 869 reduction=losses_utils.ReductionV2.AUTO, 870 name='squared_hinge'): 871 """Initializes `SquaredHinge` instance. 872 873 Args: 874 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 875 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 876 option will be determined by the usage context. For almost all cases 877 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 878 `tf.distribute.Strategy`, outside of built-in training loops such as 879 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 880 will raise an error. Please see this custom training [tutorial]( 881 https://www.tensorflow.org/tutorials/distribute/custom_training) for 882 more details. 883 name: Optional name for the op. Defaults to 'squared_hinge'. 884 """ 885 super(SquaredHinge, self).__init__( 886 squared_hinge, name=name, reduction=reduction) 887 888 889@keras_export('keras.losses.CategoricalHinge') 890class CategoricalHinge(LossFunctionWrapper): 891 """Computes the categorical hinge loss between `y_true` and `y_pred`. 892 893 `loss = maximum(neg - pos + 1, 0)` 894 where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)` 895 896 Standalone usage: 897 898 >>> y_true = [[0, 1], [0, 0]] 899 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 900 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 901 >>> h = tf.keras.losses.CategoricalHinge() 902 >>> h(y_true, y_pred).numpy() 903 1.4 904 905 >>> # Calling with 'sample_weight'. 906 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 907 0.6 908 909 >>> # Using 'sum' reduction type. 910 >>> h = tf.keras.losses.CategoricalHinge( 911 ... reduction=tf.keras.losses.Reduction.SUM) 912 >>> h(y_true, y_pred).numpy() 913 2.8 914 915 >>> # Using 'none' reduction type. 916 >>> h = tf.keras.losses.CategoricalHinge( 917 ... reduction=tf.keras.losses.Reduction.NONE) 918 >>> h(y_true, y_pred).numpy() 919 array([1.2, 1.6], dtype=float32) 920 921 Usage with the `compile()` API: 922 923 ```python 924 model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalHinge()) 925 ``` 926 """ 927 928 def __init__(self, 929 reduction=losses_utils.ReductionV2.AUTO, 930 name='categorical_hinge'): 931 """Initializes `CategoricalHinge` instance. 932 933 Args: 934 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 935 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 936 option will be determined by the usage context. For almost all cases 937 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 938 `tf.distribute.Strategy`, outside of built-in training loops such as 939 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 940 will raise an error. Please see this custom training [tutorial]( 941 https://www.tensorflow.org/tutorials/distribute/custom_training) for 942 more details. 943 name: Optional name for the op. Defaults to 'categorical_hinge'. 944 """ 945 super(CategoricalHinge, self).__init__( 946 categorical_hinge, name=name, reduction=reduction) 947 948 949@keras_export('keras.losses.Poisson') 950class Poisson(LossFunctionWrapper): 951 """Computes the Poisson loss between `y_true` and `y_pred`. 952 953 `loss = y_pred - y_true * log(y_pred)` 954 955 Standalone usage: 956 957 >>> y_true = [[0., 1.], [0., 0.]] 958 >>> y_pred = [[1., 1.], [0., 0.]] 959 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 960 >>> p = tf.keras.losses.Poisson() 961 >>> p(y_true, y_pred).numpy() 962 0.5 963 964 >>> # Calling with 'sample_weight'. 965 >>> p(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 966 0.4 967 968 >>> # Using 'sum' reduction type. 969 >>> p = tf.keras.losses.Poisson( 970 ... reduction=tf.keras.losses.Reduction.SUM) 971 >>> p(y_true, y_pred).numpy() 972 0.999 973 974 >>> # Using 'none' reduction type. 975 >>> p = tf.keras.losses.Poisson( 976 ... reduction=tf.keras.losses.Reduction.NONE) 977 >>> p(y_true, y_pred).numpy() 978 array([0.999, 0.], dtype=float32) 979 980 Usage with the `compile()` API: 981 982 ```python 983 model.compile(optimizer='sgd', loss=tf.keras.losses.Poisson()) 984 ``` 985 """ 986 987 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='poisson'): 988 """Initializes `Poisson` instance. 989 990 Args: 991 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 992 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 993 option will be determined by the usage context. For almost all cases 994 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 995 `tf.distribute.Strategy`, outside of built-in training loops such as 996 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 997 will raise an error. Please see this custom training [tutorial]( 998 https://www.tensorflow.org/tutorials/distribute/custom_training) for 999 more details. 1000 name: Optional name for the op. Defaults to 'poisson'. 1001 """ 1002 super(Poisson, self).__init__(poisson, name=name, reduction=reduction) 1003 1004 1005@keras_export('keras.losses.LogCosh') 1006class LogCosh(LossFunctionWrapper): 1007 """Computes the logarithm of the hyperbolic cosine of the prediction error. 1008 1009 `logcosh = log((exp(x) + exp(-x))/2)`, 1010 where x is the error `y_pred - y_true`. 1011 1012 Standalone usage: 1013 1014 >>> y_true = [[0., 1.], [0., 0.]] 1015 >>> y_pred = [[1., 1.], [0., 0.]] 1016 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1017 >>> l = tf.keras.losses.LogCosh() 1018 >>> l(y_true, y_pred).numpy() 1019 0.108 1020 1021 >>> # Calling with 'sample_weight'. 1022 >>> l(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 1023 0.087 1024 1025 >>> # Using 'sum' reduction type. 1026 >>> l = tf.keras.losses.LogCosh( 1027 ... reduction=tf.keras.losses.Reduction.SUM) 1028 >>> l(y_true, y_pred).numpy() 1029 0.217 1030 1031 >>> # Using 'none' reduction type. 1032 >>> l = tf.keras.losses.LogCosh( 1033 ... reduction=tf.keras.losses.Reduction.NONE) 1034 >>> l(y_true, y_pred).numpy() 1035 array([0.217, 0.], dtype=float32) 1036 1037 Usage with the `compile()` API: 1038 1039 ```python 1040 model.compile(optimizer='sgd', loss=tf.keras.losses.LogCosh()) 1041 ``` 1042 """ 1043 1044 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='log_cosh'): 1045 """Initializes `LogCosh` instance. 1046 1047 Args: 1048 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 1049 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 1050 option will be determined by the usage context. For almost all cases 1051 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 1052 `tf.distribute.Strategy`, outside of built-in training loops such as 1053 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 1054 will raise an error. Please see this custom training [tutorial]( 1055 https://www.tensorflow.org/tutorials/distribute/custom_training) for 1056 more details. 1057 name: Optional name for the op. Defaults to 'log_cosh'. 1058 """ 1059 super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction) 1060 1061 1062@keras_export('keras.losses.KLDivergence') 1063class KLDivergence(LossFunctionWrapper): 1064 """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. 1065 1066 `loss = y_true * log(y_true / y_pred)` 1067 1068 See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 1069 1070 Standalone usage: 1071 1072 >>> y_true = [[0, 1], [0, 0]] 1073 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 1074 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1075 >>> kl = tf.keras.losses.KLDivergence() 1076 >>> kl(y_true, y_pred).numpy() 1077 0.458 1078 1079 >>> # Calling with 'sample_weight'. 1080 >>> kl(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 1081 0.366 1082 1083 >>> # Using 'sum' reduction type. 1084 >>> kl = tf.keras.losses.KLDivergence( 1085 ... reduction=tf.keras.losses.Reduction.SUM) 1086 >>> kl(y_true, y_pred).numpy() 1087 0.916 1088 1089 >>> # Using 'none' reduction type. 1090 >>> kl = tf.keras.losses.KLDivergence( 1091 ... reduction=tf.keras.losses.Reduction.NONE) 1092 >>> kl(y_true, y_pred).numpy() 1093 array([0.916, -3.08e-06], dtype=float32) 1094 1095 Usage with the `compile()` API: 1096 1097 ```python 1098 model.compile(optimizer='sgd', loss=tf.keras.losses.KLDivergence()) 1099 ``` 1100 """ 1101 1102 def __init__(self, 1103 reduction=losses_utils.ReductionV2.AUTO, 1104 name='kl_divergence'): 1105 """Initializes `KLDivergence` instance. 1106 1107 Args: 1108 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 1109 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 1110 option will be determined by the usage context. For almost all cases 1111 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 1112 `tf.distribute.Strategy`, outside of built-in training loops such as 1113 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 1114 will raise an error. Please see this custom training [tutorial]( 1115 https://www.tensorflow.org/tutorials/distribute/custom_training) for 1116 more details. 1117 name: Optional name for the op. Defaults to 'kl_divergence'. 1118 """ 1119 super(KLDivergence, self).__init__( 1120 kl_divergence, name=name, reduction=reduction) 1121 1122 1123@keras_export('keras.losses.Huber') 1124class Huber(LossFunctionWrapper): 1125 """Computes the Huber loss between `y_true` and `y_pred`. 1126 1127 For each value x in `error = y_true - y_pred`: 1128 1129 ``` 1130 loss = 0.5 * x^2 if |x| <= d 1131 loss = 0.5 * d^2 + d * (|x| - d) if |x| > d 1132 ``` 1133 where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss 1134 1135 Standalone usage: 1136 1137 >>> y_true = [[0, 1], [0, 0]] 1138 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 1139 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1140 >>> h = tf.keras.losses.Huber() 1141 >>> h(y_true, y_pred).numpy() 1142 0.155 1143 1144 >>> # Calling with 'sample_weight'. 1145 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 1146 0.09 1147 1148 >>> # Using 'sum' reduction type. 1149 >>> h = tf.keras.losses.Huber( 1150 ... reduction=tf.keras.losses.Reduction.SUM) 1151 >>> h(y_true, y_pred).numpy() 1152 0.31 1153 1154 >>> # Using 'none' reduction type. 1155 >>> h = tf.keras.losses.Huber( 1156 ... reduction=tf.keras.losses.Reduction.NONE) 1157 >>> h(y_true, y_pred).numpy() 1158 array([0.18, 0.13], dtype=float32) 1159 1160 Usage with the `compile()` API: 1161 1162 ```python 1163 model.compile(optimizer='sgd', loss=tf.keras.losses.Huber()) 1164 ``` 1165 """ 1166 1167 def __init__(self, 1168 delta=1.0, 1169 reduction=losses_utils.ReductionV2.AUTO, 1170 name='huber_loss'): 1171 """Initializes `Huber` instance. 1172 1173 Args: 1174 delta: A float, the point where the Huber loss function changes from a 1175 quadratic to linear. 1176 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 1177 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 1178 option will be determined by the usage context. For almost all cases 1179 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 1180 `tf.distribute.Strategy`, outside of built-in training loops such as 1181 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 1182 will raise an error. Please see this custom training [tutorial]( 1183 https://www.tensorflow.org/tutorials/distribute/custom_training) for 1184 more details. 1185 name: Optional name for the op. Defaults to 'huber_loss'. 1186 """ 1187 super(Huber, self).__init__( 1188 huber, name=name, reduction=reduction, delta=delta) 1189 1190 1191@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse', 1192 'keras.metrics.MSE', 'keras.losses.mean_squared_error', 1193 'keras.losses.mse', 'keras.losses.MSE') 1194@dispatch.add_dispatch_support 1195def mean_squared_error(y_true, y_pred): 1196 """Computes the mean squared error between labels and predictions. 1197 1198 After computing the squared distance between the inputs, the mean value over 1199 the last dimension is returned. 1200 1201 `loss = mean(square(y_true - y_pred), axis=-1)` 1202 1203 Standalone usage: 1204 1205 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1206 >>> y_pred = np.random.random(size=(2, 3)) 1207 >>> loss = tf.keras.losses.mean_squared_error(y_true, y_pred) 1208 >>> assert loss.shape == (2,) 1209 >>> assert np.array_equal( 1210 ... loss.numpy(), np.mean(np.square(y_true - y_pred), axis=-1)) 1211 1212 Args: 1213 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1214 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1215 1216 Returns: 1217 Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. 1218 """ 1219 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1220 y_true = math_ops.cast(y_true, y_pred.dtype) 1221 return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1) 1222 1223 1224def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred): 1225 """Apply a loss function on a per batch basis. 1226 1227 Args: 1228 loss_fn: The loss function 1229 y_true: truth values (RaggedTensor) 1230 y_pred: predicted values (RaggedTensor) 1231 1232 Returns: 1233 Loss-function result. A dense tensor if the output has a single dimension 1234 (per-batch loss value); a ragged tensor otherwise. 1235 """ 1236 1237 def rt_is_equiv_dense(rt): 1238 """Returns true if this RaggedTensor has the same row_lenghts across 1239 1240 all ragged dimensions and thus can be converted to a dense tensor 1241 without loss of information. 1242 1243 Args: 1244 rt: RaggedTensor 1245 """ 1246 return math_ops.reduce_all([ 1247 math_ops.equal( 1248 math_ops.reduce_variance(math_ops.cast(row_lens, K.floatx())), 1249 constant_op.constant([0.])) for row_lens in rt.nested_row_lengths() 1250 ]) 1251 1252 def _convert_to_dense(inputs): 1253 return tuple(rt.to_tensor() for rt in inputs) 1254 1255 def _wrapper(inputs): 1256 _, y_pred = inputs 1257 if isinstance(y_pred, ragged_tensor.RaggedTensor): 1258 return control_flow_ops.cond( 1259 rt_is_equiv_dense(y_pred), 1260 lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs)) 1261 1262 return loss_fn(*inputs) 1263 1264 lshape = y_pred.shape.as_list()[1:-1] 1265 if len(lshape) > 0: 1266 spec = ragged_tensor.RaggedTensorSpec(shape=lshape, dtype=y_pred.dtype) 1267 else: 1268 spec = tensor_spec.TensorSpec(shape=[], dtype=y_pred.dtype) 1269 1270 nested_splits_list = [rt.nested_row_splits for rt in (y_true, y_pred)] 1271 assertion_list = ragged_util.assert_splits_match(nested_splits_list) 1272 with ops.control_dependencies(assertion_list): 1273 return ragged_map_ops.map_fn(_wrapper, elems=(y_true, y_pred), dtype=spec) 1274 1275 1276@dispatch.dispatch_for_types(mean_squared_error, ragged_tensor.RaggedTensor) 1277def _ragged_tensor_mse(y_true, y_pred): 1278 """ Implements support for handling RaggedTensors. 1279 1280 Args: 1281 y_true: RaggedTensor truth values. shape = `[batch_size, d0, .. dN]`. 1282 y_pred: RaggedTensor predicted values. shape = `[batch_size, d0, .. dN]`. 1283 1284 Returns: 1285 Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. 1286 When the number of dimensions of the batch feature vector [d0, .. dN] is 1287 greater than one the return value is a RaggedTensor. Otherwise a Dense 1288 tensor with dimensions [batch_size] is returned. 1289 """ 1290 return _ragged_tensor_apply_loss(mean_squared_error, y_true, y_pred) 1291 1292 1293@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae', 1294 'keras.metrics.MAE', 'keras.losses.mean_absolute_error', 1295 'keras.losses.mae', 'keras.losses.MAE') 1296@dispatch.add_dispatch_support 1297def mean_absolute_error(y_true, y_pred): 1298 """Computes the mean absolute error between labels and predictions. 1299 1300 `loss = mean(abs(y_true - y_pred), axis=-1)` 1301 1302 Standalone usage: 1303 1304 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1305 >>> y_pred = np.random.random(size=(2, 3)) 1306 >>> loss = tf.keras.losses.mean_absolute_error(y_true, y_pred) 1307 >>> assert loss.shape == (2,) 1308 >>> assert np.array_equal( 1309 ... loss.numpy(), np.mean(np.abs(y_true - y_pred), axis=-1)) 1310 1311 Args: 1312 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1313 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1314 1315 Returns: 1316 Mean absolute error values. shape = `[batch_size, d0, .. dN-1]`. 1317 """ 1318 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1319 y_true = math_ops.cast(y_true, y_pred.dtype) 1320 return K.mean(math_ops.abs(y_pred - y_true), axis=-1) 1321 1322 1323@dispatch.dispatch_for_types(mean_absolute_error, ragged_tensor.RaggedTensor) 1324def _ragged_tensor_mae(y_true, y_pred): 1325 """ RaggedTensor adapter for mean_absolute_error""" 1326 return _ragged_tensor_apply_loss(mean_absolute_error, y_true, y_pred) 1327 1328 1329@keras_export('keras.metrics.mean_absolute_percentage_error', 1330 'keras.metrics.mape', 'keras.metrics.MAPE', 1331 'keras.losses.mean_absolute_percentage_error', 1332 'keras.losses.mape', 'keras.losses.MAPE') 1333@dispatch.add_dispatch_support 1334def mean_absolute_percentage_error(y_true, y_pred): 1335 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 1336 1337 `loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)` 1338 1339 Standalone usage: 1340 1341 >>> y_true = np.random.random(size=(2, 3)) 1342 >>> y_true = np.maximum(y_true, 1e-7) # Prevent division by zero 1343 >>> y_pred = np.random.random(size=(2, 3)) 1344 >>> loss = tf.keras.losses.mean_absolute_percentage_error(y_true, y_pred) 1345 >>> assert loss.shape == (2,) 1346 >>> assert np.array_equal( 1347 ... loss.numpy(), 1348 ... 100. * np.mean(np.abs((y_true - y_pred) / y_true), axis=-1)) 1349 1350 Args: 1351 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1352 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1353 1354 Returns: 1355 Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`. 1356 """ 1357 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1358 y_true = math_ops.cast(y_true, y_pred.dtype) 1359 diff = math_ops.abs( 1360 (y_true - y_pred) / K.maximum(math_ops.abs(y_true), K.epsilon())) 1361 return 100. * K.mean(diff, axis=-1) 1362 1363 1364@keras_export('keras.metrics.mean_squared_logarithmic_error', 1365 'keras.metrics.msle', 'keras.metrics.MSLE', 1366 'keras.losses.mean_squared_logarithmic_error', 1367 'keras.losses.msle', 'keras.losses.MSLE') 1368@dispatch.add_dispatch_support 1369def mean_squared_logarithmic_error(y_true, y_pred): 1370 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 1371 1372 `loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)` 1373 1374 Standalone usage: 1375 1376 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1377 >>> y_pred = np.random.random(size=(2, 3)) 1378 >>> loss = tf.keras.losses.mean_squared_logarithmic_error(y_true, y_pred) 1379 >>> assert loss.shape == (2,) 1380 >>> y_true = np.maximum(y_true, 1e-7) 1381 >>> y_pred = np.maximum(y_pred, 1e-7) 1382 >>> assert np.allclose( 1383 ... loss.numpy(), 1384 ... np.mean( 1385 ... np.square(np.log(y_true + 1.) - np.log(y_pred + 1.)), axis=-1)) 1386 1387 Args: 1388 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1389 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1390 1391 Returns: 1392 Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`. 1393 """ 1394 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1395 y_true = math_ops.cast(y_true, y_pred.dtype) 1396 first_log = math_ops.log(K.maximum(y_pred, K.epsilon()) + 1.) 1397 second_log = math_ops.log(K.maximum(y_true, K.epsilon()) + 1.) 1398 return K.mean(math_ops.squared_difference(first_log, second_log), axis=-1) 1399 1400 1401def _maybe_convert_labels(y_true): 1402 """Converts binary labels into -1/1.""" 1403 are_zeros = math_ops.equal(y_true, 0) 1404 are_ones = math_ops.equal(y_true, 1) 1405 is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones)) 1406 1407 def _convert_binary_labels(): 1408 # Convert the binary labels to -1 or 1. 1409 return 2. * y_true - 1. 1410 1411 updated_y_true = smart_cond.smart_cond(is_binary, _convert_binary_labels, 1412 lambda: y_true) 1413 return updated_y_true 1414 1415 1416@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') 1417@dispatch.add_dispatch_support 1418def squared_hinge(y_true, y_pred): 1419 """Computes the squared hinge loss between `y_true` and `y_pred`. 1420 1421 `loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)` 1422 1423 Standalone usage: 1424 1425 >>> y_true = np.random.choice([-1, 1], size=(2, 3)) 1426 >>> y_pred = np.random.random(size=(2, 3)) 1427 >>> loss = tf.keras.losses.squared_hinge(y_true, y_pred) 1428 >>> assert loss.shape == (2,) 1429 >>> assert np.array_equal( 1430 ... loss.numpy(), 1431 ... np.mean(np.square(np.maximum(1. - y_true * y_pred, 0.)), axis=-1)) 1432 1433 Args: 1434 y_true: The ground truth values. `y_true` values are expected to be -1 or 1. 1435 If binary (0 or 1) labels are provided we will convert them to -1 or 1. 1436 shape = `[batch_size, d0, .. dN]`. 1437 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1438 1439 Returns: 1440 Squared hinge loss values. shape = `[batch_size, d0, .. dN-1]`. 1441 """ 1442 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1443 y_true = math_ops.cast(y_true, y_pred.dtype) 1444 y_true = _maybe_convert_labels(y_true) 1445 return K.mean( 1446 math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1) 1447 1448 1449@keras_export('keras.metrics.hinge', 'keras.losses.hinge') 1450@dispatch.add_dispatch_support 1451def hinge(y_true, y_pred): 1452 """Computes the hinge loss between `y_true` and `y_pred`. 1453 1454 `loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)` 1455 1456 Standalone usage: 1457 1458 >>> y_true = np.random.choice([-1, 1], size=(2, 3)) 1459 >>> y_pred = np.random.random(size=(2, 3)) 1460 >>> loss = tf.keras.losses.hinge(y_true, y_pred) 1461 >>> assert loss.shape == (2,) 1462 >>> assert np.array_equal( 1463 ... loss.numpy(), 1464 ... np.mean(np.maximum(1. - y_true * y_pred, 0.), axis=-1)) 1465 1466 Args: 1467 y_true: The ground truth values. `y_true` values are expected to be -1 or 1. 1468 If binary (0 or 1) labels are provided they will be converted to -1 or 1. 1469 shape = `[batch_size, d0, .. dN]`. 1470 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1471 1472 Returns: 1473 Hinge loss values. shape = `[batch_size, d0, .. dN-1]`. 1474 """ 1475 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1476 y_true = math_ops.cast(y_true, y_pred.dtype) 1477 y_true = _maybe_convert_labels(y_true) 1478 return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1) 1479 1480 1481@keras_export('keras.losses.categorical_hinge') 1482@dispatch.add_dispatch_support 1483def categorical_hinge(y_true, y_pred): 1484 """Computes the categorical hinge loss between `y_true` and `y_pred`. 1485 1486 `loss = maximum(neg - pos + 1, 0)` 1487 where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)` 1488 1489 Standalone usage: 1490 1491 >>> y_true = np.random.randint(0, 3, size=(2,)) 1492 >>> y_true = tf.keras.utils.to_categorical(y_true, num_classes=3) 1493 >>> y_pred = np.random.random(size=(2, 3)) 1494 >>> loss = tf.keras.losses.categorical_hinge(y_true, y_pred) 1495 >>> assert loss.shape == (2,) 1496 >>> pos = np.sum(y_true * y_pred, axis=-1) 1497 >>> neg = np.amax((1. - y_true) * y_pred, axis=-1) 1498 >>> assert np.array_equal(loss.numpy(), np.maximum(0., neg - pos + 1.)) 1499 1500 Args: 1501 y_true: The ground truth values. `y_true` values are expected to be 1502 either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor). 1503 y_pred: The predicted values. 1504 1505 Returns: 1506 Categorical hinge loss values. 1507 """ 1508 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1509 y_true = math_ops.cast(y_true, y_pred.dtype) 1510 pos = math_ops.reduce_sum(y_true * y_pred, axis=-1) 1511 neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1) 1512 zero = math_ops.cast(0., y_pred.dtype) 1513 return math_ops.maximum(neg - pos + 1., zero) 1514 1515 1516@keras_export('keras.losses.huber', v1=[]) 1517@dispatch.add_dispatch_support 1518def huber(y_true, y_pred, delta=1.0): 1519 """Computes Huber loss value. 1520 1521 For each value x in `error = y_true - y_pred`: 1522 1523 ``` 1524 loss = 0.5 * x^2 if |x| <= d 1525 loss = 0.5 * d^2 + d * (|x| - d) if |x| > d 1526 ``` 1527 where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss 1528 1529 Args: 1530 y_true: tensor of true targets. 1531 y_pred: tensor of predicted targets. 1532 delta: A float, the point where the Huber loss function changes from a 1533 quadratic to linear. 1534 1535 Returns: 1536 Tensor with one scalar loss entry per sample. 1537 """ 1538 y_pred = math_ops.cast(y_pred, dtype=K.floatx()) 1539 y_true = math_ops.cast(y_true, dtype=K.floatx()) 1540 delta = math_ops.cast(delta, dtype=K.floatx()) 1541 error = math_ops.subtract(y_pred, y_true) 1542 abs_error = math_ops.abs(error) 1543 half = ops.convert_to_tensor_v2_with_dispatch(0.5, dtype=abs_error.dtype) 1544 return K.mean( 1545 array_ops.where_v2( 1546 abs_error <= delta, half * math_ops.pow(error, 2), 1547 half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), 1548 axis=-1) 1549 1550 1551@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh', 1552 'keras.metrics.log_cosh', 'keras.metrics.logcosh') 1553@dispatch.add_dispatch_support 1554def log_cosh(y_true, y_pred): 1555 """Logarithm of the hyperbolic cosine of the prediction error. 1556 1557 `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and 1558 to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly 1559 like the mean squared error, but will not be so strongly affected by the 1560 occasional wildly incorrect prediction. 1561 1562 Standalone usage: 1563 1564 >>> y_true = np.random.random(size=(2, 3)) 1565 >>> y_pred = np.random.random(size=(2, 3)) 1566 >>> loss = tf.keras.losses.logcosh(y_true, y_pred) 1567 >>> assert loss.shape == (2,) 1568 >>> x = y_pred - y_true 1569 >>> assert np.allclose( 1570 ... loss.numpy(), 1571 ... np.mean(x + np.log(np.exp(-2. * x) + 1.) - math_ops.log(2.), axis=-1), 1572 ... atol=1e-5) 1573 1574 Args: 1575 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1576 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1577 1578 Returns: 1579 Logcosh error values. shape = `[batch_size, d0, .. dN-1]`. 1580 """ 1581 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1582 y_true = math_ops.cast(y_true, y_pred.dtype) 1583 1584 def _logcosh(x): 1585 return x + nn.softplus(-2. * x) - math_ops.cast(math_ops.log(2.), x.dtype) 1586 1587 return K.mean(_logcosh(y_pred - y_true), axis=-1) 1588 1589 1590@keras_export('keras.metrics.categorical_crossentropy', 1591 'keras.losses.categorical_crossentropy') 1592@dispatch.add_dispatch_support 1593def categorical_crossentropy(y_true, 1594 y_pred, 1595 from_logits=False, 1596 label_smoothing=0): 1597 """Computes the categorical crossentropy loss. 1598 1599 Standalone usage: 1600 1601 >>> y_true = [[0, 1, 0], [0, 0, 1]] 1602 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 1603 >>> loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred) 1604 >>> assert loss.shape == (2,) 1605 >>> loss.numpy() 1606 array([0.0513, 2.303], dtype=float32) 1607 1608 Args: 1609 y_true: Tensor of one-hot true targets. 1610 y_pred: Tensor of predicted targets. 1611 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1612 we assume that `y_pred` encodes a probability distribution. 1613 label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For 1614 example, if `0.1`, use `0.1 / num_classes` for non-target labels 1615 and `0.9 + 0.1 / num_classes` for target labels. 1616 1617 Returns: 1618 Categorical crossentropy loss value. 1619 """ 1620 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1621 y_true = math_ops.cast(y_true, y_pred.dtype) 1622 label_smoothing = ops.convert_to_tensor_v2_with_dispatch( 1623 label_smoothing, dtype=K.floatx()) 1624 1625 def _smooth_labels(): 1626 num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype) 1627 return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes) 1628 1629 y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels, 1630 lambda: y_true) 1631 return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits) 1632 1633 1634@dispatch.dispatch_for_types(categorical_crossentropy, 1635 ragged_tensor.RaggedTensor) 1636def _ragged_tensor_categorical_crossentropy(y_true, 1637 y_pred, 1638 from_logits=False, 1639 label_smoothing=0): 1640 """ Implements support for handling RaggedTensors. 1641 1642 Expected shape: (batch, sequence_len, n_classes) with sequence_len 1643 being variable per batch. 1644 Return shape: (batch, sequence_len). 1645 1646 When used by CategoricalCrossentropy() with the default reduction 1647 (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the 1648 number of elements independent of the batch. E.g. if the RaggedTensor 1649 has 2 batches with [2, 1] values respectivly the resulting loss is 1650 the sum of the individual loss values divided by 3. 1651 """ 1652 fn = functools.partial( 1653 categorical_crossentropy, 1654 from_logits=from_logits, 1655 label_smoothing=label_smoothing) 1656 return _ragged_tensor_apply_loss(fn, y_true, y_pred) 1657 1658 1659@keras_export('keras.metrics.sparse_categorical_crossentropy', 1660 'keras.losses.sparse_categorical_crossentropy') 1661@dispatch.add_dispatch_support 1662def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): 1663 """Computes the sparse categorical crossentropy loss. 1664 1665 Standalone usage: 1666 1667 >>> y_true = [1, 2] 1668 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 1669 >>> loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) 1670 >>> assert loss.shape == (2,) 1671 >>> loss.numpy() 1672 array([0.0513, 2.303], dtype=float32) 1673 1674 Args: 1675 y_true: Ground truth values. 1676 y_pred: The predicted values. 1677 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1678 we assume that `y_pred` encodes a probability distribution. 1679 axis: (Optional) Defaults to -1. The dimension along which the entropy is 1680 computed. 1681 1682 Returns: 1683 Sparse categorical crossentropy loss value. 1684 """ 1685 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1686 y_true = math_ops.cast(y_true, y_pred.dtype) 1687 return K.sparse_categorical_crossentropy( 1688 y_true, y_pred, from_logits=from_logits, axis=axis) 1689 1690 1691@keras_export('keras.metrics.binary_crossentropy', 1692 'keras.losses.binary_crossentropy') 1693@dispatch.add_dispatch_support 1694def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): 1695 """Computes the binary crossentropy loss. 1696 1697 Standalone usage: 1698 1699 >>> y_true = [[0, 1], [0, 0]] 1700 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 1701 >>> loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) 1702 >>> assert loss.shape == (2,) 1703 >>> loss.numpy() 1704 array([0.916 , 0.714], dtype=float32) 1705 1706 Args: 1707 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1708 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1709 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1710 we assume that `y_pred` encodes a probability distribution. 1711 label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by 1712 squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` 1713 for the target class and `0.5 * label_smoothing` for the non-target class. 1714 1715 Returns: 1716 Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`. 1717 """ 1718 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1719 y_true = math_ops.cast(y_true, y_pred.dtype) 1720 label_smoothing = ops.convert_to_tensor_v2_with_dispatch( 1721 label_smoothing, dtype=K.floatx()) 1722 1723 def _smooth_labels(): 1724 return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing 1725 1726 y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels, 1727 lambda: y_true) 1728 return K.mean( 1729 K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1) 1730 1731 1732@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor) 1733def _ragged_tensor_binary_crossentropy(y_true, 1734 y_pred, 1735 from_logits=False, 1736 label_smoothing=0): 1737 """ Implements support for handling RaggedTensors. 1738 1739 Expected shape: (batch, sequence_len) with sequence_len being variable 1740 per batch. 1741 Return shape: (batch,); returns the per batch mean of the loss values. 1742 1743 When used by BinaryCrossentropy() with the default reduction 1744 (SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over 1745 the number of batches. 1746 """ 1747 fn = functools.partial( 1748 binary_crossentropy, 1749 from_logits=from_logits, 1750 label_smoothing=label_smoothing) 1751 return _ragged_tensor_apply_loss(fn, y_true, y_pred) 1752 1753 1754@keras_export('keras.metrics.kl_divergence', 1755 'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld', 1756 'keras.metrics.KLD', 'keras.losses.kl_divergence', 1757 'keras.losses.kullback_leibler_divergence', 'keras.losses.kld', 1758 'keras.losses.KLD') 1759@dispatch.add_dispatch_support 1760def kl_divergence(y_true, y_pred): 1761 """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. 1762 1763 `loss = y_true * log(y_true / y_pred)` 1764 1765 See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 1766 1767 Standalone usage: 1768 1769 >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float64) 1770 >>> y_pred = np.random.random(size=(2, 3)) 1771 >>> loss = tf.keras.losses.kullback_leibler_divergence(y_true, y_pred) 1772 >>> assert loss.shape == (2,) 1773 >>> y_true = tf.keras.backend.clip(y_true, 1e-7, 1) 1774 >>> y_pred = tf.keras.backend.clip(y_pred, 1e-7, 1) 1775 >>> assert np.array_equal( 1776 ... loss.numpy(), np.sum(y_true * np.log(y_true / y_pred), axis=-1)) 1777 1778 Args: 1779 y_true: Tensor of true targets. 1780 y_pred: Tensor of predicted targets. 1781 1782 Returns: 1783 A `Tensor` with loss. 1784 1785 Raises: 1786 TypeError: If `y_true` cannot be cast to the `y_pred.dtype`. 1787 """ 1788 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1789 y_true = math_ops.cast(y_true, y_pred.dtype) 1790 y_true = K.clip(y_true, K.epsilon(), 1) 1791 y_pred = K.clip(y_pred, K.epsilon(), 1) 1792 return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1) 1793 1794 1795@keras_export('keras.metrics.poisson', 'keras.losses.poisson') 1796@dispatch.add_dispatch_support 1797def poisson(y_true, y_pred): 1798 """Computes the Poisson loss between y_true and y_pred. 1799 1800 The Poisson loss is the mean of the elements of the `Tensor` 1801 `y_pred - y_true * log(y_pred)`. 1802 1803 Standalone usage: 1804 1805 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1806 >>> y_pred = np.random.random(size=(2, 3)) 1807 >>> loss = tf.keras.losses.poisson(y_true, y_pred) 1808 >>> assert loss.shape == (2,) 1809 >>> y_pred = y_pred + 1e-7 1810 >>> assert np.allclose( 1811 ... loss.numpy(), np.mean(y_pred - y_true * np.log(y_pred), axis=-1), 1812 ... atol=1e-5) 1813 1814 Args: 1815 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1816 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1817 1818 Returns: 1819 Poisson loss value. shape = `[batch_size, d0, .. dN-1]`. 1820 1821 Raises: 1822 InvalidArgumentError: If `y_true` and `y_pred` have incompatible shapes. 1823 """ 1824 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1825 y_true = math_ops.cast(y_true, y_pred.dtype) 1826 return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1) 1827 1828 1829@keras_export( 1830 'keras.losses.cosine_similarity', 1831 v1=[ 1832 'keras.metrics.cosine_proximity', 1833 'keras.metrics.cosine', 1834 'keras.losses.cosine_proximity', 1835 'keras.losses.cosine', 1836 'keras.losses.cosine_similarity', 1837 ]) 1838@dispatch.add_dispatch_support 1839def cosine_similarity(y_true, y_pred, axis=-1): 1840 """Computes the cosine similarity between labels and predictions. 1841 1842 Note that it is a number between -1 and 1. When it is a negative number 1843 between -1 and 0, 0 indicates orthogonality and values closer to -1 1844 indicate greater similarity. The values closer to 1 indicate greater 1845 dissimilarity. This makes it usable as a loss function in a setting 1846 where you try to maximize the proximity between predictions and 1847 targets. If either `y_true` or `y_pred` is a zero vector, cosine 1848 similarity will be 0 regardless of the proximity between predictions 1849 and targets. 1850 1851 `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))` 1852 1853 Standalone usage: 1854 1855 >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] 1856 >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] 1857 >>> loss = tf.keras.losses.cosine_similarity(y_true, y_pred, axis=1) 1858 >>> loss.numpy() 1859 array([-0., -0.999, 0.999], dtype=float32) 1860 1861 Args: 1862 y_true: Tensor of true targets. 1863 y_pred: Tensor of predicted targets. 1864 axis: Axis along which to determine similarity. 1865 1866 Returns: 1867 Cosine similarity tensor. 1868 """ 1869 y_true = nn.l2_normalize(y_true, axis=axis) 1870 y_pred = nn.l2_normalize(y_pred, axis=axis) 1871 return -math_ops.reduce_sum(y_true * y_pred, axis=axis) 1872 1873 1874@keras_export('keras.losses.CosineSimilarity') 1875class CosineSimilarity(LossFunctionWrapper): 1876 """Computes the cosine similarity between labels and predictions. 1877 1878 Note that it is a number between -1 and 1. When it is a negative number 1879 between -1 and 0, 0 indicates orthogonality and values closer to -1 1880 indicate greater similarity. The values closer to 1 indicate greater 1881 dissimilarity. This makes it usable as a loss function in a setting 1882 where you try to maximize the proximity between predictions and targets. 1883 If either `y_true` or `y_pred` is a zero vector, cosine similarity will be 0 1884 regardless of the proximity between predictions and targets. 1885 1886 `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))` 1887 1888 Standalone usage: 1889 1890 >>> y_true = [[0., 1.], [1., 1.]] 1891 >>> y_pred = [[1., 0.], [1., 1.]] 1892 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1893 >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1) 1894 >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] 1895 >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] 1896 >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] 1897 >>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) 1898 >>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 1899 >>> cosine_loss(y_true, y_pred).numpy() 1900 -0.5 1901 1902 >>> # Calling with 'sample_weight'. 1903 >>> cosine_loss(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 1904 -0.0999 1905 1906 >>> # Using 'sum' reduction type. 1907 >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1, 1908 ... reduction=tf.keras.losses.Reduction.SUM) 1909 >>> cosine_loss(y_true, y_pred).numpy() 1910 -0.999 1911 1912 >>> # Using 'none' reduction type. 1913 >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1, 1914 ... reduction=tf.keras.losses.Reduction.NONE) 1915 >>> cosine_loss(y_true, y_pred).numpy() 1916 array([-0., -0.999], dtype=float32) 1917 1918 Usage with the `compile()` API: 1919 1920 ```python 1921 model.compile(optimizer='sgd', loss=tf.keras.losses.CosineSimilarity(axis=1)) 1922 ``` 1923 1924 Args: 1925 axis: (Optional) Defaults to -1. The dimension along which the cosine 1926 similarity is computed. 1927 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. 1928 Default value is `AUTO`. `AUTO` indicates that the reduction option will 1929 be determined by the usage context. For almost all cases this defaults to 1930 `SUM_OVER_BATCH_SIZE`. When used with `tf.distribute.Strategy`, outside of 1931 built-in training loops such as `tf.keras` `compile` and `fit`, using 1932 `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this 1933 custom training [tutorial] 1934 (https://www.tensorflow.org/tutorials/distribute/custom_training) for more 1935 details. 1936 name: Optional name for the op. 1937 """ 1938 1939 def __init__(self, 1940 axis=-1, 1941 reduction=losses_utils.ReductionV2.AUTO, 1942 name='cosine_similarity'): 1943 super(CosineSimilarity, self).__init__( 1944 cosine_similarity, reduction=reduction, name=name, axis=axis) 1945 1946 1947# Aliases. 1948 1949bce = BCE = binary_crossentropy 1950mse = MSE = mean_squared_error 1951mae = MAE = mean_absolute_error 1952mape = MAPE = mean_absolute_percentage_error 1953msle = MSLE = mean_squared_logarithmic_error 1954kld = KLD = kullback_leibler_divergence = kl_divergence 1955logcosh = log_cosh 1956huber_loss = huber 1957 1958 1959def is_categorical_crossentropy(loss): 1960 result = ((isinstance(loss, CategoricalCrossentropy) or 1961 (isinstance(loss, LossFunctionWrapper) and 1962 loss.fn == categorical_crossentropy) or 1963 (hasattr(loss, '__name__') and 1964 loss.__name__ == 'categorical_crossentropy') or 1965 (loss == 'categorical_crossentropy'))) 1966 return result 1967 1968 1969@keras_export('keras.losses.serialize') 1970def serialize(loss): 1971 """Serializes loss function or `Loss` instance. 1972 1973 Args: 1974 loss: A Keras `Loss` instance or a loss function. 1975 1976 Returns: 1977 Loss configuration dictionary. 1978 """ 1979 return serialize_keras_object(loss) 1980 1981 1982@keras_export('keras.losses.deserialize') 1983def deserialize(name, custom_objects=None): 1984 """Deserializes a serialized loss class/function instance. 1985 1986 Args: 1987 name: Loss configuration. 1988 custom_objects: Optional dictionary mapping names (strings) to custom 1989 objects (classes and functions) to be considered during deserialization. 1990 1991 Returns: 1992 A Keras `Loss` instance or a loss function. 1993 """ 1994 return deserialize_keras_object( 1995 name, 1996 module_objects=globals(), 1997 custom_objects=custom_objects, 1998 printable_module_name='loss function') 1999 2000 2001@keras_export('keras.losses.get') 2002def get(identifier): 2003 """Retrieves a Keras loss as a `function`/`Loss` class instance. 2004 2005 The `identifier` may be the string name of a loss function or `Loss` class. 2006 2007 >>> loss = tf.keras.losses.get("categorical_crossentropy") 2008 >>> type(loss) 2009 <class 'function'> 2010 >>> loss = tf.keras.losses.get("CategoricalCrossentropy") 2011 >>> type(loss) 2012 <class '...tensorflow.python.keras.losses.CategoricalCrossentropy'> 2013 2014 You can also specify `config` of the loss to this function by passing dict 2015 containing `class_name` and `config` as an identifier. Also note that the 2016 `class_name` must map to a `Loss` class 2017 2018 >>> identifier = {"class_name": "CategoricalCrossentropy", 2019 ... "config": {"from_logits": True}} 2020 >>> loss = tf.keras.losses.get(identifier) 2021 >>> type(loss) 2022 <class '...tensorflow.python.keras.losses.CategoricalCrossentropy'> 2023 2024 Args: 2025 identifier: A loss identifier. One of None or string name of a loss 2026 function/class or loss configuration dictionary or a loss function or a 2027 loss class instance 2028 2029 Returns: 2030 A Keras loss as a `function`/ `Loss` class instance. 2031 2032 Raises: 2033 ValueError: If `identifier` cannot be interpreted. 2034 """ 2035 if identifier is None: 2036 return None 2037 if isinstance(identifier, six.string_types): 2038 identifier = str(identifier) 2039 return deserialize(identifier) 2040 if isinstance(identifier, dict): 2041 return deserialize(identifier) 2042 elif callable(identifier): 2043 return identifier 2044 else: 2045 raise ValueError( 2046 'Could not interpret loss function identifier: {}'.format(identifier)) 2047 2048 2049LABEL_DTYPES_FOR_LOSSES = { 2050 losses_impl.sparse_softmax_cross_entropy: 'int32', 2051 sparse_categorical_crossentropy: 'int32' 2052} 2053