• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of tf.metrics module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import confusion_matrix
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn
32from tensorflow.python.ops import sets
33from tensorflow.python.ops import sparse_ops
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.ops import weights_broadcast_ops
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util.deprecation import deprecated
39from tensorflow.python.util.tf_export import tf_export
40
41
42def metric_variable(shape, dtype, validate_shape=True, name=None):
43  """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
44
45  If running in a `DistributionStrategy` context, the variable will be
46  "sync on read". This means:
47
48  *   The returned object will be a container with separate variables
49      per replica of the model.
50
51  *   When writing to the variable, e.g. using `assign_add` in a metric
52      update, the update will be applied to the variable local to the
53      replica.
54
55  *   To get a metric's result value, we need to sum the variable values
56      across the replicas before computing the final answer. Furthermore,
57      the final answer should be computed once instead of in every
58      replica. Both of these are accomplished by running the computation
59      of the final result value inside
60      `distribution_strategy_context.get_replica_context().merge_call(fn)`.
61      Inside the `merge_call()`, ops are only added to the graph once
62      and access to a sync on read variable in a computation returns
63      the sum across all replicas.
64
65  Args:
66    shape: Shape of the created variable.
67    dtype: Type of the created variable.
68    validate_shape: (Optional) Whether shape validation is enabled for
69      the created variable.
70    name: (Optional) String name of the created variable.
71
72  Returns:
73    A (non-trainable) variable initialized to zero, or if inside a
74    `DistributionStrategy` scope a sync on read variable container.
75  """
76  # Note that synchronization "ON_READ" implies trainable=False.
77  return variable_scope.variable(
78      lambda: array_ops.zeros(shape, dtype),
79      collections=[
80          ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
81      ],
82      validate_shape=validate_shape,
83      synchronization=variable_scope.VariableSynchronization.ON_READ,
84      aggregation=variable_scope.VariableAggregation.SUM,
85      name=name)
86
87
88def _remove_squeezable_dimensions(predictions, labels, weights):
89  """Squeeze or expand last dim if needed.
90
91  Squeezes last dim of `predictions` or `labels` if their rank differs by 1
92  (using confusion_matrix.remove_squeezable_dimensions).
93  Squeezes or expands last dim of `weights` if its rank differs by 1 from the
94  new rank of `predictions`.
95
96  If `weights` is scalar, it is kept scalar.
97
98  This will use static shape if available. Otherwise, it will add graph
99  operations, which could result in a performance hit.
100
101  Args:
102    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
103    labels: Optional label `Tensor` whose dimensions match `predictions`.
104    weights: Optional weight scalar or `Tensor` whose dimensions match
105      `predictions`.
106
107  Returns:
108    Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
109    the last dimension squeezed, `weights` could be extended by one dimension.
110  """
111  predictions = ops.convert_to_tensor(predictions)
112  if labels is not None:
113    labels, predictions = confusion_matrix.remove_squeezable_dimensions(
114        labels, predictions)
115    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
116
117  if weights is None:
118    return predictions, labels, None
119
120  weights = ops.convert_to_tensor(weights)
121  weights_shape = weights.get_shape()
122  weights_rank = weights_shape.ndims
123  if weights_rank == 0:
124    return predictions, labels, weights
125
126  predictions_shape = predictions.get_shape()
127  predictions_rank = predictions_shape.ndims
128  if (predictions_rank is not None) and (weights_rank is not None):
129    # Use static rank.
130    if weights_rank - predictions_rank == 1:
131      weights = array_ops.squeeze(weights, [-1])
132    elif predictions_rank - weights_rank == 1:
133      weights = array_ops.expand_dims(weights, [-1])
134  else:
135    # Use dynamic rank.
136    weights_rank_tensor = array_ops.rank(weights)
137    rank_diff = weights_rank_tensor - array_ops.rank(predictions)
138
139    def _maybe_expand_weights():
140      return control_flow_ops.cond(
141          math_ops.equal(rank_diff, -1),
142          lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
143
144    # Don't attempt squeeze if it will fail based on static check.
145    if ((weights_rank is not None) and
146        (not weights_shape.dims[-1].is_compatible_with(1))):
147      maybe_squeeze_weights = lambda: weights
148    else:
149      maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
150
151    def _maybe_adjust_weights():
152      return control_flow_ops.cond(
153          math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
154          _maybe_expand_weights)
155
156    # If weights are scalar, do nothing. Otherwise, try to add or remove a
157    # dimension to match predictions.
158    weights = control_flow_ops.cond(
159        math_ops.equal(weights_rank_tensor, 0), lambda: weights,
160        _maybe_adjust_weights)
161  return predictions, labels, weights
162
163
164def _maybe_expand_labels(labels, predictions):
165  """If necessary, expand `labels` along last dimension to match `predictions`.
166
167  Args:
168    labels: `Tensor` or `SparseTensor` with shape
169      [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
170      num_labels=1, in which case the result is an expanded `labels` with shape
171      [D1, ... DN, 1].
172    predictions: `Tensor` with shape [D1, ... DN, num_classes].
173
174  Returns:
175    `labels` with the same rank as `predictions`.
176
177  Raises:
178    ValueError: if `labels` has invalid shape.
179  """
180  with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
181    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
182
183    # If sparse, expand sparse shape.
184    if isinstance(labels, sparse_tensor.SparseTensor):
185      return control_flow_ops.cond(
186          math_ops.equal(
187              array_ops.rank(predictions),
188              array_ops.size(labels.dense_shape) + 1),
189          lambda: sparse_ops.sparse_reshape(  # pylint: disable=g-long-lambda
190              labels,
191              shape=array_ops.concat((labels.dense_shape, (1,)), 0),
192              name=scope),
193          lambda: labels)
194
195    # Otherwise, try to use static shape.
196    labels_rank = labels.get_shape().ndims
197    if labels_rank is not None:
198      predictions_rank = predictions.get_shape().ndims
199      if predictions_rank is not None:
200        if predictions_rank == labels_rank:
201          return labels
202        if predictions_rank == labels_rank + 1:
203          return array_ops.expand_dims(labels, -1, name=scope)
204        raise ValueError(
205            'Unexpected labels shape %s for predictions shape %s.' %
206            (labels.get_shape(), predictions.get_shape()))
207
208    # Otherwise, use dynamic shape.
209    return control_flow_ops.cond(
210        math_ops.equal(array_ops.rank(predictions),
211                       array_ops.rank(labels) + 1),
212        lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
213
214
215def _safe_scalar_div(numerator, denominator, name):
216  """Divides two values, returning 0 if the denominator is 0.
217
218  Args:
219    numerator: A scalar `float64` `Tensor`.
220    denominator: A scalar `float64` `Tensor`.
221    name: Name for the returned op.
222
223  Returns:
224    0 if `denominator` == 0, else `numerator` / `denominator`
225  """
226  numerator.get_shape().with_rank_at_most(1)
227  denominator.get_shape().with_rank_at_most(1)
228  return math_ops.div_no_nan(numerator, denominator, name=name)
229
230
231def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
232  """Calculate a streaming confusion matrix.
233
234  Calculates a confusion matrix. For estimation over a stream of data,
235  the function creates an  `update_op` operation.
236
237  Args:
238    labels: A `Tensor` of ground truth labels with shape [batch size] and of
239      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
240    predictions: A `Tensor` of prediction results for semantic labels, whose
241      shape is [batch size] and type `int32` or `int64`. The tensor will be
242      flattened if its rank > 1.
243    num_classes: The possible number of labels the prediction task can
244      have. This value must be provided, since a confusion matrix of
245      dimension = [num_classes, num_classes] will be allocated.
246    weights: Optional `Tensor` whose rank is either 0, or the same rank as
247      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
248      be either `1`, or the same as the corresponding `labels` dimension).
249
250  Returns:
251    total_cm: A `Tensor` representing the confusion matrix.
252    update_op: An operation that increments the confusion matrix.
253  """
254  # Local variable to accumulate the predictions in the confusion matrix.
255  total_cm = metric_variable(
256      [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
257
258  # Cast the type to int64 required by confusion_matrix_ops.
259  predictions = math_ops.cast(predictions, dtypes.int64)
260  labels = math_ops.cast(labels, dtypes.int64)
261  num_classes = math_ops.cast(num_classes, dtypes.int64)
262
263  # Flatten the input if its rank > 1.
264  if predictions.get_shape().ndims > 1:
265    predictions = array_ops.reshape(predictions, [-1])
266
267  if labels.get_shape().ndims > 1:
268    labels = array_ops.reshape(labels, [-1])
269
270  if (weights is not None) and (weights.get_shape().ndims > 1):
271    weights = array_ops.reshape(weights, [-1])
272
273  # Accumulate the prediction to current confusion matrix.
274  current_cm = confusion_matrix.confusion_matrix(
275      labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
276  update_op = state_ops.assign_add(total_cm, current_cm)
277  return total_cm, update_op
278
279
280def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
281  """Aggregate metric value across replicas."""
282  def fn(distribution, *a):
283    """Call `metric_value_fn` in the correct control flow context."""
284    if hasattr(distribution.extended, '_outer_control_flow_context'):
285      # If there was an outer context captured before this method was called,
286      # then we enter that context to create the metric value op. If the
287      # caputred context is `None`, ops.control_dependencies(None) gives the
288      # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
289      # captured context.
290      # This special handling is needed because sometimes the metric is created
291      # inside a while_loop (and perhaps a TPU rewrite context). But we don't
292      # want the value op to be evaluated every step or on the TPU. So we
293      # create it outside so that it can be evaluated at the end on the host,
294      # once the update ops have been evaluted.
295
296      # pylint: disable=protected-access
297      if distribution.extended._outer_control_flow_context is None:
298        with ops.control_dependencies(None):
299          metric_value = metric_value_fn(distribution, *a)
300      else:
301        distribution.extended._outer_control_flow_context.Enter()
302        metric_value = metric_value_fn(distribution, *a)
303        distribution.extended._outer_control_flow_context.Exit()
304        # pylint: enable=protected-access
305    else:
306      metric_value = metric_value_fn(distribution, *a)
307    if metrics_collections:
308      ops.add_to_collections(metrics_collections, metric_value)
309    return metric_value
310
311  return distribution_strategy_context.get_replica_context().merge_call(
312      fn, args=args)
313
314
315@tf_export(v1=['metrics.mean'])
316def mean(values,
317         weights=None,
318         metrics_collections=None,
319         updates_collections=None,
320         name=None):
321  """Computes the (weighted) mean of the given values.
322
323  The `mean` function creates two local variables, `total` and `count`
324  that are used to compute the average of `values`. This average is ultimately
325  returned as `mean` which is an idempotent operation that simply divides
326  `total` by `count`.
327
328  For estimation of the metric over a stream of data, the function creates an
329  `update_op` operation that updates these variables and returns the `mean`.
330  `update_op` increments `total` with the reduced sum of the product of `values`
331  and `weights`, and it increments `count` with the reduced sum of `weights`.
332
333  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
334
335  Args:
336    values: A `Tensor` of arbitrary dimensions.
337    weights: Optional `Tensor` whose rank is either 0, or the same rank as
338      `values`, and must be broadcastable to `values` (i.e., all dimensions must
339      be either `1`, or the same as the corresponding `values` dimension).
340    metrics_collections: An optional list of collections that `mean`
341      should be added to.
342    updates_collections: An optional list of collections that `update_op`
343      should be added to.
344    name: An optional variable_scope name.
345
346  Returns:
347    mean: A `Tensor` representing the current mean, the value of `total` divided
348      by `count`.
349    update_op: An operation that increments the `total` and `count` variables
350      appropriately and whose value matches `mean_value`.
351
352  Raises:
353    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
354      or if either `metrics_collections` or `updates_collections` are not a list
355      or tuple.
356    RuntimeError: If eager execution is enabled.
357  """
358  if context.executing_eagerly():
359    raise RuntimeError('tf.metrics.mean is not supported when eager execution '
360                       'is enabled.')
361
362  with variable_scope.variable_scope(name, 'mean', (values, weights)):
363    values = math_ops.cast(values, dtypes.float32)
364
365    total = metric_variable([], dtypes.float32, name='total')
366    count = metric_variable([], dtypes.float32, name='count')
367
368    if weights is None:
369      num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
370    else:
371      values, _, weights = _remove_squeezable_dimensions(
372          predictions=values, labels=None, weights=weights)
373      weights = weights_broadcast_ops.broadcast_weights(
374          math_ops.cast(weights, dtypes.float32), values)
375      values = math_ops.multiply(values, weights)
376      num_values = math_ops.reduce_sum(weights)
377
378    update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
379    with ops.control_dependencies([values]):
380      update_count_op = state_ops.assign_add(count, num_values)
381
382    def compute_mean(_, t, c):
383      return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
384
385    mean_t = _aggregate_across_replicas(
386        metrics_collections, compute_mean, total, count)
387    update_op = math_ops.div_no_nan(
388        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
389
390    if updates_collections:
391      ops.add_to_collections(updates_collections, update_op)
392
393    return mean_t, update_op
394
395
396@tf_export(v1=['metrics.accuracy'])
397def accuracy(labels,
398             predictions,
399             weights=None,
400             metrics_collections=None,
401             updates_collections=None,
402             name=None):
403  """Calculates how often `predictions` matches `labels`.
404
405  The `accuracy` function creates two local variables, `total` and
406  `count` that are used to compute the frequency with which `predictions`
407  matches `labels`. This frequency is ultimately returned as `accuracy`: an
408  idempotent operation that simply divides `total` by `count`.
409
410  For estimation of the metric over a stream of data, the function creates an
411  `update_op` operation that updates these variables and returns the `accuracy`.
412  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
413  where the corresponding elements of `predictions` and `labels` match and 0.0
414  otherwise. Then `update_op` increments `total` with the reduced sum of the
415  product of `weights` and `is_correct`, and it increments `count` with the
416  reduced sum of `weights`.
417
418  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
419
420  Args:
421    labels: The ground truth values, a `Tensor` whose shape matches
422      `predictions`.
423    predictions: The predicted values, a `Tensor` of any shape.
424    weights: Optional `Tensor` whose rank is either 0, or the same rank as
425      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
426      be either `1`, or the same as the corresponding `labels` dimension).
427    metrics_collections: An optional list of collections that `accuracy` should
428      be added to.
429    updates_collections: An optional list of collections that `update_op` should
430      be added to.
431    name: An optional variable_scope name.
432
433  Returns:
434    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
435      by `count`.
436    update_op: An operation that increments the `total` and `count` variables
437      appropriately and whose value matches `accuracy`.
438
439  Raises:
440    ValueError: If `predictions` and `labels` have mismatched shapes, or if
441      `weights` is not `None` and its shape doesn't match `predictions`, or if
442      either `metrics_collections` or `updates_collections` are not a list or
443      tuple.
444    RuntimeError: If eager execution is enabled.
445  """
446  if context.executing_eagerly():
447    raise RuntimeError('tf.metrics.accuracy is not supported when eager '
448                       'execution is enabled.')
449
450  predictions, labels, weights = _remove_squeezable_dimensions(
451      predictions=predictions, labels=labels, weights=weights)
452  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
453  if labels.dtype != predictions.dtype:
454    predictions = math_ops.cast(predictions, labels.dtype)
455  is_correct = math_ops.cast(
456      math_ops.equal(predictions, labels), dtypes.float32)
457  return mean(is_correct, weights, metrics_collections, updates_collections,
458              name or 'accuracy')
459
460
461def _confusion_matrix_at_thresholds(labels,
462                                    predictions,
463                                    thresholds,
464                                    weights=None,
465                                    includes=None):
466  """Computes true_positives, false_negatives, true_negatives, false_positives.
467
468  This function creates up to four local variables, `true_positives`,
469  `true_negatives`, `false_positives` and `false_negatives`.
470  `true_positive[i]` is defined as the total weight of values in `predictions`
471  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
472  `false_negatives[i]` is defined as the total weight of values in `predictions`
473  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
474  `true_negatives[i]` is defined as the total weight of values in `predictions`
475  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
476  `false_positives[i]` is defined as the total weight of values in `predictions`
477  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
478
479  For estimation of these metrics over a stream of data, for each metric the
480  function respectively creates an `update_op` operation that updates the
481  variable and returns its value.
482
483  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
484
485  Args:
486    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
487      `bool`.
488    predictions: A floating point `Tensor` of arbitrary shape and whose values
489      are in the range `[0, 1]`.
490    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
491    weights: Optional `Tensor` whose rank is either 0, or the same rank as
492      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
493      be either `1`, or the same as the corresponding `labels` dimension).
494    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
495        default to all four.
496
497  Returns:
498    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
499        `includes`.
500    update_ops: Dict of operations that increments the `values`. Keys are from
501        `includes`.
502
503  Raises:
504    ValueError: If `predictions` and `labels` have mismatched shapes, or if
505      `weights` is not `None` and its shape doesn't match `predictions`, or if
506      `includes` contains invalid keys.
507  """
508  all_includes = ('tp', 'fn', 'tn', 'fp')
509  if includes is None:
510    includes = all_includes
511  else:
512    for include in includes:
513      if include not in all_includes:
514        raise ValueError('Invalid key: %s.' % include)
515
516  with ops.control_dependencies([
517      check_ops.assert_greater_equal(
518          predictions,
519          math_ops.cast(0.0, dtype=predictions.dtype),
520          message='predictions must be in [0, 1]'),
521      check_ops.assert_less_equal(
522          predictions,
523          math_ops.cast(1.0, dtype=predictions.dtype),
524          message='predictions must be in [0, 1]')
525  ]):
526    predictions, labels, weights = _remove_squeezable_dimensions(
527        predictions=math_ops.cast(predictions, dtypes.float32),
528        labels=math_ops.cast(labels, dtype=dtypes.bool),
529        weights=weights)
530
531  num_thresholds = len(thresholds)
532
533  # Reshape predictions and labels.
534  predictions_2d = array_ops.reshape(predictions, [-1, 1])
535  labels_2d = array_ops.reshape(
536      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
537
538  # Use static shape if known.
539  num_predictions = predictions_2d.get_shape().as_list()[0]
540
541  # Otherwise use dynamic shape.
542  if num_predictions is None:
543    num_predictions = array_ops.shape(predictions_2d)[0]
544  thresh_tiled = array_ops.tile(
545      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
546      array_ops.stack([1, num_predictions]))
547
548  # Tile the predictions after thresholding them across different thresholds.
549  pred_is_pos = math_ops.greater(
550      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
551      thresh_tiled)
552  if ('fn' in includes) or ('tn' in includes):
553    pred_is_neg = math_ops.logical_not(pred_is_pos)
554
555  # Tile labels by number of thresholds
556  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
557  if ('fp' in includes) or ('tn' in includes):
558    label_is_neg = math_ops.logical_not(label_is_pos)
559
560  if weights is not None:
561    weights = weights_broadcast_ops.broadcast_weights(
562        math_ops.cast(weights, dtypes.float32), predictions)
563    weights_tiled = array_ops.tile(
564        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
565    thresh_tiled.get_shape().assert_is_compatible_with(
566        weights_tiled.get_shape())
567  else:
568    weights_tiled = None
569
570  values = {}
571  update_ops = {}
572
573  if 'tp' in includes:
574    true_p = metric_variable(
575        [num_thresholds], dtypes.float32, name='true_positives')
576    is_true_positive = math_ops.cast(
577        math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
578    if weights_tiled is not None:
579      is_true_positive *= weights_tiled
580    update_ops['tp'] = state_ops.assign_add(true_p,
581                                            math_ops.reduce_sum(
582                                                is_true_positive, 1))
583    values['tp'] = true_p
584
585  if 'fn' in includes:
586    false_n = metric_variable(
587        [num_thresholds], dtypes.float32, name='false_negatives')
588    is_false_negative = math_ops.cast(
589        math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
590    if weights_tiled is not None:
591      is_false_negative *= weights_tiled
592    update_ops['fn'] = state_ops.assign_add(false_n,
593                                            math_ops.reduce_sum(
594                                                is_false_negative, 1))
595    values['fn'] = false_n
596
597  if 'tn' in includes:
598    true_n = metric_variable(
599        [num_thresholds], dtypes.float32, name='true_negatives')
600    is_true_negative = math_ops.cast(
601        math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
602    if weights_tiled is not None:
603      is_true_negative *= weights_tiled
604    update_ops['tn'] = state_ops.assign_add(true_n,
605                                            math_ops.reduce_sum(
606                                                is_true_negative, 1))
607    values['tn'] = true_n
608
609  if 'fp' in includes:
610    false_p = metric_variable(
611        [num_thresholds], dtypes.float32, name='false_positives')
612    is_false_positive = math_ops.cast(
613        math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
614    if weights_tiled is not None:
615      is_false_positive *= weights_tiled
616    update_ops['fp'] = state_ops.assign_add(false_p,
617                                            math_ops.reduce_sum(
618                                                is_false_positive, 1))
619    values['fp'] = false_p
620
621  return values, update_ops
622
623
624def _aggregate_variable(v, collections):
625  f = lambda distribution, value: distribution.extended.read_var(value)
626  return _aggregate_across_replicas(collections, f, v)
627
628
629@tf_export(v1=['metrics.auc'])
630def auc(labels,
631        predictions,
632        weights=None,
633        num_thresholds=200,
634        metrics_collections=None,
635        updates_collections=None,
636        curve='ROC',
637        name=None,
638        summation_method='trapezoidal'):
639  """Computes the approximate AUC via a Riemann sum.
640
641  The `auc` function creates four local variables, `true_positives`,
642  `true_negatives`, `false_positives` and `false_negatives` that are used to
643  compute the AUC. To discretize the AUC curve, a linearly spaced set of
644  thresholds is used to compute pairs of recall and precision values. The area
645  under the ROC-curve is therefore computed using the height of the recall
646  values by the false positive rate, while the area under the PR-curve is the
647  computed using the height of the precision values by the recall.
648
649  This value is ultimately returned as `auc`, an idempotent operation that
650  computes the area under a discretized curve of precision versus recall values
651  (computed using the aforementioned variables). The `num_thresholds` variable
652  controls the degree of discretization with larger numbers of thresholds more
653  closely approximating the true AUC. The quality of the approximation may vary
654  dramatically depending on `num_thresholds`.
655
656  For best results, `predictions` should be distributed approximately uniformly
657  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
658  approximation may be poor if this is not the case. Setting `summation_method`
659  to 'minoring' or 'majoring' can help quantify the error in the approximation
660  by providing lower or upper bound estimate of the AUC.
661
662  For estimation of the metric over a stream of data, the function creates an
663  `update_op` operation that updates these variables and returns the `auc`.
664
665  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
666
667  Args:
668    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
669      `bool`.
670    predictions: A floating point `Tensor` of arbitrary shape and whose values
671      are in the range `[0, 1]`.
672    weights: Optional `Tensor` whose rank is either 0, or the same rank as
673      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
674      be either `1`, or the same as the corresponding `labels` dimension).
675    num_thresholds: The number of thresholds to use when discretizing the roc
676      curve.
677    metrics_collections: An optional list of collections that `auc` should be
678      added to.
679    updates_collections: An optional list of collections that `update_op` should
680      be added to.
681    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
682      'PR' for the Precision-Recall-curve.
683    name: An optional variable_scope name.
684    summation_method: Specifies the Riemann summation method used
685      (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
686      applies the trapezoidal rule; 'careful_interpolation', a variant of it
687      differing only by a more correct interpolation scheme for PR-AUC -
688      interpolating (true/false) positives but not the ratio that is precision;
689      'minoring' that applies left summation for increasing intervals and right
690      summation for decreasing intervals; 'majoring' that does the opposite.
691      Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
692      (to be deprecated soon) as it applies the same method for ROC, and a
693      better one (see Davis & Goadrich 2006 for details) for the PR curve.
694
695  Returns:
696    auc: A scalar `Tensor` representing the current area-under-curve.
697    update_op: An operation that increments the `true_positives`,
698      `true_negatives`, `false_positives` and `false_negatives` variables
699      appropriately and whose value matches `auc`.
700
701  Raises:
702    ValueError: If `predictions` and `labels` have mismatched shapes, or if
703      `weights` is not `None` and its shape doesn't match `predictions`, or if
704      either `metrics_collections` or `updates_collections` are not a list or
705      tuple.
706    RuntimeError: If eager execution is enabled.
707  """
708  if context.executing_eagerly():
709    raise RuntimeError('tf.metrics.auc is not supported when eager execution '
710                       'is enabled.')
711
712  with variable_scope.variable_scope(name, 'auc',
713                                     (labels, predictions, weights)):
714    if curve != 'ROC' and curve != 'PR':
715      raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
716    kepsilon = 1e-7  # to account for floating point imprecisions
717    thresholds = [
718        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
719    ]
720    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
721
722    values, update_ops = _confusion_matrix_at_thresholds(
723        labels, predictions, thresholds, weights)
724
725    # Add epsilons to avoid dividing by 0.
726    epsilon = 1.0e-6
727
728    def interpolate_pr_auc(tp, fp, fn):
729      """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
730
731      Note here we derive & use a closed formula not present in the paper
732      - as follows:
733      Modeling all of TP (true positive weight),
734      FP (false positive weight) and their sum P = TP + FP (positive weight)
735      as varying linearly within each interval [A, B] between successive
736      thresholds, we get
737        Precision = (TP_A + slope * (P - P_A)) / P
738      with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
739      The area within the interval is thus (slope / total_pos_weight) times
740        int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
741        int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
742      where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
743        int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
744      Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
745         slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
746      where dTP == TP_B - TP_A.
747      Note that when P_A == 0 the above calculation simplifies into
748        int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
749      which is really equivalent to imputing constant precision throughout the
750      first bucket having >0 true positives.
751
752      Args:
753        tp: true positive counts
754        fp: false positive counts
755        fn: false negative counts
756      Returns:
757        pr_auc: an approximation of the area under the P-R curve.
758      """
759      dtp = tp[:num_thresholds - 1] - tp[1:]
760      p = tp + fp
761      prec_slope = math_ops.div_no_nan(
762          dtp,
763          math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
764          name='prec_slope')
765      intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
766      safe_p_ratio = array_ops.where(
767          math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
768          math_ops.div_no_nan(
769              p[:num_thresholds - 1],
770              math_ops.maximum(p[1:], 0),
771              name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
772      return math_ops.reduce_sum(
773          math_ops.div_no_nan(
774              prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
775              math_ops.maximum(tp[1:] + fn[1:], 0),
776              name='pr_auc_increment'),
777          name='interpolate_pr_auc')
778
779    def compute_auc(tp, fn, tn, fp, name):
780      """Computes the roc-auc or pr-auc based on confusion counts."""
781      if curve == 'PR':
782        if summation_method == 'trapezoidal':
783          logging.warning(
784              'Trapezoidal rule is known to produce incorrect PR-AUCs; '
785              'please switch to "careful_interpolation" instead.')
786        elif summation_method == 'careful_interpolation':
787          # This one is a bit tricky and is handled separately.
788          return interpolate_pr_auc(tp, fp, fn)
789      rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
790      if curve == 'ROC':
791        fp_rate = math_ops.div(fp, fp + tn + epsilon)
792        x = fp_rate
793        y = rec
794      else:  # curve == 'PR'.
795        prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
796        x = rec
797        y = prec
798      if summation_method in ('trapezoidal', 'careful_interpolation'):
799        # Note that the case ('PR', 'careful_interpolation') has been handled
800        # above.
801        return math_ops.reduce_sum(
802            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
803                              (y[:num_thresholds - 1] + y[1:]) / 2.),
804            name=name)
805      elif summation_method == 'minoring':
806        return math_ops.reduce_sum(
807            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
808                              math_ops.minimum(y[:num_thresholds - 1], y[1:])),
809            name=name)
810      elif summation_method == 'majoring':
811        return math_ops.reduce_sum(
812            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
813                              math_ops.maximum(y[:num_thresholds - 1], y[1:])),
814            name=name)
815      else:
816        raise ValueError('Invalid summation_method: %s' % summation_method)
817
818    # sum up the areas of all the trapeziums
819    def compute_auc_value(_, values):
820      return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
821                         'value')
822
823    auc_value = _aggregate_across_replicas(
824        metrics_collections, compute_auc_value, values)
825    update_op = compute_auc(update_ops['tp'], update_ops['fn'],
826                            update_ops['tn'], update_ops['fp'], 'update_op')
827
828    if updates_collections:
829      ops.add_to_collections(updates_collections, update_op)
830
831    return auc_value, update_op
832
833
834@tf_export(v1=['metrics.mean_absolute_error'])
835def mean_absolute_error(labels,
836                        predictions,
837                        weights=None,
838                        metrics_collections=None,
839                        updates_collections=None,
840                        name=None):
841  """Computes the mean absolute error between the labels and predictions.
842
843  The `mean_absolute_error` function creates two local variables,
844  `total` and `count` that are used to compute the mean absolute error. This
845  average is weighted by `weights`, and it is ultimately returned as
846  `mean_absolute_error`: an idempotent operation that simply divides `total` by
847  `count`.
848
849  For estimation of the metric over a stream of data, the function creates an
850  `update_op` operation that updates these variables and returns the
851  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
852  absolute value of the differences between `predictions` and `labels`. Then
853  `update_op` increments `total` with the reduced sum of the product of
854  `weights` and `absolute_errors`, and it increments `count` with the reduced
855  sum of `weights`
856
857  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
858
859  Args:
860    labels: A `Tensor` of the same shape as `predictions`.
861    predictions: A `Tensor` of arbitrary shape.
862    weights: Optional `Tensor` whose rank is either 0, or the same rank as
863      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
864      be either `1`, or the same as the corresponding `labels` dimension).
865    metrics_collections: An optional list of collections that
866      `mean_absolute_error` should be added to.
867    updates_collections: An optional list of collections that `update_op` should
868      be added to.
869    name: An optional variable_scope name.
870
871  Returns:
872    mean_absolute_error: A `Tensor` representing the current mean, the value of
873      `total` divided by `count`.
874    update_op: An operation that increments the `total` and `count` variables
875      appropriately and whose value matches `mean_absolute_error`.
876
877  Raises:
878    ValueError: If `predictions` and `labels` have mismatched shapes, or if
879      `weights` is not `None` and its shape doesn't match `predictions`, or if
880      either `metrics_collections` or `updates_collections` are not a list or
881      tuple.
882    RuntimeError: If eager execution is enabled.
883  """
884  if context.executing_eagerly():
885    raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
886                       'when eager execution is enabled.')
887
888  predictions, labels, weights = _remove_squeezable_dimensions(
889      predictions=predictions, labels=labels, weights=weights)
890  absolute_errors = math_ops.abs(predictions - labels)
891  return mean(absolute_errors, weights, metrics_collections,
892              updates_collections, name or 'mean_absolute_error')
893
894
895@tf_export(v1=['metrics.mean_cosine_distance'])
896def mean_cosine_distance(labels,
897                         predictions,
898                         dim,
899                         weights=None,
900                         metrics_collections=None,
901                         updates_collections=None,
902                         name=None):
903  """Computes the cosine distance between the labels and predictions.
904
905  The `mean_cosine_distance` function creates two local variables,
906  `total` and `count` that are used to compute the average cosine distance
907  between `predictions` and `labels`. This average is weighted by `weights`,
908  and it is ultimately returned as `mean_distance`, which is an idempotent
909  operation that simply divides `total` by `count`.
910
911  For estimation of the metric over a stream of data, the function creates an
912  `update_op` operation that updates these variables and returns the
913  `mean_distance`.
914
915  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
916
917  Args:
918    labels: A `Tensor` of arbitrary shape.
919    predictions: A `Tensor` of the same shape as `labels`.
920    dim: The dimension along which the cosine distance is computed.
921    weights: Optional `Tensor` whose rank is either 0, or the same rank as
922      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
923      be either `1`, or the same as the corresponding `labels` dimension). Also,
924      dimension `dim` must be `1`.
925    metrics_collections: An optional list of collections that the metric
926      value variable should be added to.
927    updates_collections: An optional list of collections that the metric update
928      ops should be added to.
929    name: An optional variable_scope name.
930
931  Returns:
932    mean_distance: A `Tensor` representing the current mean, the value of
933      `total` divided by `count`.
934    update_op: An operation that increments the `total` and `count` variables
935      appropriately.
936
937  Raises:
938    ValueError: If `predictions` and `labels` have mismatched shapes, or if
939      `weights` is not `None` and its shape doesn't match `predictions`, or if
940      either `metrics_collections` or `updates_collections` are not a list or
941      tuple.
942    RuntimeError: If eager execution is enabled.
943  """
944  if context.executing_eagerly():
945    raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
946                       'eager execution is enabled.')
947
948  predictions, labels, weights = _remove_squeezable_dimensions(
949      predictions=predictions, labels=labels, weights=weights)
950  radial_diffs = math_ops.multiply(predictions, labels)
951  radial_diffs = math_ops.reduce_sum(
952      radial_diffs, axis=[
953          dim,
954      ], keepdims=True)
955  mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
956                                  'mean_cosine_distance')
957  mean_distance = math_ops.subtract(1.0, mean_distance)
958  update_op = math_ops.subtract(1.0, update_op)
959
960  if metrics_collections:
961    ops.add_to_collections(metrics_collections, mean_distance)
962
963  if updates_collections:
964    ops.add_to_collections(updates_collections, update_op)
965
966  return mean_distance, update_op
967
968
969@tf_export(v1=['metrics.mean_per_class_accuracy'])
970def mean_per_class_accuracy(labels,
971                            predictions,
972                            num_classes,
973                            weights=None,
974                            metrics_collections=None,
975                            updates_collections=None,
976                            name=None):
977  """Calculates the mean of the per-class accuracies.
978
979  Calculates the accuracy for each class, then takes the mean of that.
980
981  For estimation of the metric over a stream of data, the function creates an
982  `update_op` operation that updates the accuracy of each class and returns
983  them.
984
985  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
986
987  Args:
988    labels: A `Tensor` of ground truth labels with shape [batch size] and of
989      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
990    predictions: A `Tensor` of prediction results for semantic labels, whose
991      shape is [batch size] and type `int32` or `int64`. The tensor will be
992      flattened if its rank > 1.
993    num_classes: The possible number of labels the prediction task can
994      have. This value must be provided, since two variables with shape =
995      [num_classes] will be allocated.
996    weights: Optional `Tensor` whose rank is either 0, or the same rank as
997      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
998      be either `1`, or the same as the corresponding `labels` dimension).
999    metrics_collections: An optional list of collections that
1000      `mean_per_class_accuracy'
1001      should be added to.
1002    updates_collections: An optional list of collections `update_op` should be
1003      added to.
1004    name: An optional variable_scope name.
1005
1006  Returns:
1007    mean_accuracy: A `Tensor` representing the mean per class accuracy.
1008    update_op: An operation that updates the accuracy tensor.
1009
1010  Raises:
1011    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1012      `weights` is not `None` and its shape doesn't match `predictions`, or if
1013      either `metrics_collections` or `updates_collections` are not a list or
1014      tuple.
1015    RuntimeError: If eager execution is enabled.
1016  """
1017  if context.executing_eagerly():
1018    raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1019                       'when eager execution is enabled.')
1020
1021  with variable_scope.variable_scope(name, 'mean_accuracy',
1022                                     (predictions, labels, weights)):
1023    labels = math_ops.cast(labels, dtypes.int64)
1024
1025    # Flatten the input if its rank > 1.
1026    if labels.get_shape().ndims > 1:
1027      labels = array_ops.reshape(labels, [-1])
1028
1029    if predictions.get_shape().ndims > 1:
1030      predictions = array_ops.reshape(predictions, [-1])
1031
1032    # Check if shape is compatible.
1033    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1034
1035    total = metric_variable([num_classes], dtypes.float32, name='total')
1036    count = metric_variable([num_classes], dtypes.float32, name='count')
1037
1038    ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1039
1040    if labels.dtype != predictions.dtype:
1041      predictions = math_ops.cast(predictions, labels.dtype)
1042    is_correct = math_ops.cast(
1043        math_ops.equal(predictions, labels), dtypes.float32)
1044
1045    if weights is not None:
1046      if weights.get_shape().ndims > 1:
1047        weights = array_ops.reshape(weights, [-1])
1048      weights = math_ops.cast(weights, dtypes.float32)
1049
1050      is_correct *= weights
1051      ones *= weights
1052
1053    update_total_op = state_ops.scatter_add(total, labels, ones)
1054    update_count_op = state_ops.scatter_add(count, labels, is_correct)
1055
1056    def compute_mean_accuracy(_, count, total):
1057      per_class_accuracy = math_ops.div_no_nan(
1058          count, math_ops.maximum(total, 0), name=None)
1059      mean_accuracy_v = math_ops.reduce_mean(
1060          per_class_accuracy, name='mean_accuracy')
1061      return mean_accuracy_v
1062
1063    mean_accuracy_v = _aggregate_across_replicas(
1064        metrics_collections, compute_mean_accuracy, count, total)
1065
1066    update_op = math_ops.div_no_nan(
1067        update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
1068    if updates_collections:
1069      ops.add_to_collections(updates_collections, update_op)
1070
1071    return mean_accuracy_v, update_op
1072
1073
1074@tf_export(v1=['metrics.mean_iou'])
1075def mean_iou(labels,
1076             predictions,
1077             num_classes,
1078             weights=None,
1079             metrics_collections=None,
1080             updates_collections=None,
1081             name=None):
1082  """Calculate per-step mean Intersection-Over-Union (mIOU).
1083
1084  Mean Intersection-Over-Union is a common evaluation metric for
1085  semantic image segmentation, which first computes the IOU for each
1086  semantic class and then computes the average over classes.
1087  IOU is defined as follows:
1088    IOU = true_positive / (true_positive + false_positive + false_negative).
1089  The predictions are accumulated in a confusion matrix, weighted by `weights`,
1090  and mIOU is then calculated from it.
1091
1092  For estimation of the metric over a stream of data, the function creates an
1093  `update_op` operation that updates these variables and returns the `mean_iou`.
1094
1095  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1096
1097  Args:
1098    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1099      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1100    predictions: A `Tensor` of prediction results for semantic labels, whose
1101      shape is [batch size] and type `int32` or `int64`. The tensor will be
1102      flattened if its rank > 1.
1103    num_classes: The possible number of labels the prediction task can
1104      have. This value must be provided, since a confusion matrix of
1105      dimension = [num_classes, num_classes] will be allocated.
1106    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1107      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1108      be either `1`, or the same as the corresponding `labels` dimension).
1109    metrics_collections: An optional list of collections that `mean_iou`
1110      should be added to.
1111    updates_collections: An optional list of collections `update_op` should be
1112      added to.
1113    name: An optional variable_scope name.
1114
1115  Returns:
1116    mean_iou: A `Tensor` representing the mean intersection-over-union.
1117    update_op: An operation that increments the confusion matrix.
1118
1119  Raises:
1120    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1121      `weights` is not `None` and its shape doesn't match `predictions`, or if
1122      either `metrics_collections` or `updates_collections` are not a list or
1123      tuple.
1124    RuntimeError: If eager execution is enabled.
1125  """
1126  if context.executing_eagerly():
1127    raise RuntimeError('tf.metrics.mean_iou is not supported when '
1128                       'eager execution is enabled.')
1129
1130  with variable_scope.variable_scope(name, 'mean_iou',
1131                                     (predictions, labels, weights)):
1132    # Check if shape is compatible.
1133    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1134
1135    total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1136                                                      num_classes, weights)
1137
1138    def compute_mean_iou(_, total_cm):
1139      """Compute the mean intersection-over-union via the confusion matrix."""
1140      sum_over_row = math_ops.cast(
1141          math_ops.reduce_sum(total_cm, 0), dtypes.float32)
1142      sum_over_col = math_ops.cast(
1143          math_ops.reduce_sum(total_cm, 1), dtypes.float32)
1144      cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32)
1145      denominator = sum_over_row + sum_over_col - cm_diag
1146
1147      # The mean is only computed over classes that appear in the
1148      # label or prediction tensor. If the denominator is 0, we need to
1149      # ignore the class.
1150      num_valid_entries = math_ops.reduce_sum(
1151          math_ops.cast(
1152              math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1153
1154      # If the value of the denominator is 0, set it to 1 to avoid
1155      # zero division.
1156      denominator = array_ops.where(
1157          math_ops.greater(denominator, 0), denominator,
1158          array_ops.ones_like(denominator))
1159      iou = math_ops.div(cm_diag, denominator)
1160
1161      # If the number of valid entries is 0 (no classes) we return 0.
1162      result = array_ops.where(
1163          math_ops.greater(num_valid_entries, 0),
1164          math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1165      return result
1166
1167    # TODO(priyag): Use outside_compilation if in TPU context.
1168    mean_iou_v = _aggregate_across_replicas(
1169        metrics_collections, compute_mean_iou, total_cm)
1170
1171    if updates_collections:
1172      ops.add_to_collections(updates_collections, update_op)
1173
1174    return mean_iou_v, update_op
1175
1176
1177@tf_export(v1=['metrics.mean_relative_error'])
1178def mean_relative_error(labels,
1179                        predictions,
1180                        normalizer,
1181                        weights=None,
1182                        metrics_collections=None,
1183                        updates_collections=None,
1184                        name=None):
1185  """Computes the mean relative error by normalizing with the given values.
1186
1187  The `mean_relative_error` function creates two local variables,
1188  `total` and `count` that are used to compute the mean relative absolute error.
1189  This average is weighted by `weights`, and it is ultimately returned as
1190  `mean_relative_error`: an idempotent operation that simply divides `total` by
1191  `count`.
1192
1193  For estimation of the metric over a stream of data, the function creates an
1194  `update_op` operation that updates these variables and returns the
1195  `mean_reative_error`. Internally, a `relative_errors` operation divides the
1196  absolute value of the differences between `predictions` and `labels` by the
1197  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1198  product of `weights` and `relative_errors`, and it increments `count` with the
1199  reduced sum of `weights`.
1200
1201  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1202
1203  Args:
1204    labels: A `Tensor` of the same shape as `predictions`.
1205    predictions: A `Tensor` of arbitrary shape.
1206    normalizer: A `Tensor` of the same shape as `predictions`.
1207    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1208      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1209      be either `1`, or the same as the corresponding `labels` dimension).
1210    metrics_collections: An optional list of collections that
1211      `mean_relative_error` should be added to.
1212    updates_collections: An optional list of collections that `update_op` should
1213      be added to.
1214    name: An optional variable_scope name.
1215
1216  Returns:
1217    mean_relative_error: A `Tensor` representing the current mean, the value of
1218      `total` divided by `count`.
1219    update_op: An operation that increments the `total` and `count` variables
1220      appropriately and whose value matches `mean_relative_error`.
1221
1222  Raises:
1223    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1224      `weights` is not `None` and its shape doesn't match `predictions`, or if
1225      either `metrics_collections` or `updates_collections` are not a list or
1226      tuple.
1227    RuntimeError: If eager execution is enabled.
1228  """
1229  if context.executing_eagerly():
1230    raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1231                       'eager execution is enabled.')
1232
1233  predictions, labels, weights = _remove_squeezable_dimensions(
1234      predictions=predictions, labels=labels, weights=weights)
1235
1236  predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1237      predictions, normalizer)
1238  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1239  relative_errors = array_ops.where(
1240      math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1241      math_ops.div(math_ops.abs(labels - predictions), normalizer))
1242  return mean(relative_errors, weights, metrics_collections,
1243              updates_collections, name or 'mean_relative_error')
1244
1245
1246@tf_export(v1=['metrics.mean_squared_error'])
1247def mean_squared_error(labels,
1248                       predictions,
1249                       weights=None,
1250                       metrics_collections=None,
1251                       updates_collections=None,
1252                       name=None):
1253  """Computes the mean squared error between the labels and predictions.
1254
1255  The `mean_squared_error` function creates two local variables,
1256  `total` and `count` that are used to compute the mean squared error.
1257  This average is weighted by `weights`, and it is ultimately returned as
1258  `mean_squared_error`: an idempotent operation that simply divides `total` by
1259  `count`.
1260
1261  For estimation of the metric over a stream of data, the function creates an
1262  `update_op` operation that updates these variables and returns the
1263  `mean_squared_error`. Internally, a `squared_error` operation computes the
1264  element-wise square of the difference between `predictions` and `labels`. Then
1265  `update_op` increments `total` with the reduced sum of the product of
1266  `weights` and `squared_error`, and it increments `count` with the reduced sum
1267  of `weights`.
1268
1269  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1270
1271  Args:
1272    labels: A `Tensor` of the same shape as `predictions`.
1273    predictions: A `Tensor` of arbitrary shape.
1274    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1275      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1276      be either `1`, or the same as the corresponding `labels` dimension).
1277    metrics_collections: An optional list of collections that
1278      `mean_squared_error` should be added to.
1279    updates_collections: An optional list of collections that `update_op` should
1280      be added to.
1281    name: An optional variable_scope name.
1282
1283  Returns:
1284    mean_squared_error: A `Tensor` representing the current mean, the value of
1285      `total` divided by `count`.
1286    update_op: An operation that increments the `total` and `count` variables
1287      appropriately and whose value matches `mean_squared_error`.
1288
1289  Raises:
1290    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1291      `weights` is not `None` and its shape doesn't match `predictions`, or if
1292      either `metrics_collections` or `updates_collections` are not a list or
1293      tuple.
1294    RuntimeError: If eager execution is enabled.
1295  """
1296  if context.executing_eagerly():
1297    raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1298                       'eager execution is enabled.')
1299
1300  predictions, labels, weights = _remove_squeezable_dimensions(
1301      predictions=predictions, labels=labels, weights=weights)
1302  squared_error = math_ops.squared_difference(labels, predictions)
1303  return mean(squared_error, weights, metrics_collections, updates_collections,
1304              name or 'mean_squared_error')
1305
1306
1307@tf_export(v1=['metrics.mean_tensor'])
1308def mean_tensor(values,
1309                weights=None,
1310                metrics_collections=None,
1311                updates_collections=None,
1312                name=None):
1313  """Computes the element-wise (weighted) mean of the given tensors.
1314
1315  In contrast to the `mean` function which returns a scalar with the
1316  mean,  this function returns an average tensor with the same shape as the
1317  input tensors.
1318
1319  The `mean_tensor` function creates two local variables,
1320  `total_tensor` and `count_tensor` that are used to compute the average of
1321  `values`. This average is ultimately returned as `mean` which is an idempotent
1322  operation that simply divides `total` by `count`.
1323
1324  For estimation of the metric over a stream of data, the function creates an
1325  `update_op` operation that updates these variables and returns the `mean`.
1326  `update_op` increments `total` with the reduced sum of the product of `values`
1327  and `weights`, and it increments `count` with the reduced sum of `weights`.
1328
1329  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1330
1331  Args:
1332    values: A `Tensor` of arbitrary dimensions.
1333    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1334      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1335      be either `1`, or the same as the corresponding `values` dimension).
1336    metrics_collections: An optional list of collections that `mean`
1337      should be added to.
1338    updates_collections: An optional list of collections that `update_op`
1339      should be added to.
1340    name: An optional variable_scope name.
1341
1342  Returns:
1343    mean: A float `Tensor` representing the current mean, the value of `total`
1344      divided by `count`.
1345    update_op: An operation that increments the `total` and `count` variables
1346      appropriately and whose value matches `mean_value`.
1347
1348  Raises:
1349    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1350      or if either `metrics_collections` or `updates_collections` are not a list
1351      or tuple.
1352    RuntimeError: If eager execution is enabled.
1353  """
1354  if context.executing_eagerly():
1355    raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1356                       'eager execution is enabled.')
1357
1358  with variable_scope.variable_scope(name, 'mean', (values, weights)):
1359    values = math_ops.cast(values, dtypes.float32)
1360    total = metric_variable(
1361        values.get_shape(), dtypes.float32, name='total_tensor')
1362    count = metric_variable(
1363        values.get_shape(), dtypes.float32, name='count_tensor')
1364
1365    num_values = array_ops.ones_like(values)
1366    if weights is not None:
1367      values, _, weights = _remove_squeezable_dimensions(
1368          predictions=values, labels=None, weights=weights)
1369      weights = weights_broadcast_ops.broadcast_weights(
1370          math_ops.cast(weights, dtypes.float32), values)
1371      values = math_ops.multiply(values, weights)
1372      num_values = math_ops.multiply(num_values, weights)
1373
1374    update_total_op = state_ops.assign_add(total, values)
1375    with ops.control_dependencies([values]):
1376      update_count_op = state_ops.assign_add(count, num_values)
1377
1378    compute_mean = lambda _, t, c: math_ops.div_no_nan(  # pylint: disable=g-long-lambda
1379        t, math_ops.maximum(c, 0), name='value')
1380
1381    mean_t = _aggregate_across_replicas(
1382        metrics_collections, compute_mean, total, count)
1383
1384    update_op = math_ops.div_no_nan(
1385        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
1386    if updates_collections:
1387      ops.add_to_collections(updates_collections, update_op)
1388
1389    return mean_t, update_op
1390
1391
1392@tf_export(v1=['metrics.percentage_below'])
1393def percentage_below(values,
1394                     threshold,
1395                     weights=None,
1396                     metrics_collections=None,
1397                     updates_collections=None,
1398                     name=None):
1399  """Computes the percentage of values less than the given threshold.
1400
1401  The `percentage_below` function creates two local variables,
1402  `total` and `count` that are used to compute the percentage of `values` that
1403  fall below `threshold`. This rate is weighted by `weights`, and it is
1404  ultimately returned as `percentage` which is an idempotent operation that
1405  simply divides `total` by `count`.
1406
1407  For estimation of the metric over a stream of data, the function creates an
1408  `update_op` operation that updates these variables and returns the
1409  `percentage`.
1410
1411  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1412
1413  Args:
1414    values: A numeric `Tensor` of arbitrary size.
1415    threshold: A scalar threshold.
1416    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1417      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1418      be either `1`, or the same as the corresponding `values` dimension).
1419    metrics_collections: An optional list of collections that the metric
1420      value variable should be added to.
1421    updates_collections: An optional list of collections that the metric update
1422      ops should be added to.
1423    name: An optional variable_scope name.
1424
1425  Returns:
1426    percentage: A `Tensor` representing the current mean, the value of `total`
1427      divided by `count`.
1428    update_op: An operation that increments the `total` and `count` variables
1429      appropriately.
1430
1431  Raises:
1432    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1433      or if either `metrics_collections` or `updates_collections` are not a list
1434      or tuple.
1435    RuntimeError: If eager execution is enabled.
1436  """
1437  if context.executing_eagerly():
1438    raise RuntimeError('tf.metrics.percentage_below is not supported when '
1439                       'eager execution is enabled.')
1440
1441  is_below_threshold = math_ops.cast(
1442      math_ops.less(values, threshold), dtypes.float32)
1443  return mean(is_below_threshold, weights, metrics_collections,
1444              updates_collections, name or 'percentage_below_threshold')
1445
1446
1447def _count_condition(values,
1448                     weights=None,
1449                     metrics_collections=None,
1450                     updates_collections=None):
1451  """Sums the weights of cases where the given values are True.
1452
1453  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1454
1455  Args:
1456    values: A `bool` `Tensor` of arbitrary size.
1457    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1458      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1459      be either `1`, or the same as the corresponding `values` dimension).
1460    metrics_collections: An optional list of collections that the metric
1461      value variable should be added to.
1462    updates_collections: An optional list of collections that the metric update
1463      ops should be added to.
1464
1465  Returns:
1466    value_tensor: A `Tensor` representing the current value of the metric.
1467    update_op: An operation that accumulates the error from a batch of data.
1468
1469  Raises:
1470    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1471      or if either `metrics_collections` or `updates_collections` are not a list
1472      or tuple.
1473  """
1474  check_ops.assert_type(values, dtypes.bool)
1475  count = metric_variable([], dtypes.float32, name='count')
1476
1477  values = math_ops.cast(values, dtypes.float32)
1478  if weights is not None:
1479    with ops.control_dependencies((check_ops.assert_rank_in(
1480        weights, (0, array_ops.rank(values))),)):
1481      weights = math_ops.cast(weights, dtypes.float32)
1482      values = math_ops.multiply(values, weights)
1483
1484  value_tensor = _aggregate_variable(count, metrics_collections)
1485
1486  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
1487  if updates_collections:
1488    ops.add_to_collections(updates_collections, update_op)
1489
1490  return value_tensor, update_op
1491
1492
1493@tf_export(v1=['metrics.false_negatives'])
1494def false_negatives(labels,
1495                    predictions,
1496                    weights=None,
1497                    metrics_collections=None,
1498                    updates_collections=None,
1499                    name=None):
1500  """Computes the total number of false negatives.
1501
1502  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1503
1504  Args:
1505    labels: The ground truth values, a `Tensor` whose dimensions must match
1506      `predictions`. Will be cast to `bool`.
1507    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1508      be cast to `bool`.
1509    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1510      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1511      be either `1`, or the same as the corresponding `labels` dimension).
1512    metrics_collections: An optional list of collections that the metric
1513      value variable should be added to.
1514    updates_collections: An optional list of collections that the metric update
1515      ops should be added to.
1516    name: An optional variable_scope name.
1517
1518  Returns:
1519    value_tensor: A `Tensor` representing the current value of the metric.
1520    update_op: An operation that accumulates the error from a batch of data.
1521
1522  Raises:
1523    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1524      or if either `metrics_collections` or `updates_collections` are not a list
1525      or tuple.
1526    RuntimeError: If eager execution is enabled.
1527  """
1528  if context.executing_eagerly():
1529    raise RuntimeError('tf.metrics.false_negatives is not supported when '
1530                       'eager execution is enabled.')
1531
1532  with variable_scope.variable_scope(name, 'false_negatives',
1533                                     (predictions, labels, weights)):
1534
1535    predictions, labels, weights = _remove_squeezable_dimensions(
1536        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1537        labels=math_ops.cast(labels, dtype=dtypes.bool),
1538        weights=weights)
1539    is_false_negative = math_ops.logical_and(
1540        math_ops.equal(labels, True), math_ops.equal(predictions, False))
1541    return _count_condition(is_false_negative, weights, metrics_collections,
1542                            updates_collections)
1543
1544
1545@tf_export(v1=['metrics.false_negatives_at_thresholds'])
1546def false_negatives_at_thresholds(labels,
1547                                  predictions,
1548                                  thresholds,
1549                                  weights=None,
1550                                  metrics_collections=None,
1551                                  updates_collections=None,
1552                                  name=None):
1553  """Computes false negatives at provided threshold values.
1554
1555  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1556
1557  Args:
1558    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1559      `bool`.
1560    predictions: A floating point `Tensor` of arbitrary shape and whose values
1561      are in the range `[0, 1]`.
1562    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1563    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1564      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1565      be either `1`, or the same as the corresponding `labels` dimension).
1566    metrics_collections: An optional list of collections that `false_negatives`
1567      should be added to.
1568    updates_collections: An optional list of collections that `update_op` should
1569      be added to.
1570    name: An optional variable_scope name.
1571
1572  Returns:
1573    false_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1574    update_op: An operation that updates the `false_negatives` variable and
1575      returns its current value.
1576
1577  Raises:
1578    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1579      `weights` is not `None` and its shape doesn't match `predictions`, or if
1580      either `metrics_collections` or `updates_collections` are not a list or
1581      tuple.
1582    RuntimeError: If eager execution is enabled.
1583  """
1584  if context.executing_eagerly():
1585    raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1586                       'supported when eager execution is enabled.')
1587
1588  with variable_scope.variable_scope(name, 'false_negatives',
1589                                     (predictions, labels, weights)):
1590    values, update_ops = _confusion_matrix_at_thresholds(
1591        labels, predictions, thresholds, weights=weights, includes=('fn',))
1592
1593    fn_value = _aggregate_variable(values['fn'], metrics_collections)
1594
1595    if updates_collections:
1596      ops.add_to_collections(updates_collections, update_ops['fn'])
1597
1598    return fn_value, update_ops['fn']
1599
1600
1601@tf_export(v1=['metrics.false_positives'])
1602def false_positives(labels,
1603                    predictions,
1604                    weights=None,
1605                    metrics_collections=None,
1606                    updates_collections=None,
1607                    name=None):
1608  """Sum the weights of false positives.
1609
1610  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1611
1612  Args:
1613    labels: The ground truth values, a `Tensor` whose dimensions must match
1614      `predictions`. Will be cast to `bool`.
1615    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1616      be cast to `bool`.
1617    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1618      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1619      be either `1`, or the same as the corresponding `labels` dimension).
1620    metrics_collections: An optional list of collections that the metric
1621      value variable should be added to.
1622    updates_collections: An optional list of collections that the metric update
1623      ops should be added to.
1624    name: An optional variable_scope name.
1625
1626  Returns:
1627    value_tensor: A `Tensor` representing the current value of the metric.
1628    update_op: An operation that accumulates the error from a batch of data.
1629
1630  Raises:
1631    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1632      `weights` is not `None` and its shape doesn't match `predictions`, or if
1633      either `metrics_collections` or `updates_collections` are not a list or
1634      tuple.
1635    RuntimeError: If eager execution is enabled.
1636  """
1637  if context.executing_eagerly():
1638    raise RuntimeError('tf.metrics.false_positives is not supported when '
1639                       'eager execution is enabled.')
1640
1641  with variable_scope.variable_scope(name, 'false_positives',
1642                                     (predictions, labels, weights)):
1643
1644    predictions, labels, weights = _remove_squeezable_dimensions(
1645        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1646        labels=math_ops.cast(labels, dtype=dtypes.bool),
1647        weights=weights)
1648    is_false_positive = math_ops.logical_and(
1649        math_ops.equal(labels, False), math_ops.equal(predictions, True))
1650    return _count_condition(is_false_positive, weights, metrics_collections,
1651                            updates_collections)
1652
1653
1654@tf_export(v1=['metrics.false_positives_at_thresholds'])
1655def false_positives_at_thresholds(labels,
1656                                  predictions,
1657                                  thresholds,
1658                                  weights=None,
1659                                  metrics_collections=None,
1660                                  updates_collections=None,
1661                                  name=None):
1662  """Computes false positives at provided threshold values.
1663
1664  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1665
1666  Args:
1667    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1668      `bool`.
1669    predictions: A floating point `Tensor` of arbitrary shape and whose values
1670      are in the range `[0, 1]`.
1671    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1672    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1673      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1674      be either `1`, or the same as the corresponding `labels` dimension).
1675    metrics_collections: An optional list of collections that `false_positives`
1676      should be added to.
1677    updates_collections: An optional list of collections that `update_op` should
1678      be added to.
1679    name: An optional variable_scope name.
1680
1681  Returns:
1682    false_positives:  A float `Tensor` of shape `[len(thresholds)]`.
1683    update_op: An operation that updates the `false_positives` variable and
1684      returns its current value.
1685
1686  Raises:
1687    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1688      `weights` is not `None` and its shape doesn't match `predictions`, or if
1689      either `metrics_collections` or `updates_collections` are not a list or
1690      tuple.
1691    RuntimeError: If eager execution is enabled.
1692  """
1693  if context.executing_eagerly():
1694    raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1695                       'supported when eager execution is enabled.')
1696
1697  with variable_scope.variable_scope(name, 'false_positives',
1698                                     (predictions, labels, weights)):
1699    values, update_ops = _confusion_matrix_at_thresholds(
1700        labels, predictions, thresholds, weights=weights, includes=('fp',))
1701
1702    fp_value = _aggregate_variable(values['fp'], metrics_collections)
1703
1704    if updates_collections:
1705      ops.add_to_collections(updates_collections, update_ops['fp'])
1706
1707    return fp_value, update_ops['fp']
1708
1709
1710@tf_export(v1=['metrics.true_negatives'])
1711def true_negatives(labels,
1712                   predictions,
1713                   weights=None,
1714                   metrics_collections=None,
1715                   updates_collections=None,
1716                   name=None):
1717  """Sum the weights of true_negatives.
1718
1719  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1720
1721  Args:
1722    labels: The ground truth values, a `Tensor` whose dimensions must match
1723      `predictions`. Will be cast to `bool`.
1724    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1725      be cast to `bool`.
1726    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1727      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1728      be either `1`, or the same as the corresponding `labels` dimension).
1729    metrics_collections: An optional list of collections that the metric
1730      value variable should be added to.
1731    updates_collections: An optional list of collections that the metric update
1732      ops should be added to.
1733    name: An optional variable_scope name.
1734
1735  Returns:
1736    value_tensor: A `Tensor` representing the current value of the metric.
1737    update_op: An operation that accumulates the error from a batch of data.
1738
1739  Raises:
1740    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1741      `weights` is not `None` and its shape doesn't match `predictions`, or if
1742      either `metrics_collections` or `updates_collections` are not a list or
1743      tuple.
1744    RuntimeError: If eager execution is enabled.
1745  """
1746  if context.executing_eagerly():
1747    raise RuntimeError('tf.metrics.true_negatives is not '
1748                       'supported when eager execution is enabled.')
1749
1750  with variable_scope.variable_scope(name, 'true_negatives',
1751                                     (predictions, labels, weights)):
1752
1753    predictions, labels, weights = _remove_squeezable_dimensions(
1754        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1755        labels=math_ops.cast(labels, dtype=dtypes.bool),
1756        weights=weights)
1757    is_true_negative = math_ops.logical_and(
1758        math_ops.equal(labels, False), math_ops.equal(predictions, False))
1759    return _count_condition(is_true_negative, weights, metrics_collections,
1760                            updates_collections)
1761
1762
1763@tf_export(v1=['metrics.true_negatives_at_thresholds'])
1764def true_negatives_at_thresholds(labels,
1765                                 predictions,
1766                                 thresholds,
1767                                 weights=None,
1768                                 metrics_collections=None,
1769                                 updates_collections=None,
1770                                 name=None):
1771  """Computes true negatives at provided threshold values.
1772
1773  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1774
1775  Args:
1776    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1777      `bool`.
1778    predictions: A floating point `Tensor` of arbitrary shape and whose values
1779      are in the range `[0, 1]`.
1780    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1781    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1782      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1783      be either `1`, or the same as the corresponding `labels` dimension).
1784    metrics_collections: An optional list of collections that `true_negatives`
1785      should be added to.
1786    updates_collections: An optional list of collections that `update_op` should
1787      be added to.
1788    name: An optional variable_scope name.
1789
1790  Returns:
1791    true_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1792    update_op: An operation that updates the `true_negatives` variable and
1793      returns its current value.
1794
1795  Raises:
1796    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1797      `weights` is not `None` and its shape doesn't match `predictions`, or if
1798      either `metrics_collections` or `updates_collections` are not a list or
1799      tuple.
1800    RuntimeError: If eager execution is enabled.
1801  """
1802  if context.executing_eagerly():
1803    raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
1804                       'supported when eager execution is enabled.')
1805
1806  with variable_scope.variable_scope(name, 'true_negatives',
1807                                     (predictions, labels, weights)):
1808    values, update_ops = _confusion_matrix_at_thresholds(
1809        labels, predictions, thresholds, weights=weights, includes=('tn',))
1810
1811    tn_value = _aggregate_variable(values['tn'], metrics_collections)
1812
1813    if updates_collections:
1814      ops.add_to_collections(updates_collections, update_ops['tn'])
1815
1816    return tn_value, update_ops['tn']
1817
1818
1819@tf_export(v1=['metrics.true_positives'])
1820def true_positives(labels,
1821                   predictions,
1822                   weights=None,
1823                   metrics_collections=None,
1824                   updates_collections=None,
1825                   name=None):
1826  """Sum the weights of true_positives.
1827
1828  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1829
1830  Args:
1831    labels: The ground truth values, a `Tensor` whose dimensions must match
1832      `predictions`. Will be cast to `bool`.
1833    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1834      be cast to `bool`.
1835    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1836      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1837      be either `1`, or the same as the corresponding `labels` dimension).
1838    metrics_collections: An optional list of collections that the metric
1839      value variable should be added to.
1840    updates_collections: An optional list of collections that the metric update
1841      ops should be added to.
1842    name: An optional variable_scope name.
1843
1844  Returns:
1845    value_tensor: A `Tensor` representing the current value of the metric.
1846    update_op: An operation that accumulates the error from a batch of data.
1847
1848  Raises:
1849    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1850      `weights` is not `None` and its shape doesn't match `predictions`, or if
1851      either `metrics_collections` or `updates_collections` are not a list or
1852      tuple.
1853    RuntimeError: If eager execution is enabled.
1854  """
1855  if context.executing_eagerly():
1856    raise RuntimeError('tf.metrics.true_positives is not '
1857                       'supported when eager execution is enabled.')
1858
1859  with variable_scope.variable_scope(name, 'true_positives',
1860                                     (predictions, labels, weights)):
1861
1862    predictions, labels, weights = _remove_squeezable_dimensions(
1863        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1864        labels=math_ops.cast(labels, dtype=dtypes.bool),
1865        weights=weights)
1866    is_true_positive = math_ops.logical_and(
1867        math_ops.equal(labels, True), math_ops.equal(predictions, True))
1868    return _count_condition(is_true_positive, weights, metrics_collections,
1869                            updates_collections)
1870
1871
1872@tf_export(v1=['metrics.true_positives_at_thresholds'])
1873def true_positives_at_thresholds(labels,
1874                                 predictions,
1875                                 thresholds,
1876                                 weights=None,
1877                                 metrics_collections=None,
1878                                 updates_collections=None,
1879                                 name=None):
1880  """Computes true positives at provided threshold values.
1881
1882  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1883
1884  Args:
1885    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1886      `bool`.
1887    predictions: A floating point `Tensor` of arbitrary shape and whose values
1888      are in the range `[0, 1]`.
1889    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1890    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1891      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1892      be either `1`, or the same as the corresponding `labels` dimension).
1893    metrics_collections: An optional list of collections that `true_positives`
1894      should be added to.
1895    updates_collections: An optional list of collections that `update_op` should
1896      be added to.
1897    name: An optional variable_scope name.
1898
1899  Returns:
1900    true_positives:  A float `Tensor` of shape `[len(thresholds)]`.
1901    update_op: An operation that updates the `true_positives` variable and
1902      returns its current value.
1903
1904  Raises:
1905    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1906      `weights` is not `None` and its shape doesn't match `predictions`, or if
1907      either `metrics_collections` or `updates_collections` are not a list or
1908      tuple.
1909    RuntimeError: If eager execution is enabled.
1910  """
1911  if context.executing_eagerly():
1912    raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
1913                       'supported when eager execution is enabled.')
1914
1915  with variable_scope.variable_scope(name, 'true_positives',
1916                                     (predictions, labels, weights)):
1917    values, update_ops = _confusion_matrix_at_thresholds(
1918        labels, predictions, thresholds, weights=weights, includes=('tp',))
1919
1920    tp_value = _aggregate_variable(values['tp'], metrics_collections)
1921
1922    if updates_collections:
1923      ops.add_to_collections(updates_collections, update_ops['tp'])
1924
1925    return tp_value, update_ops['tp']
1926
1927
1928@tf_export(v1=['metrics.precision'])
1929def precision(labels,
1930              predictions,
1931              weights=None,
1932              metrics_collections=None,
1933              updates_collections=None,
1934              name=None):
1935  """Computes the precision of the predictions with respect to the labels.
1936
1937  The `precision` function creates two local variables,
1938  `true_positives` and `false_positives`, that are used to compute the
1939  precision. This value is ultimately returned as `precision`, an idempotent
1940  operation that simply divides `true_positives` by the sum of `true_positives`
1941  and `false_positives`.
1942
1943  For estimation of the metric over a stream of data, the function creates an
1944  `update_op` operation that updates these variables and returns the
1945  `precision`. `update_op` weights each prediction by the corresponding value in
1946  `weights`.
1947
1948  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1949
1950  Args:
1951    labels: The ground truth values, a `Tensor` whose dimensions must match
1952      `predictions`. Will be cast to `bool`.
1953    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1954      be cast to `bool`.
1955    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1956      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1957      be either `1`, or the same as the corresponding `labels` dimension).
1958    metrics_collections: An optional list of collections that `precision` should
1959      be added to.
1960    updates_collections: An optional list of collections that `update_op` should
1961      be added to.
1962    name: An optional variable_scope name.
1963
1964  Returns:
1965    precision: Scalar float `Tensor` with the value of `true_positives`
1966      divided by the sum of `true_positives` and `false_positives`.
1967    update_op: `Operation` that increments `true_positives` and
1968      `false_positives` variables appropriately and whose value matches
1969      `precision`.
1970
1971  Raises:
1972    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1973      `weights` is not `None` and its shape doesn't match `predictions`, or if
1974      either `metrics_collections` or `updates_collections` are not a list or
1975      tuple.
1976    RuntimeError: If eager execution is enabled.
1977  """
1978  if context.executing_eagerly():
1979    raise RuntimeError('tf.metrics.precision is not '
1980                       'supported when eager execution is enabled.')
1981
1982  with variable_scope.variable_scope(name, 'precision',
1983                                     (predictions, labels, weights)):
1984
1985    predictions, labels, weights = _remove_squeezable_dimensions(
1986        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1987        labels=math_ops.cast(labels, dtype=dtypes.bool),
1988        weights=weights)
1989
1990    true_p, true_positives_update_op = true_positives(
1991        labels,
1992        predictions,
1993        weights,
1994        metrics_collections=None,
1995        updates_collections=None,
1996        name=None)
1997    false_p, false_positives_update_op = false_positives(
1998        labels,
1999        predictions,
2000        weights,
2001        metrics_collections=None,
2002        updates_collections=None,
2003        name=None)
2004
2005    def compute_precision(tp, fp, name):
2006      return array_ops.where(
2007          math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
2008
2009    def once_across_replicas(_, true_p, false_p):
2010      return compute_precision(true_p, false_p, 'value')
2011
2012    p = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2013                                   true_p, false_p)
2014
2015    update_op = compute_precision(true_positives_update_op,
2016                                  false_positives_update_op, 'update_op')
2017    if updates_collections:
2018      ops.add_to_collections(updates_collections, update_op)
2019
2020    return p, update_op
2021
2022
2023@tf_export(v1=['metrics.precision_at_thresholds'])
2024def precision_at_thresholds(labels,
2025                            predictions,
2026                            thresholds,
2027                            weights=None,
2028                            metrics_collections=None,
2029                            updates_collections=None,
2030                            name=None):
2031  """Computes precision values for different `thresholds` on `predictions`.
2032
2033  The `precision_at_thresholds` function creates four local variables,
2034  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2035  for various values of thresholds. `precision[i]` is defined as the total
2036  weight of values in `predictions` above `thresholds[i]` whose corresponding
2037  entry in `labels` is `True`, divided by the total weight of values in
2038  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2039  false_positives[i])`).
2040
2041  For estimation of the metric over a stream of data, the function creates an
2042  `update_op` operation that updates these variables and returns the
2043  `precision`.
2044
2045  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2046
2047  Args:
2048    labels: The ground truth values, a `Tensor` whose dimensions must match
2049      `predictions`. Will be cast to `bool`.
2050    predictions: A floating point `Tensor` of arbitrary shape and whose values
2051      are in the range `[0, 1]`.
2052    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2053    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2054      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2055      be either `1`, or the same as the corresponding `labels` dimension).
2056    metrics_collections: An optional list of collections that `auc` should be
2057      added to.
2058    updates_collections: An optional list of collections that `update_op` should
2059      be added to.
2060    name: An optional variable_scope name.
2061
2062  Returns:
2063    precision: A float `Tensor` of shape `[len(thresholds)]`.
2064    update_op: An operation that increments the `true_positives`,
2065      `true_negatives`, `false_positives` and `false_negatives` variables that
2066      are used in the computation of `precision`.
2067
2068  Raises:
2069    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2070      `weights` is not `None` and its shape doesn't match `predictions`, or if
2071      either `metrics_collections` or `updates_collections` are not a list or
2072      tuple.
2073    RuntimeError: If eager execution is enabled.
2074  """
2075  if context.executing_eagerly():
2076    raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2077                       'supported when eager execution is enabled.')
2078
2079  with variable_scope.variable_scope(name, 'precision_at_thresholds',
2080                                     (predictions, labels, weights)):
2081    values, update_ops = _confusion_matrix_at_thresholds(
2082        labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2083
2084    # Avoid division by zero.
2085    epsilon = 1e-7
2086
2087    def compute_precision(tp, fp, name):
2088      return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
2089
2090    def precision_across_replicas(_, values):
2091      return compute_precision(values['tp'], values['fp'], 'value')
2092
2093    prec = _aggregate_across_replicas(
2094        metrics_collections, precision_across_replicas, values)
2095
2096    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2097                                  'update_op')
2098    if updates_collections:
2099      ops.add_to_collections(updates_collections, update_op)
2100
2101    return prec, update_op
2102
2103
2104@tf_export(v1=['metrics.recall'])
2105def recall(labels,
2106           predictions,
2107           weights=None,
2108           metrics_collections=None,
2109           updates_collections=None,
2110           name=None):
2111  """Computes the recall of the predictions with respect to the labels.
2112
2113  The `recall` function creates two local variables, `true_positives`
2114  and `false_negatives`, that are used to compute the recall. This value is
2115  ultimately returned as `recall`, an idempotent operation that simply divides
2116  `true_positives` by the sum of `true_positives` and `false_negatives`.
2117
2118  For estimation of the metric over a stream of data, the function creates an
2119  `update_op` that updates these variables and returns the `recall`. `update_op`
2120  weights each prediction by the corresponding value in `weights`.
2121
2122  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2123
2124  Args:
2125    labels: The ground truth values, a `Tensor` whose dimensions must match
2126      `predictions`. Will be cast to `bool`.
2127    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2128      be cast to `bool`.
2129    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2130      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2131      be either `1`, or the same as the corresponding `labels` dimension).
2132    metrics_collections: An optional list of collections that `recall` should
2133      be added to.
2134    updates_collections: An optional list of collections that `update_op` should
2135      be added to.
2136    name: An optional variable_scope name.
2137
2138  Returns:
2139    recall: Scalar float `Tensor` with the value of `true_positives` divided
2140      by the sum of `true_positives` and `false_negatives`.
2141    update_op: `Operation` that increments `true_positives` and
2142      `false_negatives` variables appropriately and whose value matches
2143      `recall`.
2144
2145  Raises:
2146    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2147      `weights` is not `None` and its shape doesn't match `predictions`, or if
2148      either `metrics_collections` or `updates_collections` are not a list or
2149      tuple.
2150    RuntimeError: If eager execution is enabled.
2151  """
2152  if context.executing_eagerly():
2153    raise RuntimeError('tf.metrics.recall is not supported is not '
2154                       'supported when eager execution is enabled.')
2155
2156  with variable_scope.variable_scope(name, 'recall',
2157                                     (predictions, labels, weights)):
2158    predictions, labels, weights = _remove_squeezable_dimensions(
2159        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2160        labels=math_ops.cast(labels, dtype=dtypes.bool),
2161        weights=weights)
2162
2163    true_p, true_positives_update_op = true_positives(
2164        labels,
2165        predictions,
2166        weights,
2167        metrics_collections=None,
2168        updates_collections=None,
2169        name=None)
2170    false_n, false_negatives_update_op = false_negatives(
2171        labels,
2172        predictions,
2173        weights,
2174        metrics_collections=None,
2175        updates_collections=None,
2176        name=None)
2177
2178    def compute_recall(true_p, false_n, name):
2179      return array_ops.where(
2180          math_ops.greater(true_p + false_n, 0),
2181          math_ops.div(true_p, true_p + false_n), 0, name)
2182
2183    def once_across_replicas(_, true_p, false_n):
2184      return compute_recall(true_p, false_n, 'value')
2185
2186    rec = _aggregate_across_replicas(
2187        metrics_collections, once_across_replicas, true_p, false_n)
2188
2189    update_op = compute_recall(true_positives_update_op,
2190                               false_negatives_update_op, 'update_op')
2191    if updates_collections:
2192      ops.add_to_collections(updates_collections, update_op)
2193
2194    return rec, update_op
2195
2196
2197def _at_k_name(name, k=None, class_id=None):
2198  if k is not None:
2199    name = '%s_at_%d' % (name, k)
2200  else:
2201    name = '%s_at_k' % (name)
2202  if class_id is not None:
2203    name = '%s_class%d' % (name, class_id)
2204  return name
2205
2206
2207def _select_class_id(ids, selected_id):
2208  """Filter all but `selected_id` out of `ids`.
2209
2210  Args:
2211    ids: `int64` `Tensor` or `SparseTensor` of IDs.
2212    selected_id: Int id to select.
2213
2214  Returns:
2215    `SparseTensor` of same dimensions as `ids`. This contains only the entries
2216    equal to `selected_id`.
2217  """
2218  ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2219  if isinstance(ids, sparse_tensor.SparseTensor):
2220    return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2221                                                        selected_id))
2222
2223  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2224  # tf.equal and tf.reduce_any?
2225
2226  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2227  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2228  ids_last_dim = array_ops.size(ids_shape) - 1
2229  filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
2230                                                    array_ops.reshape(
2231                                                        ids_last_dim, [1]))
2232
2233  # Intersect `ids` with the selected ID.
2234  filled_selected_id = array_ops.fill(filled_selected_id_shape,
2235                                      math_ops.cast(selected_id, dtypes.int64))
2236  result = sets.set_intersection(filled_selected_id, ids)
2237  return sparse_tensor.SparseTensor(
2238      indices=result.indices, values=result.values, dense_shape=ids_shape)
2239
2240
2241def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2242  """If class ID is specified, filter all other classes.
2243
2244  Args:
2245    labels: `int64` `Tensor` or `SparseTensor` with shape
2246      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2247      target classes for the associated prediction. Commonly, N=1 and `labels`
2248      has shape [batch_size, num_labels]. [D1, ... DN] must match
2249      `predictions_idx`.
2250    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2251      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2252      [batch size, k].
2253    selected_id: Int id to select.
2254
2255  Returns:
2256    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2257  """
2258  if selected_id is None:
2259    return labels, predictions_idx
2260  return (_select_class_id(labels, selected_id),
2261          _select_class_id(predictions_idx, selected_id))
2262
2263
2264def _sparse_true_positive_at_k(labels,
2265                               predictions_idx,
2266                               class_id=None,
2267                               weights=None,
2268                               name=None):
2269  """Calculates true positives for recall@k and precision@k.
2270
2271  If `class_id` is specified, calculate binary true positives for `class_id`
2272      only.
2273  If `class_id` is not specified, calculate metrics for `k` predicted vs
2274      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2275
2276  Args:
2277    labels: `int64` `Tensor` or `SparseTensor` with shape
2278      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2279      target classes for the associated prediction. Commonly, N=1 and `labels`
2280      has shape [batch_size, num_labels]. [D1, ... DN] must match
2281      `predictions_idx`.
2282    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2283      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2284      match `labels`.
2285    class_id: Class for which we want binary metrics.
2286    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2287      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2288      dimensions must be either `1`, or the same as the corresponding `labels`
2289      dimension).
2290    name: Name of operation.
2291
2292  Returns:
2293    A [D1, ... DN] `Tensor` of true positive counts.
2294  """
2295  with ops.name_scope(name, 'true_positives',
2296                      (predictions_idx, labels, weights)):
2297    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2298                                                     class_id)
2299    tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2300    tp = math_ops.cast(tp, dtypes.float64)
2301    if weights is not None:
2302      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2303          weights, tp),)):
2304        weights = math_ops.cast(weights, dtypes.float64)
2305        tp = math_ops.multiply(tp, weights)
2306    return tp
2307
2308
2309def _streaming_sparse_true_positive_at_k(labels,
2310                                         predictions_idx,
2311                                         k=None,
2312                                         class_id=None,
2313                                         weights=None,
2314                                         name=None):
2315  """Calculates weighted per step true positives for recall@k and precision@k.
2316
2317  If `class_id` is specified, calculate binary true positives for `class_id`
2318      only.
2319  If `class_id` is not specified, calculate metrics for `k` predicted vs
2320      `n` label classes, where `n` is the 2nd dimension of `labels`.
2321
2322  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2323
2324  Args:
2325    labels: `int64` `Tensor` or `SparseTensor` with shape
2326      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2327      target classes for the associated prediction. Commonly, N=1 and `labels`
2328      has shape [batch_size, num_labels]. [D1, ... DN] must match
2329      `predictions_idx`.
2330    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2331      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2332      match `labels`.
2333    k: Integer, k for @k metric. This is only used for default op name.
2334    class_id: Class for which we want binary metrics.
2335    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2336      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2337      dimensions must be either `1`, or the same as the corresponding `labels`
2338      dimension).
2339    name: Name of new variable, and namespace for other dependent ops.
2340
2341  Returns:
2342    A tuple of `Variable` and update `Operation`.
2343
2344  Raises:
2345    ValueError: If `weights` is not `None` and has an incompatible shape.
2346  """
2347  with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2348                      (predictions_idx, labels, weights)) as scope:
2349    tp = _sparse_true_positive_at_k(
2350        predictions_idx=predictions_idx,
2351        labels=labels,
2352        class_id=class_id,
2353        weights=weights)
2354    batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64)
2355
2356    var = metric_variable([], dtypes.float64, name=scope)
2357    return var, state_ops.assign_add(var, batch_total_tp, name='update')
2358
2359
2360def _sparse_false_negative_at_k(labels,
2361                                predictions_idx,
2362                                class_id=None,
2363                                weights=None):
2364  """Calculates false negatives for recall@k.
2365
2366  If `class_id` is specified, calculate binary true positives for `class_id`
2367      only.
2368  If `class_id` is not specified, calculate metrics for `k` predicted vs
2369      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2370
2371  Args:
2372    labels: `int64` `Tensor` or `SparseTensor` with shape
2373      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2374      target classes for the associated prediction. Commonly, N=1 and `labels`
2375      has shape [batch_size, num_labels]. [D1, ... DN] must match
2376      `predictions_idx`.
2377    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2378      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2379      match `labels`.
2380    class_id: Class for which we want binary metrics.
2381    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2382      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2383      dimensions must be either `1`, or the same as the corresponding `labels`
2384      dimension).
2385
2386  Returns:
2387    A [D1, ... DN] `Tensor` of false negative counts.
2388  """
2389  with ops.name_scope(None, 'false_negatives',
2390                      (predictions_idx, labels, weights)):
2391    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2392                                                     class_id)
2393    fn = sets.set_size(
2394        sets.set_difference(predictions_idx, labels, aminusb=False))
2395    fn = math_ops.cast(fn, dtypes.float64)
2396    if weights is not None:
2397      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2398          weights, fn),)):
2399        weights = math_ops.cast(weights, dtypes.float64)
2400        fn = math_ops.multiply(fn, weights)
2401    return fn
2402
2403
2404def _streaming_sparse_false_negative_at_k(labels,
2405                                          predictions_idx,
2406                                          k,
2407                                          class_id=None,
2408                                          weights=None,
2409                                          name=None):
2410  """Calculates weighted per step false negatives for recall@k.
2411
2412  If `class_id` is specified, calculate binary true positives for `class_id`
2413      only.
2414  If `class_id` is not specified, calculate metrics for `k` predicted vs
2415      `n` label classes, where `n` is the 2nd dimension of `labels`.
2416
2417  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2418
2419  Args:
2420    labels: `int64` `Tensor` or `SparseTensor` with shape
2421      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2422      target classes for the associated prediction. Commonly, N=1 and `labels`
2423      has shape [batch_size, num_labels]. [D1, ... DN] must match
2424      `predictions_idx`.
2425    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2426      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2427      match `labels`.
2428    k: Integer, k for @k metric. This is only used for default op name.
2429    class_id: Class for which we want binary metrics.
2430    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2431      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2432      dimensions must be either `1`, or the same as the corresponding `labels`
2433      dimension).
2434    name: Name of new variable, and namespace for other dependent ops.
2435
2436  Returns:
2437    A tuple of `Variable` and update `Operation`.
2438
2439  Raises:
2440    ValueError: If `weights` is not `None` and has an incompatible shape.
2441  """
2442  with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2443                      (predictions_idx, labels, weights)) as scope:
2444    fn = _sparse_false_negative_at_k(
2445        predictions_idx=predictions_idx,
2446        labels=labels,
2447        class_id=class_id,
2448        weights=weights)
2449    batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64)
2450
2451    var = metric_variable([], dtypes.float64, name=scope)
2452    return var, state_ops.assign_add(var, batch_total_fn, name='update')
2453
2454
2455@tf_export(v1=['metrics.recall_at_k'])
2456def recall_at_k(labels,
2457                predictions,
2458                k,
2459                class_id=None,
2460                weights=None,
2461                metrics_collections=None,
2462                updates_collections=None,
2463                name=None):
2464  """Computes recall@k of the predictions with respect to sparse labels.
2465
2466  If `class_id` is specified, we calculate recall by considering only the
2467      entries in the batch for which `class_id` is in the label, and computing
2468      the fraction of them for which `class_id` is in the top-k `predictions`.
2469  If `class_id` is not specified, we'll calculate recall as how often on
2470      average a class among the labels of a batch entry is in the top-k
2471      `predictions`.
2472
2473  `sparse_recall_at_k` creates two local variables,
2474  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2475  the recall_at_k frequency. This frequency is ultimately returned as
2476  `recall_at_<k>`: an idempotent operation that simply divides
2477  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2478  `false_negative_at_<k>`).
2479
2480  For estimation of the metric over a stream of data, the function creates an
2481  `update_op` operation that updates these variables and returns the
2482  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2483  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2484  `labels` calculate the true positives and false negatives weighted by
2485  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2486  `false_negative_at_<k>` using these values.
2487
2488  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2489
2490  Args:
2491    labels: `int64` `Tensor` or `SparseTensor` with shape
2492      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2493      num_labels=1. N >= 1 and num_labels is the number of target classes for
2494      the associated prediction. Commonly, N=1 and `labels` has shape
2495      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2496      should be in range [0, num_classes), where num_classes is the last
2497      dimension of `predictions`. Values outside this range always count
2498      towards `false_negative_at_<k>`.
2499    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2500      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2501      The final dimension contains the logit values for each class. [D1, ... DN]
2502      must match `labels`.
2503    k: Integer, k for @k metric.
2504    class_id: Integer class ID for which we want binary metrics. This should be
2505      in range [0, num_classes), where num_classes is the last dimension of
2506      `predictions`. If class_id is outside this range, the method returns NAN.
2507    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2508      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2509      dimensions must be either `1`, or the same as the corresponding `labels`
2510      dimension).
2511    metrics_collections: An optional list of collections that values should
2512      be added to.
2513    updates_collections: An optional list of collections that updates should
2514      be added to.
2515    name: Name of new update operation, and namespace for other dependent ops.
2516
2517  Returns:
2518    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2519      by the sum of `true_positives` and `false_negatives`.
2520    update_op: `Operation` that increments `true_positives` and
2521      `false_negatives` variables appropriately, and whose value matches
2522      `recall`.
2523
2524  Raises:
2525    ValueError: If `weights` is not `None` and its shape doesn't match
2526    `predictions`, or if either `metrics_collections` or `updates_collections`
2527    are not a list or tuple.
2528    RuntimeError: If eager execution is enabled.
2529  """
2530  if context.executing_eagerly():
2531    raise RuntimeError('tf.metrics.recall_at_k is not '
2532                       'supported when eager execution is enabled.')
2533
2534  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2535                      (predictions, labels, weights)) as scope:
2536    _, top_k_idx = nn.top_k(predictions, k)
2537    return recall_at_top_k(
2538        labels=labels,
2539        predictions_idx=top_k_idx,
2540        k=k,
2541        class_id=class_id,
2542        weights=weights,
2543        metrics_collections=metrics_collections,
2544        updates_collections=updates_collections,
2545        name=scope)
2546
2547
2548@tf_export(v1=['metrics.recall_at_top_k'])
2549def recall_at_top_k(labels,
2550                    predictions_idx,
2551                    k=None,
2552                    class_id=None,
2553                    weights=None,
2554                    metrics_collections=None,
2555                    updates_collections=None,
2556                    name=None):
2557  """Computes recall@k of top-k predictions with respect to sparse labels.
2558
2559  Differs from `recall_at_k` in that predictions must be in the form of top `k`
2560  class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2561  for more details.
2562
2563  Args:
2564    labels: `int64` `Tensor` or `SparseTensor` with shape
2565      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2566      num_labels=1. N >= 1 and num_labels is the number of target classes for
2567      the associated prediction. Commonly, N=1 and `labels` has shape
2568      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2569      should be in range [0, num_classes), where num_classes is the last
2570      dimension of `predictions`. Values outside this range always count
2571      towards `false_negative_at_<k>`.
2572    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2573      Commonly, N=1 and predictions has shape [batch size, k]. The final
2574      dimension contains the top `k` predicted class indices. [D1, ... DN] must
2575      match `labels`.
2576    k: Integer, k for @k metric. Only used for the default op name.
2577    class_id: Integer class ID for which we want binary metrics. This should be
2578      in range [0, num_classes), where num_classes is the last dimension of
2579      `predictions`. If class_id is outside this range, the method returns NAN.
2580    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2581      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2582      dimensions must be either `1`, or the same as the corresponding `labels`
2583      dimension).
2584    metrics_collections: An optional list of collections that values should
2585      be added to.
2586    updates_collections: An optional list of collections that updates should
2587      be added to.
2588    name: Name of new update operation, and namespace for other dependent ops.
2589
2590  Returns:
2591    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2592      by the sum of `true_positives` and `false_negatives`.
2593    update_op: `Operation` that increments `true_positives` and
2594      `false_negatives` variables appropriately, and whose value matches
2595      `recall`.
2596
2597  Raises:
2598    ValueError: If `weights` is not `None` and its shape doesn't match
2599    `predictions`, or if either `metrics_collections` or `updates_collections`
2600    are not a list or tuple.
2601  """
2602  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2603                      (predictions_idx, labels, weights)) as scope:
2604    labels = _maybe_expand_labels(labels, predictions_idx)
2605    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
2606    tp, tp_update = _streaming_sparse_true_positive_at_k(
2607        predictions_idx=top_k_idx,
2608        labels=labels,
2609        k=k,
2610        class_id=class_id,
2611        weights=weights)
2612    fn, fn_update = _streaming_sparse_false_negative_at_k(
2613        predictions_idx=top_k_idx,
2614        labels=labels,
2615        k=k,
2616        class_id=class_id,
2617        weights=weights)
2618
2619    def compute_recall(_, tp, fn):
2620      return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
2621
2622    metric = _aggregate_across_replicas(
2623        metrics_collections, compute_recall, tp, fn)
2624
2625    update = math_ops.div(
2626        tp_update, math_ops.add(tp_update, fn_update), name='update')
2627    if updates_collections:
2628      ops.add_to_collections(updates_collections, update)
2629    return metric, update
2630
2631
2632@tf_export(v1=['metrics.recall_at_thresholds'])
2633def recall_at_thresholds(labels,
2634                         predictions,
2635                         thresholds,
2636                         weights=None,
2637                         metrics_collections=None,
2638                         updates_collections=None,
2639                         name=None):
2640  """Computes various recall values for different `thresholds` on `predictions`.
2641
2642  The `recall_at_thresholds` function creates four local variables,
2643  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2644  for various values of thresholds. `recall[i]` is defined as the total weight
2645  of values in `predictions` above `thresholds[i]` whose corresponding entry in
2646  `labels` is `True`, divided by the total weight of `True` values in `labels`
2647  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2648
2649  For estimation of the metric over a stream of data, the function creates an
2650  `update_op` operation that updates these variables and returns the `recall`.
2651
2652  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2653
2654  Args:
2655    labels: The ground truth values, a `Tensor` whose dimensions must match
2656      `predictions`. Will be cast to `bool`.
2657    predictions: A floating point `Tensor` of arbitrary shape and whose values
2658      are in the range `[0, 1]`.
2659    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2660    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2661      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2662      be either `1`, or the same as the corresponding `labels` dimension).
2663    metrics_collections: An optional list of collections that `recall` should be
2664      added to.
2665    updates_collections: An optional list of collections that `update_op` should
2666      be added to.
2667    name: An optional variable_scope name.
2668
2669  Returns:
2670    recall: A float `Tensor` of shape `[len(thresholds)]`.
2671    update_op: An operation that increments the `true_positives`,
2672      `true_negatives`, `false_positives` and `false_negatives` variables that
2673      are used in the computation of `recall`.
2674
2675  Raises:
2676    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2677      `weights` is not `None` and its shape doesn't match `predictions`, or if
2678      either `metrics_collections` or `updates_collections` are not a list or
2679      tuple.
2680    RuntimeError: If eager execution is enabled.
2681  """
2682  if context.executing_eagerly():
2683    raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2684                       'supported when eager execution is enabled.')
2685
2686  with variable_scope.variable_scope(name, 'recall_at_thresholds',
2687                                     (predictions, labels, weights)):
2688    values, update_ops = _confusion_matrix_at_thresholds(
2689        labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2690
2691    # Avoid division by zero.
2692    epsilon = 1e-7
2693
2694    def compute_recall(tp, fn, name):
2695      return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
2696
2697    def recall_across_replicas(_, values):
2698      return compute_recall(values['tp'], values['fn'], 'value')
2699
2700    rec = _aggregate_across_replicas(
2701        metrics_collections, recall_across_replicas, values)
2702
2703    update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2704    if updates_collections:
2705      ops.add_to_collections(updates_collections, update_op)
2706
2707    return rec, update_op
2708
2709
2710@tf_export(v1=['metrics.root_mean_squared_error'])
2711def root_mean_squared_error(labels,
2712                            predictions,
2713                            weights=None,
2714                            metrics_collections=None,
2715                            updates_collections=None,
2716                            name=None):
2717  """Computes the root mean squared error between the labels and predictions.
2718
2719  The `root_mean_squared_error` function creates two local variables,
2720  `total` and `count` that are used to compute the root mean squared error.
2721  This average is weighted by `weights`, and it is ultimately returned as
2722  `root_mean_squared_error`: an idempotent operation that takes the square root
2723  of the division of `total` by `count`.
2724
2725  For estimation of the metric over a stream of data, the function creates an
2726  `update_op` operation that updates these variables and returns the
2727  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2728  the element-wise square of the difference between `predictions` and `labels`.
2729  Then `update_op` increments `total` with the reduced sum of the product of
2730  `weights` and `squared_error`, and it increments `count` with the reduced sum
2731  of `weights`.
2732
2733  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2734
2735  Args:
2736    labels: A `Tensor` of the same shape as `predictions`.
2737    predictions: A `Tensor` of arbitrary shape.
2738    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2739      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2740      be either `1`, or the same as the corresponding `labels` dimension).
2741    metrics_collections: An optional list of collections that
2742      `root_mean_squared_error` should be added to.
2743    updates_collections: An optional list of collections that `update_op` should
2744      be added to.
2745    name: An optional variable_scope name.
2746
2747  Returns:
2748    root_mean_squared_error: A `Tensor` representing the current mean, the value
2749      of `total` divided by `count`.
2750    update_op: An operation that increments the `total` and `count` variables
2751      appropriately and whose value matches `root_mean_squared_error`.
2752
2753  Raises:
2754    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2755      `weights` is not `None` and its shape doesn't match `predictions`, or if
2756      either `metrics_collections` or `updates_collections` are not a list or
2757      tuple.
2758    RuntimeError: If eager execution is enabled.
2759  """
2760  if context.executing_eagerly():
2761    raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2762                       'supported when eager execution is enabled.')
2763
2764  predictions, labels, weights = _remove_squeezable_dimensions(
2765      predictions=predictions, labels=labels, weights=weights)
2766  mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2767                                          None, name or
2768                                          'root_mean_squared_error')
2769
2770  once_across_replicas = lambda _, mse: math_ops.sqrt(mse)
2771  rmse = _aggregate_across_replicas(
2772      metrics_collections, once_across_replicas, mse)
2773
2774  update_rmse_op = math_ops.sqrt(update_mse_op)
2775  if updates_collections:
2776    ops.add_to_collections(updates_collections, update_rmse_op)
2777
2778  return rmse, update_rmse_op
2779
2780
2781@tf_export(v1=['metrics.sensitivity_at_specificity'])
2782def sensitivity_at_specificity(labels,
2783                               predictions,
2784                               specificity,
2785                               weights=None,
2786                               num_thresholds=200,
2787                               metrics_collections=None,
2788                               updates_collections=None,
2789                               name=None):
2790  """Computes the specificity at a given sensitivity.
2791
2792  The `sensitivity_at_specificity` function creates four local
2793  variables, `true_positives`, `true_negatives`, `false_positives` and
2794  `false_negatives` that are used to compute the sensitivity at the given
2795  specificity value. The threshold for the given specificity value is computed
2796  and used to evaluate the corresponding sensitivity.
2797
2798  For estimation of the metric over a stream of data, the function creates an
2799  `update_op` operation that updates these variables and returns the
2800  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
2801  `false_positives` and `false_negatives` counts with the weight of each case
2802  found in the `predictions` and `labels`.
2803
2804  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2805
2806  For additional information about specificity and sensitivity, see the
2807  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
2808
2809  Args:
2810    labels: The ground truth values, a `Tensor` whose dimensions must match
2811      `predictions`. Will be cast to `bool`.
2812    predictions: A floating point `Tensor` of arbitrary shape and whose values
2813      are in the range `[0, 1]`.
2814    specificity: A scalar value in range `[0, 1]`.
2815    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2816      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2817      be either `1`, or the same as the corresponding `labels` dimension).
2818    num_thresholds: The number of thresholds to use for matching the given
2819      specificity.
2820    metrics_collections: An optional list of collections that `sensitivity`
2821      should be added to.
2822    updates_collections: An optional list of collections that `update_op` should
2823      be added to.
2824    name: An optional variable_scope name.
2825
2826  Returns:
2827    sensitivity: A scalar `Tensor` representing the sensitivity at the given
2828      `specificity` value.
2829    update_op: An operation that increments the `true_positives`,
2830      `true_negatives`, `false_positives` and `false_negatives` variables
2831      appropriately and whose value matches `sensitivity`.
2832
2833  Raises:
2834    ValueError: If `predictions` and `labels` have mismatched shapes, if
2835      `weights` is not `None` and its shape doesn't match `predictions`, or if
2836      `specificity` is not between 0 and 1, or if either `metrics_collections`
2837      or `updates_collections` are not a list or tuple.
2838    RuntimeError: If eager execution is enabled.
2839  """
2840  if context.executing_eagerly():
2841    raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
2842                       'supported when eager execution is enabled.')
2843
2844  if specificity < 0 or specificity > 1:
2845    raise ValueError('`specificity` must be in the range [0, 1].')
2846
2847  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
2848                                     (predictions, labels, weights)):
2849    kepsilon = 1e-7  # to account for floating point imprecisions
2850    thresholds = [
2851        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
2852    ]
2853    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
2854
2855    values, update_ops = _confusion_matrix_at_thresholds(
2856        labels, predictions, thresholds, weights)
2857
2858    def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
2859      specificities = math_ops.div(tn, tn + fp + kepsilon)
2860      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
2861      tf_index = math_ops.cast(tf_index, dtypes.int32)
2862
2863      # Now, we have the implicit threshold, so compute the sensitivity:
2864      return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
2865                          name)
2866
2867    def sensitivity_across_replicas(_, values):
2868      return compute_sensitivity_at_specificity(
2869          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
2870
2871    sensitivity = _aggregate_across_replicas(
2872        metrics_collections, sensitivity_across_replicas, values)
2873
2874    update_op = compute_sensitivity_at_specificity(
2875        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
2876        'update_op')
2877    if updates_collections:
2878      ops.add_to_collections(updates_collections, update_op)
2879
2880    return sensitivity, update_op
2881
2882
2883def _expand_and_tile(tensor, multiple, dim=0, name=None):
2884  """Slice `tensor` shape in 2, then tile along the sliced dimension.
2885
2886  A new dimension is inserted in shape of `tensor` before `dim`, then values are
2887  tiled `multiple` times along the new dimension.
2888
2889  Args:
2890    tensor: Input `Tensor` or `SparseTensor`.
2891    multiple: Integer, number of times to tile.
2892    dim: Integer, dimension along which to tile.
2893    name: Name of operation.
2894
2895  Returns:
2896    `Tensor` result of expanding and tiling `tensor`.
2897
2898  Raises:
2899    ValueError: if `multiple` is less than 1, or `dim` is not in
2900    `[-rank(tensor), rank(tensor)]`.
2901  """
2902  if multiple < 1:
2903    raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
2904  with ops.name_scope(name, 'expand_and_tile',
2905                      (tensor, multiple, dim)) as scope:
2906    # Sparse.
2907    tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
2908    if isinstance(tensor, sparse_tensor.SparseTensor):
2909      if dim < 0:
2910        expand_dims = array_ops.reshape(
2911            array_ops.size(tensor.dense_shape) + dim, [1])
2912      else:
2913        expand_dims = [dim]
2914      expanded_shape = array_ops.concat(
2915          (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
2916           array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
2917          0,
2918          name='expanded_shape')
2919      expanded = sparse_ops.sparse_reshape(
2920          tensor, shape=expanded_shape, name='expand')
2921      if multiple == 1:
2922        return expanded
2923      return sparse_ops.sparse_concat(
2924          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
2925
2926    # Dense.
2927    expanded = array_ops.expand_dims(
2928        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
2929    if multiple == 1:
2930      return expanded
2931    ones = array_ops.ones_like(array_ops.shape(tensor))
2932    tile_multiples = array_ops.concat(
2933        (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
2934    return array_ops.tile(expanded, tile_multiples, name=scope)
2935
2936
2937def _num_relevant(labels, k):
2938  """Computes number of relevant values for each row in labels.
2939
2940  For labels with shape [D1, ... DN, num_labels], this is the minimum of
2941  `num_labels` and `k`.
2942
2943  Args:
2944    labels: `int64` `Tensor` or `SparseTensor` with shape
2945      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2946      target classes for the associated prediction. Commonly, N=1 and `labels`
2947      has shape [batch_size, num_labels].
2948    k: Integer, k for @k metric.
2949
2950  Returns:
2951    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
2952    relevant values for that row.
2953
2954  Raises:
2955    ValueError: if inputs have invalid dtypes or values.
2956  """
2957  if k < 1:
2958    raise ValueError('Invalid k=%s.' % k)
2959  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
2960    # For SparseTensor, calculate separate count for each row.
2961    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
2962    if isinstance(labels, sparse_tensor.SparseTensor):
2963      return math_ops.minimum(sets.set_size(labels), k, name=scope)
2964
2965    # For dense Tensor, calculate scalar count based on last dimension, and
2966    # tile across labels shape.
2967    labels_shape = array_ops.shape(labels)
2968    labels_size = labels_shape[-1]
2969    num_relevant_scalar = math_ops.minimum(labels_size, k)
2970    return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
2971
2972
2973def _sparse_average_precision_at_top_k(labels, predictions_idx):
2974  """Computes average precision@k of predictions with respect to sparse labels.
2975
2976  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
2977  for each row is:
2978
2979    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
2980
2981  A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
2982  `labels`, and the result `Tensors`. In the common case, this is [batch_size].
2983  Each row of the results contains the average precision for that row.
2984
2985  Args:
2986    labels: `int64` `Tensor` or `SparseTensor` with shape
2987      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2988      num_labels=1. N >= 1 and num_labels is the number of target classes for
2989      the associated prediction. Commonly, N=1 and `labels` has shape
2990      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
2991      Values should be in range [0, num_classes).
2992    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2993      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
2994      dimension must be set and contains the top `k` predicted class indices.
2995      [D1, ... DN] must match `labels`. Values should be in range
2996      [0, num_classes).
2997
2998  Returns:
2999    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3000    precision for that row.
3001
3002  Raises:
3003    ValueError: if the last dimension of predictions_idx is not set.
3004  """
3005  with ops.name_scope(None, 'average_precision',
3006                      (predictions_idx, labels)) as scope:
3007    predictions_idx = math_ops.cast(
3008        predictions_idx, dtypes.int64, name='predictions_idx')
3009    if predictions_idx.get_shape().ndims == 0:
3010      raise ValueError('The rank of predictions_idx must be at least 1.')
3011    k = predictions_idx.get_shape().as_list()[-1]
3012    if k is None:
3013      raise ValueError('The last dimension of predictions_idx must be set.')
3014    labels = _maybe_expand_labels(labels, predictions_idx)
3015
3016    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3017    # prediction for each k, so we can calculate separate true positive values
3018    # for each k.
3019    predictions_idx_per_k = array_ops.expand_dims(
3020        predictions_idx, -1, name='predictions_idx_per_k')
3021
3022    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3023    labels_per_k = _expand_and_tile(
3024        labels, multiple=k, dim=-1, name='labels_per_k')
3025
3026    # The following tensors are all of shape [D1, ... DN, k], containing values
3027    # per row, per k value.
3028    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3029    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3030    #     the formula above.
3031    # `tp_per_k` (int32) - True positive counts.
3032    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3033    #     the precision denominator.
3034    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3035    #     term from the formula above.
3036    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3037    #     precisions at all k for which relevance indicator is true.
3038    relevant_per_k = _sparse_true_positive_at_k(
3039        labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3040    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3041    retrieved_per_k = math_ops.cumsum(
3042        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3043    precision_per_k = math_ops.div(
3044        math_ops.cast(tp_per_k, dtypes.float64),
3045        math_ops.cast(retrieved_per_k, dtypes.float64),
3046        name='precision_per_k')
3047    relevant_precision_per_k = math_ops.multiply(
3048        precision_per_k,
3049        math_ops.cast(relevant_per_k, dtypes.float64),
3050        name='relevant_precision_per_k')
3051
3052    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3053    precision_sum = math_ops.reduce_sum(
3054        relevant_precision_per_k, axis=(-1,), name='precision_sum')
3055
3056    # Divide by number of relevant items to get average precision. These are
3057    # the "num_relevant_items" and "AveP" terms from the formula above.
3058    num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
3059    return math_ops.div(precision_sum, num_relevant_items, name=scope)
3060
3061
3062def _streaming_sparse_average_precision_at_top_k(labels,
3063                                                 predictions_idx,
3064                                                 weights=None,
3065                                                 metrics_collections=None,
3066                                                 updates_collections=None,
3067                                                 name=None):
3068  """Computes average precision@k of predictions with respect to sparse labels.
3069
3070  `sparse_average_precision_at_top_k` creates two local variables,
3071  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3072  are used to compute the frequency. This frequency is ultimately returned as
3073  `average_precision_at_<k>`: an idempotent operation that simply divides
3074  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3075
3076  For estimation of the metric over a stream of data, the function creates an
3077  `update_op` operation that updates these variables and returns the
3078  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3079  the true positives and false positives weighted by `weights`. Then `update_op`
3080  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3081  values.
3082
3083  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3084
3085  Args:
3086    labels: `int64` `Tensor` or `SparseTensor` with shape
3087      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3088      num_labels=1. N >= 1 and num_labels is the number of target classes for
3089      the associated prediction. Commonly, N=1 and `labels` has shape
3090      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3091      Values should be in range [0, num_classes).
3092    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3093      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3094      dimension contains the top `k` predicted class indices. [D1, ... DN] must
3095      match `labels`. Values should be in range [0, num_classes).
3096    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3097      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3098      dimensions must be either `1`, or the same as the corresponding `labels`
3099      dimension).
3100    metrics_collections: An optional list of collections that values should
3101      be added to.
3102    updates_collections: An optional list of collections that updates should
3103      be added to.
3104    name: Name of new update operation, and namespace for other dependent ops.
3105
3106  Returns:
3107    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3108      precision values.
3109    update: `Operation` that increments variables appropriately, and whose
3110      value matches `metric`.
3111  """
3112  with ops.name_scope(name, 'average_precision_at_top_k',
3113                      (predictions_idx, labels, weights)) as scope:
3114    # Calculate per-example average precision, and apply weights.
3115    average_precision = _sparse_average_precision_at_top_k(
3116        predictions_idx=predictions_idx, labels=labels)
3117    if weights is not None:
3118      weights = weights_broadcast_ops.broadcast_weights(
3119          math_ops.cast(weights, dtypes.float64), average_precision)
3120      average_precision = math_ops.multiply(average_precision, weights)
3121
3122    # Create accumulation variables and update ops for max average precision and
3123    # total average precision.
3124    with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3125      # `max` is the max possible precision. Since max for any row is 1.0:
3126      # - For the unweighted case, this is just the number of rows.
3127      # - For the weighted case, it's the sum of the weights broadcast across
3128      #   `average_precision` rows.
3129      max_var = metric_variable([], dtypes.float64, name=max_scope)
3130      if weights is None:
3131        batch_max = math_ops.cast(
3132            array_ops.size(average_precision, name='batch_max'), dtypes.float64)
3133      else:
3134        batch_max = math_ops.reduce_sum(weights, name='batch_max')
3135      max_update = state_ops.assign_add(max_var, batch_max, name='update')
3136    with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3137      total_var = metric_variable([], dtypes.float64, name=total_scope)
3138      batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3139      total_update = state_ops.assign_add(total_var, batch_total, name='update')
3140
3141    # Divide total by max to get mean, for both vars and the update ops.
3142    def precision_across_replicas(_, total_var, max_var):
3143      return _safe_scalar_div(total_var, max_var, name='mean')
3144
3145    mean_average_precision = _aggregate_across_replicas(
3146        metrics_collections, precision_across_replicas, total_var, max_var)
3147
3148    update = _safe_scalar_div(total_update, max_update, name=scope)
3149    if updates_collections:
3150      ops.add_to_collections(updates_collections, update)
3151
3152    return mean_average_precision, update
3153
3154
3155@tf_export(v1=['metrics.sparse_average_precision_at_k'])
3156@deprecated(None, 'Use average_precision_at_k instead')
3157def sparse_average_precision_at_k(labels,
3158                                  predictions,
3159                                  k,
3160                                  weights=None,
3161                                  metrics_collections=None,
3162                                  updates_collections=None,
3163                                  name=None):
3164  """Renamed to `average_precision_at_k`, please use that method instead."""
3165  return average_precision_at_k(
3166      labels=labels,
3167      predictions=predictions,
3168      k=k,
3169      weights=weights,
3170      metrics_collections=metrics_collections,
3171      updates_collections=updates_collections,
3172      name=name)
3173
3174
3175@tf_export(v1=['metrics.average_precision_at_k'])
3176def average_precision_at_k(labels,
3177                           predictions,
3178                           k,
3179                           weights=None,
3180                           metrics_collections=None,
3181                           updates_collections=None,
3182                           name=None):
3183  """Computes average precision@k of predictions with respect to sparse labels.
3184
3185  `average_precision_at_k` creates two local variables,
3186  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3187  are used to compute the frequency. This frequency is ultimately returned as
3188  `average_precision_at_<k>`: an idempotent operation that simply divides
3189  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3190
3191  For estimation of the metric over a stream of data, the function creates an
3192  `update_op` operation that updates these variables and returns the
3193  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3194  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3195  `labels` calculate the true positives and false positives weighted by
3196  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3197  `false_positive_at_<k>` using these values.
3198
3199  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3200
3201  Args:
3202    labels: `int64` `Tensor` or `SparseTensor` with shape
3203      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3204      num_labels=1. N >= 1 and num_labels is the number of target classes for
3205      the associated prediction. Commonly, N=1 and `labels` has shape
3206      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3207      should be in range [0, num_classes), where num_classes is the last
3208      dimension of `predictions`. Values outside this range are ignored.
3209    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3210      N >= 1. Commonly, N=1 and `predictions` has shape
3211      [batch size, num_classes]. The final dimension contains the logit values
3212      for each class. [D1, ... DN] must match `labels`.
3213    k: Integer, k for @k metric. This will calculate an average precision for
3214      range `[1,k]`, as documented above.
3215    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3216      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3217      dimensions must be either `1`, or the same as the corresponding `labels`
3218      dimension).
3219    metrics_collections: An optional list of collections that values should
3220      be added to.
3221    updates_collections: An optional list of collections that updates should
3222      be added to.
3223    name: Name of new update operation, and namespace for other dependent ops.
3224
3225  Returns:
3226    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3227      precision values.
3228    update: `Operation` that increments variables appropriately, and whose
3229      value matches `metric`.
3230
3231  Raises:
3232    ValueError: if k is invalid.
3233    RuntimeError: If eager execution is enabled.
3234  """
3235  if context.executing_eagerly():
3236    raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3237                       'supported when eager execution is enabled.')
3238
3239  if k < 1:
3240    raise ValueError('Invalid k=%s.' % k)
3241  with ops.name_scope(name, _at_k_name('average_precision', k),
3242                      (predictions, labels, weights)) as scope:
3243    # Calculate top k indices to produce [D1, ... DN, k] tensor.
3244    _, predictions_idx = nn.top_k(predictions, k)
3245    return _streaming_sparse_average_precision_at_top_k(
3246        labels=labels,
3247        predictions_idx=predictions_idx,
3248        weights=weights,
3249        metrics_collections=metrics_collections,
3250        updates_collections=updates_collections,
3251        name=scope)
3252
3253
3254def _sparse_false_positive_at_k(labels,
3255                                predictions_idx,
3256                                class_id=None,
3257                                weights=None):
3258  """Calculates false positives for precision@k.
3259
3260  If `class_id` is specified, calculate binary true positives for `class_id`
3261      only.
3262  If `class_id` is not specified, calculate metrics for `k` predicted vs
3263      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3264
3265  Args:
3266    labels: `int64` `Tensor` or `SparseTensor` with shape
3267      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3268      target classes for the associated prediction. Commonly, N=1 and `labels`
3269      has shape [batch_size, num_labels]. [D1, ... DN] must match
3270      `predictions_idx`.
3271    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3272      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3273      match `labels`.
3274    class_id: Class for which we want binary metrics.
3275    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3276      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3277      dimensions must be either `1`, or the same as the corresponding `labels`
3278      dimension).
3279
3280  Returns:
3281    A [D1, ... DN] `Tensor` of false positive counts.
3282  """
3283  with ops.name_scope(None, 'false_positives',
3284                      (predictions_idx, labels, weights)):
3285    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3286                                                     class_id)
3287    fp = sets.set_size(
3288        sets.set_difference(predictions_idx, labels, aminusb=True))
3289    fp = math_ops.cast(fp, dtypes.float64)
3290    if weights is not None:
3291      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
3292          weights, fp),)):
3293        weights = math_ops.cast(weights, dtypes.float64)
3294        fp = math_ops.multiply(fp, weights)
3295    return fp
3296
3297
3298def _streaming_sparse_false_positive_at_k(labels,
3299                                          predictions_idx,
3300                                          k=None,
3301                                          class_id=None,
3302                                          weights=None,
3303                                          name=None):
3304  """Calculates weighted per step false positives for precision@k.
3305
3306  If `class_id` is specified, calculate binary true positives for `class_id`
3307      only.
3308  If `class_id` is not specified, calculate metrics for `k` predicted vs
3309      `n` label classes, where `n` is the 2nd dimension of `labels`.
3310
3311  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3312
3313  Args:
3314    labels: `int64` `Tensor` or `SparseTensor` with shape
3315      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3316      target classes for the associated prediction. Commonly, N=1 and `labels`
3317      has shape [batch_size, num_labels]. [D1, ... DN] must match
3318      `predictions_idx`.
3319    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3320      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3321      match `labels`.
3322    k: Integer, k for @k metric. This is only used for default op name.
3323    class_id: Class for which we want binary metrics.
3324    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3325      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3326      dimensions must be either `1`, or the same as the corresponding `labels`
3327      dimension).
3328    name: Name of new variable, and namespace for other dependent ops.
3329
3330  Returns:
3331    A tuple of `Variable` and update `Operation`.
3332
3333  Raises:
3334    ValueError: If `weights` is not `None` and has an incompatible shape.
3335  """
3336  with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3337                      (predictions_idx, labels, weights)) as scope:
3338    fp = _sparse_false_positive_at_k(
3339        predictions_idx=predictions_idx,
3340        labels=labels,
3341        class_id=class_id,
3342        weights=weights)
3343    batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64)
3344
3345    var = metric_variable([], dtypes.float64, name=scope)
3346    return var, state_ops.assign_add(var, batch_total_fp, name='update')
3347
3348
3349@tf_export(v1=['metrics.precision_at_top_k'])
3350def precision_at_top_k(labels,
3351                       predictions_idx,
3352                       k=None,
3353                       class_id=None,
3354                       weights=None,
3355                       metrics_collections=None,
3356                       updates_collections=None,
3357                       name=None):
3358  """Computes precision@k of the predictions with respect to sparse labels.
3359
3360  Differs from `sparse_precision_at_k` in that predictions must be in the form
3361  of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3362  Refer to `sparse_precision_at_k` for more details.
3363
3364  Args:
3365    labels: `int64` `Tensor` or `SparseTensor` with shape
3366      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3367      num_labels=1. N >= 1 and num_labels is the number of target classes for
3368      the associated prediction. Commonly, N=1 and `labels` has shape
3369      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3370      should be in range [0, num_classes), where num_classes is the last
3371      dimension of `predictions`. Values outside this range are ignored.
3372    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3373      N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3374      The final dimension contains the top `k` predicted class indices.
3375      [D1, ... DN] must match `labels`.
3376    k: Integer, k for @k metric. Only used for the default op name.
3377    class_id: Integer class ID for which we want binary metrics. This should be
3378      in range [0, num_classes], where num_classes is the last dimension of
3379      `predictions`. If `class_id` is outside this range, the method returns
3380      NAN.
3381    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3382      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3383      dimensions must be either `1`, or the same as the corresponding `labels`
3384      dimension).
3385    metrics_collections: An optional list of collections that values should
3386      be added to.
3387    updates_collections: An optional list of collections that updates should
3388      be added to.
3389    name: Name of new update operation, and namespace for other dependent ops.
3390
3391  Returns:
3392    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3393      divided by the sum of `true_positives` and `false_positives`.
3394    update_op: `Operation` that increments `true_positives` and
3395      `false_positives` variables appropriately, and whose value matches
3396      `precision`.
3397
3398  Raises:
3399    ValueError: If `weights` is not `None` and its shape doesn't match
3400      `predictions`, or if either `metrics_collections` or `updates_collections`
3401      are not a list or tuple.
3402    RuntimeError: If eager execution is enabled.
3403  """
3404  if context.executing_eagerly():
3405    raise RuntimeError('tf.metrics.precision_at_top_k is not '
3406                       'supported when eager execution is enabled.')
3407
3408  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3409                      (predictions_idx, labels, weights)) as scope:
3410    labels = _maybe_expand_labels(labels, predictions_idx)
3411    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
3412    tp, tp_update = _streaming_sparse_true_positive_at_k(
3413        predictions_idx=top_k_idx,
3414        labels=labels,
3415        k=k,
3416        class_id=class_id,
3417        weights=weights)
3418    fp, fp_update = _streaming_sparse_false_positive_at_k(
3419        predictions_idx=top_k_idx,
3420        labels=labels,
3421        k=k,
3422        class_id=class_id,
3423        weights=weights)
3424
3425    def precision_across_replicas(_, tp, fp):
3426      return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
3427
3428    metric = _aggregate_across_replicas(
3429        metrics_collections, precision_across_replicas, tp, fp)
3430
3431    update = math_ops.div(
3432        tp_update, math_ops.add(tp_update, fp_update), name='update')
3433    if updates_collections:
3434      ops.add_to_collections(updates_collections, update)
3435    return metric, update
3436
3437
3438@tf_export(v1=['metrics.sparse_precision_at_k'])
3439@deprecated(None, 'Use precision_at_k instead')
3440def sparse_precision_at_k(labels,
3441                          predictions,
3442                          k,
3443                          class_id=None,
3444                          weights=None,
3445                          metrics_collections=None,
3446                          updates_collections=None,
3447                          name=None):
3448  """Renamed to `precision_at_k`, please use that method instead."""
3449  return precision_at_k(
3450      labels=labels,
3451      predictions=predictions,
3452      k=k,
3453      class_id=class_id,
3454      weights=weights,
3455      metrics_collections=metrics_collections,
3456      updates_collections=updates_collections,
3457      name=name)
3458
3459
3460@tf_export(v1=['metrics.precision_at_k'])
3461def precision_at_k(labels,
3462                   predictions,
3463                   k,
3464                   class_id=None,
3465                   weights=None,
3466                   metrics_collections=None,
3467                   updates_collections=None,
3468                   name=None):
3469  """Computes precision@k of the predictions with respect to sparse labels.
3470
3471  If `class_id` is specified, we calculate precision by considering only the
3472      entries in the batch for which `class_id` is in the top-k highest
3473      `predictions`, and computing the fraction of them for which `class_id` is
3474      indeed a correct label.
3475  If `class_id` is not specified, we'll calculate precision as how often on
3476      average a class among the top-k classes with the highest predicted values
3477      of a batch entry is correct and can be found in the label for that entry.
3478
3479  `precision_at_k` creates two local variables,
3480  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3481  the precision@k frequency. This frequency is ultimately returned as
3482  `precision_at_<k>`: an idempotent operation that simply divides
3483  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3484  `false_positive_at_<k>`).
3485
3486  For estimation of the metric over a stream of data, the function creates an
3487  `update_op` operation that updates these variables and returns the
3488  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3489  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3490  `labels` calculate the true positives and false positives weighted by
3491  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3492  `false_positive_at_<k>` using these values.
3493
3494  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3495
3496  Args:
3497    labels: `int64` `Tensor` or `SparseTensor` with shape
3498      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3499      num_labels=1. N >= 1 and num_labels is the number of target classes for
3500      the associated prediction. Commonly, N=1 and `labels` has shape
3501      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3502      should be in range [0, num_classes), where num_classes is the last
3503      dimension of `predictions`. Values outside this range are ignored.
3504    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3505      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3506      The final dimension contains the logit values for each class. [D1, ... DN]
3507      must match `labels`.
3508    k: Integer, k for @k metric.
3509    class_id: Integer class ID for which we want binary metrics. This should be
3510      in range [0, num_classes], where num_classes is the last dimension of
3511      `predictions`. If `class_id` is outside this range, the method returns
3512      NAN.
3513    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3514      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3515      dimensions must be either `1`, or the same as the corresponding `labels`
3516      dimension).
3517    metrics_collections: An optional list of collections that values should
3518      be added to.
3519    updates_collections: An optional list of collections that updates should
3520      be added to.
3521    name: Name of new update operation, and namespace for other dependent ops.
3522
3523  Returns:
3524    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3525      divided by the sum of `true_positives` and `false_positives`.
3526    update_op: `Operation` that increments `true_positives` and
3527      `false_positives` variables appropriately, and whose value matches
3528      `precision`.
3529
3530  Raises:
3531    ValueError: If `weights` is not `None` and its shape doesn't match
3532      `predictions`, or if either `metrics_collections` or `updates_collections`
3533      are not a list or tuple.
3534    RuntimeError: If eager execution is enabled.
3535  """
3536  if context.executing_eagerly():
3537    raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3538                       'supported when eager execution is enabled.')
3539
3540  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3541                      (predictions, labels, weights)) as scope:
3542    _, top_k_idx = nn.top_k(predictions, k)
3543    return precision_at_top_k(
3544        labels=labels,
3545        predictions_idx=top_k_idx,
3546        k=k,
3547        class_id=class_id,
3548        weights=weights,
3549        metrics_collections=metrics_collections,
3550        updates_collections=updates_collections,
3551        name=scope)
3552
3553
3554@tf_export(v1=['metrics.specificity_at_sensitivity'])
3555def specificity_at_sensitivity(labels,
3556                               predictions,
3557                               sensitivity,
3558                               weights=None,
3559                               num_thresholds=200,
3560                               metrics_collections=None,
3561                               updates_collections=None,
3562                               name=None):
3563  """Computes the specificity at a given sensitivity.
3564
3565  The `specificity_at_sensitivity` function creates four local
3566  variables, `true_positives`, `true_negatives`, `false_positives` and
3567  `false_negatives` that are used to compute the specificity at the given
3568  sensitivity value. The threshold for the given sensitivity value is computed
3569  and used to evaluate the corresponding specificity.
3570
3571  For estimation of the metric over a stream of data, the function creates an
3572  `update_op` operation that updates these variables and returns the
3573  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3574  `false_positives` and `false_negatives` counts with the weight of each case
3575  found in the `predictions` and `labels`.
3576
3577  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3578
3579  For additional information about specificity and sensitivity, see the
3580  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3581
3582  Args:
3583    labels: The ground truth values, a `Tensor` whose dimensions must match
3584      `predictions`. Will be cast to `bool`.
3585    predictions: A floating point `Tensor` of arbitrary shape and whose values
3586      are in the range `[0, 1]`.
3587    sensitivity: A scalar value in range `[0, 1]`.
3588    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3589      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3590      be either `1`, or the same as the corresponding `labels` dimension).
3591    num_thresholds: The number of thresholds to use for matching the given
3592      sensitivity.
3593    metrics_collections: An optional list of collections that `specificity`
3594      should be added to.
3595    updates_collections: An optional list of collections that `update_op` should
3596      be added to.
3597    name: An optional variable_scope name.
3598
3599  Returns:
3600    specificity: A scalar `Tensor` representing the specificity at the given
3601      `specificity` value.
3602    update_op: An operation that increments the `true_positives`,
3603      `true_negatives`, `false_positives` and `false_negatives` variables
3604      appropriately and whose value matches `specificity`.
3605
3606  Raises:
3607    ValueError: If `predictions` and `labels` have mismatched shapes, if
3608      `weights` is not `None` and its shape doesn't match `predictions`, or if
3609      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3610      or `updates_collections` are not a list or tuple.
3611    RuntimeError: If eager execution is enabled.
3612  """
3613  if context.executing_eagerly():
3614    raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3615                       'supported when eager execution is enabled.')
3616
3617  if sensitivity < 0 or sensitivity > 1:
3618    raise ValueError('`sensitivity` must be in the range [0, 1].')
3619
3620  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3621                                     (predictions, labels, weights)):
3622    kepsilon = 1e-7  # to account for floating point imprecisions
3623    thresholds = [
3624        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3625    ]
3626    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3627
3628    values, update_ops = _confusion_matrix_at_thresholds(
3629        labels, predictions, thresholds, weights)
3630
3631    def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3632      """Computes the specificity at the given sensitivity.
3633
3634      Args:
3635        tp: True positives.
3636        tn: True negatives.
3637        fp: False positives.
3638        fn: False negatives.
3639        name: The name of the operation.
3640
3641      Returns:
3642        The specificity using the aggregated values.
3643      """
3644      sensitivities = math_ops.div(tp, tp + fn + kepsilon)
3645
3646      # We'll need to use this trick until tf.argmax allows us to specify
3647      # whether we should use the first or last index in case of ties.
3648      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3649      indices_at_minval = math_ops.equal(
3650          math_ops.abs(sensitivities - sensitivity), min_val)
3651      indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64)
3652      indices_at_minval = math_ops.cumsum(indices_at_minval)
3653      tf_index = math_ops.argmax(indices_at_minval, 0)
3654      tf_index = math_ops.cast(tf_index, dtypes.int32)
3655
3656      # Now, we have the implicit threshold, so compute the specificity:
3657      return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
3658                          name)
3659
3660    def specificity_across_replicas(_, values):
3661      return compute_specificity_at_sensitivity(
3662          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3663
3664    specificity = _aggregate_across_replicas(
3665        metrics_collections, specificity_across_replicas, values)
3666
3667    update_op = compute_specificity_at_sensitivity(
3668        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3669        'update_op')
3670    if updates_collections:
3671      ops.add_to_collections(updates_collections, update_op)
3672
3673    return specificity, update_op
3674