• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in loss functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import functools
22
23import six
24
25from tensorflow.python.autograph.core import ag_ctx
26from tensorflow.python.autograph.impl import api as autograph
27from tensorflow.python.distribute import distribution_strategy_context
28from tensorflow.python.eager import context
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import smart_cond
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.keras import backend as K
35from tensorflow.python.keras.utils import losses_utils
36from tensorflow.python.keras.utils import tf_utils
37from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
38from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import nn
43from tensorflow.python.ops.losses import losses_impl
44from tensorflow.python.ops.ragged import ragged_map_ops
45from tensorflow.python.ops.ragged import ragged_tensor
46from tensorflow.python.ops.ragged import ragged_util
47from tensorflow.python.util import dispatch
48from tensorflow.python.util.tf_export import keras_export
49from tensorflow.tools.docs import doc_controls
50
51
52@keras_export('keras.losses.Loss')
53class Loss(object):
54  """Loss base class.
55
56  To be implemented by subclasses:
57  * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.
58
59  Example subclass implementation:
60
61  ```python
62  class MeanSquaredError(Loss):
63
64    def call(self, y_true, y_pred):
65      y_pred = tf.convert_to_tensor_v2(y_pred)
66      y_true = tf.cast(y_true, y_pred.dtype)
67      return tf.reduce_mean(math_ops.square(y_pred - y_true), axis=-1)
68  ```
69
70  When used with `tf.distribute.Strategy`, outside of built-in training loops
71  such as `tf.keras` `compile` and `fit`, please use 'SUM' or 'NONE' reduction
72  types, and reduce losses explicitly in your training loop. Using 'AUTO' or
73  'SUM_OVER_BATCH_SIZE' will raise an error.
74
75  Please see this custom training [tutorial](
76    https://www.tensorflow.org/tutorials/distribute/custom_training) for more
77  details on this.
78
79  You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:
80  ```python
81  with strategy.scope():
82    loss_obj = tf.keras.losses.CategoricalCrossentropy(
83        reduction=tf.keras.losses.Reduction.NONE)
84    ....
85    loss = (tf.reduce_sum(loss_obj(labels, predictions)) *
86            (1. / global_batch_size))
87  ```
88  """
89
90  def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name=None):
91    """Initializes `Loss` class.
92
93    Args:
94      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
95        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
96        option will be determined by the usage context. For almost all cases
97        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
98        `tf.distribute.Strategy`, outside of built-in training loops such as
99        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
100        will raise an error. Please see this custom training [tutorial](
101          https://www.tensorflow.org/tutorials/distribute/custom_training) for
102            more details.
103      name: Optional name for the op.
104    """
105    losses_utils.ReductionV2.validate(reduction)
106    self.reduction = reduction
107    self.name = name
108    # SUM_OVER_BATCH is only allowed in losses managed by `fit` or
109    # CannedEstimators.
110    self._allow_sum_over_batch_size = False
111    self._set_name_scope()
112
113  def _set_name_scope(self):
114    """Creates a valid `name_scope` name."""
115    if self.name is None:
116      self._name_scope = self.__class__.__name__
117    elif self.name == '<lambda>':
118      self._name_scope = 'lambda'
119    else:
120      # E.g. '_my_loss' => 'my_loss'
121      self._name_scope = self.name.strip('_')
122
123  def __call__(self, y_true, y_pred, sample_weight=None):
124    """Invokes the `Loss` instance.
125
126    Args:
127      y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`, except
128        sparse loss functions such as sparse categorical crossentropy where
129        shape = `[batch_size, d0, .. dN-1]`
130      y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
131      sample_weight: Optional `sample_weight` acts as a coefficient for the
132        loss. If a scalar is provided, then the loss is simply scaled by the
133        given value. If `sample_weight` is a tensor of size `[batch_size]`, then
134        the total loss for each sample of the batch is rescaled by the
135        corresponding element in the `sample_weight` vector. If the shape of
136        `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to
137        this shape), then each loss element of `y_pred` is scaled
138        by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
139          functions reduce by 1 dimension, usually axis=-1.)
140
141    Returns:
142      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has
143        shape `[batch_size, d0, .. dN-1]`; otherwise, it is scalar. (Note `dN-1`
144        because all loss functions reduce by 1 dimension, usually axis=-1.)
145
146    Raises:
147      ValueError: If the shape of `sample_weight` is invalid.
148    """
149    # If we are wrapping a lambda function strip '<>' from the name as it is not
150    # accepted in scope name.
151    graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
152        y_true, y_pred, sample_weight)
153    with K.name_scope(self._name_scope), graph_ctx:
154      if context.executing_eagerly():
155        call_fn = self.call
156      else:
157        call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
158      losses = call_fn(y_true, y_pred)
159      return losses_utils.compute_weighted_loss(
160          losses, sample_weight, reduction=self._get_reduction())
161
162  @classmethod
163  def from_config(cls, config):
164    """Instantiates a `Loss` from its config (output of `get_config()`).
165
166    Args:
167        config: Output of `get_config()`.
168
169    Returns:
170        A `Loss` instance.
171    """
172    return cls(**config)
173
174  def get_config(self):
175    """Returns the config dictionary for a `Loss` instance."""
176    return {'reduction': self.reduction, 'name': self.name}
177
178  @abc.abstractmethod
179  @doc_controls.for_subclass_implementers
180  def call(self, y_true, y_pred):
181    """Invokes the `Loss` instance.
182
183    Args:
184      y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`, except
185        sparse loss functions such as sparse categorical crossentropy where
186        shape = `[batch_size, d0, .. dN-1]`
187      y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
188
189    Returns:
190      Loss values with the shape `[batch_size, d0, .. dN-1]`.
191    """
192    raise NotImplementedError('Must be implemented in subclasses.')
193
194  def _get_reduction(self):
195    """Handles `AUTO` reduction cases and returns the reduction value."""
196    if (not self._allow_sum_over_batch_size and
197        distribution_strategy_context.has_strategy() and
198        (self.reduction == losses_utils.ReductionV2.AUTO or
199         self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)):
200      raise ValueError(
201          'Please use `tf.keras.losses.Reduction.SUM` or '
202          '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are '
203          'used with `tf.distribute.Strategy` outside of the built-in training '
204          'loops. You can implement '
205          '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch '
206          'size like:\n```\nwith strategy.scope():\n'
207          '    loss_obj = tf.keras.losses.CategoricalCrossentropy('
208          'reduction=tf.keras.losses.Reduction.NONE)\n....\n'
209          '    loss = tf.reduce_sum(loss_obj(labels, predictions)) * '
210          '(1. / global_batch_size)\n```\nPlease see '
211          'https://www.tensorflow.org/tutorials/distribute/custom_training'
212          ' for more details.')
213
214    if self.reduction == losses_utils.ReductionV2.AUTO:
215      return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
216    return self.reduction
217
218
219class LossFunctionWrapper(Loss):
220  """Wraps a loss function in the `Loss` class."""
221
222  def __init__(self,
223               fn,
224               reduction=losses_utils.ReductionV2.AUTO,
225               name=None,
226               **kwargs):
227    """Initializes `LossFunctionWrapper` class.
228
229    Args:
230      fn: The loss function to wrap, with signature `fn(y_true, y_pred,
231        **kwargs)`.
232      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
233        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
234        option will be determined by the usage context. For almost all cases
235        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
236        `tf.distribute.Strategy`, outside of built-in training loops such as
237        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
238        will raise an error. Please see this custom training [tutorial](
239          https://www.tensorflow.org/tutorials/distribute/custom_training) for
240            more details.
241      name: (Optional) name for the loss.
242      **kwargs: The keyword arguments that are passed on to `fn`.
243    """
244    super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
245    self.fn = fn
246    self._fn_kwargs = kwargs
247
248  def call(self, y_true, y_pred):
249    """Invokes the `LossFunctionWrapper` instance.
250
251    Args:
252      y_true: Ground truth values.
253      y_pred: The predicted values.
254
255    Returns:
256      Loss values per sample.
257    """
258    if tensor_util.is_tf_type(y_pred) and tensor_util.is_tf_type(y_true):
259      y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
260
261    ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
262    return ag_fn(y_true, y_pred, **self._fn_kwargs)
263
264  def get_config(self):
265    config = {}
266    for k, v in six.iteritems(self._fn_kwargs):
267      config[k] = K.eval(v) if tf_utils.is_tensor_or_variable(v) else v
268    base_config = super(LossFunctionWrapper, self).get_config()
269    return dict(list(base_config.items()) + list(config.items()))
270
271
272@keras_export('keras.losses.MeanSquaredError')
273class MeanSquaredError(LossFunctionWrapper):
274  """Computes the mean of squares of errors between labels and predictions.
275
276  `loss = square(y_true - y_pred)`
277
278  Standalone usage:
279
280  >>> y_true = [[0., 1.], [0., 0.]]
281  >>> y_pred = [[1., 1.], [1., 0.]]
282  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
283  >>> mse = tf.keras.losses.MeanSquaredError()
284  >>> mse(y_true, y_pred).numpy()
285  0.5
286
287  >>> # Calling with 'sample_weight'.
288  >>> mse(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
289  0.25
290
291  >>> # Using 'sum' reduction type.
292  >>> mse = tf.keras.losses.MeanSquaredError(
293  ...     reduction=tf.keras.losses.Reduction.SUM)
294  >>> mse(y_true, y_pred).numpy()
295  1.0
296
297  >>> # Using 'none' reduction type.
298  >>> mse = tf.keras.losses.MeanSquaredError(
299  ...     reduction=tf.keras.losses.Reduction.NONE)
300  >>> mse(y_true, y_pred).numpy()
301  array([0.5, 0.5], dtype=float32)
302
303  Usage with the `compile()` API:
304
305  ```python
306  model.compile(optimizer='sgd', loss=tf.keras.losses.MeanSquaredError())
307  ```
308  """
309
310  def __init__(self,
311               reduction=losses_utils.ReductionV2.AUTO,
312               name='mean_squared_error'):
313    """Initializes `MeanSquaredError` instance.
314
315    Args:
316      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
317        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
318        option will be determined by the usage context. For almost all cases
319        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
320        `tf.distribute.Strategy`, outside of built-in training loops such as
321        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
322        will raise an error. Please see this custom training [tutorial](
323          https://www.tensorflow.org/tutorials/distribute/custom_training) for
324            more details.
325      name: Optional name for the op. Defaults to 'mean_squared_error'.
326    """
327    super(MeanSquaredError, self).__init__(
328        mean_squared_error, name=name, reduction=reduction)
329
330
331@keras_export('keras.losses.MeanAbsoluteError')
332class MeanAbsoluteError(LossFunctionWrapper):
333  """Computes the mean of absolute difference between labels and predictions.
334
335  `loss = abs(y_true - y_pred)`
336
337  Standalone usage:
338
339  >>> y_true = [[0., 1.], [0., 0.]]
340  >>> y_pred = [[1., 1.], [1., 0.]]
341  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
342  >>> mae = tf.keras.losses.MeanAbsoluteError()
343  >>> mae(y_true, y_pred).numpy()
344  0.5
345
346  >>> # Calling with 'sample_weight'.
347  >>> mae(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
348  0.25
349
350  >>> # Using 'sum' reduction type.
351  >>> mae = tf.keras.losses.MeanAbsoluteError(
352  ...     reduction=tf.keras.losses.Reduction.SUM)
353  >>> mae(y_true, y_pred).numpy()
354  1.0
355
356  >>> # Using 'none' reduction type.
357  >>> mae = tf.keras.losses.MeanAbsoluteError(
358  ...     reduction=tf.keras.losses.Reduction.NONE)
359  >>> mae(y_true, y_pred).numpy()
360  array([0.5, 0.5], dtype=float32)
361
362  Usage with the `compile()` API:
363
364  ```python
365  model.compile(optimizer='sgd', loss=tf.keras.losses.MeanAbsoluteError())
366  ```
367  """
368
369  def __init__(self,
370               reduction=losses_utils.ReductionV2.AUTO,
371               name='mean_absolute_error'):
372    """Initializes `MeanAbsoluteError` instance.
373
374    Args:
375      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
376        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
377        option will be determined by the usage context. For almost all cases
378        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
379        `tf.distribute.Strategy`, outside of built-in training loops such as
380        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
381        will raise an error. Please see this custom training [tutorial](
382          https://www.tensorflow.org/tutorials/distribute/custom_training) for
383            more details.
384      name: Optional name for the op. Defaults to 'mean_absolute_error'.
385    """
386    super(MeanAbsoluteError, self).__init__(
387        mean_absolute_error, name=name, reduction=reduction)
388
389
390@keras_export('keras.losses.MeanAbsolutePercentageError')
391class MeanAbsolutePercentageError(LossFunctionWrapper):
392  """Computes the mean absolute percentage error between `y_true` and `y_pred`.
393
394  `loss = 100 * abs(y_true - y_pred) / y_true`
395
396  Standalone usage:
397
398  >>> y_true = [[2., 1.], [2., 3.]]
399  >>> y_pred = [[1., 1.], [1., 0.]]
400  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
401  >>> mape = tf.keras.losses.MeanAbsolutePercentageError()
402  >>> mape(y_true, y_pred).numpy()
403  50.
404
405  >>> # Calling with 'sample_weight'.
406  >>> mape(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
407  20.
408
409  >>> # Using 'sum' reduction type.
410  >>> mape = tf.keras.losses.MeanAbsolutePercentageError(
411  ...     reduction=tf.keras.losses.Reduction.SUM)
412  >>> mape(y_true, y_pred).numpy()
413  100.
414
415  >>> # Using 'none' reduction type.
416  >>> mape = tf.keras.losses.MeanAbsolutePercentageError(
417  ...     reduction=tf.keras.losses.Reduction.NONE)
418  >>> mape(y_true, y_pred).numpy()
419  array([25., 75.], dtype=float32)
420
421  Usage with the `compile()` API:
422
423  ```python
424  model.compile(optimizer='sgd',
425                loss=tf.keras.losses.MeanAbsolutePercentageError())
426  ```
427  """
428
429  def __init__(self,
430               reduction=losses_utils.ReductionV2.AUTO,
431               name='mean_absolute_percentage_error'):
432    """Initializes `MeanAbsolutePercentageError` instance.
433
434    Args:
435      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
436        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
437        option will be determined by the usage context. For almost all cases
438        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
439        `tf.distribute.Strategy`, outside of built-in training loops such as
440        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
441        will raise an error. Please see this custom training [tutorial](
442          https://www.tensorflow.org/tutorials/distribute/custom_training) for
443            more details.
444      name: Optional name for the op. Defaults to
445        'mean_absolute_percentage_error'.
446    """
447    super(MeanAbsolutePercentageError, self).__init__(
448        mean_absolute_percentage_error, name=name, reduction=reduction)
449
450
451@keras_export('keras.losses.MeanSquaredLogarithmicError')
452class MeanSquaredLogarithmicError(LossFunctionWrapper):
453  """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
454
455  `loss = square(log(y_true + 1.) - log(y_pred + 1.))`
456
457  Standalone usage:
458
459  >>> y_true = [[0., 1.], [0., 0.]]
460  >>> y_pred = [[1., 1.], [1., 0.]]
461  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
462  >>> msle = tf.keras.losses.MeanSquaredLogarithmicError()
463  >>> msle(y_true, y_pred).numpy()
464  0.240
465
466  >>> # Calling with 'sample_weight'.
467  >>> msle(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
468  0.120
469
470  >>> # Using 'sum' reduction type.
471  >>> msle = tf.keras.losses.MeanSquaredLogarithmicError(
472  ...     reduction=tf.keras.losses.Reduction.SUM)
473  >>> msle(y_true, y_pred).numpy()
474  0.480
475
476  >>> # Using 'none' reduction type.
477  >>> msle = tf.keras.losses.MeanSquaredLogarithmicError(
478  ...     reduction=tf.keras.losses.Reduction.NONE)
479  >>> msle(y_true, y_pred).numpy()
480  array([0.240, 0.240], dtype=float32)
481
482  Usage with the `compile()` API:
483
484  ```python
485  model.compile(optimizer='sgd',
486                loss=tf.keras.losses.MeanSquaredLogarithmicError())
487  ```
488  """
489
490  def __init__(self,
491               reduction=losses_utils.ReductionV2.AUTO,
492               name='mean_squared_logarithmic_error'):
493    """Initializes `MeanSquaredLogarithmicError` instance.
494
495    Args:
496      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
497        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
498        option will be determined by the usage context. For almost all cases
499        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
500        `tf.distribute.Strategy`, outside of built-in training loops such as
501        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
502        will raise an error. Please see this custom training [tutorial](
503          https://www.tensorflow.org/tutorials/distribute/custom_training) for
504            more details.
505      name: Optional name for the op. Defaults to
506        'mean_squared_logarithmic_error'.
507    """
508    super(MeanSquaredLogarithmicError, self).__init__(
509        mean_squared_logarithmic_error, name=name, reduction=reduction)
510
511
512@keras_export('keras.losses.BinaryCrossentropy')
513class BinaryCrossentropy(LossFunctionWrapper):
514  """Computes the cross-entropy loss between true labels and predicted labels.
515
516  Use this cross-entropy loss for binary (0 or 1) classification applications.
517  The loss function requires the following inputs:
518
519  - `y_true` (true label): This is either 0 or 1.
520  - `y_pred` (predicted value): This is the model's prediction, i.e, a single
521    floating-point value which either represents a
522    [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
523    when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
524    `from_logits=False`).
525
526  **Recommended Usage:** (set `from_logits=True`)
527
528  With `tf.keras` API:
529
530  ```python
531  model.compile(
532    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
533    ....
534  )
535  ```
536
537  As a standalone function:
538
539  >>> # Example 1: (batch_size = 1, number of samples = 4)
540  >>> y_true = [0, 1, 0, 0]
541  >>> y_pred = [-18.6, 0.51, 2.94, -12.8]
542  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
543  >>> bce(y_true, y_pred).numpy()
544  0.865
545
546  >>> # Example 2: (batch_size = 2, number of samples = 4)
547  >>> y_true = [[0, 1], [0, 0]]
548  >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]]
549  >>> # Using default 'auto'/'sum_over_batch_size' reduction type.
550  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
551  >>> bce(y_true, y_pred).numpy()
552  0.865
553  >>> # Using 'sample_weight' attribute
554  >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
555  0.243
556  >>> # Using 'sum' reduction` type.
557  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
558  ...     reduction=tf.keras.losses.Reduction.SUM)
559  >>> bce(y_true, y_pred).numpy()
560  1.730
561  >>> # Using 'none' reduction type.
562  >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
563  ...     reduction=tf.keras.losses.Reduction.NONE)
564  >>> bce(y_true, y_pred).numpy()
565  array([0.235, 1.496], dtype=float32)
566
567  **Default Usage:** (set `from_logits=False`)
568
569  >>> # Make the following updates to the above "Recommended Usage" section
570  >>> # 1. Set `from_logits=False`
571  >>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')
572  >>> # 2. Update `y_pred` to use probabilities instead of logits
573  >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]
574  """
575
576  def __init__(self,
577               from_logits=False,
578               label_smoothing=0,
579               reduction=losses_utils.ReductionV2.AUTO,
580               name='binary_crossentropy'):
581    """Initializes `BinaryCrossentropy` instance.
582
583    Args:
584      from_logits: Whether to interpret `y_pred` as a tensor of
585        [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
586          assume that `y_pred` contains probabilities (i.e., values in [0, 1]).
587          **Note - Using from_logits=True may be more numerically stable.
588      label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > 0,
589        we compute the loss between the predicted labels and a smoothed version
590        of the true labels, where the smoothing squeezes the labels towards 0.5.
591        Larger values of `label_smoothing` correspond to heavier smoothing.
592      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
593        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
594        option will be determined by the usage context. For almost all cases
595        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
596        `tf.distribute.Strategy`, outside of built-in training loops such as
597        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
598        will raise an error. Please see this custom training [tutorial](
599          https://www.tensorflow.org/tutorials/distribute/custom_training) for
600            more details.
601      name: (Optional) Name for the op. Defaults to 'binary_crossentropy'.
602    """
603    super(BinaryCrossentropy, self).__init__(
604        binary_crossentropy,
605        name=name,
606        reduction=reduction,
607        from_logits=from_logits,
608        label_smoothing=label_smoothing)
609    self.from_logits = from_logits
610
611
612@keras_export('keras.losses.CategoricalCrossentropy')
613class CategoricalCrossentropy(LossFunctionWrapper):
614  """Computes the crossentropy loss between the labels and predictions.
615
616  Use this crossentropy loss function when there are two or more label classes.
617  We expect labels to be provided in a `one_hot` representation. If you want to
618  provide labels as integers, please use `SparseCategoricalCrossentropy` loss.
619  There should be `# classes` floating point values per feature.
620
621  In the snippet below, there is `# classes` floating pointing values per
622  example. The shape of both `y_pred` and `y_true` are
623  `[batch_size, num_classes]`.
624
625  Standalone usage:
626
627  >>> y_true = [[0, 1, 0], [0, 0, 1]]
628  >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
629  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
630  >>> cce = tf.keras.losses.CategoricalCrossentropy()
631  >>> cce(y_true, y_pred).numpy()
632  1.177
633
634  >>> # Calling with 'sample_weight'.
635  >>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
636  0.814
637
638  >>> # Using 'sum' reduction type.
639  >>> cce = tf.keras.losses.CategoricalCrossentropy(
640  ...     reduction=tf.keras.losses.Reduction.SUM)
641  >>> cce(y_true, y_pred).numpy()
642  2.354
643
644  >>> # Using 'none' reduction type.
645  >>> cce = tf.keras.losses.CategoricalCrossentropy(
646  ...     reduction=tf.keras.losses.Reduction.NONE)
647  >>> cce(y_true, y_pred).numpy()
648  array([0.0513, 2.303], dtype=float32)
649
650  Usage with the `compile()` API:
651
652  ```python
653  model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalCrossentropy())
654  ```
655  """
656
657  def __init__(self,
658               from_logits=False,
659               label_smoothing=0,
660               reduction=losses_utils.ReductionV2.AUTO,
661               name='categorical_crossentropy'):
662    """Initializes `CategoricalCrossentropy` instance.
663
664    Args:
665      from_logits: Whether `y_pred` is expected to be a logits tensor. By
666        default, we assume that `y_pred` encodes a probability distribution.
667        **Note - Using from_logits=True is more numerically stable.**
668      label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
669        meaning the confidence on label values are relaxed. For example, if
670        `0.1`, use `0.1 / num_classes` for non-target labels and
671        `0.9 + 0.1 / num_classes` for target labels.
672      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
673        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
674        option will be determined by the usage context. For almost all cases
675        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
676        `tf.distribute.Strategy`, outside of built-in training loops such as
677        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
678        will raise an error. Please see this custom training [tutorial](
679          https://www.tensorflow.org/tutorials/distribute/custom_training) for
680            more details.
681      name: Optional name for the op. Defaults to 'categorical_crossentropy'.
682    """
683    super(CategoricalCrossentropy, self).__init__(
684        categorical_crossentropy,
685        name=name,
686        reduction=reduction,
687        from_logits=from_logits,
688        label_smoothing=label_smoothing)
689
690
691@keras_export('keras.losses.SparseCategoricalCrossentropy')
692class SparseCategoricalCrossentropy(LossFunctionWrapper):
693  """Computes the crossentropy loss between the labels and predictions.
694
695  Use this crossentropy loss function when there are two or more label classes.
696  We expect labels to be provided as integers. If you want to provide labels
697  using `one-hot` representation, please use `CategoricalCrossentropy` loss.
698  There should be `# classes` floating point values per feature for `y_pred`
699  and a single floating point value per feature for `y_true`.
700
701  In the snippet below, there is a single floating point value per example for
702  `y_true` and `# classes` floating pointing values per example for `y_pred`.
703  The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
704  `[batch_size, num_classes]`.
705
706  Standalone usage:
707
708  >>> y_true = [1, 2]
709  >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
710  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
711  >>> scce = tf.keras.losses.SparseCategoricalCrossentropy()
712  >>> scce(y_true, y_pred).numpy()
713  1.177
714
715  >>> # Calling with 'sample_weight'.
716  >>> scce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
717  0.814
718
719  >>> # Using 'sum' reduction type.
720  >>> scce = tf.keras.losses.SparseCategoricalCrossentropy(
721  ...     reduction=tf.keras.losses.Reduction.SUM)
722  >>> scce(y_true, y_pred).numpy()
723  2.354
724
725  >>> # Using 'none' reduction type.
726  >>> scce = tf.keras.losses.SparseCategoricalCrossentropy(
727  ...     reduction=tf.keras.losses.Reduction.NONE)
728  >>> scce(y_true, y_pred).numpy()
729  array([0.0513, 2.303], dtype=float32)
730
731  Usage with the `compile()` API:
732
733  ```python
734  model.compile(optimizer='sgd',
735                loss=tf.keras.losses.SparseCategoricalCrossentropy())
736  ```
737  """
738
739  def __init__(self,
740               from_logits=False,
741               reduction=losses_utils.ReductionV2.AUTO,
742               name='sparse_categorical_crossentropy'):
743    """Initializes `SparseCategoricalCrossentropy` instance.
744
745    Args:
746      from_logits: Whether `y_pred` is expected to be a logits tensor. By
747        default, we assume that `y_pred` encodes a probability distribution.
748        **Note - Using from_logits=True may be more numerically stable.
749      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
750        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
751        option will be determined by the usage context. For almost all cases
752        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
753        `tf.distribute.Strategy`, outside of built-in training loops such as
754        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
755        will raise an error. Please see this custom training [tutorial](
756          https://www.tensorflow.org/tutorials/distribute/custom_training) for
757            more details.
758      name: Optional name for the op. Defaults to
759        'sparse_categorical_crossentropy'.
760    """
761    super(SparseCategoricalCrossentropy, self).__init__(
762        sparse_categorical_crossentropy,
763        name=name,
764        reduction=reduction,
765        from_logits=from_logits)
766
767
768@keras_export('keras.losses.Hinge')
769class Hinge(LossFunctionWrapper):
770  """Computes the hinge loss between `y_true` and `y_pred`.
771
772  `loss = maximum(1 - y_true * y_pred, 0)`
773
774  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
775  provided we will convert them to -1 or 1.
776
777  Standalone usage:
778
779  >>> y_true = [[0., 1.], [0., 0.]]
780  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
781  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
782  >>> h = tf.keras.losses.Hinge()
783  >>> h(y_true, y_pred).numpy()
784  1.3
785
786  >>> # Calling with 'sample_weight'.
787  >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy()
788  0.55
789
790  >>> # Using 'sum' reduction type.
791  >>> h = tf.keras.losses.Hinge(
792  ...     reduction=tf.keras.losses.Reduction.SUM)
793  >>> h(y_true, y_pred).numpy()
794  2.6
795
796  >>> # Using 'none' reduction type.
797  >>> h = tf.keras.losses.Hinge(
798  ...     reduction=tf.keras.losses.Reduction.NONE)
799  >>> h(y_true, y_pred).numpy()
800  array([1.1, 1.5], dtype=float32)
801
802  Usage with the `compile()` API:
803
804  ```python
805  model.compile(optimizer='sgd', loss=tf.keras.losses.Hinge())
806  ```
807  """
808
809  def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='hinge'):
810    """Initializes `Hinge` instance.
811
812    Args:
813      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
814        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
815        option will be determined by the usage context. For almost all cases
816        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
817        `tf.distribute.Strategy`, outside of built-in training loops such as
818        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
819        will raise an error. Please see this custom training [tutorial](
820          https://www.tensorflow.org/tutorials/distribute/custom_training) for
821            more details.
822      name: Optional name for the op. Defaults to 'hinge'.
823    """
824    super(Hinge, self).__init__(hinge, name=name, reduction=reduction)
825
826
827@keras_export('keras.losses.SquaredHinge')
828class SquaredHinge(LossFunctionWrapper):
829  """Computes the squared hinge loss between `y_true` and `y_pred`.
830
831  `loss = square(maximum(1 - y_true * y_pred, 0))`
832
833  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
834  provided we will convert them to -1 or 1.
835
836  Standalone usage:
837
838  >>> y_true = [[0., 1.], [0., 0.]]
839  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
840  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
841  >>> h = tf.keras.losses.SquaredHinge()
842  >>> h(y_true, y_pred).numpy()
843  1.86
844
845  >>> # Calling with 'sample_weight'.
846  >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy()
847  0.73
848
849  >>> # Using 'sum' reduction type.
850  >>> h = tf.keras.losses.SquaredHinge(
851  ...     reduction=tf.keras.losses.Reduction.SUM)
852  >>> h(y_true, y_pred).numpy()
853  3.72
854
855  >>> # Using 'none' reduction type.
856  >>> h = tf.keras.losses.SquaredHinge(
857  ...     reduction=tf.keras.losses.Reduction.NONE)
858  >>> h(y_true, y_pred).numpy()
859  array([1.46, 2.26], dtype=float32)
860
861  Usage with the `compile()` API:
862
863  ```python
864  model.compile(optimizer='sgd', loss=tf.keras.losses.SquaredHinge())
865  ```
866  """
867
868  def __init__(self,
869               reduction=losses_utils.ReductionV2.AUTO,
870               name='squared_hinge'):
871    """Initializes `SquaredHinge` instance.
872
873    Args:
874      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
875        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
876        option will be determined by the usage context. For almost all cases
877        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
878        `tf.distribute.Strategy`, outside of built-in training loops such as
879        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
880        will raise an error. Please see this custom training [tutorial](
881          https://www.tensorflow.org/tutorials/distribute/custom_training) for
882            more details.
883      name: Optional name for the op. Defaults to 'squared_hinge'.
884    """
885    super(SquaredHinge, self).__init__(
886        squared_hinge, name=name, reduction=reduction)
887
888
889@keras_export('keras.losses.CategoricalHinge')
890class CategoricalHinge(LossFunctionWrapper):
891  """Computes the categorical hinge loss between `y_true` and `y_pred`.
892
893  `loss = maximum(neg - pos + 1, 0)`
894  where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)`
895
896  Standalone usage:
897
898  >>> y_true = [[0, 1], [0, 0]]
899  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
900  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
901  >>> h = tf.keras.losses.CategoricalHinge()
902  >>> h(y_true, y_pred).numpy()
903  1.4
904
905  >>> # Calling with 'sample_weight'.
906  >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy()
907  0.6
908
909  >>> # Using 'sum' reduction type.
910  >>> h = tf.keras.losses.CategoricalHinge(
911  ...     reduction=tf.keras.losses.Reduction.SUM)
912  >>> h(y_true, y_pred).numpy()
913  2.8
914
915  >>> # Using 'none' reduction type.
916  >>> h = tf.keras.losses.CategoricalHinge(
917  ...     reduction=tf.keras.losses.Reduction.NONE)
918  >>> h(y_true, y_pred).numpy()
919  array([1.2, 1.6], dtype=float32)
920
921  Usage with the `compile()` API:
922
923  ```python
924  model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalHinge())
925  ```
926  """
927
928  def __init__(self,
929               reduction=losses_utils.ReductionV2.AUTO,
930               name='categorical_hinge'):
931    """Initializes `CategoricalHinge` instance.
932
933    Args:
934      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
935        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
936        option will be determined by the usage context. For almost all cases
937        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
938        `tf.distribute.Strategy`, outside of built-in training loops such as
939        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
940        will raise an error. Please see this custom training [tutorial](
941          https://www.tensorflow.org/tutorials/distribute/custom_training) for
942            more details.
943      name: Optional name for the op. Defaults to 'categorical_hinge'.
944    """
945    super(CategoricalHinge, self).__init__(
946        categorical_hinge, name=name, reduction=reduction)
947
948
949@keras_export('keras.losses.Poisson')
950class Poisson(LossFunctionWrapper):
951  """Computes the Poisson loss between `y_true` and `y_pred`.
952
953  `loss = y_pred - y_true * log(y_pred)`
954
955  Standalone usage:
956
957  >>> y_true = [[0., 1.], [0., 0.]]
958  >>> y_pred = [[1., 1.], [0., 0.]]
959  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
960  >>> p = tf.keras.losses.Poisson()
961  >>> p(y_true, y_pred).numpy()
962  0.5
963
964  >>> # Calling with 'sample_weight'.
965  >>> p(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
966  0.4
967
968  >>> # Using 'sum' reduction type.
969  >>> p = tf.keras.losses.Poisson(
970  ...     reduction=tf.keras.losses.Reduction.SUM)
971  >>> p(y_true, y_pred).numpy()
972  0.999
973
974  >>> # Using 'none' reduction type.
975  >>> p = tf.keras.losses.Poisson(
976  ...     reduction=tf.keras.losses.Reduction.NONE)
977  >>> p(y_true, y_pred).numpy()
978  array([0.999, 0.], dtype=float32)
979
980  Usage with the `compile()` API:
981
982  ```python
983  model.compile(optimizer='sgd', loss=tf.keras.losses.Poisson())
984  ```
985  """
986
987  def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='poisson'):
988    """Initializes `Poisson` instance.
989
990    Args:
991      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
992        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
993        option will be determined by the usage context. For almost all cases
994        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
995        `tf.distribute.Strategy`, outside of built-in training loops such as
996        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
997        will raise an error. Please see this custom training [tutorial](
998          https://www.tensorflow.org/tutorials/distribute/custom_training) for
999            more details.
1000      name: Optional name for the op. Defaults to 'poisson'.
1001    """
1002    super(Poisson, self).__init__(poisson, name=name, reduction=reduction)
1003
1004
1005@keras_export('keras.losses.LogCosh')
1006class LogCosh(LossFunctionWrapper):
1007  """Computes the logarithm of the hyperbolic cosine of the prediction error.
1008
1009  `logcosh = log((exp(x) + exp(-x))/2)`,
1010  where x is the error `y_pred - y_true`.
1011
1012  Standalone usage:
1013
1014  >>> y_true = [[0., 1.], [0., 0.]]
1015  >>> y_pred = [[1., 1.], [0., 0.]]
1016  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
1017  >>> l = tf.keras.losses.LogCosh()
1018  >>> l(y_true, y_pred).numpy()
1019  0.108
1020
1021  >>> # Calling with 'sample_weight'.
1022  >>> l(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
1023  0.087
1024
1025  >>> # Using 'sum' reduction type.
1026  >>> l = tf.keras.losses.LogCosh(
1027  ...     reduction=tf.keras.losses.Reduction.SUM)
1028  >>> l(y_true, y_pred).numpy()
1029  0.217
1030
1031  >>> # Using 'none' reduction type.
1032  >>> l = tf.keras.losses.LogCosh(
1033  ...     reduction=tf.keras.losses.Reduction.NONE)
1034  >>> l(y_true, y_pred).numpy()
1035  array([0.217, 0.], dtype=float32)
1036
1037  Usage with the `compile()` API:
1038
1039  ```python
1040  model.compile(optimizer='sgd', loss=tf.keras.losses.LogCosh())
1041  ```
1042  """
1043
1044  def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='log_cosh'):
1045    """Initializes `LogCosh` instance.
1046
1047    Args:
1048      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
1049        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
1050        option will be determined by the usage context. For almost all cases
1051        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
1052        `tf.distribute.Strategy`, outside of built-in training loops such as
1053        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
1054        will raise an error. Please see this custom training [tutorial](
1055          https://www.tensorflow.org/tutorials/distribute/custom_training) for
1056            more details.
1057      name: Optional name for the op. Defaults to 'log_cosh'.
1058    """
1059    super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction)
1060
1061
1062@keras_export('keras.losses.KLDivergence')
1063class KLDivergence(LossFunctionWrapper):
1064  """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`.
1065
1066  `loss = y_true * log(y_true / y_pred)`
1067
1068  See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
1069
1070  Standalone usage:
1071
1072  >>> y_true = [[0, 1], [0, 0]]
1073  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
1074  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
1075  >>> kl = tf.keras.losses.KLDivergence()
1076  >>> kl(y_true, y_pred).numpy()
1077  0.458
1078
1079  >>> # Calling with 'sample_weight'.
1080  >>> kl(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
1081  0.366
1082
1083  >>> # Using 'sum' reduction type.
1084  >>> kl = tf.keras.losses.KLDivergence(
1085  ...     reduction=tf.keras.losses.Reduction.SUM)
1086  >>> kl(y_true, y_pred).numpy()
1087  0.916
1088
1089  >>> # Using 'none' reduction type.
1090  >>> kl = tf.keras.losses.KLDivergence(
1091  ...     reduction=tf.keras.losses.Reduction.NONE)
1092  >>> kl(y_true, y_pred).numpy()
1093  array([0.916, -3.08e-06], dtype=float32)
1094
1095  Usage with the `compile()` API:
1096
1097  ```python
1098  model.compile(optimizer='sgd', loss=tf.keras.losses.KLDivergence())
1099  ```
1100  """
1101
1102  def __init__(self,
1103               reduction=losses_utils.ReductionV2.AUTO,
1104               name='kl_divergence'):
1105    """Initializes `KLDivergence` instance.
1106
1107    Args:
1108      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
1109        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
1110        option will be determined by the usage context. For almost all cases
1111        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
1112        `tf.distribute.Strategy`, outside of built-in training loops such as
1113        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
1114        will raise an error. Please see this custom training [tutorial](
1115          https://www.tensorflow.org/tutorials/distribute/custom_training) for
1116            more details.
1117      name: Optional name for the op. Defaults to 'kl_divergence'.
1118    """
1119    super(KLDivergence, self).__init__(
1120        kl_divergence, name=name, reduction=reduction)
1121
1122
1123@keras_export('keras.losses.Huber')
1124class Huber(LossFunctionWrapper):
1125  """Computes the Huber loss between `y_true` and `y_pred`.
1126
1127  For each value x in `error = y_true - y_pred`:
1128
1129  ```
1130  loss = 0.5 * x^2                  if |x| <= d
1131  loss = 0.5 * d^2 + d * (|x| - d)  if |x| > d
1132  ```
1133  where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss
1134
1135  Standalone usage:
1136
1137  >>> y_true = [[0, 1], [0, 0]]
1138  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
1139  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
1140  >>> h = tf.keras.losses.Huber()
1141  >>> h(y_true, y_pred).numpy()
1142  0.155
1143
1144  >>> # Calling with 'sample_weight'.
1145  >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy()
1146  0.09
1147
1148  >>> # Using 'sum' reduction type.
1149  >>> h = tf.keras.losses.Huber(
1150  ...     reduction=tf.keras.losses.Reduction.SUM)
1151  >>> h(y_true, y_pred).numpy()
1152  0.31
1153
1154  >>> # Using 'none' reduction type.
1155  >>> h = tf.keras.losses.Huber(
1156  ...     reduction=tf.keras.losses.Reduction.NONE)
1157  >>> h(y_true, y_pred).numpy()
1158  array([0.18, 0.13], dtype=float32)
1159
1160  Usage with the `compile()` API:
1161
1162  ```python
1163  model.compile(optimizer='sgd', loss=tf.keras.losses.Huber())
1164  ```
1165  """
1166
1167  def __init__(self,
1168               delta=1.0,
1169               reduction=losses_utils.ReductionV2.AUTO,
1170               name='huber_loss'):
1171    """Initializes `Huber` instance.
1172
1173    Args:
1174      delta: A float, the point where the Huber loss function changes from a
1175        quadratic to linear.
1176      reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
1177        loss. Default value is `AUTO`. `AUTO` indicates that the reduction
1178        option will be determined by the usage context. For almost all cases
1179        this defaults to `SUM_OVER_BATCH_SIZE`. When used with
1180        `tf.distribute.Strategy`, outside of built-in training loops such as
1181        `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
1182        will raise an error. Please see this custom training [tutorial](
1183          https://www.tensorflow.org/tutorials/distribute/custom_training) for
1184            more details.
1185      name: Optional name for the op. Defaults to 'huber_loss'.
1186    """
1187    super(Huber, self).__init__(
1188        huber, name=name, reduction=reduction, delta=delta)
1189
1190
1191@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse',
1192              'keras.metrics.MSE', 'keras.losses.mean_squared_error',
1193              'keras.losses.mse', 'keras.losses.MSE')
1194@dispatch.add_dispatch_support
1195def mean_squared_error(y_true, y_pred):
1196  """Computes the mean squared error between labels and predictions.
1197
1198  After computing the squared distance between the inputs, the mean value over
1199  the last dimension is returned.
1200
1201  `loss = mean(square(y_true - y_pred), axis=-1)`
1202
1203  Standalone usage:
1204
1205  >>> y_true = np.random.randint(0, 2, size=(2, 3))
1206  >>> y_pred = np.random.random(size=(2, 3))
1207  >>> loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
1208  >>> assert loss.shape == (2,)
1209  >>> assert np.array_equal(
1210  ...     loss.numpy(), np.mean(np.square(y_true - y_pred), axis=-1))
1211
1212  Args:
1213    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1214    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1215
1216  Returns:
1217    Mean squared error values. shape = `[batch_size, d0, .. dN-1]`.
1218  """
1219  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1220  y_true = math_ops.cast(y_true, y_pred.dtype)
1221  return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
1222
1223
1224def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
1225  """Apply a loss function on a per batch basis.
1226
1227  Args:
1228    loss_fn: The loss function
1229    y_true: truth values (RaggedTensor)
1230    y_pred: predicted values (RaggedTensor)
1231
1232  Returns:
1233    Loss-function result. A dense tensor if the output has a single dimension
1234    (per-batch loss value); a ragged tensor otherwise.
1235  """
1236
1237  def rt_is_equiv_dense(rt):
1238    """Returns true if this RaggedTensor has the same row_lenghts across
1239
1240       all ragged dimensions and thus can be converted to a dense tensor
1241       without loss of information.
1242
1243    Args:
1244      rt: RaggedTensor
1245    """
1246    return math_ops.reduce_all([
1247        math_ops.equal(
1248            math_ops.reduce_variance(math_ops.cast(row_lens, K.floatx())),
1249            constant_op.constant([0.])) for row_lens in rt.nested_row_lengths()
1250    ])
1251
1252  def _convert_to_dense(inputs):
1253    return tuple(rt.to_tensor() for rt in inputs)
1254
1255  def _wrapper(inputs):
1256    _, y_pred = inputs
1257    if isinstance(y_pred, ragged_tensor.RaggedTensor):
1258      return control_flow_ops.cond(
1259          rt_is_equiv_dense(y_pred),
1260          lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
1261
1262    return loss_fn(*inputs)
1263
1264  lshape = y_pred.shape.as_list()[1:-1]
1265  if len(lshape) > 0:
1266    spec = ragged_tensor.RaggedTensorSpec(shape=lshape, dtype=y_pred.dtype)
1267  else:
1268    spec = tensor_spec.TensorSpec(shape=[], dtype=y_pred.dtype)
1269
1270  nested_splits_list = [rt.nested_row_splits for rt in (y_true, y_pred)]
1271  assertion_list = ragged_util.assert_splits_match(nested_splits_list)
1272  with ops.control_dependencies(assertion_list):
1273    return ragged_map_ops.map_fn(_wrapper, elems=(y_true, y_pred), dtype=spec)
1274
1275
1276@dispatch.dispatch_for_types(mean_squared_error, ragged_tensor.RaggedTensor)
1277def _ragged_tensor_mse(y_true, y_pred):
1278  """ Implements support for handling RaggedTensors.
1279
1280  Args:
1281    y_true: RaggedTensor truth values. shape = `[batch_size, d0, .. dN]`.
1282    y_pred: RaggedTensor predicted values. shape = `[batch_size, d0, .. dN]`.
1283
1284  Returns:
1285    Mean squared error values. shape = `[batch_size, d0, .. dN-1]`.
1286    When the number of dimensions of the batch feature vector [d0, .. dN] is
1287    greater than one the return value is a RaggedTensor. Otherwise a Dense
1288    tensor with dimensions [batch_size] is returned.
1289  """
1290  return _ragged_tensor_apply_loss(mean_squared_error, y_true, y_pred)
1291
1292
1293@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae',
1294              'keras.metrics.MAE', 'keras.losses.mean_absolute_error',
1295              'keras.losses.mae', 'keras.losses.MAE')
1296@dispatch.add_dispatch_support
1297def mean_absolute_error(y_true, y_pred):
1298  """Computes the mean absolute error between labels and predictions.
1299
1300  `loss = mean(abs(y_true - y_pred), axis=-1)`
1301
1302  Standalone usage:
1303
1304  >>> y_true = np.random.randint(0, 2, size=(2, 3))
1305  >>> y_pred = np.random.random(size=(2, 3))
1306  >>> loss = tf.keras.losses.mean_absolute_error(y_true, y_pred)
1307  >>> assert loss.shape == (2,)
1308  >>> assert np.array_equal(
1309  ...     loss.numpy(), np.mean(np.abs(y_true - y_pred), axis=-1))
1310
1311  Args:
1312    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1313    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1314
1315  Returns:
1316    Mean absolute error values. shape = `[batch_size, d0, .. dN-1]`.
1317  """
1318  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1319  y_true = math_ops.cast(y_true, y_pred.dtype)
1320  return K.mean(math_ops.abs(y_pred - y_true), axis=-1)
1321
1322
1323@dispatch.dispatch_for_types(mean_absolute_error, ragged_tensor.RaggedTensor)
1324def _ragged_tensor_mae(y_true, y_pred):
1325  """ RaggedTensor adapter for mean_absolute_error"""
1326  return _ragged_tensor_apply_loss(mean_absolute_error, y_true, y_pred)
1327
1328
1329@keras_export('keras.metrics.mean_absolute_percentage_error',
1330              'keras.metrics.mape', 'keras.metrics.MAPE',
1331              'keras.losses.mean_absolute_percentage_error',
1332              'keras.losses.mape', 'keras.losses.MAPE')
1333@dispatch.add_dispatch_support
1334def mean_absolute_percentage_error(y_true, y_pred):
1335  """Computes the mean absolute percentage error between `y_true` and `y_pred`.
1336
1337  `loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)`
1338
1339  Standalone usage:
1340
1341  >>> y_true = np.random.random(size=(2, 3))
1342  >>> y_true = np.maximum(y_true, 1e-7)  # Prevent division by zero
1343  >>> y_pred = np.random.random(size=(2, 3))
1344  >>> loss = tf.keras.losses.mean_absolute_percentage_error(y_true, y_pred)
1345  >>> assert loss.shape == (2,)
1346  >>> assert np.array_equal(
1347  ...     loss.numpy(),
1348  ...     100. * np.mean(np.abs((y_true - y_pred) / y_true), axis=-1))
1349
1350  Args:
1351    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1352    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1353
1354  Returns:
1355    Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`.
1356  """
1357  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1358  y_true = math_ops.cast(y_true, y_pred.dtype)
1359  diff = math_ops.abs(
1360      (y_true - y_pred) / K.maximum(math_ops.abs(y_true), K.epsilon()))
1361  return 100. * K.mean(diff, axis=-1)
1362
1363
1364@keras_export('keras.metrics.mean_squared_logarithmic_error',
1365              'keras.metrics.msle', 'keras.metrics.MSLE',
1366              'keras.losses.mean_squared_logarithmic_error',
1367              'keras.losses.msle', 'keras.losses.MSLE')
1368@dispatch.add_dispatch_support
1369def mean_squared_logarithmic_error(y_true, y_pred):
1370  """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
1371
1372  `loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)`
1373
1374  Standalone usage:
1375
1376  >>> y_true = np.random.randint(0, 2, size=(2, 3))
1377  >>> y_pred = np.random.random(size=(2, 3))
1378  >>> loss = tf.keras.losses.mean_squared_logarithmic_error(y_true, y_pred)
1379  >>> assert loss.shape == (2,)
1380  >>> y_true = np.maximum(y_true, 1e-7)
1381  >>> y_pred = np.maximum(y_pred, 1e-7)
1382  >>> assert np.allclose(
1383  ...     loss.numpy(),
1384  ...     np.mean(
1385  ...         np.square(np.log(y_true + 1.) - np.log(y_pred + 1.)), axis=-1))
1386
1387  Args:
1388    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1389    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1390
1391  Returns:
1392    Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`.
1393  """
1394  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1395  y_true = math_ops.cast(y_true, y_pred.dtype)
1396  first_log = math_ops.log(K.maximum(y_pred, K.epsilon()) + 1.)
1397  second_log = math_ops.log(K.maximum(y_true, K.epsilon()) + 1.)
1398  return K.mean(math_ops.squared_difference(first_log, second_log), axis=-1)
1399
1400
1401def _maybe_convert_labels(y_true):
1402  """Converts binary labels into -1/1."""
1403  are_zeros = math_ops.equal(y_true, 0)
1404  are_ones = math_ops.equal(y_true, 1)
1405  is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))
1406
1407  def _convert_binary_labels():
1408    # Convert the binary labels to -1 or 1.
1409    return 2. * y_true - 1.
1410
1411  updated_y_true = smart_cond.smart_cond(is_binary, _convert_binary_labels,
1412                                         lambda: y_true)
1413  return updated_y_true
1414
1415
1416@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
1417@dispatch.add_dispatch_support
1418def squared_hinge(y_true, y_pred):
1419  """Computes the squared hinge loss between `y_true` and `y_pred`.
1420
1421  `loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)`
1422
1423  Standalone usage:
1424
1425  >>> y_true = np.random.choice([-1, 1], size=(2, 3))
1426  >>> y_pred = np.random.random(size=(2, 3))
1427  >>> loss = tf.keras.losses.squared_hinge(y_true, y_pred)
1428  >>> assert loss.shape == (2,)
1429  >>> assert np.array_equal(
1430  ...     loss.numpy(),
1431  ...     np.mean(np.square(np.maximum(1. - y_true * y_pred, 0.)), axis=-1))
1432
1433  Args:
1434    y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
1435      If binary (0 or 1) labels are provided we will convert them to -1 or 1.
1436      shape = `[batch_size, d0, .. dN]`.
1437    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1438
1439  Returns:
1440     Squared hinge loss values. shape = `[batch_size, d0, .. dN-1]`.
1441  """
1442  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1443  y_true = math_ops.cast(y_true, y_pred.dtype)
1444  y_true = _maybe_convert_labels(y_true)
1445  return K.mean(
1446      math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1)
1447
1448
1449@keras_export('keras.metrics.hinge', 'keras.losses.hinge')
1450@dispatch.add_dispatch_support
1451def hinge(y_true, y_pred):
1452  """Computes the hinge loss between `y_true` and `y_pred`.
1453
1454  `loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)`
1455
1456  Standalone usage:
1457
1458  >>> y_true = np.random.choice([-1, 1], size=(2, 3))
1459  >>> y_pred = np.random.random(size=(2, 3))
1460  >>> loss = tf.keras.losses.hinge(y_true, y_pred)
1461  >>> assert loss.shape == (2,)
1462  >>> assert np.array_equal(
1463  ...     loss.numpy(),
1464  ...     np.mean(np.maximum(1. - y_true * y_pred, 0.), axis=-1))
1465
1466  Args:
1467    y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
1468      If binary (0 or 1) labels are provided they will be converted to -1 or 1.
1469      shape = `[batch_size, d0, .. dN]`.
1470    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1471
1472  Returns:
1473    Hinge loss values. shape = `[batch_size, d0, .. dN-1]`.
1474  """
1475  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1476  y_true = math_ops.cast(y_true, y_pred.dtype)
1477  y_true = _maybe_convert_labels(y_true)
1478  return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1)
1479
1480
1481@keras_export('keras.losses.categorical_hinge')
1482@dispatch.add_dispatch_support
1483def categorical_hinge(y_true, y_pred):
1484  """Computes the categorical hinge loss between `y_true` and `y_pred`.
1485
1486  `loss = maximum(neg - pos + 1, 0)`
1487  where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)`
1488
1489  Standalone usage:
1490
1491  >>> y_true = np.random.randint(0, 3, size=(2,))
1492  >>> y_true = tf.keras.utils.to_categorical(y_true, num_classes=3)
1493  >>> y_pred = np.random.random(size=(2, 3))
1494  >>> loss = tf.keras.losses.categorical_hinge(y_true, y_pred)
1495  >>> assert loss.shape == (2,)
1496  >>> pos = np.sum(y_true * y_pred, axis=-1)
1497  >>> neg = np.amax((1. - y_true) * y_pred, axis=-1)
1498  >>> assert np.array_equal(loss.numpy(), np.maximum(0., neg - pos + 1.))
1499
1500  Args:
1501    y_true: The ground truth values. `y_true` values are expected to be
1502    either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor).
1503    y_pred: The predicted values.
1504
1505  Returns:
1506    Categorical hinge loss values.
1507  """
1508  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1509  y_true = math_ops.cast(y_true, y_pred.dtype)
1510  pos = math_ops.reduce_sum(y_true * y_pred, axis=-1)
1511  neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1)
1512  zero = math_ops.cast(0., y_pred.dtype)
1513  return math_ops.maximum(neg - pos + 1., zero)
1514
1515
1516@keras_export('keras.losses.huber', v1=[])
1517@dispatch.add_dispatch_support
1518def huber(y_true, y_pred, delta=1.0):
1519  """Computes Huber loss value.
1520
1521  For each value x in `error = y_true - y_pred`:
1522
1523  ```
1524  loss = 0.5 * x^2                  if |x| <= d
1525  loss = 0.5 * d^2 + d * (|x| - d)  if |x| > d
1526  ```
1527  where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss
1528
1529  Args:
1530    y_true: tensor of true targets.
1531    y_pred: tensor of predicted targets.
1532    delta: A float, the point where the Huber loss function changes from a
1533      quadratic to linear.
1534
1535  Returns:
1536    Tensor with one scalar loss entry per sample.
1537  """
1538  y_pred = math_ops.cast(y_pred, dtype=K.floatx())
1539  y_true = math_ops.cast(y_true, dtype=K.floatx())
1540  delta = math_ops.cast(delta, dtype=K.floatx())
1541  error = math_ops.subtract(y_pred, y_true)
1542  abs_error = math_ops.abs(error)
1543  half = ops.convert_to_tensor_v2_with_dispatch(0.5, dtype=abs_error.dtype)
1544  return K.mean(
1545      array_ops.where_v2(
1546          abs_error <= delta, half * math_ops.pow(error, 2),
1547          half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
1548      axis=-1)
1549
1550
1551@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh',
1552              'keras.metrics.log_cosh', 'keras.metrics.logcosh')
1553@dispatch.add_dispatch_support
1554def log_cosh(y_true, y_pred):
1555  """Logarithm of the hyperbolic cosine of the prediction error.
1556
1557  `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
1558  to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
1559  like the mean squared error, but will not be so strongly affected by the
1560  occasional wildly incorrect prediction.
1561
1562  Standalone usage:
1563
1564  >>> y_true = np.random.random(size=(2, 3))
1565  >>> y_pred = np.random.random(size=(2, 3))
1566  >>> loss = tf.keras.losses.logcosh(y_true, y_pred)
1567  >>> assert loss.shape == (2,)
1568  >>> x = y_pred - y_true
1569  >>> assert np.allclose(
1570  ...     loss.numpy(),
1571  ...     np.mean(x + np.log(np.exp(-2. * x) + 1.) - math_ops.log(2.), axis=-1),
1572  ...     atol=1e-5)
1573
1574  Args:
1575    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1576    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1577
1578  Returns:
1579    Logcosh error values. shape = `[batch_size, d0, .. dN-1]`.
1580  """
1581  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1582  y_true = math_ops.cast(y_true, y_pred.dtype)
1583
1584  def _logcosh(x):
1585    return x + nn.softplus(-2. * x) - math_ops.cast(math_ops.log(2.), x.dtype)
1586
1587  return K.mean(_logcosh(y_pred - y_true), axis=-1)
1588
1589
1590@keras_export('keras.metrics.categorical_crossentropy',
1591              'keras.losses.categorical_crossentropy')
1592@dispatch.add_dispatch_support
1593def categorical_crossentropy(y_true,
1594                             y_pred,
1595                             from_logits=False,
1596                             label_smoothing=0):
1597  """Computes the categorical crossentropy loss.
1598
1599  Standalone usage:
1600
1601  >>> y_true = [[0, 1, 0], [0, 0, 1]]
1602  >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
1603  >>> loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
1604  >>> assert loss.shape == (2,)
1605  >>> loss.numpy()
1606  array([0.0513, 2.303], dtype=float32)
1607
1608  Args:
1609    y_true: Tensor of one-hot true targets.
1610    y_pred: Tensor of predicted targets.
1611    from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
1612      we assume that `y_pred` encodes a probability distribution.
1613    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
1614      example, if `0.1`, use `0.1 / num_classes` for non-target labels
1615      and `0.9 + 0.1 / num_classes` for target labels.
1616
1617  Returns:
1618    Categorical crossentropy loss value.
1619  """
1620  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1621  y_true = math_ops.cast(y_true, y_pred.dtype)
1622  label_smoothing = ops.convert_to_tensor_v2_with_dispatch(
1623      label_smoothing, dtype=K.floatx())
1624
1625  def _smooth_labels():
1626    num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype)
1627    return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
1628
1629  y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels,
1630                                 lambda: y_true)
1631  return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
1632
1633
1634@dispatch.dispatch_for_types(categorical_crossentropy,
1635                             ragged_tensor.RaggedTensor)
1636def _ragged_tensor_categorical_crossentropy(y_true,
1637                                            y_pred,
1638                                            from_logits=False,
1639                                            label_smoothing=0):
1640  """ Implements support for handling RaggedTensors.
1641
1642      Expected shape: (batch, sequence_len, n_classes) with sequence_len
1643      being variable per batch.
1644      Return shape: (batch, sequence_len).
1645
1646      When used by CategoricalCrossentropy() with the default reduction
1647      (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
1648      number of elements independent of the batch. E.g. if the RaggedTensor
1649      has 2 batches with [2, 1] values respectivly the resulting loss is
1650      the sum of the individual loss values divided by 3.
1651  """
1652  fn = functools.partial(
1653      categorical_crossentropy,
1654      from_logits=from_logits,
1655      label_smoothing=label_smoothing)
1656  return _ragged_tensor_apply_loss(fn, y_true, y_pred)
1657
1658
1659@keras_export('keras.metrics.sparse_categorical_crossentropy',
1660              'keras.losses.sparse_categorical_crossentropy')
1661@dispatch.add_dispatch_support
1662def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
1663  """Computes the sparse categorical crossentropy loss.
1664
1665  Standalone usage:
1666
1667  >>> y_true = [1, 2]
1668  >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
1669  >>> loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
1670  >>> assert loss.shape == (2,)
1671  >>> loss.numpy()
1672  array([0.0513, 2.303], dtype=float32)
1673
1674  Args:
1675    y_true: Ground truth values.
1676    y_pred: The predicted values.
1677    from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
1678      we assume that `y_pred` encodes a probability distribution.
1679    axis: (Optional) Defaults to -1. The dimension along which the entropy is
1680      computed.
1681
1682  Returns:
1683    Sparse categorical crossentropy loss value.
1684  """
1685  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1686  y_true = math_ops.cast(y_true, y_pred.dtype)
1687  return K.sparse_categorical_crossentropy(
1688      y_true, y_pred, from_logits=from_logits, axis=axis)
1689
1690
1691@keras_export('keras.metrics.binary_crossentropy',
1692              'keras.losses.binary_crossentropy')
1693@dispatch.add_dispatch_support
1694def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
1695  """Computes the binary crossentropy loss.
1696
1697  Standalone usage:
1698
1699  >>> y_true = [[0, 1], [0, 0]]
1700  >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
1701  >>> loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
1702  >>> assert loss.shape == (2,)
1703  >>> loss.numpy()
1704  array([0.916 , 0.714], dtype=float32)
1705
1706  Args:
1707    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1708    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1709    from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
1710      we assume that `y_pred` encodes a probability distribution.
1711    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by
1712      squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing`
1713      for the target class and `0.5 * label_smoothing` for the non-target class.
1714
1715  Returns:
1716    Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
1717  """
1718  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1719  y_true = math_ops.cast(y_true, y_pred.dtype)
1720  label_smoothing = ops.convert_to_tensor_v2_with_dispatch(
1721      label_smoothing, dtype=K.floatx())
1722
1723  def _smooth_labels():
1724    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
1725
1726  y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels,
1727                                 lambda: y_true)
1728  return K.mean(
1729      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
1730
1731
1732@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor)
1733def _ragged_tensor_binary_crossentropy(y_true,
1734                                       y_pred,
1735                                       from_logits=False,
1736                                       label_smoothing=0):
1737  """ Implements support for handling RaggedTensors.
1738
1739      Expected shape: (batch, sequence_len) with sequence_len being variable
1740      per batch.
1741      Return shape: (batch,); returns the per batch mean of the loss values.
1742
1743      When used by BinaryCrossentropy() with the default reduction
1744      (SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over
1745      the number of batches.
1746  """
1747  fn = functools.partial(
1748      binary_crossentropy,
1749      from_logits=from_logits,
1750      label_smoothing=label_smoothing)
1751  return _ragged_tensor_apply_loss(fn, y_true, y_pred)
1752
1753
1754@keras_export('keras.metrics.kl_divergence',
1755              'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld',
1756              'keras.metrics.KLD', 'keras.losses.kl_divergence',
1757              'keras.losses.kullback_leibler_divergence', 'keras.losses.kld',
1758              'keras.losses.KLD')
1759@dispatch.add_dispatch_support
1760def kl_divergence(y_true, y_pred):
1761  """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`.
1762
1763  `loss = y_true * log(y_true / y_pred)`
1764
1765  See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
1766
1767  Standalone usage:
1768
1769  >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float64)
1770  >>> y_pred = np.random.random(size=(2, 3))
1771  >>> loss = tf.keras.losses.kullback_leibler_divergence(y_true, y_pred)
1772  >>> assert loss.shape == (2,)
1773  >>> y_true = tf.keras.backend.clip(y_true, 1e-7, 1)
1774  >>> y_pred = tf.keras.backend.clip(y_pred, 1e-7, 1)
1775  >>> assert np.array_equal(
1776  ...     loss.numpy(), np.sum(y_true * np.log(y_true / y_pred), axis=-1))
1777
1778  Args:
1779    y_true: Tensor of true targets.
1780    y_pred: Tensor of predicted targets.
1781
1782  Returns:
1783    A `Tensor` with loss.
1784
1785  Raises:
1786    TypeError: If `y_true` cannot be cast to the `y_pred.dtype`.
1787  """
1788  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1789  y_true = math_ops.cast(y_true, y_pred.dtype)
1790  y_true = K.clip(y_true, K.epsilon(), 1)
1791  y_pred = K.clip(y_pred, K.epsilon(), 1)
1792  return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1)
1793
1794
1795@keras_export('keras.metrics.poisson', 'keras.losses.poisson')
1796@dispatch.add_dispatch_support
1797def poisson(y_true, y_pred):
1798  """Computes the Poisson loss between y_true and y_pred.
1799
1800  The Poisson loss is the mean of the elements of the `Tensor`
1801  `y_pred - y_true * log(y_pred)`.
1802
1803  Standalone usage:
1804
1805  >>> y_true = np.random.randint(0, 2, size=(2, 3))
1806  >>> y_pred = np.random.random(size=(2, 3))
1807  >>> loss = tf.keras.losses.poisson(y_true, y_pred)
1808  >>> assert loss.shape == (2,)
1809  >>> y_pred = y_pred + 1e-7
1810  >>> assert np.allclose(
1811  ...     loss.numpy(), np.mean(y_pred - y_true * np.log(y_pred), axis=-1),
1812  ...     atol=1e-5)
1813
1814  Args:
1815    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
1816    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
1817
1818  Returns:
1819     Poisson loss value. shape = `[batch_size, d0, .. dN-1]`.
1820
1821  Raises:
1822    InvalidArgumentError: If `y_true` and `y_pred` have incompatible shapes.
1823  """
1824  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
1825  y_true = math_ops.cast(y_true, y_pred.dtype)
1826  return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1)
1827
1828
1829@keras_export(
1830    'keras.losses.cosine_similarity',
1831    v1=[
1832        'keras.metrics.cosine_proximity',
1833        'keras.metrics.cosine',
1834        'keras.losses.cosine_proximity',
1835        'keras.losses.cosine',
1836        'keras.losses.cosine_similarity',
1837    ])
1838@dispatch.add_dispatch_support
1839def cosine_similarity(y_true, y_pred, axis=-1):
1840  """Computes the cosine similarity between labels and predictions.
1841
1842  Note that it is a number between -1 and 1. When it is a negative number
1843  between -1 and 0, 0 indicates orthogonality and values closer to -1
1844  indicate greater similarity. The values closer to 1 indicate greater
1845  dissimilarity. This makes it usable as a loss function in a setting
1846  where you try to maximize the proximity between predictions and
1847  targets. If either `y_true` or `y_pred` is a zero vector, cosine
1848  similarity will be 0 regardless of the proximity between predictions
1849  and targets.
1850
1851  `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))`
1852
1853  Standalone usage:
1854
1855  >>> y_true = [[0., 1.], [1., 1.], [1., 1.]]
1856  >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]
1857  >>> loss = tf.keras.losses.cosine_similarity(y_true, y_pred, axis=1)
1858  >>> loss.numpy()
1859  array([-0., -0.999, 0.999], dtype=float32)
1860
1861  Args:
1862    y_true: Tensor of true targets.
1863    y_pred: Tensor of predicted targets.
1864    axis: Axis along which to determine similarity.
1865
1866  Returns:
1867    Cosine similarity tensor.
1868  """
1869  y_true = nn.l2_normalize(y_true, axis=axis)
1870  y_pred = nn.l2_normalize(y_pred, axis=axis)
1871  return -math_ops.reduce_sum(y_true * y_pred, axis=axis)
1872
1873
1874@keras_export('keras.losses.CosineSimilarity')
1875class CosineSimilarity(LossFunctionWrapper):
1876  """Computes the cosine similarity between labels and predictions.
1877
1878  Note that it is a number between -1 and 1. When it is a negative number
1879  between -1 and 0, 0 indicates orthogonality and values closer to -1
1880  indicate greater similarity. The values closer to 1 indicate greater
1881  dissimilarity. This makes it usable as a loss function in a setting
1882  where you try to maximize the proximity between predictions and targets.
1883  If either `y_true` or `y_pred` is a zero vector, cosine similarity will be 0
1884  regardless of the proximity between predictions and targets.
1885
1886  `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))`
1887
1888  Standalone usage:
1889
1890  >>> y_true = [[0., 1.], [1., 1.]]
1891  >>> y_pred = [[1., 0.], [1., 1.]]
1892  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
1893  >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1)
1894  >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]]
1895  >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]]
1896  >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
1897  >>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
1898  >>> #       = -((0. + 0.) +  (0.5 + 0.5)) / 2
1899  >>> cosine_loss(y_true, y_pred).numpy()
1900  -0.5
1901
1902  >>> # Calling with 'sample_weight'.
1903  >>> cosine_loss(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
1904  -0.0999
1905
1906  >>> # Using 'sum' reduction type.
1907  >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1,
1908  ...     reduction=tf.keras.losses.Reduction.SUM)
1909  >>> cosine_loss(y_true, y_pred).numpy()
1910  -0.999
1911
1912  >>> # Using 'none' reduction type.
1913  >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1,
1914  ...     reduction=tf.keras.losses.Reduction.NONE)
1915  >>> cosine_loss(y_true, y_pred).numpy()
1916  array([-0., -0.999], dtype=float32)
1917
1918  Usage with the `compile()` API:
1919
1920  ```python
1921  model.compile(optimizer='sgd', loss=tf.keras.losses.CosineSimilarity(axis=1))
1922  ```
1923
1924  Args:
1925    axis: (Optional) Defaults to -1. The dimension along which the cosine
1926      similarity is computed.
1927    reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
1928      Default value is `AUTO`. `AUTO` indicates that the reduction option will
1929      be determined by the usage context. For almost all cases this defaults to
1930      `SUM_OVER_BATCH_SIZE`. When used with `tf.distribute.Strategy`, outside of
1931      built-in training loops such as `tf.keras` `compile` and `fit`, using
1932      `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this
1933      custom training [tutorial]
1934      (https://www.tensorflow.org/tutorials/distribute/custom_training) for more
1935        details.
1936    name: Optional name for the op.
1937  """
1938
1939  def __init__(self,
1940               axis=-1,
1941               reduction=losses_utils.ReductionV2.AUTO,
1942               name='cosine_similarity'):
1943    super(CosineSimilarity, self).__init__(
1944        cosine_similarity, reduction=reduction, name=name, axis=axis)
1945
1946
1947# Aliases.
1948
1949bce = BCE = binary_crossentropy
1950mse = MSE = mean_squared_error
1951mae = MAE = mean_absolute_error
1952mape = MAPE = mean_absolute_percentage_error
1953msle = MSLE = mean_squared_logarithmic_error
1954kld = KLD = kullback_leibler_divergence = kl_divergence
1955logcosh = log_cosh
1956huber_loss = huber
1957
1958
1959def is_categorical_crossentropy(loss):
1960  result = ((isinstance(loss, CategoricalCrossentropy) or
1961             (isinstance(loss, LossFunctionWrapper) and
1962              loss.fn == categorical_crossentropy) or
1963             (hasattr(loss, '__name__') and
1964              loss.__name__ == 'categorical_crossentropy') or
1965             (loss == 'categorical_crossentropy')))
1966  return result
1967
1968
1969@keras_export('keras.losses.serialize')
1970def serialize(loss):
1971  """Serializes loss function or `Loss` instance.
1972
1973  Args:
1974    loss: A Keras `Loss` instance or a loss function.
1975
1976  Returns:
1977    Loss configuration dictionary.
1978  """
1979  return serialize_keras_object(loss)
1980
1981
1982@keras_export('keras.losses.deserialize')
1983def deserialize(name, custom_objects=None):
1984  """Deserializes a serialized loss class/function instance.
1985
1986  Args:
1987      name: Loss configuration.
1988      custom_objects: Optional dictionary mapping names (strings) to custom
1989        objects (classes and functions) to be considered during deserialization.
1990
1991  Returns:
1992      A Keras `Loss` instance or a loss function.
1993  """
1994  return deserialize_keras_object(
1995      name,
1996      module_objects=globals(),
1997      custom_objects=custom_objects,
1998      printable_module_name='loss function')
1999
2000
2001@keras_export('keras.losses.get')
2002def get(identifier):
2003  """Retrieves a Keras loss as a `function`/`Loss` class instance.
2004
2005  The `identifier` may be the string name of a loss function or `Loss` class.
2006
2007  >>> loss = tf.keras.losses.get("categorical_crossentropy")
2008  >>> type(loss)
2009  <class 'function'>
2010  >>> loss = tf.keras.losses.get("CategoricalCrossentropy")
2011  >>> type(loss)
2012  <class '...tensorflow.python.keras.losses.CategoricalCrossentropy'>
2013
2014  You can also specify `config` of the loss to this function by passing dict
2015  containing `class_name` and `config` as an identifier. Also note that the
2016  `class_name` must map to a `Loss` class
2017
2018  >>> identifier = {"class_name": "CategoricalCrossentropy",
2019  ...               "config": {"from_logits": True}}
2020  >>> loss = tf.keras.losses.get(identifier)
2021  >>> type(loss)
2022  <class '...tensorflow.python.keras.losses.CategoricalCrossentropy'>
2023
2024  Args:
2025    identifier: A loss identifier. One of None or string name of a loss
2026      function/class or loss configuration dictionary or a loss function or a
2027      loss class instance
2028
2029  Returns:
2030    A Keras loss as a `function`/ `Loss` class instance.
2031
2032  Raises:
2033    ValueError: If `identifier` cannot be interpreted.
2034  """
2035  if identifier is None:
2036    return None
2037  if isinstance(identifier, six.string_types):
2038    identifier = str(identifier)
2039    return deserialize(identifier)
2040  if isinstance(identifier, dict):
2041    return deserialize(identifier)
2042  elif callable(identifier):
2043    return identifier
2044  else:
2045    raise ValueError(
2046        'Could not interpret loss function identifier: {}'.format(identifier))
2047
2048
2049LABEL_DTYPES_FOR_LOSSES = {
2050    losses_impl.sparse_softmax_cross_entropy: 'int32',
2051    sparse_categorical_crossentropy: 'int32'
2052}
2053