• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Loss operations for use in neural networks."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import confusion_matrix
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.ops import weights_broadcast_ops
31from tensorflow.python.ops.losses import util
32from tensorflow.python.util import dispatch
33from tensorflow.python.util.deprecation import deprecated_args
34from tensorflow.python.util.deprecation import deprecated_argument_lookup
35from tensorflow.python.util.tf_export import tf_export
36
37
38@tf_export(v1=["losses.Reduction"])
39class Reduction(object):
40  """Types of loss reduction.
41
42  Contains the following values:
43
44  * `NONE`: Un-reduced weighted losses with the same shape as input.
45  * `SUM`: Scalar sum of weighted losses.
46  * `MEAN`: Scalar `SUM` divided by sum of weights. DEPRECATED.
47  * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
48  * `SUM_OVER_NONZERO_WEIGHTS`: Scalar `SUM` divided by number of non-zero
49     weights. DEPRECATED.
50  * `SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`. DEPRECATED.
51  """
52
53  NONE = "none"
54  SUM = "weighted_sum"
55  SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
56  MEAN = "weighted_mean"
57  SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
58  SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS
59
60  @classmethod
61  def all(cls):
62    return (
63        cls.NONE,
64        cls.SUM,
65        cls.MEAN,
66        cls.SUM_OVER_BATCH_SIZE,
67        cls.SUM_OVER_NONZERO_WEIGHTS,
68        cls.SUM_BY_NONZERO_WEIGHTS)
69
70  @classmethod
71  def validate(cls, key):
72    if key not in cls.all():
73      raise ValueError("Invalid Reduction Key %s." % key)
74
75
76def _safe_mean(losses, num_present):
77  """Computes a safe mean of the losses.
78
79  Args:
80    losses: `Tensor` whose elements contain individual loss measurements.
81    num_present: The number of measurable elements in `losses`.
82
83  Returns:
84    A scalar representing the mean of `losses`. If `num_present` is zero,
85      then zero is returned.
86  """
87  total_loss = math_ops.reduce_sum(losses)
88  return math_ops.div_no_nan(total_loss, num_present, name="value")
89
90
91def _num_present(losses, weights, per_batch=False):
92  """Computes the number of elements in the loss function induced by `weights`.
93
94  A given weights tensor induces different numbers of usable elements in the
95  `losses` tensor. The `weights` tensor is broadcast across `losses` for all
96  possible dimensions. For example, if `losses` is a tensor of dimension
97  `[4, 5, 6, 3]` and `weights` is a tensor of shape `[4, 5]`, then `weights` is,
98  in effect, tiled to match the shape of `losses`. Following this effective
99  tile, the total number of present elements is the number of non-zero weights.
100
101  Args:
102    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
103    weights: `Tensor` of shape `[]`, `[batch_size]` or
104      `[batch_size, d1, ... dK]`, where K < N.
105    per_batch: Whether to return the number of elements per batch or as a sum
106      total.
107
108  Returns:
109    The number of present (non-zero) elements in the losses tensor. If
110      `per_batch` is `True`, the value is returned as a tensor of size
111      `[batch_size]`. Otherwise, a single scalar tensor is returned.
112  """
113  if ((isinstance(weights, float) and weights != 0.0) or
114      (context.executing_eagerly() and weights._rank() == 0  # pylint: disable=protected-access
115       and not math_ops.equal(weights, 0.0))):
116    return _num_elements(losses)
117  with ops.name_scope(None, "num_present", (losses, weights)) as scope:
118    weights = math_ops.cast(weights, dtype=dtypes.float32)
119    present = array_ops.where(
120        math_ops.equal(weights, 0.0),
121        array_ops.zeros_like(weights),
122        array_ops.ones_like(weights))
123    present = weights_broadcast_ops.broadcast_weights(present, losses)
124    if per_batch:
125      return math_ops.reduce_sum(
126          present,
127          axis=math_ops.range(1, array_ops.rank(present)),
128          keepdims=True,
129          name=scope)
130    return math_ops.reduce_sum(present, name=scope)
131
132
133def _num_elements(losses):
134  """Computes the number of elements in `losses` tensor."""
135  with ops.name_scope(None, "num_elements", values=[losses]) as scope:
136    return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
137
138
139@tf_export(v1=["losses.compute_weighted_loss"])
140@dispatch.add_dispatch_support
141def compute_weighted_loss(
142    losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
143    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
144  """Computes the weighted loss.
145
146  Args:
147    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
148    weights: Optional `Tensor` whose rank is either 0, or the same rank as
149      `losses`, and must be broadcastable to `losses` (i.e., all dimensions must
150      be either `1`, or the same as the corresponding `losses` dimension).
151    scope: the scope for the operations performed in computing the loss.
152    loss_collection: the loss will be added to these collections.
153    reduction: Type of reduction to apply to loss.
154
155  Returns:
156    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
157    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
158
159  Raises:
160    ValueError: If `weights` is `None` or the shape is not compatible with
161      `losses`, or if the number of dimensions (rank) of either `losses` or
162      `weights` is missing.
163
164  Note:
165    When calculating the gradient of a weighted loss contributions from
166    both `losses` and `weights` are considered. If your `weights` depend
167    on some model parameters but you do not want this to affect the loss
168    gradient, you need to apply `tf.stop_gradient` to `weights` before
169    passing them to `compute_weighted_loss`.
170
171  @compatibility(eager)
172  The `loss_collection` argument is ignored when executing eagerly. Consider
173  holding on to the return value or collecting losses via a `tf.keras.Model`.
174  @end_compatibility
175  """
176  Reduction.validate(reduction)
177  with ops.name_scope(scope, "weighted_loss", (losses, weights)):
178    # Save the `reduction` argument for loss normalization when distributing
179    # to multiple replicas. Used only for estimator + v1 optimizer flow.
180    ops.get_default_graph()._last_loss_reduction = reduction  # pylint: disable=protected-access
181
182    with ops.control_dependencies((
183        weights_broadcast_ops.assert_broadcastable(weights, losses),)):
184      losses = ops.convert_to_tensor(losses)
185      input_dtype = losses.dtype
186      losses = math_ops.cast(losses, dtype=dtypes.float32)
187      weights = math_ops.cast(weights, dtype=dtypes.float32)
188      weighted_losses = math_ops.multiply(losses, weights)
189      if reduction == Reduction.NONE:
190        loss = weighted_losses
191      else:
192        loss = math_ops.reduce_sum(weighted_losses)
193        if reduction == Reduction.MEAN:
194          loss = _safe_mean(
195              loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
196        elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or
197              reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS):
198          loss = _safe_mean(loss, _num_present(losses, weights))
199        elif reduction == Reduction.SUM_OVER_BATCH_SIZE:
200          loss = _safe_mean(loss, _num_elements(losses))
201
202      # Convert the result back to the input type.
203      loss = math_ops.cast(loss, input_dtype)
204      util.add_loss(loss, loss_collection)
205      return loss
206
207
208@tf_export(v1=["losses.absolute_difference"])
209@dispatch.add_dispatch_support
210def absolute_difference(
211    labels, predictions, weights=1.0, scope=None,
212    loss_collection=ops.GraphKeys.LOSSES,
213    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
214  """Adds an Absolute Difference loss to the training procedure.
215
216  `weights` acts as a coefficient for the loss. If a scalar is provided, then
217  the loss is simply scaled by the given value. If `weights` is a `Tensor` of
218  shape `[batch_size]`, then the total loss for each sample of the batch is
219  rescaled by the corresponding element in the `weights` vector. If the shape of
220  `weights` matches the shape of `predictions`, then the loss of each
221  measurable element of `predictions` is scaled by the corresponding value of
222  `weights`.
223
224  Args:
225    labels: The ground truth output tensor, same dimensions as 'predictions'.
226    predictions: The predicted outputs.
227    weights: Optional `Tensor` whose rank is either 0, or the same rank as
228      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
229      be either `1`, or the same as the corresponding `losses` dimension).
230    scope: The scope for the operations performed in computing the loss.
231    loss_collection: collection to which this loss will be added.
232    reduction: Type of reduction to apply to loss.
233
234  Returns:
235    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
236    shape as `labels`; otherwise, it is scalar.
237
238  Raises:
239    ValueError: If the shape of `predictions` doesn't match that of
240      `labels` or if the shape of `weights` is invalid or if `labels`
241      or `predictions` is None.
242
243  @compatibility(eager)
244  The `loss_collection` argument is ignored when executing eagerly. Consider
245  holding on to the return value or collecting losses via a `tf.keras.Model`.
246  @end_compatibility
247  """
248  if labels is None:
249    raise ValueError("labels must not be None.")
250  if predictions is None:
251    raise ValueError("predictions must not be None.")
252  with ops.name_scope(scope, "absolute_difference",
253                      (predictions, labels, weights)) as scope:
254    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
255    labels = math_ops.cast(labels, dtype=dtypes.float32)
256    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
257    losses = math_ops.abs(math_ops.subtract(predictions, labels))
258    return compute_weighted_loss(
259        losses, weights, scope, loss_collection, reduction=reduction)
260
261
262@tf_export(v1=["losses.cosine_distance"])
263@dispatch.add_dispatch_support
264@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
265def cosine_distance(
266    labels, predictions, axis=None, weights=1.0, scope=None,
267    loss_collection=ops.GraphKeys.LOSSES,
268    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS,
269    dim=None):
270  """Adds a cosine-distance loss to the training procedure.
271
272  Note that the function assumes that `predictions` and `labels` are already
273  unit-normalized.
274
275  Args:
276    labels: `Tensor` whose shape matches 'predictions'
277    predictions: An arbitrary matrix.
278    axis: The dimension along which the cosine distance is computed.
279    weights: Optional `Tensor` whose rank is either 0, or the same rank as
280      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
281      be either `1`, or the same as the corresponding `losses` dimension).
282    scope: The scope for the operations performed in computing the loss.
283    loss_collection: collection to which this loss will be added.
284    reduction: Type of reduction to apply to loss.
285    dim: The old (deprecated) name for `axis`.
286
287  Returns:
288    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
289    shape as `labels`; otherwise, it is scalar.
290
291  Raises:
292    ValueError: If `predictions` shape doesn't match `labels` shape, or
293      `axis`, `labels`, `predictions` or `weights` is `None`.
294
295  @compatibility(eager)
296  The `loss_collection` argument is ignored when executing eagerly. Consider
297  holding on to the return value or collecting losses via a `tf.keras.Model`.
298  @end_compatibility
299  """
300  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
301  if axis is None:
302    raise ValueError("You must specify 'axis'.")
303  if labels is None:
304    raise ValueError("labels must not be None.")
305  if predictions is None:
306    raise ValueError("predictions must not be None.")
307  with ops.name_scope(scope, "cosine_distance_loss",
308                      (predictions, labels, weights)) as scope:
309    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
310    labels = math_ops.cast(labels, dtype=dtypes.float32)
311    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
312
313    radial_diffs = math_ops.multiply(predictions, labels)
314    losses = 1 - math_ops.reduce_sum(radial_diffs, axis=(axis,), keepdims=True)
315    return compute_weighted_loss(
316        losses, weights, scope, loss_collection, reduction=reduction)
317
318
319@tf_export(v1=["losses.hinge_loss"])
320@dispatch.add_dispatch_support
321def hinge_loss(labels, logits, weights=1.0, scope=None,
322               loss_collection=ops.GraphKeys.LOSSES,
323               reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
324  """Adds a hinge loss to the training procedure.
325
326  Args:
327    labels: The ground truth output tensor. Its shape should match the shape of
328      logits. The values of the tensor are expected to be 0.0 or 1.0. Internally
329      the {0,1} labels are converted to {-1,1} when calculating the hinge loss.
330    logits: The logits, a float tensor. Note that logits are assumed to be
331      unbounded and 0-centered. A value > 0 (resp. < 0) is considered a positive
332      (resp. negative) binary prediction.
333    weights: Optional `Tensor` whose rank is either 0, or the same rank as
334      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
335      be either `1`, or the same as the corresponding `losses` dimension).
336    scope: The scope for the operations performed in computing the loss.
337    loss_collection: collection to which the loss will be added.
338    reduction: Type of reduction to apply to loss.
339
340  Returns:
341    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
342    shape as `labels`; otherwise, it is scalar.
343
344  Raises:
345    ValueError: If the shapes of `logits` and `labels` don't match or
346      if `labels` or `logits` is None.
347
348  @compatibility(eager)
349  The `loss_collection` argument is ignored when executing eagerly. Consider
350  holding on to the return value or collecting losses via a `tf.keras.Model`.
351  @end_compatibility
352  """
353  if labels is None:
354    raise ValueError("labels must not be None.")
355  if logits is None:
356    raise ValueError("logits must not be None.")
357  with ops.name_scope(scope, "hinge_loss", (logits, labels, weights)) as scope:
358    logits = math_ops.cast(logits, dtype=dtypes.float32)
359    labels = math_ops.cast(labels, dtype=dtypes.float32)
360    logits.get_shape().assert_is_compatible_with(labels.get_shape())
361    # We first need to convert binary labels to -1/1 labels (as floats).
362    all_ones = array_ops.ones_like(labels)
363    labels = math_ops.subtract(2 * labels, all_ones)
364    losses = nn_ops.relu(
365        math_ops.subtract(all_ones, math_ops.multiply(labels, logits)))
366    return compute_weighted_loss(
367        losses, weights, scope, loss_collection, reduction=reduction)
368
369
370@tf_export(v1=["losses.huber_loss"])
371@dispatch.add_dispatch_support
372def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
373               loss_collection=ops.GraphKeys.LOSSES,
374               reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
375  """Adds a [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss) term to the training procedure.
376
377  For each value x in `error=labels-predictions`, the following is calculated:
378
379  ```
380    0.5 * x^2                  if |x| <= d
381    0.5 * d^2 + d * (|x| - d)  if |x| > d
382  ```
383
384  where d is `delta`.
385
386  `weights` acts as a coefficient for the loss. If a scalar is provided, then
387  the loss is simply scaled by the given value. If `weights` is a tensor of size
388  `[batch_size]`, then the total loss for each sample of the batch is rescaled
389  by the corresponding element in the `weights` vector. If the shape of
390  `weights` matches the shape of `predictions`, then the loss of each
391  measurable element of `predictions` is scaled by the corresponding value of
392  `weights`.
393
394  Args:
395    labels: The ground truth output tensor, same dimensions as 'predictions'.
396    predictions: The predicted outputs.
397    weights: Optional `Tensor` whose rank is either 0, or the same rank as
398      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
399      be either `1`, or the same as the corresponding `losses` dimension).
400    delta: `float`, the point where the huber loss function changes from a
401      quadratic to linear.
402    scope: The scope for the operations performed in computing the loss.
403    loss_collection: collection to which the loss will be added.
404    reduction: Type of reduction to apply to loss.
405
406  Returns:
407    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
408    shape as `labels`; otherwise, it is scalar.
409
410  Raises:
411    ValueError: If the shape of `predictions` doesn't match that of `labels` or
412      if the shape of `weights` is invalid.  Also if `labels` or
413     `predictions` is None.
414
415  @compatibility(eager)
416  The `loss_collection` argument is ignored when executing eagerly. Consider
417  holding on to the return value or collecting losses via a `tf.keras.Model`.
418  @end_compatibility
419  """
420  if labels is None:
421    raise ValueError("labels must not be None.")
422  if predictions is None:
423    raise ValueError("predictions must not be None.")
424  with ops.name_scope(scope, "huber_loss",
425                      (predictions, labels, weights)) as scope:
426    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
427    labels = math_ops.cast(labels, dtype=dtypes.float32)
428    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
429    error = math_ops.subtract(predictions, labels)
430    abs_error = math_ops.abs(error)
431    quadratic = math_ops.minimum(abs_error, delta)
432    # The following expression is the same in value as
433    # tf.maximum(abs_error - delta, 0), but importantly the gradient for the
434    # expression when abs_error == delta is 0 (for tf.maximum it would be 1).
435    # This is necessary to avoid doubling the gradient, since there is already a
436    # nonzero contribution to the gradient from the quadratic term.
437    linear = math_ops.subtract(abs_error, quadratic)
438    losses = math_ops.add(
439        math_ops.multiply(
440            ops.convert_to_tensor(0.5, dtype=quadratic.dtype),
441            math_ops.multiply(quadratic, quadratic)),
442        math_ops.multiply(delta, linear))
443    return compute_weighted_loss(
444        losses, weights, scope, loss_collection, reduction=reduction)
445
446
447@tf_export(v1=["losses.log_loss"])
448@dispatch.add_dispatch_support
449def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
450             loss_collection=ops.GraphKeys.LOSSES,
451             reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
452  """Adds a Log Loss term to the training procedure.
453
454  `weights` acts as a coefficient for the loss. If a scalar is provided, then
455  the loss is simply scaled by the given value. If `weights` is a tensor of size
456  `[batch_size]`, then the total loss for each sample of the batch is rescaled
457  by the corresponding element in the `weights` vector. If the shape of
458  `weights` matches the shape of `predictions`, then the loss of each
459  measurable element of `predictions` is scaled by the corresponding value of
460  `weights`.
461
462  Args:
463    labels: The ground truth output tensor, same dimensions as 'predictions'.
464    predictions: The predicted outputs.
465    weights: Optional `Tensor` whose rank is either 0, or the same rank as
466      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
467      be either `1`, or the same as the corresponding `losses` dimension).
468    epsilon: A small increment to add to avoid taking a log of zero.
469    scope: The scope for the operations performed in computing the loss.
470    loss_collection: collection to which the loss will be added.
471    reduction: Type of reduction to apply to loss.
472
473  Returns:
474    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
475    shape as `labels`; otherwise, it is scalar.
476
477  Raises:
478    ValueError: If the shape of `predictions` doesn't match that of `labels` or
479      if the shape of `weights` is invalid.  Also if `labels` or `predictions`
480      is None.
481
482  @compatibility(eager)
483  The `loss_collection` argument is ignored when executing eagerly. Consider
484  holding on to the return value or collecting losses via a `tf.keras.Model`.
485  @end_compatibility
486  """
487  if labels is None:
488    raise ValueError("labels must not be None.")
489  if predictions is None:
490    raise ValueError("predictions must not be None.")
491  with ops.name_scope(scope, "log_loss",
492                      (predictions, labels, weights)) as scope:
493    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
494    labels = math_ops.cast(labels, dtype=dtypes.float32)
495    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
496    losses = -math_ops.multiply(
497        labels,
498        math_ops.log(predictions + epsilon)) - math_ops.multiply(
499            (1 - labels), math_ops.log(1 - predictions + epsilon))
500    return compute_weighted_loss(
501        losses, weights, scope, loss_collection, reduction=reduction)
502
503
504# TODO(b/37208492): Add reduction arg.
505@tf_export(v1=["losses.mean_pairwise_squared_error"])
506@dispatch.add_dispatch_support
507def mean_pairwise_squared_error(
508    labels, predictions, weights=1.0, scope=None,
509    loss_collection=ops.GraphKeys.LOSSES):
510  """Adds a pairwise-errors-squared loss to the training procedure.
511
512  Unlike `mean_squared_error`, which is a measure of the differences between
513  corresponding elements of `predictions` and `labels`,
514  `mean_pairwise_squared_error` is a measure of the differences between pairs of
515  corresponding elements of `predictions` and `labels`.
516
517  For example, if `labels`=[a, b, c] and `predictions`=[x, y, z], there are
518  three pairs of differences are summed to compute the loss:
519    loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3
520
521  Note that since the inputs are of shape `[batch_size, d0, ... dN]`, the
522  corresponding pairs are computed within each batch sample but not across
523  samples within a batch. For example, if `predictions` represents a batch of
524  16 grayscale images of dimension [batch_size, 100, 200], then the set of pairs
525  is drawn from each image, but not across images.
526
527  `weights` acts as a coefficient for the loss. If a scalar is provided, then
528  the loss is simply scaled by the given value. If `weights` is a tensor of size
529  `[batch_size]`, then the total loss for each sample of the batch is rescaled
530  by the corresponding element in the `weights` vector.
531
532  Args:
533    labels: The ground truth output tensor, whose shape must match the shape of
534      `predictions`.
535    predictions: The predicted outputs, a tensor of size
536      `[batch_size, d0, .. dN]` where N+1 is the total number of dimensions in
537      `predictions`.
538    weights: Coefficients for the loss a scalar, a tensor of shape
539      `[batch_size]` or a tensor whose shape matches `predictions`.
540    scope: The scope for the operations performed in computing the loss.
541    loss_collection: collection to which the loss will be added.
542
543  Returns:
544    A scalar `Tensor` that returns the weighted loss.
545
546  Raises:
547    ValueError: If the shape of `predictions` doesn't match that of `labels` or
548      if the shape of `weights` is invalid.  Also if `labels` or `predictions`
549      is None.
550
551  @compatibility(eager)
552  The `loss_collection` argument is ignored when executing eagerly. Consider
553  holding on to the return value or collecting losses via a `tf.keras.Model`.
554  @end_compatibility
555  """
556  if labels is None:
557    raise ValueError("labels must not be None.")
558  if predictions is None:
559    raise ValueError("predictions must not be None.")
560  with ops.name_scope(scope, "mean_pairwise_squared_error",
561                      (predictions, labels, weights)) as scope:
562    weights = math_ops.cast(weights, dtype=dtypes.float32)
563    labels = math_ops.cast(labels, dtype=dtypes.float32)
564    with ops.control_dependencies((
565        weights_broadcast_ops.assert_broadcastable(weights, labels),)):
566      predictions = math_ops.cast(predictions, dtype=dtypes.float32)
567      predictions.get_shape().assert_is_compatible_with(labels.get_shape())
568
569      diffs = math_ops.subtract(predictions, labels)
570
571      axis = math_ops.range(1, array_ops.rank(diffs))
572
573      sum_squares_diff_per_batch = math_ops.reduce_sum(
574          math_ops.square(diffs), axis=axis, keepdims=True)
575      num_present_per_batch = _num_present(diffs, weights, per_batch=True)
576
577      term1 = 2.0 * math_ops.div_no_nan(
578          sum_squares_diff_per_batch,
579          math_ops.maximum(num_present_per_batch - 1, 0),
580          name="value")
581
582      sum_diff = math_ops.reduce_sum(diffs, axis=axis, keepdims=True)
583      term2 = 2.0 * math_ops.div_no_nan(
584          math_ops.square(sum_diff),
585          math_ops.maximum(
586              math_ops.multiply(num_present_per_batch,
587                                num_present_per_batch - 1), 0),
588          name="value")
589
590      weighted_losses = math_ops.multiply(term1 - term2, weights)
591      loss = math_ops.reduce_sum(weighted_losses)
592
593      mean_loss = array_ops.where(
594          math_ops.reduce_sum(num_present_per_batch) > 0,
595          loss,
596          array_ops.zeros_like(loss),
597          name="value")
598      util.add_loss(mean_loss, loss_collection)
599      return mean_loss
600
601
602@tf_export(v1=["losses.mean_squared_error"])
603@dispatch.add_dispatch_support
604def mean_squared_error(
605    labels, predictions, weights=1.0, scope=None,
606    loss_collection=ops.GraphKeys.LOSSES,
607    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
608  """Adds a Sum-of-Squares loss to the training procedure.
609
610  `weights` acts as a coefficient for the loss. If a scalar is provided, then
611  the loss is simply scaled by the given value. If `weights` is a tensor of size
612  `[batch_size]`, then the total loss for each sample of the batch is rescaled
613  by the corresponding element in the `weights` vector. If the shape of
614  `weights` matches the shape of `predictions`, then the loss of each
615  measurable element of `predictions` is scaled by the corresponding value of
616  `weights`.
617
618  Args:
619    labels: The ground truth output tensor, same dimensions as 'predictions'.
620    predictions: The predicted outputs.
621    weights: Optional `Tensor` whose rank is either 0, or the same rank as
622      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
623      be either `1`, or the same as the corresponding `losses` dimension).
624    scope: The scope for the operations performed in computing the loss.
625    loss_collection: collection to which the loss will be added.
626    reduction: Type of reduction to apply to loss.
627
628  Returns:
629    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
630    shape as `labels`; otherwise, it is scalar.
631
632  Raises:
633    ValueError: If the shape of `predictions` doesn't match that of `labels` or
634      if the shape of `weights` is invalid.  Also if `labels` or `predictions`
635      is None.
636
637  @compatibility(eager)
638  The `loss_collection` argument is ignored when executing eagerly. Consider
639  holding on to the return value or collecting losses via a `tf.keras.Model`.
640  @end_compatibility
641  """
642  if labels is None:
643    raise ValueError("labels must not be None.")
644  if predictions is None:
645    raise ValueError("predictions must not be None.")
646  with ops.name_scope(scope, "mean_squared_error",
647                      (predictions, labels, weights)) as scope:
648    predictions = math_ops.cast(predictions, dtype=dtypes.float32)
649    labels = math_ops.cast(labels, dtype=dtypes.float32)
650    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
651    losses = math_ops.squared_difference(predictions, labels)
652    return compute_weighted_loss(
653        losses, weights, scope, loss_collection, reduction=reduction)
654
655
656@tf_export(v1=["losses.sigmoid_cross_entropy"])
657@dispatch.add_dispatch_support
658def sigmoid_cross_entropy(
659    multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
660    loss_collection=ops.GraphKeys.LOSSES,
661    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
662  """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
663
664  `weights` acts as a coefficient for the loss. If a scalar is provided,
665  then the loss is simply scaled by the given value. If `weights` is a
666  tensor of shape `[batch_size]`, then the loss weights apply to each
667  corresponding sample.
668
669  If `label_smoothing` is nonzero, smooth the labels towards 1/2:
670
671      new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
672                              + 0.5 * label_smoothing
673
674  Args:
675    multi_class_labels: `[batch_size, num_classes]` target integer labels in
676      `{0, 1}`.
677    logits: Float `[batch_size, num_classes]` logits outputs of the network.
678    weights: Optional `Tensor` whose rank is either 0, or the same rank as
679    `multi_class_labels`, and must be broadcastable to `multi_class_labels`
680    (i.e., all dimensions must be either `1`, or the same as the
681    corresponding `losses` dimension).
682    label_smoothing: If greater than `0` then smooth the labels.
683    scope: The scope for the operations performed in computing the loss.
684    loss_collection: collection to which the loss will be added.
685    reduction: Type of reduction to apply to loss.
686
687  Returns:
688    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
689    `NONE`, this has the same shape as `logits`; otherwise, it is scalar.
690
691  Raises:
692    ValueError: If the shape of `logits` doesn't match that of
693      `multi_class_labels` or if the shape of `weights` is invalid, or if
694      `weights` is None.  Also if `multi_class_labels` or `logits` is None.
695
696  @compatibility(eager)
697  The `loss_collection` argument is ignored when executing eagerly. Consider
698  holding on to the return value or collecting losses via a `tf.keras.Model`.
699  @end_compatibility
700  """
701  if multi_class_labels is None:
702    raise ValueError("multi_class_labels must not be None.")
703  if logits is None:
704    raise ValueError("logits must not be None.")
705  with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
706                      (logits, multi_class_labels, weights)) as scope:
707    logits = ops.convert_to_tensor(logits)
708    multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
709    logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
710
711    if label_smoothing > 0:
712      multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
713                            0.5 * label_smoothing)
714
715    losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
716                                                  logits=logits,
717                                                  name="xentropy")
718    return compute_weighted_loss(
719        losses, weights, scope, loss_collection, reduction=reduction)
720
721
722@tf_export(v1=["losses.softmax_cross_entropy"])
723@dispatch.add_dispatch_support
724def softmax_cross_entropy(
725    onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
726    loss_collection=ops.GraphKeys.LOSSES,
727    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
728  """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits_v2.
729
730  `weights` acts as a coefficient for the loss. If a scalar is provided,
731  then the loss is simply scaled by the given value. If `weights` is a
732  tensor of shape `[batch_size]`, then the loss weights apply to each
733  corresponding sample.
734
735  If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
736      new_onehot_labels = onehot_labels * (1 - label_smoothing)
737                          + label_smoothing / num_classes
738
739  Note that `onehot_labels` and `logits` must have the same shape,
740  e.g. `[batch_size, num_classes]`. The shape of `weights` must be
741  broadcastable to loss, whose shape is decided by the shape of `logits`.
742  In case the shape of `logits` is `[batch_size, num_classes]`, loss is
743  a `Tensor` of shape `[batch_size]`.
744
745  Args:
746    onehot_labels: One-hot-encoded labels.
747    logits: Logits outputs of the network.
748    weights: Optional `Tensor` that is broadcastable to loss.
749    label_smoothing: If greater than 0 then smooth the labels.
750    scope: the scope for the operations performed in computing the loss.
751    loss_collection: collection to which the loss will be added.
752    reduction: Type of reduction to apply to loss.
753
754  Returns:
755    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
756    `NONE`, this has shape `[batch_size]`; otherwise, it is scalar.
757
758  Raises:
759    ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
760      or if the shape of `weights` is invalid or if `weights` is None.  Also if
761      `onehot_labels` or `logits` is None.
762
763  @compatibility(eager)
764  The `loss_collection` argument is ignored when executing eagerly. Consider
765  holding on to the return value or collecting losses via a `tf.keras.Model`.
766  @end_compatibility
767  """
768  if onehot_labels is None:
769    raise ValueError("onehot_labels must not be None.")
770  if logits is None:
771    raise ValueError("logits must not be None.")
772  with ops.name_scope(scope, "softmax_cross_entropy_loss",
773                      (logits, onehot_labels, weights)) as scope:
774    logits = ops.convert_to_tensor(logits)
775    onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
776    logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
777
778    if label_smoothing > 0:
779      num_classes = math_ops.cast(
780          array_ops.shape(onehot_labels)[-1], logits.dtype)
781      smooth_positives = 1.0 - label_smoothing
782      smooth_negatives = label_smoothing / num_classes
783      onehot_labels = onehot_labels * smooth_positives + smooth_negatives
784
785    onehot_labels = array_ops.stop_gradient(
786        onehot_labels, name="labels_stop_gradient")
787    losses = nn.softmax_cross_entropy_with_logits_v2(
788        labels=onehot_labels, logits=logits, name="xentropy")
789
790    return compute_weighted_loss(
791        losses, weights, scope, loss_collection, reduction=reduction)
792
793
794# TODO(ptucker): Merge this with similar method in metrics_impl.
795def _remove_squeezable_dimensions(
796    labels, predictions, weights=None, expected_rank_diff=0):
797  """Internal version of _remove_squeezable_dimensions which handles weights.
798
799  Squeezes `predictions` and `labels` if their ranks differ from expected by
800  exactly 1.
801  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
802
803  This will use static shape if available. Otherwise, it will add graph
804  operations, which could result in a performance hit.
805
806  Args:
807    labels: Label values, a `Tensor` whose dimensions match `predictions`.
808    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
809    weights: Optional weight `Tensor`. It will be squeezed if it's not scalar,
810      and its rank is 1 more than the new rank of `labels`.
811    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
812
813  Returns:
814    Tuple of `predictions`, `labels` and `weights`, possibly with the last
815    dimension squeezed.
816  """
817  labels, predictions = confusion_matrix.remove_squeezable_dimensions(
818      labels, predictions, expected_rank_diff=expected_rank_diff)
819
820  if weights is not None:
821    weights = ops.convert_to_tensor(weights)
822    labels_rank = labels.get_shape().ndims
823    weights_shape = weights.get_shape()
824    weights_rank = weights_shape.ndims
825
826    if (labels_rank is not None) and (weights_rank is not None):
827      # Use static rank.
828      rank_diff = weights_rank - labels_rank
829      if rank_diff == 1:
830        weights = array_ops.squeeze(weights, [-1])
831      return labels, predictions, weights
832
833    # Use dynamic rank.
834    rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
835    if (weights_rank is None) or (
836        weights_rank > 0 and weights_shape.dims[-1].is_compatible_with(1)):
837      weights = control_flow_ops.cond(
838          math_ops.equal(1, rank_diff),
839          lambda: array_ops.squeeze(weights, [-1]),
840          lambda: weights)
841
842  return labels, predictions, weights
843
844
845@tf_export(v1=["losses.sparse_softmax_cross_entropy"])
846@dispatch.add_dispatch_support
847def sparse_softmax_cross_entropy(
848    labels, logits, weights=1.0, scope=None,
849    loss_collection=ops.GraphKeys.LOSSES,
850    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
851  """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
852
853  `weights` acts as a coefficient for the loss. If a scalar is provided,
854  then the loss is simply scaled by the given value. If `weights` is a
855  tensor of shape `[batch_size]`, then the loss weights apply to each
856  corresponding sample.
857
858  Args:
859    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
860      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
861      must be an index in `[0, num_classes)`. Other values will raise an
862      exception when this op is run on CPU, and return `NaN` for corresponding
863      loss and gradient rows on GPU.
864    logits: Unscaled log probabilities of shape
865      `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or
866      `float64`.
867    weights: Coefficients for the loss. This must be scalar or broadcastable to
868      `labels` (i.e. same rank and each dimension is either 1 or the same).
869    scope: the scope for the operations performed in computing the loss.
870    loss_collection: collection to which the loss will be added.
871    reduction: Type of reduction to apply to loss.
872
873  Returns:
874    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
875    `NONE`, this has the same shape as `labels`; otherwise, it is scalar.
876
877  Raises:
878    ValueError: If the shapes of `logits`, `labels`, and `weights` are
879      incompatible, or if any of them are None.
880
881  @compatibility(eager)
882  The `loss_collection` argument is ignored when executing eagerly. Consider
883  holding on to the return value or collecting losses via a `tf.keras.Model`.
884  @end_compatibility
885  """
886  if labels is None:
887    raise ValueError("labels must not be None.")
888  if logits is None:
889    raise ValueError("logits must not be None.")
890  with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
891                      (logits, labels, weights)) as scope:
892    # As documented above in Args, labels contain class IDs and logits contains
893    # 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;
894    # therefore, expected_rank_diff=1.
895    labels, logits, weights = _remove_squeezable_dimensions(
896        labels, logits, weights, expected_rank_diff=1)
897    losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
898                                                         logits=logits,
899                                                         name="xentropy")
900    return compute_weighted_loss(
901        losses, weights, scope, loss_collection, reduction=reduction)
902