• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Loss operations for use in neural networks."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import confusion_matrix
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.ops import weights_broadcast_ops
31from tensorflow.python.ops.losses import util
32from tensorflow.python.util import dispatch
33from tensorflow.python.util.deprecation import deprecated_args
34from tensorflow.python.util.deprecation import deprecated_argument_lookup
35from tensorflow.python.util.tf_export import tf_export
36
37
38@tf_export(v1=["losses.Reduction"])
39class Reduction(object):
40  """Types of loss reduction.
41
42  Contains the following values:
43
44  * `NONE`: Un-reduced weighted losses with the same shape as input.
45  * `SUM`: Scalar sum of weighted losses.
46  * `MEAN`: Scalar `SUM` divided by sum of weights. DEPRECATED.
47  * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
48  * `SUM_OVER_NONZERO_WEIGHTS`: Scalar `SUM` divided by number of non-zero
49     weights. DEPRECATED.
50  * `SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`. DEPRECATED.
51  """
52
53  NONE = "none"
54  SUM = "weighted_sum"
55  SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
56  MEAN = "weighted_mean"
57  SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
58  SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS
59
60  @classmethod
61  def all(cls):
62    return (
63        cls.NONE,
64        cls.SUM,
65        cls.MEAN,
66        cls.SUM_OVER_BATCH_SIZE,
67        cls.SUM_OVER_NONZERO_WEIGHTS,
68        cls.SUM_BY_NONZERO_WEIGHTS)
69
70  @classmethod
71  def validate(cls, key):
72    if key not in cls.all():
73      raise ValueError("Invalid Reduction Key %s." % key)
74
75
76def _safe_mean(losses, num_present):
77  """Computes a safe mean of the losses.
78
79  Args:
80    losses: `Tensor` whose elements contain individual loss measurements.
81    num_present: The number of measurable elements in `losses`.
82
83  Returns:
84    A scalar representing the mean of `losses`. If `num_present` is zero,
85      then zero is returned.
86  """
87  total_loss = math_ops.reduce_sum(losses)
88  return math_ops.div_no_nan(total_loss, num_present, name="value")
89
90
91def _num_present(losses, weights, per_batch=False):
92  """Computes the number of elements in the loss function induced by `weights`.
93
94  A given weights tensor induces different numbers of usable elements in the
95  `losses` tensor. The `weights` tensor is broadcast across `losses` for all
96  possible dimensions. For example, if `losses` is a tensor of dimension
97  `[4, 5, 6, 3]` and `weights` is a tensor of shape `[4, 5]`, then `weights` is,
98  in effect, tiled to match the shape of `losses`. Following this effective
99  tile, the total number of present elements is the number of non-zero weights.
100
101  Args:
102    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
103    weights: `Tensor` of shape `[]`, `[batch_size]` or
104      `[batch_size, d1, ... dK]`, where K < N.
105    per_batch: Whether to return the number of elements per batch or as a sum
106      total.
107
108  Returns:
109    The number of present (non-zero) elements in the losses tensor. If
110      `per_batch` is `True`, the value is returned as a tensor of size
111      `[batch_size]`. Otherwise, a single scalar tensor is returned.
112  """
113  if ((isinstance(weights, float) and weights != 0.0) or
114      (context.executing_eagerly() and weights._rank() == 0  # pylint: disable=protected-access
115       and not math_ops.equal(weights, 0.0))):
116    return _num_elements(losses)
117  with ops.name_scope(None, "num_present", (losses, weights)) as scope:
118    weights = math_ops.cast(weights, dtype=dtypes.float32)
119    present = array_ops.where(
120        math_ops.equal(weights, 0.0),
121        array_ops.zeros_like(weights),
122        array_ops.ones_like(weights))
123    present = weights_broadcast_ops.broadcast_weights(present, losses)
124    if per_batch:
125      return math_ops.reduce_sum(
126          present,
127          axis=math_ops.range(1, array_ops.rank(present)),
128          keepdims=True,
129          name=scope)
130    return math_ops.reduce_sum(present, name=scope)
131
132
133def _num_elements(losses):
134  """Computes the number of elements in `losses` tensor."""
135  with ops.name_scope(None, "num_elements", values=[losses]) as scope:
136    return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
137
138
139@tf_export(v1=["losses.compute_weighted_loss"])
140@dispatch.add_dispatch_support
141def compute_weighted_loss(
142    losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
143    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
144  """Computes the weighted loss.
145
146  Args:
147    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
148    weights: Optional `Tensor` whose rank is either 0, or the same rank as
149      `losses`, and must be broadcastable to `losses` (i.e., all dimensions must
150      be either `1`, or the same as the corresponding `losses` dimension).
151    scope: the scope for the operations performed in computing the loss.
152    loss_collection: the loss will be added to these collections.
153    reduction: Type of reduction to apply to loss.
154
155  Returns:
156    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
157    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
158
159  Raises:
160    ValueError: If `weights` is `None` or the shape is not compatible with
161      `losses`, or if the number of dimensions (rank) of either `losses` or
162      `weights` is missing.
163
164  Note:
165    When calculating the gradient of a weighted loss contributions from
166    both `losses` and `weights` are considered. If your `weights` depend
167    on some model parameters but you do not want this to affect the loss
168    gradient, you need to apply `tf.stop_gradient` to `weights` before
169    passing them to `compute_weighted_loss`.
170
171  @compatibility(eager)
172  The `loss_collection` argument is ignored when executing eagerly. Consider
173  holding on to the return value or collecting losses via a `tf.keras.Model`.
174  @end_compatibility
175  """
176  Reduction.validate(reduction)
177  with ops.name_scope(scope, "weighted_loss", (losses, weights)):
178    # Save the `reduction` argument for loss normalization when distributing
179    # to multiple replicas. Used only for estimator + v1 optimizer flow.
180    ops.get_default_graph()._last_loss_reduction = reduction  # pylint: disable=protected-access
181
182    def compute_loss(losses, weights, loss_collection, reduction):
183      losses = ops.convert_to_tensor(losses)
184      input_dtype = losses.dtype
185      losses = math_ops.cast(losses, dtype=dtypes.float32)
186      weights = math_ops.cast(weights, dtype=dtypes.float32)
187      weighted_losses = math_ops.multiply(losses, weights)
188      if reduction == Reduction.NONE:
189        loss = weighted_losses
190      else:
191        loss = math_ops.reduce_sum(weighted_losses)
192        if reduction == Reduction.MEAN:
193          loss = _safe_mean(
194              loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
195        elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or
196              reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS):
197          loss = _safe_mean(loss, _num_present(losses, weights))
198        elif reduction == Reduction.SUM_OVER_BATCH_SIZE:
199          loss = _safe_mean(loss, _num_elements(losses))
200
201      # Convert the result back to the input type.
202      loss = math_ops.cast(loss, input_dtype)
203      util.add_loss(loss, loss_collection)
204      return loss
205
206    # Skip the assert_broadcastable in XLA context because asserts are not
207    # supported so it only causes unnecessary ops. Also skip it because it uses
208    # a DenseToDenseSetOperation op that is incompatible with XLA when
209    # the shape(s) are dynamic.
210    if control_flow_ops.get_enclosing_xla_context() is not None:
211      return compute_loss(losses, weights, loss_collection, reduction)
212    else:
213      with ops.control_dependencies(
214          (weights_broadcast_ops.assert_broadcastable(weights, losses),)):
215        return compute_loss(losses, weights, loss_collection, reduction)
216
217
218@tf_export(v1=["losses.absolute_difference"])
219@dispatch.add_dispatch_support
220def absolute_difference(
221    labels, predictions, weights=1.0, scope=None,
222    loss_collection=ops.GraphKeys.LOSSES,
223    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
224  """Adds an Absolute Difference loss to the training procedure.
225
226  `weights` acts as a coefficient for the loss. If a scalar is provided, then
227  the loss is simply scaled by the given value. If `weights` is a `Tensor` of
228  shape `[batch_size]`, then the total loss for each sample of the batch is
229  rescaled by the corresponding element in the `weights` vector. If the shape of
230  `weights` matches the shape of `predictions`, then the loss of each
231  measurable element of `predictions` is scaled by the corresponding value of
232  `weights`.
233
234  Args:
235    labels: The ground truth output tensor, same dimensions as 'predictions'.
236    predictions: The predicted outputs.
237    weights: Optional `Tensor` whose rank is either 0, or the same rank as
238      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
239      be either `1`, or the same as the corresponding `losses` dimension).
240    scope: The scope for the operations performed in computing the loss.
241    loss_collection: collection to which this loss will be added.
242    reduction: Type of reduction to apply to loss.
243
244  Returns:
245    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
246    shape as `labels`; otherwise, it is scalar.
247
248  Raises:
249    ValueError: If the shape of `predictions` doesn't match that of
250      `labels` or if the shape of `weights` is invalid or if `labels`
251      or `predictions` is None.
252
253  @compatibility(eager)
254  The `loss_collection` argument is ignored when executing eagerly. Consider
255  holding on to the return value or collecting losses via a `tf.keras.Model`.
256  @end_compatibility
257  """
258  if labels is None:
259    raise ValueError("labels must not be None.")
260  if predictions is None:
261    raise ValueError("predictions must not be None.")
262  with ops.name_scope(scope, "absolute_difference",
263                      (predictions, labels, weights)) as scope:
264    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
265    labels = math_ops.cast(labels, dtype=dtypes.float32)
266    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
267    losses = math_ops.abs(math_ops.subtract(predictions, labels))
268    return compute_weighted_loss(
269        losses, weights, scope, loss_collection, reduction=reduction)
270
271
272@tf_export(v1=["losses.cosine_distance"])
273@dispatch.add_dispatch_support
274@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
275def cosine_distance(
276    labels, predictions, axis=None, weights=1.0, scope=None,
277    loss_collection=ops.GraphKeys.LOSSES,
278    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS,
279    dim=None):
280  """Adds a cosine-distance loss to the training procedure.
281
282  Note that the function assumes that `predictions` and `labels` are already
283  unit-normalized.
284
285  Args:
286    labels: `Tensor` whose shape matches 'predictions'
287    predictions: An arbitrary matrix.
288    axis: The dimension along which the cosine distance is computed.
289    weights: Optional `Tensor` whose rank is either 0, or the same rank as
290      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
291      be either `1`, or the same as the corresponding `losses` dimension).
292    scope: The scope for the operations performed in computing the loss.
293    loss_collection: collection to which this loss will be added.
294    reduction: Type of reduction to apply to loss.
295    dim: The old (deprecated) name for `axis`.
296
297  Returns:
298    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
299    shape as `labels`; otherwise, it is scalar.
300
301  Raises:
302    ValueError: If `predictions` shape doesn't match `labels` shape, or
303      `axis`, `labels`, `predictions` or `weights` is `None`.
304
305  @compatibility(eager)
306  The `loss_collection` argument is ignored when executing eagerly. Consider
307  holding on to the return value or collecting losses via a `tf.keras.Model`.
308  @end_compatibility
309  """
310  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
311  if axis is None:
312    raise ValueError("You must specify 'axis'.")
313  if labels is None:
314    raise ValueError("labels must not be None.")
315  if predictions is None:
316    raise ValueError("predictions must not be None.")
317  with ops.name_scope(scope, "cosine_distance_loss",
318                      (predictions, labels, weights)) as scope:
319    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
320    labels = math_ops.cast(labels, dtype=dtypes.float32)
321    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
322
323    radial_diffs = math_ops.multiply(predictions, labels)
324    losses = 1 - math_ops.reduce_sum(radial_diffs, axis=(axis,), keepdims=True)
325    return compute_weighted_loss(
326        losses, weights, scope, loss_collection, reduction=reduction)
327
328
329@tf_export(v1=["losses.hinge_loss"])
330@dispatch.add_dispatch_support
331def hinge_loss(labels, logits, weights=1.0, scope=None,
332               loss_collection=ops.GraphKeys.LOSSES,
333               reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
334  """Adds a hinge loss to the training procedure.
335
336  Args:
337    labels: The ground truth output tensor. Its shape should match the shape of
338      logits. The values of the tensor are expected to be 0.0 or 1.0. Internally
339      the {0,1} labels are converted to {-1,1} when calculating the hinge loss.
340    logits: The logits, a float tensor. Note that logits are assumed to be
341      unbounded and 0-centered. A value > 0 (resp. < 0) is considered a positive
342      (resp. negative) binary prediction.
343    weights: Optional `Tensor` whose rank is either 0, or the same rank as
344      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
345      be either `1`, or the same as the corresponding `losses` dimension).
346    scope: The scope for the operations performed in computing the loss.
347    loss_collection: collection to which the loss will be added.
348    reduction: Type of reduction to apply to loss.
349
350  Returns:
351    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
352    shape as `labels`; otherwise, it is scalar.
353
354  Raises:
355    ValueError: If the shapes of `logits` and `labels` don't match or
356      if `labels` or `logits` is None.
357
358  @compatibility(eager)
359  The `loss_collection` argument is ignored when executing eagerly. Consider
360  holding on to the return value or collecting losses via a `tf.keras.Model`.
361  @end_compatibility
362  """
363  if labels is None:
364    raise ValueError("labels must not be None.")
365  if logits is None:
366    raise ValueError("logits must not be None.")
367  with ops.name_scope(scope, "hinge_loss", (logits, labels, weights)) as scope:
368    logits = math_ops.cast(logits, dtype=dtypes.float32)
369    labels = math_ops.cast(labels, dtype=dtypes.float32)
370    logits.get_shape().assert_is_compatible_with(labels.get_shape())
371    # We first need to convert binary labels to -1/1 labels (as floats).
372    all_ones = array_ops.ones_like(labels)
373    labels = math_ops.subtract(2 * labels, all_ones)
374    losses = nn_ops.relu(
375        math_ops.subtract(all_ones, math_ops.multiply(labels, logits)))
376    return compute_weighted_loss(
377        losses, weights, scope, loss_collection, reduction=reduction)
378
379
380@tf_export(v1=["losses.huber_loss"])
381@dispatch.add_dispatch_support
382def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
383               loss_collection=ops.GraphKeys.LOSSES,
384               reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
385  """Adds a [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss) term to the training procedure.
386
387  For each value x in `error=labels-predictions`, the following is calculated:
388
389  ```
390    0.5 * x^2                  if |x| <= d
391    0.5 * d^2 + d * (|x| - d)  if |x| > d
392  ```
393
394  where d is `delta`.
395
396  `weights` acts as a coefficient for the loss. If a scalar is provided, then
397  the loss is simply scaled by the given value. If `weights` is a tensor of size
398  `[batch_size]`, then the total loss for each sample of the batch is rescaled
399  by the corresponding element in the `weights` vector. If the shape of
400  `weights` matches the shape of `predictions`, then the loss of each
401  measurable element of `predictions` is scaled by the corresponding value of
402  `weights`.
403
404  Args:
405    labels: The ground truth output tensor, same dimensions as 'predictions'.
406    predictions: The predicted outputs.
407    weights: Optional `Tensor` whose rank is either 0, or the same rank as
408      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
409      be either `1`, or the same as the corresponding `losses` dimension).
410    delta: `float`, the point where the huber loss function changes from a
411      quadratic to linear.
412    scope: The scope for the operations performed in computing the loss.
413    loss_collection: collection to which the loss will be added.
414    reduction: Type of reduction to apply to loss.
415
416  Returns:
417    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
418    shape as `labels`; otherwise, it is scalar.
419
420  Raises:
421    ValueError: If the shape of `predictions` doesn't match that of `labels` or
422      if the shape of `weights` is invalid.  Also if `labels` or
423     `predictions` is None.
424
425  @compatibility(eager)
426  The `loss_collection` argument is ignored when executing eagerly. Consider
427  holding on to the return value or collecting losses via a `tf.keras.Model`.
428  @end_compatibility
429  """
430  if labels is None:
431    raise ValueError("labels must not be None.")
432  if predictions is None:
433    raise ValueError("predictions must not be None.")
434  with ops.name_scope(scope, "huber_loss",
435                      (predictions, labels, weights)) as scope:
436    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
437    labels = math_ops.cast(labels, dtype=dtypes.float32)
438    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
439    error = math_ops.subtract(predictions, labels)
440    abs_error = math_ops.abs(error)
441    quadratic = math_ops.minimum(abs_error, delta)
442    # The following expression is the same in value as
443    # tf.maximum(abs_error - delta, 0), but importantly the gradient for the
444    # expression when abs_error == delta is 0 (for tf.maximum it would be 1).
445    # This is necessary to avoid doubling the gradient, since there is already a
446    # nonzero contribution to the gradient from the quadratic term.
447    linear = math_ops.subtract(abs_error, quadratic)
448    losses = math_ops.add(
449        math_ops.multiply(
450            ops.convert_to_tensor(0.5, dtype=quadratic.dtype),
451            math_ops.multiply(quadratic, quadratic)),
452        math_ops.multiply(delta, linear))
453    return compute_weighted_loss(
454        losses, weights, scope, loss_collection, reduction=reduction)
455
456
457@tf_export(v1=["losses.log_loss"])
458@dispatch.add_dispatch_support
459def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
460             loss_collection=ops.GraphKeys.LOSSES,
461             reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
462  """Adds a Log Loss term to the training procedure.
463
464  `weights` acts as a coefficient for the loss. If a scalar is provided, then
465  the loss is simply scaled by the given value. If `weights` is a tensor of size
466  `[batch_size]`, then the total loss for each sample of the batch is rescaled
467  by the corresponding element in the `weights` vector. If the shape of
468  `weights` matches the shape of `predictions`, then the loss of each
469  measurable element of `predictions` is scaled by the corresponding value of
470  `weights`.
471
472  Args:
473    labels: The ground truth output tensor, same dimensions as 'predictions'.
474    predictions: The predicted outputs.
475    weights: Optional `Tensor` whose rank is either 0, or the same rank as
476      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
477      be either `1`, or the same as the corresponding `losses` dimension).
478    epsilon: A small increment to add to avoid taking a log of zero.
479    scope: The scope for the operations performed in computing the loss.
480    loss_collection: collection to which the loss will be added.
481    reduction: Type of reduction to apply to loss.
482
483  Returns:
484    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
485    shape as `labels`; otherwise, it is scalar.
486
487  Raises:
488    ValueError: If the shape of `predictions` doesn't match that of `labels` or
489      if the shape of `weights` is invalid.  Also if `labels` or `predictions`
490      is None.
491
492  @compatibility(eager)
493  The `loss_collection` argument is ignored when executing eagerly. Consider
494  holding on to the return value or collecting losses via a `tf.keras.Model`.
495  @end_compatibility
496  """
497  if labels is None:
498    raise ValueError("labels must not be None.")
499  if predictions is None:
500    raise ValueError("predictions must not be None.")
501  with ops.name_scope(scope, "log_loss",
502                      (predictions, labels, weights)) as scope:
503    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
504    labels = math_ops.cast(labels, dtype=dtypes.float32)
505    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
506    losses = -math_ops.multiply(
507        labels,
508        math_ops.log(predictions + epsilon)) - math_ops.multiply(
509            (1 - labels), math_ops.log(1 - predictions + epsilon))
510    return compute_weighted_loss(
511        losses, weights, scope, loss_collection, reduction=reduction)
512
513
514# TODO(b/37208492): Add reduction arg.
515@tf_export(v1=["losses.mean_pairwise_squared_error"])
516@dispatch.add_dispatch_support
517def mean_pairwise_squared_error(
518    labels, predictions, weights=1.0, scope=None,
519    loss_collection=ops.GraphKeys.LOSSES):
520  """Adds a pairwise-errors-squared loss to the training procedure.
521
522  Unlike `mean_squared_error`, which is a measure of the differences between
523  corresponding elements of `predictions` and `labels`,
524  `mean_pairwise_squared_error` is a measure of the differences between pairs of
525  corresponding elements of `predictions` and `labels`.
526
527  For example, if `labels`=[a, b, c] and `predictions`=[x, y, z], there are
528  three pairs of differences are summed to compute the loss:
529    loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3
530
531  Note that since the inputs are of shape `[batch_size, d0, ... dN]`, the
532  corresponding pairs are computed within each batch sample but not across
533  samples within a batch. For example, if `predictions` represents a batch of
534  16 grayscale images of dimension [batch_size, 100, 200], then the set of pairs
535  is drawn from each image, but not across images.
536
537  `weights` acts as a coefficient for the loss. If a scalar is provided, then
538  the loss is simply scaled by the given value. If `weights` is a tensor of size
539  `[batch_size]`, then the total loss for each sample of the batch is rescaled
540  by the corresponding element in the `weights` vector.
541
542  Args:
543    labels: The ground truth output tensor, whose shape must match the shape of
544      `predictions`.
545    predictions: The predicted outputs, a tensor of size
546      `[batch_size, d0, .. dN]` where N+1 is the total number of dimensions in
547      `predictions`.
548    weights: Coefficients for the loss a scalar, a tensor of shape
549      `[batch_size]` or a tensor whose shape matches `predictions`.
550    scope: The scope for the operations performed in computing the loss.
551    loss_collection: collection to which the loss will be added.
552
553  Returns:
554    A scalar `Tensor` that returns the weighted loss.
555
556  Raises:
557    ValueError: If the shape of `predictions` doesn't match that of `labels` or
558      if the shape of `weights` is invalid.  Also if `labels` or `predictions`
559      is None.
560
561  @compatibility(eager)
562  The `loss_collection` argument is ignored when executing eagerly. Consider
563  holding on to the return value or collecting losses via a `tf.keras.Model`.
564  @end_compatibility
565  """
566  if labels is None:
567    raise ValueError("labels must not be None.")
568  if predictions is None:
569    raise ValueError("predictions must not be None.")
570  with ops.name_scope(scope, "mean_pairwise_squared_error",
571                      (predictions, labels, weights)) as scope:
572    weights = math_ops.cast(weights, dtype=dtypes.float32)
573    labels = math_ops.cast(labels, dtype=dtypes.float32)
574
575    def compute_loss(labels, predictions, weights, loss_collection):
576      predictions = math_ops.cast(predictions, dtype=dtypes.float32)
577      predictions.get_shape().assert_is_compatible_with(labels.get_shape())
578
579      diffs = math_ops.subtract(predictions, labels)
580
581      axis = math_ops.range(1, array_ops.rank(diffs))
582
583      sum_squares_diff_per_batch = math_ops.reduce_sum(
584          math_ops.square(diffs), axis=axis, keepdims=True)
585      num_present_per_batch = _num_present(diffs, weights, per_batch=True)
586
587      term1 = 2.0 * math_ops.div_no_nan(
588          sum_squares_diff_per_batch,
589          math_ops.maximum(num_present_per_batch - 1, 0),
590          name="value")
591
592      sum_diff = math_ops.reduce_sum(diffs, axis=axis, keepdims=True)
593      term2 = 2.0 * math_ops.div_no_nan(
594          math_ops.square(sum_diff),
595          math_ops.maximum(
596              math_ops.multiply(num_present_per_batch,
597                                num_present_per_batch - 1), 0),
598          name="value")
599
600      weighted_losses = math_ops.multiply(term1 - term2, weights)
601      loss = math_ops.reduce_sum(weighted_losses)
602
603      mean_loss = array_ops.where(
604          math_ops.reduce_sum(num_present_per_batch) > 0,
605          loss,
606          array_ops.zeros_like(loss),
607          name="value")
608      util.add_loss(mean_loss, loss_collection)
609      return mean_loss
610
611    # Skip the assert_broadcastable in XLA context because asserts are not
612    # supported so it only causes unnecessary ops. Also skip it because it uses
613    # a DenseToDenseSetOperation op that is incompatible with XLA when
614    # the shape(s) are dynamic.
615    if control_flow_ops.get_enclosing_xla_context() is not None:
616      return compute_loss(labels, predictions, weights, loss_collection)
617    else:
618      with ops.control_dependencies(
619          (weights_broadcast_ops.assert_broadcastable(weights, labels),)):
620        return compute_loss(labels, predictions, weights, loss_collection)
621
622
623@tf_export(v1=["losses.mean_squared_error"])
624@dispatch.add_dispatch_support
625def mean_squared_error(
626    labels, predictions, weights=1.0, scope=None,
627    loss_collection=ops.GraphKeys.LOSSES,
628    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
629  """Adds a Sum-of-Squares loss to the training procedure.
630
631  `weights` acts as a coefficient for the loss. If a scalar is provided, then
632  the loss is simply scaled by the given value. If `weights` is a tensor of size
633  `[batch_size]`, then the total loss for each sample of the batch is rescaled
634  by the corresponding element in the `weights` vector. If the shape of
635  `weights` matches the shape of `predictions`, then the loss of each
636  measurable element of `predictions` is scaled by the corresponding value of
637  `weights`.
638
639  Args:
640    labels: The ground truth output tensor, same dimensions as 'predictions'.
641    predictions: The predicted outputs.
642    weights: Optional `Tensor` whose rank is either 0, or the same rank as
643      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
644      be either `1`, or the same as the corresponding `losses` dimension).
645    scope: The scope for the operations performed in computing the loss.
646    loss_collection: collection to which the loss will be added.
647    reduction: Type of reduction to apply to loss.
648
649  Returns:
650    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
651    shape as `labels`; otherwise, it is scalar.
652
653  Raises:
654    ValueError: If the shape of `predictions` doesn't match that of `labels` or
655      if the shape of `weights` is invalid.  Also if `labels` or `predictions`
656      is None.
657
658  @compatibility(TF2)
659
660  `tf.compat.v1.losses.mean_squared_error` is mostly compatible with eager
661  execution and `tf.function`. But, the `loss_collection` argument is
662  ignored when executing eagerly and no loss will be written to the loss
663  collections. You will need to either hold on to the return value manually
664  or rely on `tf.keras.Model` loss tracking.
665
666
667  To switch to native TF2 style, instantiate the
668   `tf.keras.losses.MeanSquaredError` class and call the object instead.
669
670
671  #### Structural Mapping to Native TF2
672
673  Before:
674
675  ```python
676  loss = tf.compat.v1.losses.mean_squared_error(
677    labels=labels,
678    predictions=predictions,
679    weights=weights,
680    reduction=reduction)
681  ```
682
683  After:
684
685  ```python
686  loss_fn = tf.keras.losses.MeanSquaredError(
687    reduction=reduction)
688  loss = loss_fn(
689    y_true=labels,
690    y_pred=predictions,
691    sample_weight=weights)
692  ```
693
694  #### How to Map Arguments
695
696  | TF1 Arg Name          | TF2 Arg Name     | Note                       |
697  | :-------------------- | :--------------- | :------------------------- |
698  | `labels`              | `y_true`         | In `__call__()` method     |
699  | `predictions`         | `y_pred`         | In `__call__()` method     |
700  | `weights`             | `sample_weight`  | In `__call__()` method.    |
701  : : : The shape requirements for `sample_weight` is different from      :
702  : : : `weights`. Please check the [argument definition][api_docs] for   :
703  : : : details.                                                          :
704  | `scope`               | Not supported    | -                          |
705  | `loss_collection`     | Not supported    | Losses should be tracked   |
706  : : : explicitly or with Keras APIs, for example, [add_loss][add_loss], :
707  : : : instead of via collections                                        :
708  | `reduction`           | `reduction`      | In constructor. Value of   |
709  : : : `tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE`,              :
710  : : : `tf.compat.v1.losses.Reduction.SUM`,                              :
711  : : : `tf.compat.v1.losses.Reduction.NONE` in                           :
712  : : : `tf.compat.v1.losses.softmax_cross_entropy` correspond to         :
713  : : : `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE`,                  :
714  : : : `tf.keras.losses.Reduction.SUM`,                                  :
715  : : : `tf.keras.losses.Reduction.NONE`, respectively. If you            :
716  : : : used other value for `reduction`, including the default value     :
717  : : :  `tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS`, there is :
718  : : : no directly corresponding value. Please modify the loss           :
719  : : : implementation manually.                                          :
720
721  [add_loss]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_loss
722  [api_docs]:https://www.tensorflow.org/api_docs/python/tf/keras/losses/MeanSquaredError#__call__
723
724
725  #### Before & After Usage Example
726
727  Before:
728
729  >>> y_true = [1, 2, 3]
730  >>> y_pred = [1, 3, 5]
731  >>> weights = [0, 1, 0.25]
732  >>> # samples with zero-weight are excluded from calculation when `reduction`
733  >>> # argument is set to default value `Reduction.SUM_BY_NONZERO_WEIGHTS`
734  >>> tf.compat.v1.losses.mean_squared_error(
735  ...    labels=y_true,
736  ...    predictions=y_pred,
737  ...    weights=weights).numpy()
738  1.0
739
740  >>> tf.compat.v1.losses.mean_squared_error(
741  ...    labels=y_true,
742  ...    predictions=y_pred,
743  ...    weights=weights,
744  ...    reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE).numpy()
745  0.66667
746
747  After:
748
749  >>> y_true = [[1.0], [2.0], [3.0]]
750  >>> y_pred = [[1.0], [3.0], [5.0]]
751  >>> weights = [1, 1, 0.25]
752  >>> mse = tf.keras.losses.MeanSquaredError(
753  ...    reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
754  >>> mse(y_true=y_true, y_pred=y_pred, sample_weight=weights).numpy()
755  0.66667
756
757  @end_compatibility
758  """
759  if labels is None:
760    raise ValueError("labels must not be None.")
761  if predictions is None:
762    raise ValueError("predictions must not be None.")
763  with ops.name_scope(scope, "mean_squared_error",
764                      (predictions, labels, weights)) as scope:
765    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
766    labels = math_ops.cast(labels, dtype=dtypes.float32)
767    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
768    losses = math_ops.squared_difference(predictions, labels)
769    return compute_weighted_loss(
770        losses, weights, scope, loss_collection, reduction=reduction)
771
772
773@tf_export(v1=["losses.sigmoid_cross_entropy"])
774@dispatch.add_dispatch_support
775def sigmoid_cross_entropy(
776    multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
777    loss_collection=ops.GraphKeys.LOSSES,
778    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
779  """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
780
781  `weights` acts as a coefficient for the loss. If a scalar is provided,
782  then the loss is simply scaled by the given value. If `weights` is a
783  tensor of shape `[batch_size]`, then the loss weights apply to each
784  corresponding sample.
785
786  If `label_smoothing` is nonzero, smooth the labels towards 1/2:
787
788      new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
789                              + 0.5 * label_smoothing
790
791  Args:
792    multi_class_labels: `[batch_size, num_classes]` target integer labels in
793      `{0, 1}`.
794    logits: Float `[batch_size, num_classes]` logits outputs of the network.
795    weights: Optional `Tensor` whose rank is either 0, or the same rank as
796    `multi_class_labels`, and must be broadcastable to `multi_class_labels`
797    (i.e., all dimensions must be either `1`, or the same as the
798    corresponding `losses` dimension).
799    label_smoothing: If greater than `0` then smooth the labels.
800    scope: The scope for the operations performed in computing the loss.
801    loss_collection: collection to which the loss will be added.
802    reduction: Type of reduction to apply to loss.
803
804  Returns:
805    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
806    `NONE`, this has the same shape as `logits`; otherwise, it is scalar.
807
808  Raises:
809    ValueError: If the shape of `logits` doesn't match that of
810      `multi_class_labels` or if the shape of `weights` is invalid, or if
811      `weights` is None.  Also if `multi_class_labels` or `logits` is None.
812
813  @compatibility(eager)
814  The `loss_collection` argument is ignored when executing eagerly. Consider
815  holding on to the return value or collecting losses via a `tf.keras.Model`.
816  @end_compatibility
817  """
818  if multi_class_labels is None:
819    raise ValueError("multi_class_labels must not be None.")
820  if logits is None:
821    raise ValueError("logits must not be None.")
822  with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
823                      (logits, multi_class_labels, weights)) as scope:
824    logits = ops.convert_to_tensor(logits)
825    multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
826    logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
827
828    if label_smoothing > 0:
829      multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
830                            0.5 * label_smoothing)
831
832    losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
833                                                  logits=logits,
834                                                  name="xentropy")
835    return compute_weighted_loss(
836        losses, weights, scope, loss_collection, reduction=reduction)
837
838
839@tf_export(v1=["losses.softmax_cross_entropy"])
840@dispatch.add_dispatch_support
841def softmax_cross_entropy(
842    onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
843    loss_collection=ops.GraphKeys.LOSSES,
844    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
845  r"""Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits_v2.
846
847  `weights` acts as a coefficient for the loss. If a scalar is provided,
848  then the loss is simply scaled by the given value. If `weights` is a
849  tensor of shape `[batch_size]`, then the loss weights apply to each
850  corresponding sample.
851
852  If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
853      new_onehot_labels = onehot_labels * (1 - label_smoothing)
854                          + label_smoothing / num_classes
855
856  Note that `onehot_labels` and `logits` must have the same shape,
857  e.g. `[batch_size, num_classes]`. The shape of `weights` must be
858  broadcastable to loss, whose shape is decided by the shape of `logits`.
859  In case the shape of `logits` is `[batch_size, num_classes]`, loss is
860  a `Tensor` of shape `[batch_size]`.
861
862  Args:
863    onehot_labels: One-hot-encoded labels.
864    logits: Logits outputs of the network.
865    weights: Optional `Tensor` that is broadcastable to loss.
866    label_smoothing: If greater than 0 then smooth the labels.
867    scope: the scope for the operations performed in computing the loss.
868    loss_collection: collection to which the loss will be added.
869    reduction: Type of reduction to apply to loss.
870
871  Returns:
872    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
873    `NONE`, this has shape `[batch_size]`; otherwise, it is scalar.
874
875  Raises:
876    ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
877      or if the shape of `weights` is invalid or if `weights` is None.  Also if
878      `onehot_labels` or `logits` is None.
879
880  @compatibility(TF2)
881
882  `tf.compat.v1.losses.softmax_cross_entropy` is mostly compatible with eager
883  execution and `tf.function`. But, the `loss_collection` argument is
884  ignored when executing eagerly and no loss will be written to the loss
885  collections. You will need to either hold on to the return value manually
886  or rely on `tf.keras.Model` loss tracking.
887
888
889  To switch to native TF2 style, instantiate the
890   `tf.keras.losses.CategoricalCrossentropy` class with `from_logits` set
891  as `True` and call the object instead.
892
893
894  #### Structural Mapping to Native TF2
895
896  Before:
897
898  ```python
899  loss = tf.compat.v1.losses.softmax_cross_entropy(
900    onehot_labels=onehot_labels,
901    logits=logits,
902    weights=weights,
903    label_smoothing=smoothing)
904  ```
905
906  After:
907
908  ```python
909  loss_fn = tf.keras.losses.CategoricalCrossentropy(
910    from_logits=True,
911    label_smoothing=smoothing)
912  loss = loss_fn(
913    y_true=onehot_labels,
914    y_pred=logits,
915    sample_weight=weights)
916  ```
917
918  #### How to Map Arguments
919
920  | TF1 Arg Name          | TF2 Arg Name     | Note                       |
921  | :-------------------- | :--------------- | :------------------------- |
922  |  -                    | `from_logits`    | Set `from_logits` as True  |
923  :                       :                  : to have identical behavior :
924  | `onehot_labels`       | `y_true`         | In `__call__()` method     |
925  | `logits`              | `y_pred`         | In `__call__()` method     |
926  | `weights`             | `sample_weight`  | In `__call__()` method     |
927  | `label_smoothing`     | `label_smoothing`| In constructor             |
928  | `scope`               | Not supported    | -                          |
929  | `loss_collection`     | Not supported    | Losses should be tracked   |
930  :                       :                  : explicitly or with Keras   :
931  :                       :                  : APIs, for example,         :
932  :                       :                  : [add_loss][add_loss],      :
933  :                       :                  : instead of via collections :
934  | `reduction`           | `reduction`      | In constructor. Value of   |
935  : : : `tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE`,              :
936  : : : `tf.compat.v1.losses.Reduction.SUM`,                              :
937  : : : `tf.compat.v1.losses.Reduction.NONE` in                           :
938  : : : `tf.compat.v1.losses.softmax_cross_entropy` correspond to         :
939  : : : `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE`,                  :
940  : : : `tf.keras.losses.Reduction.SUM`,                                  :
941  : : : `tf.keras.losses.Reduction.NONE`, respectively. If you            :
942  : : : used other value for `reduction`, including the default value     :
943  : : :  `tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS`, there is :
944  : : : no directly corresponding value. Please modify the loss           :
945  : : : implementation manually.                                          :
946
947  [add_loss]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_loss
948
949
950  #### Before & After Usage Example
951
952  Before:
953
954  >>> y_true = [[0, 1, 0], [0, 0, 1]]
955  >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
956  >>> weights = [0.3, 0.7]
957  >>> smoothing = 0.2
958  >>> tf.compat.v1.losses.softmax_cross_entropy(y_true, y_pred, weights=weights,
959  ...   label_smoothing=smoothing).numpy()
960  0.57618
961
962  After:
963
964  >>> cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True,
965  ...   label_smoothing=smoothing)
966  >>> cce(y_true, y_pred, sample_weight=weights).numpy()
967  0.57618
968
969  @end_compatibility
970  """
971  if onehot_labels is None:
972    raise ValueError("onehot_labels must not be None.")
973  if logits is None:
974    raise ValueError("logits must not be None.")
975  with ops.name_scope(scope, "softmax_cross_entropy_loss",
976                      (logits, onehot_labels, weights)) as scope:
977    logits = ops.convert_to_tensor(logits)
978    onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
979    logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
980
981    if label_smoothing > 0:
982      num_classes = math_ops.cast(
983          array_ops.shape(onehot_labels)[-1], logits.dtype)
984      smooth_positives = 1.0 - label_smoothing
985      smooth_negatives = label_smoothing / num_classes
986      onehot_labels = onehot_labels * smooth_positives + smooth_negatives
987
988    onehot_labels = array_ops.stop_gradient(
989        onehot_labels, name="labels_stop_gradient")
990    losses = nn.softmax_cross_entropy_with_logits_v2(
991        labels=onehot_labels, logits=logits, name="xentropy")
992
993    return compute_weighted_loss(
994        losses, weights, scope, loss_collection, reduction=reduction)
995
996
997# TODO(ptucker): Merge this with similar method in metrics_impl.
998def _remove_squeezable_dimensions(
999    labels, predictions, weights=None, expected_rank_diff=0):
1000  """Internal version of _remove_squeezable_dimensions which handles weights.
1001
1002  Squeezes `predictions` and `labels` if their ranks differ from expected by
1003  exactly 1.
1004  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
1005
1006  This will use static shape if available. Otherwise, it will add graph
1007  operations, which could result in a performance hit.
1008
1009  Args:
1010    labels: Label values, a `Tensor` whose dimensions match `predictions`.
1011    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
1012    weights: Optional weight `Tensor`. It will be squeezed if it's not scalar,
1013      and its rank is 1 more than the new rank of `labels`.
1014    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
1015
1016  Returns:
1017    Tuple of `predictions`, `labels` and `weights`, possibly with the last
1018    dimension squeezed.
1019  """
1020  labels, predictions = confusion_matrix.remove_squeezable_dimensions(
1021      labels, predictions, expected_rank_diff=expected_rank_diff)
1022
1023  if weights is not None:
1024    weights = ops.convert_to_tensor(weights)
1025    labels_rank = labels.get_shape().ndims
1026    weights_shape = weights.get_shape()
1027    weights_rank = weights_shape.ndims
1028
1029    if (labels_rank is not None) and (weights_rank is not None):
1030      # Use static rank.
1031      rank_diff = weights_rank - labels_rank
1032      if rank_diff == 1:
1033        weights = array_ops.squeeze(weights, [-1])
1034      return labels, predictions, weights
1035
1036    # Use dynamic rank.
1037    rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
1038    if (weights_rank is None) or (
1039        weights_rank > 0 and weights_shape.dims[-1].is_compatible_with(1)):
1040      weights = control_flow_ops.cond(
1041          math_ops.equal(1, rank_diff),
1042          lambda: array_ops.squeeze(weights, [-1]),
1043          lambda: weights)
1044
1045  return labels, predictions, weights
1046
1047
1048@tf_export(v1=["losses.sparse_softmax_cross_entropy"])
1049@dispatch.add_dispatch_support
1050def sparse_softmax_cross_entropy(
1051    labels, logits, weights=1.0, scope=None,
1052    loss_collection=ops.GraphKeys.LOSSES,
1053    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
1054  """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
1055
1056  `weights` acts as a coefficient for the loss. If a scalar is provided,
1057  then the loss is simply scaled by the given value. If `weights` is a
1058  tensor of shape `[batch_size]`, then the loss weights apply to each
1059  corresponding sample.
1060
1061  Args:
1062    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
1063      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
1064      must be an index in `[0, num_classes)`. Other values will raise an
1065      exception when this op is run on CPU, and return `NaN` for corresponding
1066      loss and gradient rows on GPU.
1067    logits: Unscaled log probabilities of shape
1068      `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or
1069      `float64`.
1070    weights: Coefficients for the loss. This must be scalar or broadcastable to
1071      `labels` (i.e. same rank and each dimension is either 1 or the same).
1072    scope: the scope for the operations performed in computing the loss.
1073    loss_collection: collection to which the loss will be added.
1074    reduction: Type of reduction to apply to loss.
1075
1076  Returns:
1077    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
1078    `NONE`, this has the same shape as `labels`; otherwise, it is scalar.
1079
1080  Raises:
1081    ValueError: If the shapes of `logits`, `labels`, and `weights` are
1082      incompatible, or if any of them are None.
1083
1084  @compatibility(eager)
1085  The `loss_collection` argument is ignored when executing eagerly. Consider
1086  holding on to the return value or collecting losses via a `tf.keras.Model`.
1087  @end_compatibility
1088  """
1089  if labels is None:
1090    raise ValueError("labels must not be None.")
1091  if logits is None:
1092    raise ValueError("logits must not be None.")
1093  with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
1094                      (logits, labels, weights)) as scope:
1095    # As documented above in Args, labels contain class IDs and logits contains
1096    # 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;
1097    # therefore, expected_rank_diff=1.
1098    labels, logits, weights = _remove_squeezable_dimensions(
1099        labels, logits, weights, expected_rank_diff=1)
1100    losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
1101                                                         logits=logits,
1102                                                         name="xentropy")
1103    return compute_weighted_loss(
1104        losses, weights, scope, loss_collection, reduction=reduction)
1105