• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2# Licensed under the Apache License, Version 2.0 (the "License");
3# you may not use this file except in compliance with the License.
4# You may obtain a copy of the License at
5#
6#     http://www.apache.org/licenses/LICENSE-2.0
7#
8# Unless required by applicable law or agreed to in writing, software
9# distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and
12# limitations under the License.
13# ==============================================================================
14"""Implementation of tf.metrics module."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.distribute import distribution_strategy_context
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import confusion_matrix
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn
31from tensorflow.python.ops import sets
32from tensorflow.python.ops import sparse_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.ops import weights_broadcast_ops
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util.deprecation import deprecated
38from tensorflow.python.util.tf_export import tf_export
39
40
41def metric_variable(shape, dtype, validate_shape=True, name=None):
42  """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
43
44  If running in a `DistributionStrategy` context, the variable will be
45  "sync on read". This means:
46
47  *   The returned object will be a container with separate variables
48      per replica of the model.
49
50  *   When writing to the variable, e.g. using `assign_add` in a metric
51      update, the update will be applied to the variable local to the
52      replica.
53
54  *   To get a metric's result value, we need to sum the variable values
55      across the replicas before computing the final answer. Furthermore,
56      the final answer should be computed once instead of in every
57      replica. Both of these are accomplished by running the computation
58      of the final result value inside
59      `distribution_strategy_context.get_replica_context().merge_call(fn)`.
60      Inside the `merge_call()`, ops are only added to the graph once
61      and access to a sync on read variable in a computation returns
62      the sum across all replicas.
63
64  Args:
65    shape: Shape of the created variable.
66    dtype: Type of the created variable.
67    validate_shape: (Optional) Whether shape validation is enabled for
68      the created variable.
69    name: (Optional) String name of the created variable.
70
71  Returns:
72    A (non-trainable) variable initialized to zero, or if inside a
73    `DistributionStrategy` scope a sync on read variable container.
74  """
75  # Note that synchronization "ON_READ" implies trainable=False.
76  return variable_scope.variable(
77      lambda: array_ops.zeros(shape, dtype),
78      trainable=False,
79      collections=[
80          ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
81      ],
82      validate_shape=validate_shape,
83      synchronization=variable_scope.VariableSynchronization.ON_READ,
84      aggregation=variable_scope.VariableAggregation.SUM,
85      name=name)
86
87
88def _remove_squeezable_dimensions(predictions, labels, weights):
89  """Squeeze or expand last dim if needed.
90
91  Squeezes last dim of `predictions` or `labels` if their rank differs by 1
92  (using confusion_matrix.remove_squeezable_dimensions).
93  Squeezes or expands last dim of `weights` if its rank differs by 1 from the
94  new rank of `predictions`.
95
96  If `weights` is scalar, it is kept scalar.
97
98  This will use static shape if available. Otherwise, it will add graph
99  operations, which could result in a performance hit.
100
101  Args:
102    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
103    labels: Optional label `Tensor` whose dimensions match `predictions`.
104    weights: Optional weight scalar or `Tensor` whose dimensions match
105      `predictions`.
106
107  Returns:
108    Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
109    the last dimension squeezed, `weights` could be extended by one dimension.
110  """
111  predictions = ops.convert_to_tensor(predictions)
112  if labels is not None:
113    labels, predictions = confusion_matrix.remove_squeezable_dimensions(
114        labels, predictions)
115    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
116
117  if weights is None:
118    return predictions, labels, None
119
120  weights = ops.convert_to_tensor(weights)
121  weights_shape = weights.get_shape()
122  weights_rank = weights_shape.ndims
123  if weights_rank == 0:
124    return predictions, labels, weights
125
126  predictions_shape = predictions.get_shape()
127  predictions_rank = predictions_shape.ndims
128  if (predictions_rank is not None) and (weights_rank is not None):
129    # Use static rank.
130    if weights_rank - predictions_rank == 1:
131      weights = array_ops.squeeze(weights, [-1])
132    elif predictions_rank - weights_rank == 1:
133      weights = array_ops.expand_dims(weights, [-1])
134  else:
135    # Use dynamic rank.
136    weights_rank_tensor = array_ops.rank(weights)
137    rank_diff = weights_rank_tensor - array_ops.rank(predictions)
138
139    def _maybe_expand_weights():
140      return control_flow_ops.cond(
141          math_ops.equal(rank_diff, -1),
142          lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
143
144    # Don't attempt squeeze if it will fail based on static check.
145    if ((weights_rank is not None) and
146        (not weights_shape.dims[-1].is_compatible_with(1))):
147      maybe_squeeze_weights = lambda: weights
148    else:
149      maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
150
151    def _maybe_adjust_weights():
152      return control_flow_ops.cond(
153          math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
154          _maybe_expand_weights)
155
156    # If weights are scalar, do nothing. Otherwise, try to add or remove a
157    # dimension to match predictions.
158    weights = control_flow_ops.cond(
159        math_ops.equal(weights_rank_tensor, 0), lambda: weights,
160        _maybe_adjust_weights)
161  return predictions, labels, weights
162
163
164def _maybe_expand_labels(labels, predictions):
165  """If necessary, expand `labels` along last dimension to match `predictions`.
166
167  Args:
168    labels: `Tensor` or `SparseTensor` with shape
169      [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
170      num_labels=1, in which case the result is an expanded `labels` with shape
171      [D1, ... DN, 1].
172    predictions: `Tensor` with shape [D1, ... DN, num_classes].
173
174  Returns:
175    `labels` with the same rank as `predictions`.
176
177  Raises:
178    ValueError: if `labels` has invalid shape.
179  """
180  with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
181    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
182
183    # If sparse, expand sparse shape.
184    if isinstance(labels, sparse_tensor.SparseTensor):
185      return control_flow_ops.cond(
186          math_ops.equal(
187              array_ops.rank(predictions),
188              array_ops.size(labels.dense_shape) + 1),
189          lambda: sparse_ops.sparse_reshape(  # pylint: disable=g-long-lambda
190              labels,
191              shape=array_ops.concat((labels.dense_shape, (1,)), 0),
192              name=scope),
193          lambda: labels)
194
195    # Otherwise, try to use static shape.
196    labels_rank = labels.get_shape().ndims
197    if labels_rank is not None:
198      predictions_rank = predictions.get_shape().ndims
199      if predictions_rank is not None:
200        if predictions_rank == labels_rank:
201          return labels
202        if predictions_rank == labels_rank + 1:
203          return array_ops.expand_dims(labels, -1, name=scope)
204        raise ValueError(
205            f'Unexpected labels shape {labels.get_shape()} for predictions '
206            f'shape {predictions.get_shape()}. Predictions rank should be the '
207            'same rank as labels rank or labels rank plus one .')
208
209    # Otherwise, use dynamic shape.
210    return control_flow_ops.cond(
211        math_ops.equal(array_ops.rank(predictions),
212                       array_ops.rank(labels) + 1),
213        lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
214
215
216def _safe_scalar_div(numerator, denominator, name):
217  """Divides two values, returning 0 if the denominator is 0.
218
219  Args:
220    numerator: A scalar `float64` `Tensor`.
221    denominator: A scalar `float64` `Tensor`.
222    name: Name for the returned op.
223
224  Returns:
225    0 if `denominator` == 0, else `numerator` / `denominator`
226  """
227  numerator.get_shape().with_rank_at_most(1)
228  denominator.get_shape().with_rank_at_most(1)
229  return math_ops.div_no_nan(numerator, denominator, name=name)
230
231
232def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
233  """Calculate a streaming confusion matrix.
234
235  Calculates a confusion matrix. For estimation over a stream of data,
236  the function creates an  `update_op` operation.
237
238  Args:
239    labels: A `Tensor` of ground truth labels with shape [batch size] and of
240      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
241    predictions: A `Tensor` of prediction results for semantic labels, whose
242      shape is [batch size] and type `int32` or `int64`. The tensor will be
243      flattened if its rank > 1.
244    num_classes: The possible number of labels the prediction task can
245      have. This value must be provided, since a confusion matrix of
246      dimension = [num_classes, num_classes] will be allocated.
247    weights: Optional `Tensor` whose rank is either 0, or the same rank as
248      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
249      be either `1`, or the same as the corresponding `labels` dimension).
250
251  Returns:
252    total_cm: A `Tensor` representing the confusion matrix.
253    update_op: An operation that increments the confusion matrix.
254  """
255  # Local variable to accumulate the predictions in the confusion matrix.
256  total_cm = metric_variable(
257      [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
258
259  # Cast the type to int64 required by confusion_matrix_ops.
260  predictions = math_ops.cast(predictions, dtypes.int64)
261  labels = math_ops.cast(labels, dtypes.int64)
262  num_classes = math_ops.cast(num_classes, dtypes.int64)
263
264  # Flatten the input if its rank > 1.
265  if predictions.get_shape().ndims > 1:
266    predictions = array_ops.reshape(predictions, [-1])
267
268  if labels.get_shape().ndims > 1:
269    labels = array_ops.reshape(labels, [-1])
270
271  if (weights is not None) and (weights.get_shape().ndims > 1):
272    weights = array_ops.reshape(weights, [-1])
273
274  # Accumulate the prediction to current confusion matrix.
275  current_cm = confusion_matrix.confusion_matrix(
276      labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
277  update_op = state_ops.assign_add(total_cm, current_cm)
278  return total_cm, update_op
279
280
281def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
282  """Aggregate metric value across replicas."""
283  def fn(distribution, *a):
284    """Call `metric_value_fn` in the correct control flow context."""
285    if hasattr(distribution.extended, '_outer_control_flow_context'):
286      # If there was an outer context captured before this method was called,
287      # then we enter that context to create the metric value op. If the
288      # captured context is `None`, ops.control_dependencies(None) gives the
289      # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
290      # captured context.
291      # This special handling is needed because sometimes the metric is created
292      # inside a while_loop (and perhaps a TPU rewrite context). But we don't
293      # want the value op to be evaluated every step or on the TPU. So we
294      # create it outside so that it can be evaluated at the end on the host,
295      # once the update ops have been evaluated.
296
297      # pylint: disable=protected-access
298      if distribution.extended._outer_control_flow_context is None:
299        with ops.control_dependencies(None):
300          metric_value = metric_value_fn(distribution, *a)
301      else:
302        distribution.extended._outer_control_flow_context.Enter()
303        metric_value = metric_value_fn(distribution, *a)
304        distribution.extended._outer_control_flow_context.Exit()
305        # pylint: enable=protected-access
306    else:
307      metric_value = metric_value_fn(distribution, *a)
308    if metrics_collections:
309      ops.add_to_collections(metrics_collections, metric_value)
310    return metric_value
311
312  return distribution_strategy_context.get_replica_context().merge_call(
313      fn, args=args)
314
315
316@tf_export(v1=['metrics.mean'])
317def mean(values,
318         weights=None,
319         metrics_collections=None,
320         updates_collections=None,
321         name=None):
322  """Computes the (weighted) mean of the given values.
323
324  The `mean` function creates two local variables, `total` and `count`
325  that are used to compute the average of `values`. This average is ultimately
326  returned as `mean` which is an idempotent operation that simply divides
327  `total` by `count`.
328
329  For estimation of the metric over a stream of data, the function creates an
330  `update_op` operation that updates these variables and returns the `mean`.
331  `update_op` increments `total` with the reduced sum of the product of `values`
332  and `weights`, and it increments `count` with the reduced sum of `weights`.
333
334  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
335
336  Args:
337    values: A `Tensor` of arbitrary dimensions.
338    weights: Optional `Tensor` whose rank is either 0, or the same rank as
339      `values`, and must be broadcastable to `values` (i.e., all dimensions must
340      be either `1`, or the same as the corresponding `values` dimension).
341    metrics_collections: An optional list of collections that `mean`
342      should be added to.
343    updates_collections: An optional list of collections that `update_op`
344      should be added to.
345    name: An optional variable_scope name.
346
347  Returns:
348    mean: A `Tensor` representing the current mean, the value of `total` divided
349      by `count`.
350    update_op: An operation that increments the `total` and `count` variables
351      appropriately and whose value matches `mean_value`.
352
353  Raises:
354    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
355      or if either `metrics_collections` or `updates_collections` are not a list
356      or tuple.
357    RuntimeError: If eager execution is enabled.
358  """
359  if context.executing_eagerly():
360    raise RuntimeError('tf.metrics.mean is not supported when eager execution '
361                       'is enabled.')
362
363  with variable_scope.variable_scope(name, 'mean', (values, weights)):
364    values = math_ops.cast(values, dtypes.float32)
365
366    total = metric_variable([], dtypes.float32, name='total')
367    count = metric_variable([], dtypes.float32, name='count')
368
369    if weights is None:
370      num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
371    else:
372      values, _, weights = _remove_squeezable_dimensions(
373          predictions=values, labels=None, weights=weights)
374      weights = weights_broadcast_ops.broadcast_weights(
375          math_ops.cast(weights, dtypes.float32), values)
376      values = math_ops.multiply(values, weights)
377      num_values = math_ops.reduce_sum(weights)
378
379    update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
380    with ops.control_dependencies([values]):
381      update_count_op = state_ops.assign_add(count, num_values)
382
383    def compute_mean(_, t, c):
384      return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
385
386    mean_t = _aggregate_across_replicas(
387        metrics_collections, compute_mean, total, count)
388    update_op = math_ops.div_no_nan(
389        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
390
391    if updates_collections:
392      ops.add_to_collections(updates_collections, update_op)
393
394    return mean_t, update_op
395
396
397@tf_export(v1=['metrics.accuracy'])
398def accuracy(labels,
399             predictions,
400             weights=None,
401             metrics_collections=None,
402             updates_collections=None,
403             name=None):
404  """Calculates how often `predictions` matches `labels`.
405
406  The `accuracy` function creates two local variables, `total` and
407  `count` that are used to compute the frequency with which `predictions`
408  matches `labels`. This frequency is ultimately returned as `accuracy`: an
409  idempotent operation that simply divides `total` by `count`.
410
411  For estimation of the metric over a stream of data, the function creates an
412  `update_op` operation that updates these variables and returns the `accuracy`.
413  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
414  where the corresponding elements of `predictions` and `labels` match and 0.0
415  otherwise. Then `update_op` increments `total` with the reduced sum of the
416  product of `weights` and `is_correct`, and it increments `count` with the
417  reduced sum of `weights`.
418
419  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
420
421  Args:
422    labels: The ground truth values, a `Tensor` whose shape matches
423      `predictions`.
424    predictions: The predicted values, a `Tensor` of any shape.
425    weights: Optional `Tensor` whose rank is either 0, or the same rank as
426      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
427      be either `1`, or the same as the corresponding `labels` dimension).
428    metrics_collections: An optional list of collections that `accuracy` should
429      be added to.
430    updates_collections: An optional list of collections that `update_op` should
431      be added to.
432    name: An optional variable_scope name.
433
434  Returns:
435    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
436      by `count`.
437    update_op: An operation that increments the `total` and `count` variables
438      appropriately and whose value matches `accuracy`.
439
440  Raises:
441    ValueError: If `predictions` and `labels` have mismatched shapes, or if
442      `weights` is not `None` and its shape doesn't match `predictions`, or if
443      either `metrics_collections` or `updates_collections` are not a list or
444      tuple.
445    RuntimeError: If eager execution is enabled.
446
447  @compatibility(TF2)
448  `tf.compat.v1.metrics.accuracy` is not compatible with eager
449  execution or `tf.function`.
450  Please use `tf.keras.metrics.Accuracy` instead for TF2 migration. After
451  instantiating a `tf.keras.metrics.Accuracy` object, you can first call the
452  `update_state()` method to record the prediction/labels, and then call the
453  `result()` method to get the accuracy eagerly. You can also attach it to a
454  Keras model when calling the `compile` method. Please refer to [this
455  guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses)
456  for more details.
457
458  #### Structural Mapping to Native TF2
459
460  Before:
461
462  ```python
463  accuracy, update_op = tf.compat.v1.metrics.accuracy(
464    labels=labels,
465    predictions=predictions,
466    weights=weights,
467    metrics_collections=metrics_collections,
468    update_collections=update_collections,
469    name=name)
470  ```
471
472  After:
473
474  ```python
475   m = tf.keras.metrics.Accuracy(
476     name=name,
477     dtype=None)
478
479   m.update_state(
480   y_true=labels,
481   y_pred=predictions,
482   sample_weight=weights)
483
484   accuracy = m.result()
485  ```
486
487  #### How to Map Arguments
488
489  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
490  | :-------------------- | :-------------- | :------------------------- |
491  | `label`               | `y_true`        | In `update_state()` method |
492  | `predictions`         | `y_true`        | In `update_state()` method |
493  | `weights`             | `sample_weight` | In `update_state()` method |
494  | `metrics_collections` | Not supported   | Metrics should be tracked  |
495  :                       :                 : explicitly or with Keras   :
496  :                       :                 : APIs, for example,         :
497  :                       :                 : [add_metric][add_metric],  :
498  :                       :                 : instead of via collections :
499  | `updates_collections` | Not supported   | -                          |
500  | `name`                | `name`          | In constructor             |
501
502  [add_metric]:https//www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric
503
504
505  #### Before & After Usage Example
506
507  Before:
508
509  >>> g = tf.Graph()
510  >>> with g.as_default():
511  ...   logits = [1, 2, 3]
512  ...   labels = [0, 2, 3]
513  ...   acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels)
514  ...   global_init = tf.compat.v1.global_variables_initializer()
515  ...   local_init = tf.compat.v1.local_variables_initializer()
516  >>> sess = tf.compat.v1.Session(graph=g)
517  >>> sess.run([global_init, local_init])
518  >>> print(sess.run([acc, acc_op]))
519  [0.0, 0.66667]
520
521
522  After:
523
524  >>> m = tf.keras.metrics.Accuracy()
525  >>> m.update_state([1, 2, 3], [0, 2, 3])
526  >>> m.result().numpy()
527  0.66667
528
529  ```python
530  # Used within Keras model
531  model.compile(optimizer='sgd',
532                loss='mse',
533                metrics=[tf.keras.metrics.Accuracy()])
534  ```
535
536  @end_compatibility
537  """
538  if context.executing_eagerly():
539    raise RuntimeError('tf.metrics.accuracy is not supported when eager '
540                       'execution is enabled.')
541
542  predictions, labels, weights = _remove_squeezable_dimensions(
543      predictions=predictions, labels=labels, weights=weights)
544  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
545  if labels.dtype != predictions.dtype:
546    predictions = math_ops.cast(predictions, labels.dtype)
547  is_correct = math_ops.cast(
548      math_ops.equal(predictions, labels), dtypes.float32)
549  return mean(is_correct, weights, metrics_collections, updates_collections,
550              name or 'accuracy')
551
552
553def _confusion_matrix_at_thresholds(labels,
554                                    predictions,
555                                    thresholds,
556                                    weights=None,
557                                    includes=None):
558  """Computes true_positives, false_negatives, true_negatives, false_positives.
559
560  This function creates up to four local variables, `true_positives`,
561  `true_negatives`, `false_positives` and `false_negatives`.
562  `true_positive[i]` is defined as the total weight of values in `predictions`
563  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
564  `false_negatives[i]` is defined as the total weight of values in `predictions`
565  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
566  `true_negatives[i]` is defined as the total weight of values in `predictions`
567  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
568  `false_positives[i]` is defined as the total weight of values in `predictions`
569  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
570
571  For estimation of these metrics over a stream of data, for each metric the
572  function respectively creates an `update_op` operation that updates the
573  variable and returns its value.
574
575  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
576
577  Args:
578    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
579      `bool`.
580    predictions: A floating point `Tensor` of arbitrary shape and whose values
581      are in the range `[0, 1]`.
582    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
583    weights: Optional `Tensor` whose rank is either 0, or the same rank as
584      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
585      be either `1`, or the same as the corresponding `labels` dimension).
586    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
587        default to all four.
588
589  Returns:
590    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
591        `includes`.
592    update_ops: Dict of operations that increments the `values`. Keys are from
593        `includes`.
594
595  Raises:
596    ValueError: If `predictions` and `labels` have mismatched shapes, or if
597      `weights` is not `None` and its shape doesn't match `predictions`, or if
598      `includes` contains invalid keys.
599  """
600  all_includes = ('tp', 'fn', 'tn', 'fp')
601  if includes is None:
602    includes = all_includes
603  else:
604    for include in includes:
605      if include not in all_includes:
606        raise ValueError(f'Invalid key: {include}')
607
608  with ops.control_dependencies([
609      check_ops.assert_greater_equal(
610          predictions,
611          math_ops.cast(0.0, dtype=predictions.dtype),
612          message='predictions must be in [0, 1]'),
613      check_ops.assert_less_equal(
614          predictions,
615          math_ops.cast(1.0, dtype=predictions.dtype),
616          message='predictions must be in [0, 1]')
617  ]):
618    predictions, labels, weights = _remove_squeezable_dimensions(
619        predictions=math_ops.cast(predictions, dtypes.float32),
620        labels=math_ops.cast(labels, dtype=dtypes.bool),
621        weights=weights)
622
623  num_thresholds = len(thresholds)
624
625  # Reshape predictions and labels.
626  predictions_2d = array_ops.reshape(predictions, [-1, 1])
627  labels_2d = array_ops.reshape(
628      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
629
630  # Use static shape if known.
631  num_predictions = predictions_2d.get_shape().as_list()[0]
632
633  # Otherwise use dynamic shape.
634  if num_predictions is None:
635    num_predictions = array_ops.shape(predictions_2d)[0]
636  thresh_tiled = array_ops.tile(
637      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
638      array_ops.stack([1, num_predictions]))
639
640  # Tile the predictions after thresholding them across different thresholds.
641  pred_is_pos = math_ops.greater(
642      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
643      thresh_tiled)
644  if ('fn' in includes) or ('tn' in includes):
645    pred_is_neg = math_ops.logical_not(pred_is_pos)
646
647  # Tile labels by number of thresholds
648  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
649  if ('fp' in includes) or ('tn' in includes):
650    label_is_neg = math_ops.logical_not(label_is_pos)
651
652  if weights is not None:
653    weights = weights_broadcast_ops.broadcast_weights(
654        math_ops.cast(weights, dtypes.float32), predictions)
655    weights_tiled = array_ops.tile(
656        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
657    thresh_tiled.get_shape().assert_is_compatible_with(
658        weights_tiled.get_shape())
659  else:
660    weights_tiled = None
661
662  values = {}
663  update_ops = {}
664
665  if 'tp' in includes:
666    true_p = metric_variable(
667        [num_thresholds], dtypes.float32, name='true_positives')
668    is_true_positive = math_ops.cast(
669        math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
670    if weights_tiled is not None:
671      is_true_positive *= weights_tiled
672    update_ops['tp'] = state_ops.assign_add(true_p,
673                                            math_ops.reduce_sum(
674                                                is_true_positive, 1))
675    values['tp'] = true_p
676
677  if 'fn' in includes:
678    false_n = metric_variable(
679        [num_thresholds], dtypes.float32, name='false_negatives')
680    is_false_negative = math_ops.cast(
681        math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
682    if weights_tiled is not None:
683      is_false_negative *= weights_tiled
684    update_ops['fn'] = state_ops.assign_add(false_n,
685                                            math_ops.reduce_sum(
686                                                is_false_negative, 1))
687    values['fn'] = false_n
688
689  if 'tn' in includes:
690    true_n = metric_variable(
691        [num_thresholds], dtypes.float32, name='true_negatives')
692    is_true_negative = math_ops.cast(
693        math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
694    if weights_tiled is not None:
695      is_true_negative *= weights_tiled
696    update_ops['tn'] = state_ops.assign_add(true_n,
697                                            math_ops.reduce_sum(
698                                                is_true_negative, 1))
699    values['tn'] = true_n
700
701  if 'fp' in includes:
702    false_p = metric_variable(
703        [num_thresholds], dtypes.float32, name='false_positives')
704    is_false_positive = math_ops.cast(
705        math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
706    if weights_tiled is not None:
707      is_false_positive *= weights_tiled
708    update_ops['fp'] = state_ops.assign_add(false_p,
709                                            math_ops.reduce_sum(
710                                                is_false_positive, 1))
711    values['fp'] = false_p
712
713  return values, update_ops
714
715
716def _aggregate_variable(v, collections):
717  f = lambda distribution, value: distribution.extended.read_var(value)
718  return _aggregate_across_replicas(collections, f, v)
719
720
721@tf_export(v1=['metrics.auc'])
722@deprecated(None,
723            'The value of AUC returned by this may race with the update so '
724            'this is deprecated. Please use tf.keras.metrics.AUC instead.')
725def auc(labels,
726        predictions,
727        weights=None,
728        num_thresholds=200,
729        metrics_collections=None,
730        updates_collections=None,
731        curve='ROC',
732        name=None,
733        summation_method='trapezoidal',
734        thresholds=None):
735  """Computes the approximate AUC via a Riemann sum.
736
737  The `auc` function creates four local variables, `true_positives`,
738  `true_negatives`, `false_positives` and `false_negatives` that are used to
739  compute the AUC. To discretize the AUC curve, a linearly spaced set of
740  thresholds is used to compute pairs of recall and precision values. The area
741  under the ROC-curve is therefore computed using the height of the recall
742  values by the false positive rate, while the area under the PR-curve is the
743  computed using the height of the precision values by the recall.
744
745  This value is ultimately returned as `auc`, an idempotent operation that
746  computes the area under a discretized curve of precision versus recall values
747  (computed using the aforementioned variables). The `num_thresholds` variable
748  controls the degree of discretization with larger numbers of thresholds more
749  closely approximating the true AUC. The quality of the approximation may vary
750  dramatically depending on `num_thresholds`.
751
752  For best results, `predictions` should be distributed approximately uniformly
753  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
754  approximation may be poor if this is not the case. Setting `summation_method`
755  to 'minoring' or 'majoring' can help quantify the error in the approximation
756  by providing lower or upper bound estimate of the AUC. The `thresholds`
757  parameter can be used to manually specify thresholds which split the
758  predictions more evenly.
759
760  For estimation of the metric over a stream of data, the function creates an
761  `update_op` operation that updates these variables and returns the `auc`.
762
763  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
764
765  Args:
766    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
767      `bool`.
768    predictions: A floating point `Tensor` of arbitrary shape and whose values
769      are in the range `[0, 1]`.
770    weights: Optional `Tensor` whose rank is either 0, or the same rank as
771      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
772      be either `1`, or the same as the corresponding `labels` dimension).
773    num_thresholds: The number of thresholds to use when discretizing the roc
774      curve.
775    metrics_collections: An optional list of collections that `auc` should be
776      added to.
777    updates_collections: An optional list of collections that `update_op` should
778      be added to.
779    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
780      'PR' for the Precision-Recall-curve.
781    name: An optional variable_scope name.
782    summation_method: Specifies the Riemann summation method used
783      (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
784      applies the trapezoidal rule; 'careful_interpolation', a variant of it
785      differing only by a more correct interpolation scheme for PR-AUC -
786      interpolating (true/false) positives but not the ratio that is precision;
787      'minoring' that applies left summation for increasing intervals and right
788      summation for decreasing intervals; 'majoring' that does the opposite.
789      Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
790      (to be deprecated soon) as it applies the same method for ROC, and a
791      better one (see Davis & Goadrich 2006 for details) for the PR curve.
792    thresholds: An optional list of floating point values to use as the
793      thresholds for discretizing the curve. If set, the `num_thresholds`
794      parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
795      equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
796      automatically included with these to correctly handle predictions equal to
797       exactly 0 or 1.
798
799  Returns:
800    auc: A scalar `Tensor` representing the current area-under-curve.
801    update_op: An operation that increments the `true_positives`,
802      `true_negatives`, `false_positives` and `false_negatives` variables
803      appropriately and whose value matches `auc`.
804
805  Raises:
806    ValueError: If `predictions` and `labels` have mismatched shapes, or if
807      `weights` is not `None` and its shape doesn't match `predictions`, or if
808      either `metrics_collections` or `updates_collections` are not a list or
809      tuple.
810    RuntimeError: If eager execution is enabled.
811  """
812  if context.executing_eagerly():
813    raise RuntimeError('tf.metrics.auc is not supported when eager execution '
814                       'is enabled.')
815
816  with variable_scope.variable_scope(name, 'auc',
817                                     (labels, predictions, weights)):
818    if curve != 'ROC' and curve != 'PR':
819      raise ValueError(f'Curve must be either ROC or PR. Curve {curve} is '
820                       'unknown.')
821
822    kepsilon = 1e-7  # To account for floating point imprecisions.
823    if thresholds is not None:
824      # If specified, use the supplied thresholds.
825      thresholds = sorted(thresholds)
826      num_thresholds = len(thresholds) + 2
827    else:
828      # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
829      # (0, 1).
830      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
831                    for i in range(num_thresholds - 2)]
832
833    # Add an endpoint "threshold" below zero and above one for either threshold
834    # method.
835    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
836
837    values, update_ops = _confusion_matrix_at_thresholds(
838        labels, predictions, thresholds, weights)
839
840    # Add epsilons to avoid dividing by 0.
841    epsilon = 1.0e-6
842
843    def interpolate_pr_auc(tp, fp, fn):
844      """Interpolation formula inspired by section 4 of (Davis et al., 2006).
845
846      Note here we derive & use a closed formula not present in the paper
847      - as follows:
848      Modeling all of TP (true positive weight),
849      FP (false positive weight) and their sum P = TP + FP (positive weight)
850      as varying linearly within each interval [A, B] between successive
851      thresholds, we get
852        Precision = (TP_A + slope * (P - P_A)) / P
853      with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
854      The area within the interval is thus (slope / total_pos_weight) times
855        int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
856        int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
857      where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
858        int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
859      Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
860         slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
861      where dTP == TP_B - TP_A.
862      Note that when P_A == 0 the above calculation simplifies into
863        int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
864      which is really equivalent to imputing constant precision throughout the
865      first bucket having >0 true positives.
866
867      Args:
868        tp: true positive counts
869        fp: false positive counts
870        fn: false negative counts
871
872      Returns:
873        pr_auc: an approximation of the area under the P-R curve.
874
875      References:
876        The Relationship Between Precision-Recall and ROC Curves:
877          [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874)
878          ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf))
879      """
880      dtp = tp[:num_thresholds - 1] - tp[1:]
881      p = tp + fp
882      prec_slope = math_ops.div_no_nan(
883          dtp,
884          math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
885          name='prec_slope')
886      intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
887      safe_p_ratio = array_ops.where(
888          math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
889          math_ops.div_no_nan(
890              p[:num_thresholds - 1],
891              math_ops.maximum(p[1:], 0),
892              name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
893      return math_ops.reduce_sum(
894          math_ops.div_no_nan(
895              prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
896              math_ops.maximum(tp[1:] + fn[1:], 0),
897              name='pr_auc_increment'),
898          name='interpolate_pr_auc')
899
900    def compute_auc(tp, fn, tn, fp, name):
901      """Computes the roc-auc or pr-auc based on confusion counts."""
902      if curve == 'PR':
903        if summation_method == 'trapezoidal':
904          logging.warning(
905              'Trapezoidal rule is known to produce incorrect PR-AUCs; '
906              'please switch to "careful_interpolation" instead.')
907        elif summation_method == 'careful_interpolation':
908          # This one is a bit tricky and is handled separately.
909          return interpolate_pr_auc(tp, fp, fn)
910      rec = math_ops.divide(tp + epsilon, tp + fn + epsilon)
911      if curve == 'ROC':
912        fp_rate = math_ops.divide(fp, fp + tn + epsilon)
913        x = fp_rate
914        y = rec
915      else:  # curve == 'PR'.
916        prec = math_ops.divide(tp + epsilon, tp + fp + epsilon)
917        x = rec
918        y = prec
919      if summation_method in ('trapezoidal', 'careful_interpolation'):
920        # Note that the case ('PR', 'careful_interpolation') has been handled
921        # above.
922        return math_ops.reduce_sum(
923            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
924                              (y[:num_thresholds - 1] + y[1:]) / 2.),
925            name=name)
926      elif summation_method == 'minoring':
927        return math_ops.reduce_sum(
928            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
929                              math_ops.minimum(y[:num_thresholds - 1], y[1:])),
930            name=name)
931      elif summation_method == 'majoring':
932        return math_ops.reduce_sum(
933            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
934                              math_ops.maximum(y[:num_thresholds - 1], y[1:])),
935            name=name)
936      else:
937        raise ValueError(f'Invalid summation_method: {summation_method} '
938                         'summation_method should be \'trapezoidal\', '
939                         '\'careful_interpolation\', \'minoring\', or '
940                         '\'majoring\'.')
941
942    # sum up the areas of all the trapeziums
943    def compute_auc_value(_, values):
944      return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
945                         'value')
946
947    auc_value = _aggregate_across_replicas(
948        metrics_collections, compute_auc_value, values)
949    update_op = compute_auc(update_ops['tp'], update_ops['fn'],
950                            update_ops['tn'], update_ops['fp'], 'update_op')
951
952    if updates_collections:
953      ops.add_to_collections(updates_collections, update_op)
954
955    return auc_value, update_op
956
957
958@tf_export(v1=['metrics.mean_absolute_error'])
959def mean_absolute_error(labels,
960                        predictions,
961                        weights=None,
962                        metrics_collections=None,
963                        updates_collections=None,
964                        name=None):
965  """Computes the mean absolute error between the labels and predictions.
966
967  The `mean_absolute_error` function creates two local variables,
968  `total` and `count` that are used to compute the mean absolute error. This
969  average is weighted by `weights`, and it is ultimately returned as
970  `mean_absolute_error`: an idempotent operation that simply divides `total` by
971  `count`.
972
973  For estimation of the metric over a stream of data, the function creates an
974  `update_op` operation that updates these variables and returns the
975  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
976  absolute value of the differences between `predictions` and `labels`. Then
977  `update_op` increments `total` with the reduced sum of the product of
978  `weights` and `absolute_errors`, and it increments `count` with the reduced
979  sum of `weights`
980
981  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
982
983  Args:
984    labels: A `Tensor` of the same shape as `predictions`.
985    predictions: A `Tensor` of arbitrary shape.
986    weights: Optional `Tensor` whose rank is either 0, or the same rank as
987      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
988      be either `1`, or the same as the corresponding `labels` dimension).
989    metrics_collections: An optional list of collections that
990      `mean_absolute_error` should be added to.
991    updates_collections: An optional list of collections that `update_op` should
992      be added to.
993    name: An optional variable_scope name.
994
995  Returns:
996    mean_absolute_error: A `Tensor` representing the current mean, the value of
997      `total` divided by `count`.
998    update_op: An operation that increments the `total` and `count` variables
999      appropriately and whose value matches `mean_absolute_error`.
1000
1001  Raises:
1002    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1003      `weights` is not `None` and its shape doesn't match `predictions`, or if
1004      either `metrics_collections` or `updates_collections` are not a list or
1005      tuple.
1006    RuntimeError: If eager execution is enabled.
1007  """
1008  if context.executing_eagerly():
1009    raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
1010                       'when eager execution is enabled.')
1011
1012  predictions, labels, weights = _remove_squeezable_dimensions(
1013      predictions=predictions, labels=labels, weights=weights)
1014  absolute_errors = math_ops.abs(predictions - labels)
1015  return mean(absolute_errors, weights, metrics_collections,
1016              updates_collections, name or 'mean_absolute_error')
1017
1018
1019@tf_export(v1=['metrics.mean_cosine_distance'])
1020def mean_cosine_distance(labels,
1021                         predictions,
1022                         dim,
1023                         weights=None,
1024                         metrics_collections=None,
1025                         updates_collections=None,
1026                         name=None):
1027  """Computes the cosine distance between the labels and predictions.
1028
1029  The `mean_cosine_distance` function creates two local variables,
1030  `total` and `count` that are used to compute the average cosine distance
1031  between `predictions` and `labels`. This average is weighted by `weights`,
1032  and it is ultimately returned as `mean_distance`, which is an idempotent
1033  operation that simply divides `total` by `count`.
1034
1035  For estimation of the metric over a stream of data, the function creates an
1036  `update_op` operation that updates these variables and returns the
1037  `mean_distance`.
1038
1039  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1040
1041  Args:
1042    labels: A `Tensor` of arbitrary shape.
1043    predictions: A `Tensor` of the same shape as `labels`.
1044    dim: The dimension along which the cosine distance is computed.
1045    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1046      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1047      be either `1`, or the same as the corresponding `labels` dimension). Also,
1048      dimension `dim` must be `1`.
1049    metrics_collections: An optional list of collections that the metric
1050      value variable should be added to.
1051    updates_collections: An optional list of collections that the metric update
1052      ops should be added to.
1053    name: An optional variable_scope name.
1054
1055  Returns:
1056    mean_distance: A `Tensor` representing the current mean, the value of
1057      `total` divided by `count`.
1058    update_op: An operation that increments the `total` and `count` variables
1059      appropriately.
1060
1061  Raises:
1062    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1063      `weights` is not `None` and its shape doesn't match `predictions`, or if
1064      either `metrics_collections` or `updates_collections` are not a list or
1065      tuple.
1066    RuntimeError: If eager execution is enabled.
1067  """
1068  if context.executing_eagerly():
1069    raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
1070                       'eager execution is enabled.')
1071
1072  predictions, labels, weights = _remove_squeezable_dimensions(
1073      predictions=predictions, labels=labels, weights=weights)
1074  radial_diffs = math_ops.multiply(predictions, labels)
1075  radial_diffs = math_ops.reduce_sum(
1076      radial_diffs, axis=[
1077          dim,
1078      ], keepdims=True)
1079  mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
1080                                  'mean_cosine_distance')
1081  mean_distance = math_ops.subtract(1.0, mean_distance)
1082  update_op = math_ops.subtract(1.0, update_op)
1083
1084  if metrics_collections:
1085    ops.add_to_collections(metrics_collections, mean_distance)
1086
1087  if updates_collections:
1088    ops.add_to_collections(updates_collections, update_op)
1089
1090  return mean_distance, update_op
1091
1092
1093@tf_export(v1=['metrics.mean_per_class_accuracy'])
1094def mean_per_class_accuracy(labels,
1095                            predictions,
1096                            num_classes,
1097                            weights=None,
1098                            metrics_collections=None,
1099                            updates_collections=None,
1100                            name=None):
1101  """Calculates the mean of the per-class accuracies.
1102
1103  Calculates the accuracy for each class, then takes the mean of that.
1104
1105  For estimation of the metric over a stream of data, the function creates an
1106  `update_op` operation that updates the accuracy of each class and returns
1107  them.
1108
1109  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1110
1111  Args:
1112    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1113      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1114    predictions: A `Tensor` of prediction results for semantic labels, whose
1115      shape is [batch size] and type `int32` or `int64`. The tensor will be
1116      flattened if its rank > 1.
1117    num_classes: The possible number of labels the prediction task can
1118      have. This value must be provided, since two variables with shape =
1119      [num_classes] will be allocated.
1120    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1121      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1122      be either `1`, or the same as the corresponding `labels` dimension).
1123    metrics_collections: An optional list of collections that
1124      `mean_per_class_accuracy'
1125      should be added to.
1126    updates_collections: An optional list of collections `update_op` should be
1127      added to.
1128    name: An optional variable_scope name.
1129
1130  Returns:
1131    mean_accuracy: A `Tensor` representing the mean per class accuracy.
1132    update_op: An operation that updates the accuracy tensor.
1133
1134  Raises:
1135    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1136      `weights` is not `None` and its shape doesn't match `predictions`, or if
1137      either `metrics_collections` or `updates_collections` are not a list or
1138      tuple.
1139    RuntimeError: If eager execution is enabled.
1140  """
1141  if context.executing_eagerly():
1142    raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1143                       'when eager execution is enabled.')
1144
1145  with variable_scope.variable_scope(name, 'mean_accuracy',
1146                                     (predictions, labels, weights)):
1147    labels = math_ops.cast(labels, dtypes.int64)
1148
1149    # Flatten the input if its rank > 1.
1150    if labels.get_shape().ndims > 1:
1151      labels = array_ops.reshape(labels, [-1])
1152
1153    if predictions.get_shape().ndims > 1:
1154      predictions = array_ops.reshape(predictions, [-1])
1155
1156    # Check if shape is compatible.
1157    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1158
1159    total = metric_variable([num_classes], dtypes.float32, name='total')
1160    count = metric_variable([num_classes], dtypes.float32, name='count')
1161
1162    ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1163
1164    if labels.dtype != predictions.dtype:
1165      predictions = math_ops.cast(predictions, labels.dtype)
1166    is_correct = math_ops.cast(
1167        math_ops.equal(predictions, labels), dtypes.float32)
1168
1169    if weights is not None:
1170      if weights.get_shape().ndims > 1:
1171        weights = array_ops.reshape(weights, [-1])
1172      weights = math_ops.cast(weights, dtypes.float32)
1173
1174      is_correct *= weights
1175      ones *= weights
1176
1177    update_total_op = state_ops.scatter_add(total, labels, ones)
1178    update_count_op = state_ops.scatter_add(count, labels, is_correct)
1179
1180    def compute_mean_accuracy(_, count, total):
1181      per_class_accuracy = math_ops.div_no_nan(
1182          count, math_ops.maximum(total, 0), name=None)
1183      mean_accuracy_v = math_ops.reduce_mean(
1184          per_class_accuracy, name='mean_accuracy')
1185      return mean_accuracy_v
1186
1187    mean_accuracy_v = _aggregate_across_replicas(
1188        metrics_collections, compute_mean_accuracy, count, total)
1189
1190    update_op = math_ops.div_no_nan(
1191        update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
1192    if updates_collections:
1193      ops.add_to_collections(updates_collections, update_op)
1194
1195    return mean_accuracy_v, update_op
1196
1197
1198@tf_export(v1=['metrics.mean_iou'])
1199def mean_iou(labels,
1200             predictions,
1201             num_classes,
1202             weights=None,
1203             metrics_collections=None,
1204             updates_collections=None,
1205             name=None):
1206  """Calculate per-step mean Intersection-Over-Union (mIOU).
1207
1208  Mean Intersection-Over-Union is a common evaluation metric for
1209  semantic image segmentation, which first computes the IOU for each
1210  semantic class and then computes the average over classes.
1211  IOU is defined as follows:
1212    IOU = true_positive / (true_positive + false_positive + false_negative).
1213  The predictions are accumulated in a confusion matrix, weighted by `weights`,
1214  and mIOU is then calculated from it.
1215
1216  For estimation of the metric over a stream of data, the function creates an
1217  `update_op` operation that updates these variables and returns the `mean_iou`.
1218
1219  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1220
1221  Args:
1222    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1223      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1224    predictions: A `Tensor` of prediction results for semantic labels, whose
1225      shape is [batch size] and type `int32` or `int64`. The tensor will be
1226      flattened if its rank > 1.
1227    num_classes: The possible number of labels the prediction task can
1228      have. This value must be provided, since a confusion matrix of
1229      dimension = [num_classes, num_classes] will be allocated.
1230    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1231      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1232      be either `1`, or the same as the corresponding `labels` dimension).
1233    metrics_collections: An optional list of collections that `mean_iou`
1234      should be added to.
1235    updates_collections: An optional list of collections `update_op` should be
1236      added to.
1237    name: An optional variable_scope name.
1238
1239  Returns:
1240    mean_iou: A `Tensor` representing the mean intersection-over-union.
1241    update_op: An operation that increments the confusion matrix.
1242
1243  Raises:
1244    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1245      `weights` is not `None` and its shape doesn't match `predictions`, or if
1246      either `metrics_collections` or `updates_collections` are not a list or
1247      tuple.
1248    RuntimeError: If eager execution is enabled.
1249  """
1250  if context.executing_eagerly():
1251    raise RuntimeError('tf.metrics.mean_iou is not supported when '
1252                       'eager execution is enabled.')
1253
1254  with variable_scope.variable_scope(name, 'mean_iou',
1255                                     (predictions, labels, weights)):
1256    # Check if shape is compatible.
1257    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1258
1259    total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1260                                                      num_classes, weights)
1261
1262    def compute_mean_iou(_, total_cm):
1263      """Compute the mean intersection-over-union via the confusion matrix."""
1264      sum_over_row = math_ops.cast(
1265          math_ops.reduce_sum(total_cm, 0), dtypes.float32)
1266      sum_over_col = math_ops.cast(
1267          math_ops.reduce_sum(total_cm, 1), dtypes.float32)
1268      cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32)
1269      denominator = sum_over_row + sum_over_col - cm_diag
1270
1271      # The mean is only computed over classes that appear in the
1272      # label or prediction tensor. If the denominator is 0, we need to
1273      # ignore the class.
1274      num_valid_entries = math_ops.reduce_sum(
1275          math_ops.cast(
1276              math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1277
1278      # If the value of the denominator is 0, set it to 1 to avoid
1279      # zero division.
1280      denominator = array_ops.where(
1281          math_ops.greater(denominator, 0), denominator,
1282          array_ops.ones_like(denominator))
1283      iou = math_ops.divide(cm_diag, denominator)
1284
1285      # If the number of valid entries is 0 (no classes) we return 0.
1286      result = array_ops.where(
1287          math_ops.greater(num_valid_entries, 0),
1288          math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1289      return result
1290
1291    # TODO(priyag): Use outside_compilation if in TPU context.
1292    mean_iou_v = _aggregate_across_replicas(
1293        metrics_collections, compute_mean_iou, total_cm)
1294
1295    if updates_collections:
1296      ops.add_to_collections(updates_collections, update_op)
1297
1298    return mean_iou_v, update_op
1299
1300
1301@tf_export(v1=['metrics.mean_relative_error'])
1302def mean_relative_error(labels,
1303                        predictions,
1304                        normalizer,
1305                        weights=None,
1306                        metrics_collections=None,
1307                        updates_collections=None,
1308                        name=None):
1309  """Computes the mean relative error by normalizing with the given values.
1310
1311  The `mean_relative_error` function creates two local variables,
1312  `total` and `count` that are used to compute the mean relative absolute error.
1313  This average is weighted by `weights`, and it is ultimately returned as
1314  `mean_relative_error`: an idempotent operation that simply divides `total` by
1315  `count`.
1316
1317  For estimation of the metric over a stream of data, the function creates an
1318  `update_op` operation that updates these variables and returns the
1319  `mean_reative_error`. Internally, a `relative_errors` operation divides the
1320  absolute value of the differences between `predictions` and `labels` by the
1321  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1322  product of `weights` and `relative_errors`, and it increments `count` with the
1323  reduced sum of `weights`.
1324
1325  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1326
1327  Args:
1328    labels: A `Tensor` of the same shape as `predictions`.
1329    predictions: A `Tensor` of arbitrary shape.
1330    normalizer: A `Tensor` of the same shape as `predictions`.
1331    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1332      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1333      be either `1`, or the same as the corresponding `labels` dimension).
1334    metrics_collections: An optional list of collections that
1335      `mean_relative_error` should be added to.
1336    updates_collections: An optional list of collections that `update_op` should
1337      be added to.
1338    name: An optional variable_scope name.
1339
1340  Returns:
1341    mean_relative_error: A `Tensor` representing the current mean, the value of
1342      `total` divided by `count`.
1343    update_op: An operation that increments the `total` and `count` variables
1344      appropriately and whose value matches `mean_relative_error`.
1345
1346  Raises:
1347    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1348      `weights` is not `None` and its shape doesn't match `predictions`, or if
1349      either `metrics_collections` or `updates_collections` are not a list or
1350      tuple.
1351    RuntimeError: If eager execution is enabled.
1352  """
1353  if context.executing_eagerly():
1354    raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1355                       'eager execution is enabled.')
1356
1357  predictions, labels, weights = _remove_squeezable_dimensions(
1358      predictions=predictions, labels=labels, weights=weights)
1359
1360  predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1361      predictions, normalizer)
1362  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1363  relative_errors = array_ops.where(
1364      math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1365      math_ops.divide(math_ops.abs(labels - predictions), normalizer))
1366  return mean(relative_errors, weights, metrics_collections,
1367              updates_collections, name or 'mean_relative_error')
1368
1369
1370@tf_export(v1=['metrics.mean_squared_error'])
1371def mean_squared_error(labels,
1372                       predictions,
1373                       weights=None,
1374                       metrics_collections=None,
1375                       updates_collections=None,
1376                       name=None):
1377  """Computes the mean squared error between the labels and predictions.
1378
1379  The `mean_squared_error` function creates two local variables,
1380  `total` and `count` that are used to compute the mean squared error.
1381  This average is weighted by `weights`, and it is ultimately returned as
1382  `mean_squared_error`: an idempotent operation that simply divides `total` by
1383  `count`.
1384
1385  For estimation of the metric over a stream of data, the function creates an
1386  `update_op` operation that updates these variables and returns the
1387  `mean_squared_error`. Internally, a `squared_error` operation computes the
1388  element-wise square of the difference between `predictions` and `labels`. Then
1389  `update_op` increments `total` with the reduced sum of the product of
1390  `weights` and `squared_error`, and it increments `count` with the reduced sum
1391  of `weights`.
1392
1393  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1394
1395  Args:
1396    labels: A `Tensor` of the same shape as `predictions`.
1397    predictions: A `Tensor` of arbitrary shape.
1398    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1399      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1400      be either `1`, or the same as the corresponding `labels` dimension).
1401    metrics_collections: An optional list of collections that
1402      `mean_squared_error` should be added to.
1403    updates_collections: An optional list of collections that `update_op` should
1404      be added to.
1405    name: An optional variable_scope name.
1406
1407  Returns:
1408    mean_squared_error: A `Tensor` representing the current mean, the value of
1409      `total` divided by `count`.
1410    update_op: An operation that increments the `total` and `count` variables
1411      appropriately and whose value matches `mean_squared_error`.
1412
1413  Raises:
1414    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1415      `weights` is not `None` and its shape doesn't match `predictions`, or if
1416      either `metrics_collections` or `updates_collections` are not a list or
1417      tuple.
1418    RuntimeError: If eager execution is enabled.
1419  """
1420  if context.executing_eagerly():
1421    raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1422                       'eager execution is enabled.')
1423
1424  predictions, labels, weights = _remove_squeezable_dimensions(
1425      predictions=predictions, labels=labels, weights=weights)
1426  squared_error = math_ops.squared_difference(labels, predictions)
1427  return mean(squared_error, weights, metrics_collections, updates_collections,
1428              name or 'mean_squared_error')
1429
1430
1431@tf_export(v1=['metrics.mean_tensor'])
1432def mean_tensor(values,
1433                weights=None,
1434                metrics_collections=None,
1435                updates_collections=None,
1436                name=None):
1437  """Computes the element-wise (weighted) mean of the given tensors.
1438
1439  In contrast to the `mean` function which returns a scalar with the
1440  mean,  this function returns an average tensor with the same shape as the
1441  input tensors.
1442
1443  The `mean_tensor` function creates two local variables,
1444  `total_tensor` and `count_tensor` that are used to compute the average of
1445  `values`. This average is ultimately returned as `mean` which is an idempotent
1446  operation that simply divides `total` by `count`.
1447
1448  For estimation of the metric over a stream of data, the function creates an
1449  `update_op` operation that updates these variables and returns the `mean`.
1450  `update_op` increments `total` with the reduced sum of the product of `values`
1451  and `weights`, and it increments `count` with the reduced sum of `weights`.
1452
1453  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1454
1455  Args:
1456    values: A `Tensor` of arbitrary dimensions.
1457    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1458      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1459      be either `1`, or the same as the corresponding `values` dimension).
1460    metrics_collections: An optional list of collections that `mean`
1461      should be added to.
1462    updates_collections: An optional list of collections that `update_op`
1463      should be added to.
1464    name: An optional variable_scope name.
1465
1466  Returns:
1467    mean: A float `Tensor` representing the current mean, the value of `total`
1468      divided by `count`.
1469    update_op: An operation that increments the `total` and `count` variables
1470      appropriately and whose value matches `mean_value`.
1471
1472  Raises:
1473    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1474      or if either `metrics_collections` or `updates_collections` are not a list
1475      or tuple.
1476    RuntimeError: If eager execution is enabled.
1477  """
1478  if context.executing_eagerly():
1479    raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1480                       'eager execution is enabled.')
1481
1482  with variable_scope.variable_scope(name, 'mean', (values, weights)):
1483    values = math_ops.cast(values, dtypes.float32)
1484    total = metric_variable(
1485        values.get_shape(), dtypes.float32, name='total_tensor')
1486    count = metric_variable(
1487        values.get_shape(), dtypes.float32, name='count_tensor')
1488
1489    num_values = array_ops.ones_like(values)
1490    if weights is not None:
1491      values, _, weights = _remove_squeezable_dimensions(
1492          predictions=values, labels=None, weights=weights)
1493      weights = weights_broadcast_ops.broadcast_weights(
1494          math_ops.cast(weights, dtypes.float32), values)
1495      values = math_ops.multiply(values, weights)
1496      num_values = math_ops.multiply(num_values, weights)
1497
1498    update_total_op = state_ops.assign_add(total, values)
1499    with ops.control_dependencies([values]):
1500      update_count_op = state_ops.assign_add(count, num_values)
1501
1502    compute_mean = lambda _, t, c: math_ops.div_no_nan(  # pylint: disable=g-long-lambda
1503        t, math_ops.maximum(c, 0), name='value')
1504
1505    mean_t = _aggregate_across_replicas(
1506        metrics_collections, compute_mean, total, count)
1507
1508    update_op = math_ops.div_no_nan(
1509        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
1510    if updates_collections:
1511      ops.add_to_collections(updates_collections, update_op)
1512
1513    return mean_t, update_op
1514
1515
1516@tf_export(v1=['metrics.percentage_below'])
1517def percentage_below(values,
1518                     threshold,
1519                     weights=None,
1520                     metrics_collections=None,
1521                     updates_collections=None,
1522                     name=None):
1523  """Computes the percentage of values less than the given threshold.
1524
1525  The `percentage_below` function creates two local variables,
1526  `total` and `count` that are used to compute the percentage of `values` that
1527  fall below `threshold`. This rate is weighted by `weights`, and it is
1528  ultimately returned as `percentage` which is an idempotent operation that
1529  simply divides `total` by `count`.
1530
1531  For estimation of the metric over a stream of data, the function creates an
1532  `update_op` operation that updates these variables and returns the
1533  `percentage`.
1534
1535  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1536
1537  Args:
1538    values: A numeric `Tensor` of arbitrary size.
1539    threshold: A scalar threshold.
1540    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1541      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1542      be either `1`, or the same as the corresponding `values` dimension).
1543    metrics_collections: An optional list of collections that the metric
1544      value variable should be added to.
1545    updates_collections: An optional list of collections that the metric update
1546      ops should be added to.
1547    name: An optional variable_scope name.
1548
1549  Returns:
1550    percentage: A `Tensor` representing the current mean, the value of `total`
1551      divided by `count`.
1552    update_op: An operation that increments the `total` and `count` variables
1553      appropriately.
1554
1555  Raises:
1556    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1557      or if either `metrics_collections` or `updates_collections` are not a list
1558      or tuple.
1559    RuntimeError: If eager execution is enabled.
1560  """
1561  if context.executing_eagerly():
1562    raise RuntimeError('tf.metrics.percentage_below is not supported when '
1563                       'eager execution is enabled.')
1564
1565  is_below_threshold = math_ops.cast(
1566      math_ops.less(values, threshold), dtypes.float32)
1567  return mean(is_below_threshold, weights, metrics_collections,
1568              updates_collections, name or 'percentage_below_threshold')
1569
1570
1571def _count_condition(values,
1572                     weights=None,
1573                     metrics_collections=None,
1574                     updates_collections=None):
1575  """Sums the weights of cases where the given values are True.
1576
1577  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1578
1579  Args:
1580    values: A `bool` `Tensor` of arbitrary size.
1581    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1582      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1583      be either `1`, or the same as the corresponding `values` dimension).
1584    metrics_collections: An optional list of collections that the metric
1585      value variable should be added to.
1586    updates_collections: An optional list of collections that the metric update
1587      ops should be added to.
1588
1589  Returns:
1590    value_tensor: A `Tensor` representing the current value of the metric.
1591    update_op: An operation that accumulates the error from a batch of data.
1592
1593  Raises:
1594    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1595      or if either `metrics_collections` or `updates_collections` are not a list
1596      or tuple.
1597  """
1598  check_ops.assert_type(values, dtypes.bool)
1599  count = metric_variable([], dtypes.float32, name='count')
1600
1601  values = math_ops.cast(values, dtypes.float32)
1602  if weights is not None:
1603    with ops.control_dependencies((check_ops.assert_rank_in(
1604        weights, (0, array_ops.rank(values))),)):
1605      weights = math_ops.cast(weights, dtypes.float32)
1606      values = math_ops.multiply(values, weights)
1607
1608  value_tensor = _aggregate_variable(count, metrics_collections)
1609
1610  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
1611  if updates_collections:
1612    ops.add_to_collections(updates_collections, update_op)
1613
1614  return value_tensor, update_op
1615
1616
1617@tf_export(v1=['metrics.false_negatives'])
1618def false_negatives(labels,
1619                    predictions,
1620                    weights=None,
1621                    metrics_collections=None,
1622                    updates_collections=None,
1623                    name=None):
1624  """Computes the total number of false negatives.
1625
1626  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1627
1628  Args:
1629    labels: The ground truth values, a `Tensor` whose dimensions must match
1630      `predictions`. Will be cast to `bool`.
1631    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1632      be cast to `bool`.
1633    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1634      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1635      be either `1`, or the same as the corresponding `labels` dimension).
1636    metrics_collections: An optional list of collections that the metric
1637      value variable should be added to.
1638    updates_collections: An optional list of collections that the metric update
1639      ops should be added to.
1640    name: An optional variable_scope name.
1641
1642  Returns:
1643    value_tensor: A `Tensor` representing the current value of the metric.
1644    update_op: An operation that accumulates the error from a batch of data.
1645
1646  Raises:
1647    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1648      or if either `metrics_collections` or `updates_collections` are not a list
1649      or tuple.
1650    RuntimeError: If eager execution is enabled.
1651  """
1652  if context.executing_eagerly():
1653    raise RuntimeError('tf.metrics.false_negatives is not supported when '
1654                       'eager execution is enabled.')
1655
1656  with variable_scope.variable_scope(name, 'false_negatives',
1657                                     (predictions, labels, weights)):
1658
1659    predictions, labels, weights = _remove_squeezable_dimensions(
1660        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1661        labels=math_ops.cast(labels, dtype=dtypes.bool),
1662        weights=weights)
1663    is_false_negative = math_ops.logical_and(
1664        math_ops.equal(labels, True), math_ops.equal(predictions, False))
1665    return _count_condition(is_false_negative, weights, metrics_collections,
1666                            updates_collections)
1667
1668
1669@tf_export(v1=['metrics.false_negatives_at_thresholds'])
1670def false_negatives_at_thresholds(labels,
1671                                  predictions,
1672                                  thresholds,
1673                                  weights=None,
1674                                  metrics_collections=None,
1675                                  updates_collections=None,
1676                                  name=None):
1677  """Computes false negatives at provided threshold values.
1678
1679  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1680
1681  Args:
1682    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1683      `bool`.
1684    predictions: A floating point `Tensor` of arbitrary shape and whose values
1685      are in the range `[0, 1]`.
1686    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1687    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1688      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1689      be either `1`, or the same as the corresponding `labels` dimension).
1690    metrics_collections: An optional list of collections that `false_negatives`
1691      should be added to.
1692    updates_collections: An optional list of collections that `update_op` should
1693      be added to.
1694    name: An optional variable_scope name.
1695
1696  Returns:
1697    false_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1698    update_op: An operation that updates the `false_negatives` variable and
1699      returns its current value.
1700
1701  Raises:
1702    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1703      `weights` is not `None` and its shape doesn't match `predictions`, or if
1704      either `metrics_collections` or `updates_collections` are not a list or
1705      tuple.
1706    RuntimeError: If eager execution is enabled.
1707  """
1708  if context.executing_eagerly():
1709    raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1710                       'supported when eager execution is enabled.')
1711
1712  with variable_scope.variable_scope(name, 'false_negatives',
1713                                     (predictions, labels, weights)):
1714    values, update_ops = _confusion_matrix_at_thresholds(
1715        labels, predictions, thresholds, weights=weights, includes=('fn',))
1716
1717    fn_value = _aggregate_variable(values['fn'], metrics_collections)
1718
1719    if updates_collections:
1720      ops.add_to_collections(updates_collections, update_ops['fn'])
1721
1722    return fn_value, update_ops['fn']
1723
1724
1725@tf_export(v1=['metrics.false_positives'])
1726def false_positives(labels,
1727                    predictions,
1728                    weights=None,
1729                    metrics_collections=None,
1730                    updates_collections=None,
1731                    name=None):
1732  """Sum the weights of false positives.
1733
1734  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1735
1736  Args:
1737    labels: The ground truth values, a `Tensor` whose dimensions must match
1738      `predictions`. Will be cast to `bool`.
1739    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1740      be cast to `bool`.
1741    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1742      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1743      be either `1`, or the same as the corresponding `labels` dimension).
1744    metrics_collections: An optional list of collections that the metric
1745      value variable should be added to.
1746    updates_collections: An optional list of collections that the metric update
1747      ops should be added to.
1748    name: An optional variable_scope name.
1749
1750  Returns:
1751    value_tensor: A `Tensor` representing the current value of the metric.
1752    update_op: An operation that accumulates the error from a batch of data.
1753
1754  Raises:
1755    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1756      `weights` is not `None` and its shape doesn't match `predictions`, or if
1757      either `metrics_collections` or `updates_collections` are not a list or
1758      tuple.
1759    RuntimeError: If eager execution is enabled.
1760  """
1761  if context.executing_eagerly():
1762    raise RuntimeError('tf.metrics.false_positives is not supported when '
1763                       'eager execution is enabled.')
1764
1765  with variable_scope.variable_scope(name, 'false_positives',
1766                                     (predictions, labels, weights)):
1767
1768    predictions, labels, weights = _remove_squeezable_dimensions(
1769        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1770        labels=math_ops.cast(labels, dtype=dtypes.bool),
1771        weights=weights)
1772    is_false_positive = math_ops.logical_and(
1773        math_ops.equal(labels, False), math_ops.equal(predictions, True))
1774    return _count_condition(is_false_positive, weights, metrics_collections,
1775                            updates_collections)
1776
1777
1778@tf_export(v1=['metrics.false_positives_at_thresholds'])
1779def false_positives_at_thresholds(labels,
1780                                  predictions,
1781                                  thresholds,
1782                                  weights=None,
1783                                  metrics_collections=None,
1784                                  updates_collections=None,
1785                                  name=None):
1786  """Computes false positives at provided threshold values.
1787
1788  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1789
1790  Args:
1791    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1792      `bool`.
1793    predictions: A floating point `Tensor` of arbitrary shape and whose values
1794      are in the range `[0, 1]`.
1795    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1796    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1797      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1798      be either `1`, or the same as the corresponding `labels` dimension).
1799    metrics_collections: An optional list of collections that `false_positives`
1800      should be added to.
1801    updates_collections: An optional list of collections that `update_op` should
1802      be added to.
1803    name: An optional variable_scope name.
1804
1805  Returns:
1806    false_positives:  A float `Tensor` of shape `[len(thresholds)]`.
1807    update_op: An operation that updates the `false_positives` variable and
1808      returns its current value.
1809
1810  Raises:
1811    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1812      `weights` is not `None` and its shape doesn't match `predictions`, or if
1813      either `metrics_collections` or `updates_collections` are not a list or
1814      tuple.
1815    RuntimeError: If eager execution is enabled.
1816  """
1817  if context.executing_eagerly():
1818    raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1819                       'supported when eager execution is enabled.')
1820
1821  with variable_scope.variable_scope(name, 'false_positives',
1822                                     (predictions, labels, weights)):
1823    values, update_ops = _confusion_matrix_at_thresholds(
1824        labels, predictions, thresholds, weights=weights, includes=('fp',))
1825
1826    fp_value = _aggregate_variable(values['fp'], metrics_collections)
1827
1828    if updates_collections:
1829      ops.add_to_collections(updates_collections, update_ops['fp'])
1830
1831    return fp_value, update_ops['fp']
1832
1833
1834@tf_export(v1=['metrics.true_negatives'])
1835def true_negatives(labels,
1836                   predictions,
1837                   weights=None,
1838                   metrics_collections=None,
1839                   updates_collections=None,
1840                   name=None):
1841  """Sum the weights of true_negatives.
1842
1843  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1844
1845  Args:
1846    labels: The ground truth values, a `Tensor` whose dimensions must match
1847      `predictions`. Will be cast to `bool`.
1848    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1849      be cast to `bool`.
1850    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1851      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1852      be either `1`, or the same as the corresponding `labels` dimension).
1853    metrics_collections: An optional list of collections that the metric
1854      value variable should be added to.
1855    updates_collections: An optional list of collections that the metric update
1856      ops should be added to.
1857    name: An optional variable_scope name.
1858
1859  Returns:
1860    value_tensor: A `Tensor` representing the current value of the metric.
1861    update_op: An operation that accumulates the error from a batch of data.
1862
1863  Raises:
1864    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1865      `weights` is not `None` and its shape doesn't match `predictions`, or if
1866      either `metrics_collections` or `updates_collections` are not a list or
1867      tuple.
1868    RuntimeError: If eager execution is enabled.
1869  """
1870  if context.executing_eagerly():
1871    raise RuntimeError('tf.metrics.true_negatives is not '
1872                       'supported when eager execution is enabled.')
1873
1874  with variable_scope.variable_scope(name, 'true_negatives',
1875                                     (predictions, labels, weights)):
1876
1877    predictions, labels, weights = _remove_squeezable_dimensions(
1878        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1879        labels=math_ops.cast(labels, dtype=dtypes.bool),
1880        weights=weights)
1881    is_true_negative = math_ops.logical_and(
1882        math_ops.equal(labels, False), math_ops.equal(predictions, False))
1883    return _count_condition(is_true_negative, weights, metrics_collections,
1884                            updates_collections)
1885
1886
1887@tf_export(v1=['metrics.true_negatives_at_thresholds'])
1888def true_negatives_at_thresholds(labels,
1889                                 predictions,
1890                                 thresholds,
1891                                 weights=None,
1892                                 metrics_collections=None,
1893                                 updates_collections=None,
1894                                 name=None):
1895  """Computes true negatives at provided threshold values.
1896
1897  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1898
1899  Args:
1900    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1901      `bool`.
1902    predictions: A floating point `Tensor` of arbitrary shape and whose values
1903      are in the range `[0, 1]`.
1904    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1905    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1906      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1907      be either `1`, or the same as the corresponding `labels` dimension).
1908    metrics_collections: An optional list of collections that `true_negatives`
1909      should be added to.
1910    updates_collections: An optional list of collections that `update_op` should
1911      be added to.
1912    name: An optional variable_scope name.
1913
1914  Returns:
1915    true_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1916    update_op: An operation that updates the `true_negatives` variable and
1917      returns its current value.
1918
1919  Raises:
1920    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1921      `weights` is not `None` and its shape doesn't match `predictions`, or if
1922      either `metrics_collections` or `updates_collections` are not a list or
1923      tuple.
1924    RuntimeError: If eager execution is enabled.
1925  """
1926  if context.executing_eagerly():
1927    raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
1928                       'supported when eager execution is enabled.')
1929
1930  with variable_scope.variable_scope(name, 'true_negatives',
1931                                     (predictions, labels, weights)):
1932    values, update_ops = _confusion_matrix_at_thresholds(
1933        labels, predictions, thresholds, weights=weights, includes=('tn',))
1934
1935    tn_value = _aggregate_variable(values['tn'], metrics_collections)
1936
1937    if updates_collections:
1938      ops.add_to_collections(updates_collections, update_ops['tn'])
1939
1940    return tn_value, update_ops['tn']
1941
1942
1943@tf_export(v1=['metrics.true_positives'])
1944def true_positives(labels,
1945                   predictions,
1946                   weights=None,
1947                   metrics_collections=None,
1948                   updates_collections=None,
1949                   name=None):
1950  """Sum the weights of true_positives.
1951
1952  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1953
1954  Args:
1955    labels: The ground truth values, a `Tensor` whose dimensions must match
1956      `predictions`. Will be cast to `bool`.
1957    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1958      be cast to `bool`.
1959    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1960      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1961      be either `1`, or the same as the corresponding `labels` dimension).
1962    metrics_collections: An optional list of collections that the metric
1963      value variable should be added to.
1964    updates_collections: An optional list of collections that the metric update
1965      ops should be added to.
1966    name: An optional variable_scope name.
1967
1968  Returns:
1969    value_tensor: A `Tensor` representing the current value of the metric.
1970    update_op: An operation that accumulates the error from a batch of data.
1971
1972  Raises:
1973    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1974      `weights` is not `None` and its shape doesn't match `predictions`, or if
1975      either `metrics_collections` or `updates_collections` are not a list or
1976      tuple.
1977    RuntimeError: If eager execution is enabled.
1978  """
1979  if context.executing_eagerly():
1980    raise RuntimeError('tf.metrics.true_positives is not '
1981                       'supported when eager execution is enabled.')
1982
1983  with variable_scope.variable_scope(name, 'true_positives',
1984                                     (predictions, labels, weights)):
1985
1986    predictions, labels, weights = _remove_squeezable_dimensions(
1987        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1988        labels=math_ops.cast(labels, dtype=dtypes.bool),
1989        weights=weights)
1990    is_true_positive = math_ops.logical_and(
1991        math_ops.equal(labels, True), math_ops.equal(predictions, True))
1992    return _count_condition(is_true_positive, weights, metrics_collections,
1993                            updates_collections)
1994
1995
1996@tf_export(v1=['metrics.true_positives_at_thresholds'])
1997def true_positives_at_thresholds(labels,
1998                                 predictions,
1999                                 thresholds,
2000                                 weights=None,
2001                                 metrics_collections=None,
2002                                 updates_collections=None,
2003                                 name=None):
2004  """Computes true positives at provided threshold values.
2005
2006  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2007
2008  Args:
2009    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
2010      `bool`.
2011    predictions: A floating point `Tensor` of arbitrary shape and whose values
2012      are in the range `[0, 1]`.
2013    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2014    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2015      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2016      be either `1`, or the same as the corresponding `labels` dimension).
2017    metrics_collections: An optional list of collections that `true_positives`
2018      should be added to.
2019    updates_collections: An optional list of collections that `update_op` should
2020      be added to.
2021    name: An optional variable_scope name.
2022
2023  Returns:
2024    true_positives:  A float `Tensor` of shape `[len(thresholds)]`.
2025    update_op: An operation that updates the `true_positives` variable and
2026      returns its current value.
2027
2028  Raises:
2029    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2030      `weights` is not `None` and its shape doesn't match `predictions`, or if
2031      either `metrics_collections` or `updates_collections` are not a list or
2032      tuple.
2033    RuntimeError: If eager execution is enabled.
2034  """
2035  if context.executing_eagerly():
2036    raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
2037                       'supported when eager execution is enabled.')
2038
2039  with variable_scope.variable_scope(name, 'true_positives',
2040                                     (predictions, labels, weights)):
2041    values, update_ops = _confusion_matrix_at_thresholds(
2042        labels, predictions, thresholds, weights=weights, includes=('tp',))
2043
2044    tp_value = _aggregate_variable(values['tp'], metrics_collections)
2045
2046    if updates_collections:
2047      ops.add_to_collections(updates_collections, update_ops['tp'])
2048
2049    return tp_value, update_ops['tp']
2050
2051
2052@tf_export(v1=['metrics.precision'])
2053def precision(labels,
2054              predictions,
2055              weights=None,
2056              metrics_collections=None,
2057              updates_collections=None,
2058              name=None):
2059  """Computes the precision of the predictions with respect to the labels.
2060
2061  The `precision` function creates two local variables,
2062  `true_positives` and `false_positives`, that are used to compute the
2063  precision. This value is ultimately returned as `precision`, an idempotent
2064  operation that simply divides `true_positives` by the sum of `true_positives`
2065  and `false_positives`.
2066
2067  For estimation of the metric over a stream of data, the function creates an
2068  `update_op` operation that updates these variables and returns the
2069  `precision`. `update_op` weights each prediction by the corresponding value in
2070  `weights`.
2071
2072  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2073
2074  Args:
2075    labels: The ground truth values, a `Tensor` whose dimensions must match
2076      `predictions`. Will be cast to `bool`.
2077    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2078      be cast to `bool`.
2079    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2080      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2081      be either `1`, or the same as the corresponding `labels` dimension).
2082    metrics_collections: An optional list of collections that `precision` should
2083      be added to.
2084    updates_collections: An optional list of collections that `update_op` should
2085      be added to.
2086    name: An optional variable_scope name.
2087
2088  Returns:
2089    precision: Scalar float `Tensor` with the value of `true_positives`
2090      divided by the sum of `true_positives` and `false_positives`.
2091    update_op: `Operation` that increments `true_positives` and
2092      `false_positives` variables appropriately and whose value matches
2093      `precision`.
2094
2095  Raises:
2096    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2097      `weights` is not `None` and its shape doesn't match `predictions`, or if
2098      either `metrics_collections` or `updates_collections` are not a list or
2099      tuple.
2100    RuntimeError: If eager execution is enabled.
2101  """
2102  if context.executing_eagerly():
2103    raise RuntimeError('tf.metrics.precision is not '
2104                       'supported when eager execution is enabled.')
2105
2106  with variable_scope.variable_scope(name, 'precision',
2107                                     (predictions, labels, weights)):
2108
2109    predictions, labels, weights = _remove_squeezable_dimensions(
2110        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2111        labels=math_ops.cast(labels, dtype=dtypes.bool),
2112        weights=weights)
2113
2114    true_p, true_positives_update_op = true_positives(
2115        labels,
2116        predictions,
2117        weights,
2118        metrics_collections=None,
2119        updates_collections=None,
2120        name=None)
2121    false_p, false_positives_update_op = false_positives(
2122        labels,
2123        predictions,
2124        weights,
2125        metrics_collections=None,
2126        updates_collections=None,
2127        name=None)
2128
2129    def compute_precision(tp, fp, name):
2130      return array_ops.where(
2131          math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name)
2132
2133    def once_across_replicas(_, true_p, false_p):
2134      return compute_precision(true_p, false_p, 'value')
2135
2136    p = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2137                                   true_p, false_p)
2138
2139    update_op = compute_precision(true_positives_update_op,
2140                                  false_positives_update_op, 'update_op')
2141    if updates_collections:
2142      ops.add_to_collections(updates_collections, update_op)
2143
2144    return p, update_op
2145
2146
2147@tf_export(v1=['metrics.precision_at_thresholds'])
2148def precision_at_thresholds(labels,
2149                            predictions,
2150                            thresholds,
2151                            weights=None,
2152                            metrics_collections=None,
2153                            updates_collections=None,
2154                            name=None):
2155  """Computes precision values for different `thresholds` on `predictions`.
2156
2157  The `precision_at_thresholds` function creates four local variables,
2158  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2159  for various values of thresholds. `precision[i]` is defined as the total
2160  weight of values in `predictions` above `thresholds[i]` whose corresponding
2161  entry in `labels` is `True`, divided by the total weight of values in
2162  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2163  false_positives[i])`).
2164
2165  For estimation of the metric over a stream of data, the function creates an
2166  `update_op` operation that updates these variables and returns the
2167  `precision`.
2168
2169  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2170
2171  Args:
2172    labels: The ground truth values, a `Tensor` whose dimensions must match
2173      `predictions`. Will be cast to `bool`.
2174    predictions: A floating point `Tensor` of arbitrary shape and whose values
2175      are in the range `[0, 1]`.
2176    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2177    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2178      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2179      be either `1`, or the same as the corresponding `labels` dimension).
2180    metrics_collections: An optional list of collections that `auc` should be
2181      added to.
2182    updates_collections: An optional list of collections that `update_op` should
2183      be added to.
2184    name: An optional variable_scope name.
2185
2186  Returns:
2187    precision: A float `Tensor` of shape `[len(thresholds)]`.
2188    update_op: An operation that increments the `true_positives`,
2189      `true_negatives`, `false_positives` and `false_negatives` variables that
2190      are used in the computation of `precision`.
2191
2192  Raises:
2193    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2194      `weights` is not `None` and its shape doesn't match `predictions`, or if
2195      either `metrics_collections` or `updates_collections` are not a list or
2196      tuple.
2197    RuntimeError: If eager execution is enabled.
2198  """
2199  if context.executing_eagerly():
2200    raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2201                       'supported when eager execution is enabled.')
2202
2203  with variable_scope.variable_scope(name, 'precision_at_thresholds',
2204                                     (predictions, labels, weights)):
2205    values, update_ops = _confusion_matrix_at_thresholds(
2206        labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2207
2208    # Avoid division by zero.
2209    epsilon = 1e-7
2210
2211    def compute_precision(tp, fp, name):
2212      return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name)
2213
2214    def precision_across_replicas(_, values):
2215      return compute_precision(values['tp'], values['fp'], 'value')
2216
2217    prec = _aggregate_across_replicas(
2218        metrics_collections, precision_across_replicas, values)
2219
2220    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2221                                  'update_op')
2222    if updates_collections:
2223      ops.add_to_collections(updates_collections, update_op)
2224
2225    return prec, update_op
2226
2227
2228@tf_export(v1=['metrics.recall'])
2229def recall(labels,
2230           predictions,
2231           weights=None,
2232           metrics_collections=None,
2233           updates_collections=None,
2234           name=None):
2235  """Computes the recall of the predictions with respect to the labels.
2236
2237  The `recall` function creates two local variables, `true_positives`
2238  and `false_negatives`, that are used to compute the recall. This value is
2239  ultimately returned as `recall`, an idempotent operation that simply divides
2240  `true_positives` by the sum of `true_positives` and `false_negatives`.
2241
2242  For estimation of the metric over a stream of data, the function creates an
2243  `update_op` that updates these variables and returns the `recall`. `update_op`
2244  weights each prediction by the corresponding value in `weights`.
2245
2246  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2247
2248  Args:
2249    labels: The ground truth values, a `Tensor` whose dimensions must match
2250      `predictions`. Will be cast to `bool`.
2251    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2252      be cast to `bool`.
2253    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2254      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2255      be either `1`, or the same as the corresponding `labels` dimension).
2256    metrics_collections: An optional list of collections that `recall` should
2257      be added to.
2258    updates_collections: An optional list of collections that `update_op` should
2259      be added to.
2260    name: An optional variable_scope name.
2261
2262  Returns:
2263    recall: Scalar float `Tensor` with the value of `true_positives` divided
2264      by the sum of `true_positives` and `false_negatives`.
2265    update_op: `Operation` that increments `true_positives` and
2266      `false_negatives` variables appropriately and whose value matches
2267      `recall`.
2268
2269  Raises:
2270    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2271      `weights` is not `None` and its shape doesn't match `predictions`, or if
2272      either `metrics_collections` or `updates_collections` are not a list or
2273      tuple.
2274    RuntimeError: If eager execution is enabled.
2275  """
2276  if context.executing_eagerly():
2277    raise RuntimeError('tf.metrics.recall is not supported is not '
2278                       'supported when eager execution is enabled.')
2279
2280  with variable_scope.variable_scope(name, 'recall',
2281                                     (predictions, labels, weights)):
2282    predictions, labels, weights = _remove_squeezable_dimensions(
2283        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2284        labels=math_ops.cast(labels, dtype=dtypes.bool),
2285        weights=weights)
2286
2287    true_p, true_positives_update_op = true_positives(
2288        labels,
2289        predictions,
2290        weights,
2291        metrics_collections=None,
2292        updates_collections=None,
2293        name=None)
2294    false_n, false_negatives_update_op = false_negatives(
2295        labels,
2296        predictions,
2297        weights,
2298        metrics_collections=None,
2299        updates_collections=None,
2300        name=None)
2301
2302    def compute_recall(true_p, false_n, name):
2303      return array_ops.where(
2304          math_ops.greater(true_p + false_n, 0),
2305          math_ops.divide(true_p, true_p + false_n), 0, name)
2306
2307    def once_across_replicas(_, true_p, false_n):
2308      return compute_recall(true_p, false_n, 'value')
2309
2310    rec = _aggregate_across_replicas(
2311        metrics_collections, once_across_replicas, true_p, false_n)
2312
2313    update_op = compute_recall(true_positives_update_op,
2314                               false_negatives_update_op, 'update_op')
2315    if updates_collections:
2316      ops.add_to_collections(updates_collections, update_op)
2317
2318    return rec, update_op
2319
2320
2321def _at_k_name(name, k=None, class_id=None):
2322  if k is not None:
2323    name = '%s_at_%d' % (name, k)
2324  else:
2325    name = '%s_at_k' % (name)
2326  if class_id is not None:
2327    name = '%s_class%d' % (name, class_id)
2328  return name
2329
2330
2331def _select_class_id(ids, selected_id):
2332  """Filter all but `selected_id` out of `ids`.
2333
2334  Args:
2335    ids: `int64` `Tensor` or `SparseTensor` of IDs.
2336    selected_id: Int id to select.
2337
2338  Returns:
2339    `SparseTensor` of same dimensions as `ids`. This contains only the entries
2340    equal to `selected_id`.
2341  """
2342  ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2343  if isinstance(ids, sparse_tensor.SparseTensor):
2344    return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2345                                                        selected_id))
2346
2347  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2348  # tf.equal and tf.reduce_any?
2349
2350  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2351  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2352  ids_last_dim = array_ops.size(ids_shape) - 1
2353  filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
2354                                                    array_ops.reshape(
2355                                                        ids_last_dim, [1]))
2356
2357  # Intersect `ids` with the selected ID.
2358  filled_selected_id = array_ops.fill(filled_selected_id_shape,
2359                                      math_ops.cast(selected_id, dtypes.int64))
2360  result = sets.set_intersection(filled_selected_id, ids)
2361  return sparse_tensor.SparseTensor(
2362      indices=result.indices, values=result.values, dense_shape=ids_shape)
2363
2364
2365def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2366  """If class ID is specified, filter all other classes.
2367
2368  Args:
2369    labels: `int64` `Tensor` or `SparseTensor` with shape
2370      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2371      target classes for the associated prediction. Commonly, N=1 and `labels`
2372      has shape [batch_size, num_labels]. [D1, ... DN] must match
2373      `predictions_idx`.
2374    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2375      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2376      [batch size, k].
2377    selected_id: Int id to select.
2378
2379  Returns:
2380    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2381  """
2382  if selected_id is None:
2383    return labels, predictions_idx
2384  return (_select_class_id(labels, selected_id),
2385          _select_class_id(predictions_idx, selected_id))
2386
2387
2388def _sparse_true_positive_at_k(labels,
2389                               predictions_idx,
2390                               class_id=None,
2391                               weights=None,
2392                               name=None):
2393  """Calculates true positives for recall@k and precision@k.
2394
2395  If `class_id` is specified, calculate binary true positives for `class_id`
2396      only.
2397  If `class_id` is not specified, calculate metrics for `k` predicted vs
2398      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2399
2400  Args:
2401    labels: `int64` `Tensor` or `SparseTensor` with shape
2402      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2403      target classes for the associated prediction. Commonly, N=1 and `labels`
2404      has shape [batch_size, num_labels]. [D1, ... DN] must match
2405      `predictions_idx`.
2406    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2407      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2408      match `labels`.
2409    class_id: Class for which we want binary metrics.
2410    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2411      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2412      dimensions must be either `1`, or the same as the corresponding `labels`
2413      dimension).
2414    name: Name of operation.
2415
2416  Returns:
2417    A [D1, ... DN] `Tensor` of true positive counts.
2418  """
2419  with ops.name_scope(name, 'true_positives',
2420                      (predictions_idx, labels, weights)):
2421    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2422                                                     class_id)
2423    tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2424    tp = math_ops.cast(tp, dtypes.float64)
2425    if weights is not None:
2426      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2427          weights, tp),)):
2428        weights = math_ops.cast(weights, dtypes.float64)
2429        tp = math_ops.multiply(tp, weights)
2430    return tp
2431
2432
2433def _streaming_sparse_true_positive_at_k(labels,
2434                                         predictions_idx,
2435                                         k=None,
2436                                         class_id=None,
2437                                         weights=None,
2438                                         name=None):
2439  """Calculates weighted per step true positives for recall@k and precision@k.
2440
2441  If `class_id` is specified, calculate binary true positives for `class_id`
2442      only.
2443  If `class_id` is not specified, calculate metrics for `k` predicted vs
2444      `n` label classes, where `n` is the 2nd dimension of `labels`.
2445
2446  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2447
2448  Args:
2449    labels: `int64` `Tensor` or `SparseTensor` with shape
2450      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2451      target classes for the associated prediction. Commonly, N=1 and `labels`
2452      has shape [batch_size, num_labels]. [D1, ... DN] must match
2453      `predictions_idx`.
2454    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2455      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2456      match `labels`.
2457    k: Integer, k for @k metric. This is only used for default op name.
2458    class_id: Class for which we want binary metrics.
2459    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2460      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2461      dimensions must be either `1`, or the same as the corresponding `labels`
2462      dimension).
2463    name: Name of new variable, and namespace for other dependent ops.
2464
2465  Returns:
2466    A tuple of `Variable` and update `Operation`.
2467
2468  Raises:
2469    ValueError: If `weights` is not `None` and has an incompatible shape.
2470  """
2471  with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2472                      (predictions_idx, labels, weights)) as scope:
2473    tp = _sparse_true_positive_at_k(
2474        predictions_idx=predictions_idx,
2475        labels=labels,
2476        class_id=class_id,
2477        weights=weights)
2478    batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64)
2479
2480    var = metric_variable([], dtypes.float64, name=scope)
2481    return var, state_ops.assign_add(var, batch_total_tp, name='update')
2482
2483
2484def _sparse_false_negative_at_k(labels,
2485                                predictions_idx,
2486                                class_id=None,
2487                                weights=None):
2488  """Calculates false negatives for recall@k.
2489
2490  If `class_id` is specified, calculate binary true positives for `class_id`
2491      only.
2492  If `class_id` is not specified, calculate metrics for `k` predicted vs
2493      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2494
2495  Args:
2496    labels: `int64` `Tensor` or `SparseTensor` with shape
2497      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2498      target classes for the associated prediction. Commonly, N=1 and `labels`
2499      has shape [batch_size, num_labels]. [D1, ... DN] must match
2500      `predictions_idx`.
2501    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2502      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2503      match `labels`.
2504    class_id: Class for which we want binary metrics.
2505    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2506      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2507      dimensions must be either `1`, or the same as the corresponding `labels`
2508      dimension).
2509
2510  Returns:
2511    A [D1, ... DN] `Tensor` of false negative counts.
2512  """
2513  with ops.name_scope(None, 'false_negatives',
2514                      (predictions_idx, labels, weights)):
2515    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2516                                                     class_id)
2517    fn = sets.set_size(
2518        sets.set_difference(predictions_idx, labels, aminusb=False))
2519    fn = math_ops.cast(fn, dtypes.float64)
2520    if weights is not None:
2521      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2522          weights, fn),)):
2523        weights = math_ops.cast(weights, dtypes.float64)
2524        fn = math_ops.multiply(fn, weights)
2525    return fn
2526
2527
2528def _streaming_sparse_false_negative_at_k(labels,
2529                                          predictions_idx,
2530                                          k,
2531                                          class_id=None,
2532                                          weights=None,
2533                                          name=None):
2534  """Calculates weighted per step false negatives for recall@k.
2535
2536  If `class_id` is specified, calculate binary true positives for `class_id`
2537      only.
2538  If `class_id` is not specified, calculate metrics for `k` predicted vs
2539      `n` label classes, where `n` is the 2nd dimension of `labels`.
2540
2541  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2542
2543  Args:
2544    labels: `int64` `Tensor` or `SparseTensor` with shape
2545      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2546      target classes for the associated prediction. Commonly, N=1 and `labels`
2547      has shape [batch_size, num_labels]. [D1, ... DN] must match
2548      `predictions_idx`.
2549    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2550      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2551      match `labels`.
2552    k: Integer, k for @k metric. This is only used for default op name.
2553    class_id: Class for which we want binary metrics.
2554    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2555      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2556      dimensions must be either `1`, or the same as the corresponding `labels`
2557      dimension).
2558    name: Name of new variable, and namespace for other dependent ops.
2559
2560  Returns:
2561    A tuple of `Variable` and update `Operation`.
2562
2563  Raises:
2564    ValueError: If `weights` is not `None` and has an incompatible shape.
2565  """
2566  with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2567                      (predictions_idx, labels, weights)) as scope:
2568    fn = _sparse_false_negative_at_k(
2569        predictions_idx=predictions_idx,
2570        labels=labels,
2571        class_id=class_id,
2572        weights=weights)
2573    batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64)
2574
2575    var = metric_variable([], dtypes.float64, name=scope)
2576    return var, state_ops.assign_add(var, batch_total_fn, name='update')
2577
2578
2579@tf_export(v1=['metrics.recall_at_k'])
2580def recall_at_k(labels,
2581                predictions,
2582                k,
2583                class_id=None,
2584                weights=None,
2585                metrics_collections=None,
2586                updates_collections=None,
2587                name=None):
2588  """Computes recall@k of the predictions with respect to sparse labels.
2589
2590  If `class_id` is specified, we calculate recall by considering only the
2591      entries in the batch for which `class_id` is in the label, and computing
2592      the fraction of them for which `class_id` is in the top-k `predictions`.
2593  If `class_id` is not specified, we'll calculate recall as how often on
2594      average a class among the labels of a batch entry is in the top-k
2595      `predictions`.
2596
2597  `sparse_recall_at_k` creates two local variables,
2598  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2599  the recall_at_k frequency. This frequency is ultimately returned as
2600  `recall_at_<k>`: an idempotent operation that simply divides
2601  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2602  `false_negative_at_<k>`).
2603
2604  For estimation of the metric over a stream of data, the function creates an
2605  `update_op` operation that updates these variables and returns the
2606  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2607  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2608  `labels` calculate the true positives and false negatives weighted by
2609  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2610  `false_negative_at_<k>` using these values.
2611
2612  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2613
2614  Args:
2615    labels: `int64` `Tensor` or `SparseTensor` with shape
2616      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2617      num_labels=1. N >= 1 and num_labels is the number of target classes for
2618      the associated prediction. Commonly, N=1 and `labels` has shape
2619      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2620      should be in range [0, num_classes), where num_classes is the last
2621      dimension of `predictions`. Values outside this range always count
2622      towards `false_negative_at_<k>`.
2623    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2624      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2625      The final dimension contains the logit values for each class. [D1, ... DN]
2626      must match `labels`.
2627    k: Integer, k for @k metric.
2628    class_id: Integer class ID for which we want binary metrics. This should be
2629      in range [0, num_classes), where num_classes is the last dimension of
2630      `predictions`. If class_id is outside this range, the method returns NAN.
2631    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2632      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2633      dimensions must be either `1`, or the same as the corresponding `labels`
2634      dimension).
2635    metrics_collections: An optional list of collections that values should
2636      be added to.
2637    updates_collections: An optional list of collections that updates should
2638      be added to.
2639    name: Name of new update operation, and namespace for other dependent ops.
2640
2641  Returns:
2642    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2643      by the sum of `true_positives` and `false_negatives`.
2644    update_op: `Operation` that increments `true_positives` and
2645      `false_negatives` variables appropriately, and whose value matches
2646      `recall`.
2647
2648  Raises:
2649    ValueError: If `weights` is not `None` and its shape doesn't match
2650    `predictions`, or if either `metrics_collections` or `updates_collections`
2651    are not a list or tuple.
2652    RuntimeError: If eager execution is enabled.
2653  """
2654  if context.executing_eagerly():
2655    raise RuntimeError('tf.metrics.recall_at_k is not '
2656                       'supported when eager execution is enabled.')
2657
2658  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2659                      (predictions, labels, weights)) as scope:
2660    _, top_k_idx = nn.top_k(predictions, k)
2661    return recall_at_top_k(
2662        labels=labels,
2663        predictions_idx=top_k_idx,
2664        k=k,
2665        class_id=class_id,
2666        weights=weights,
2667        metrics_collections=metrics_collections,
2668        updates_collections=updates_collections,
2669        name=scope)
2670
2671
2672@tf_export(v1=['metrics.recall_at_top_k'])
2673def recall_at_top_k(labels,
2674                    predictions_idx,
2675                    k=None,
2676                    class_id=None,
2677                    weights=None,
2678                    metrics_collections=None,
2679                    updates_collections=None,
2680                    name=None):
2681  """Computes recall@k of top-k predictions with respect to sparse labels.
2682
2683  Differs from `recall_at_k` in that predictions must be in the form of top `k`
2684  class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2685  for more details.
2686
2687  Args:
2688    labels: `int64` `Tensor` or `SparseTensor` with shape
2689      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2690      num_labels=1. N >= 1 and num_labels is the number of target classes for
2691      the associated prediction. Commonly, N=1 and `labels` has shape
2692      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2693      should be in range [0, num_classes), where num_classes is the last
2694      dimension of `predictions`. Values outside this range always count
2695      towards `false_negative_at_<k>`.
2696    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2697      Commonly, N=1 and predictions has shape [batch size, k]. The final
2698      dimension contains the top `k` predicted class indices. [D1, ... DN] must
2699      match `labels`.
2700    k: Integer, k for @k metric. Only used for the default op name.
2701    class_id: Integer class ID for which we want binary metrics. This should be
2702      in range [0, num_classes), where num_classes is the last dimension of
2703      `predictions`. If class_id is outside this range, the method returns NAN.
2704    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2705      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2706      dimensions must be either `1`, or the same as the corresponding `labels`
2707      dimension).
2708    metrics_collections: An optional list of collections that values should
2709      be added to.
2710    updates_collections: An optional list of collections that updates should
2711      be added to.
2712    name: Name of new update operation, and namespace for other dependent ops.
2713
2714  Returns:
2715    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2716      by the sum of `true_positives` and `false_negatives`.
2717    update_op: `Operation` that increments `true_positives` and
2718      `false_negatives` variables appropriately, and whose value matches
2719      `recall`.
2720
2721  Raises:
2722    ValueError: If `weights` is not `None` and its shape doesn't match
2723    `predictions`, or if either `metrics_collections` or `updates_collections`
2724    are not a list or tuple.
2725  """
2726  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2727                      (predictions_idx, labels, weights)) as scope:
2728    labels = _maybe_expand_labels(labels, predictions_idx)
2729    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
2730    tp, tp_update = _streaming_sparse_true_positive_at_k(
2731        predictions_idx=top_k_idx,
2732        labels=labels,
2733        k=k,
2734        class_id=class_id,
2735        weights=weights)
2736    fn, fn_update = _streaming_sparse_false_negative_at_k(
2737        predictions_idx=top_k_idx,
2738        labels=labels,
2739        k=k,
2740        class_id=class_id,
2741        weights=weights)
2742
2743    def compute_recall(_, tp, fn):
2744      return math_ops.divide(tp, math_ops.add(tp, fn), name=scope)
2745
2746    metric = _aggregate_across_replicas(
2747        metrics_collections, compute_recall, tp, fn)
2748
2749    update = math_ops.divide(
2750        tp_update, math_ops.add(tp_update, fn_update), name='update')
2751    if updates_collections:
2752      ops.add_to_collections(updates_collections, update)
2753    return metric, update
2754
2755
2756@tf_export(v1=['metrics.recall_at_thresholds'])
2757def recall_at_thresholds(labels,
2758                         predictions,
2759                         thresholds,
2760                         weights=None,
2761                         metrics_collections=None,
2762                         updates_collections=None,
2763                         name=None):
2764  """Computes various recall values for different `thresholds` on `predictions`.
2765
2766  The `recall_at_thresholds` function creates four local variables,
2767  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2768  for various values of thresholds. `recall[i]` is defined as the total weight
2769  of values in `predictions` above `thresholds[i]` whose corresponding entry in
2770  `labels` is `True`, divided by the total weight of `True` values in `labels`
2771  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2772
2773  For estimation of the metric over a stream of data, the function creates an
2774  `update_op` operation that updates these variables and returns the `recall`.
2775
2776  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2777
2778  Args:
2779    labels: The ground truth values, a `Tensor` whose dimensions must match
2780      `predictions`. Will be cast to `bool`.
2781    predictions: A floating point `Tensor` of arbitrary shape and whose values
2782      are in the range `[0, 1]`.
2783    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2784    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2785      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2786      be either `1`, or the same as the corresponding `labels` dimension).
2787    metrics_collections: An optional list of collections that `recall` should be
2788      added to.
2789    updates_collections: An optional list of collections that `update_op` should
2790      be added to.
2791    name: An optional variable_scope name.
2792
2793  Returns:
2794    recall: A float `Tensor` of shape `[len(thresholds)]`.
2795    update_op: An operation that increments the `true_positives`,
2796      `true_negatives`, `false_positives` and `false_negatives` variables that
2797      are used in the computation of `recall`.
2798
2799  Raises:
2800    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2801      `weights` is not `None` and its shape doesn't match `predictions`, or if
2802      either `metrics_collections` or `updates_collections` are not a list or
2803      tuple.
2804    RuntimeError: If eager execution is enabled.
2805  """
2806  if context.executing_eagerly():
2807    raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2808                       'supported when eager execution is enabled.')
2809
2810  with variable_scope.variable_scope(name, 'recall_at_thresholds',
2811                                     (predictions, labels, weights)):
2812    values, update_ops = _confusion_matrix_at_thresholds(
2813        labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2814
2815    # Avoid division by zero.
2816    epsilon = 1e-7
2817
2818    def compute_recall(tp, fn, name):
2819      return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name)
2820
2821    def recall_across_replicas(_, values):
2822      return compute_recall(values['tp'], values['fn'], 'value')
2823
2824    rec = _aggregate_across_replicas(
2825        metrics_collections, recall_across_replicas, values)
2826
2827    update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2828    if updates_collections:
2829      ops.add_to_collections(updates_collections, update_op)
2830
2831    return rec, update_op
2832
2833
2834@tf_export(v1=['metrics.root_mean_squared_error'])
2835def root_mean_squared_error(labels,
2836                            predictions,
2837                            weights=None,
2838                            metrics_collections=None,
2839                            updates_collections=None,
2840                            name=None):
2841  """Computes the root mean squared error between the labels and predictions.
2842
2843  The `root_mean_squared_error` function creates two local variables,
2844  `total` and `count` that are used to compute the root mean squared error.
2845  This average is weighted by `weights`, and it is ultimately returned as
2846  `root_mean_squared_error`: an idempotent operation that takes the square root
2847  of the division of `total` by `count`.
2848
2849  For estimation of the metric over a stream of data, the function creates an
2850  `update_op` operation that updates these variables and returns the
2851  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2852  the element-wise square of the difference between `predictions` and `labels`.
2853  Then `update_op` increments `total` with the reduced sum of the product of
2854  `weights` and `squared_error`, and it increments `count` with the reduced sum
2855  of `weights`.
2856
2857  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2858
2859  Args:
2860    labels: A `Tensor` of the same shape as `predictions`.
2861    predictions: A `Tensor` of arbitrary shape.
2862    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2863      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2864      be either `1`, or the same as the corresponding `labels` dimension).
2865    metrics_collections: An optional list of collections that
2866      `root_mean_squared_error` should be added to.
2867    updates_collections: An optional list of collections that `update_op` should
2868      be added to.
2869    name: An optional variable_scope name.
2870
2871  Returns:
2872    root_mean_squared_error: A `Tensor` representing the current mean, the value
2873      of `total` divided by `count`.
2874    update_op: An operation that increments the `total` and `count` variables
2875      appropriately and whose value matches `root_mean_squared_error`.
2876
2877  Raises:
2878    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2879      `weights` is not `None` and its shape doesn't match `predictions`, or if
2880      either `metrics_collections` or `updates_collections` are not a list or
2881      tuple.
2882    RuntimeError: If eager execution is enabled.
2883  """
2884  if context.executing_eagerly():
2885    raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2886                       'supported when eager execution is enabled.')
2887
2888  predictions, labels, weights = _remove_squeezable_dimensions(
2889      predictions=predictions, labels=labels, weights=weights)
2890  mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2891                                          None, name or
2892                                          'root_mean_squared_error')
2893
2894  once_across_replicas = lambda _, mse: math_ops.sqrt(mse)
2895  rmse = _aggregate_across_replicas(
2896      metrics_collections, once_across_replicas, mse)
2897
2898  update_rmse_op = math_ops.sqrt(update_mse_op)
2899  if updates_collections:
2900    ops.add_to_collections(updates_collections, update_rmse_op)
2901
2902  return rmse, update_rmse_op
2903
2904
2905@tf_export(v1=['metrics.sensitivity_at_specificity'])
2906def sensitivity_at_specificity(labels,
2907                               predictions,
2908                               specificity,
2909                               weights=None,
2910                               num_thresholds=200,
2911                               metrics_collections=None,
2912                               updates_collections=None,
2913                               name=None):
2914  """Computes the specificity at a given sensitivity.
2915
2916  The `sensitivity_at_specificity` function creates four local
2917  variables, `true_positives`, `true_negatives`, `false_positives` and
2918  `false_negatives` that are used to compute the sensitivity at the given
2919  specificity value. The threshold for the given specificity value is computed
2920  and used to evaluate the corresponding sensitivity.
2921
2922  For estimation of the metric over a stream of data, the function creates an
2923  `update_op` operation that updates these variables and returns the
2924  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
2925  `false_positives` and `false_negatives` counts with the weight of each case
2926  found in the `predictions` and `labels`.
2927
2928  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2929
2930  For additional information about specificity and sensitivity, see the
2931  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
2932
2933  Args:
2934    labels: The ground truth values, a `Tensor` whose dimensions must match
2935      `predictions`. Will be cast to `bool`.
2936    predictions: A floating point `Tensor` of arbitrary shape and whose values
2937      are in the range `[0, 1]`.
2938    specificity: A scalar value in range `[0, 1]`.
2939    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2940      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2941      be either `1`, or the same as the corresponding `labels` dimension).
2942    num_thresholds: The number of thresholds to use for matching the given
2943      specificity.
2944    metrics_collections: An optional list of collections that `sensitivity`
2945      should be added to.
2946    updates_collections: An optional list of collections that `update_op` should
2947      be added to.
2948    name: An optional variable_scope name.
2949
2950  Returns:
2951    sensitivity: A scalar `Tensor` representing the sensitivity at the given
2952      `specificity` value.
2953    update_op: An operation that increments the `true_positives`,
2954      `true_negatives`, `false_positives` and `false_negatives` variables
2955      appropriately and whose value matches `sensitivity`.
2956
2957  Raises:
2958    ValueError: If `predictions` and `labels` have mismatched shapes, if
2959      `weights` is not `None` and its shape doesn't match `predictions`, or if
2960      `specificity` is not between 0 and 1, or if either `metrics_collections`
2961      or `updates_collections` are not a list or tuple.
2962    RuntimeError: If eager execution is enabled.
2963  """
2964  if context.executing_eagerly():
2965    raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
2966                       'supported when eager execution is enabled.')
2967
2968  if specificity < 0 or specificity > 1:
2969    raise ValueError('`specificity` must be in the range [0, 1]. Currently, '
2970                     f'`specificity` got {specificity}.')
2971
2972  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
2973                                     (predictions, labels, weights)):
2974    kepsilon = 1e-7  # to account for floating point imprecisions
2975    thresholds = [
2976        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
2977    ]
2978    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
2979
2980    values, update_ops = _confusion_matrix_at_thresholds(
2981        labels, predictions, thresholds, weights)
2982
2983    def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
2984      specificities = math_ops.divide(tn, tn + fp + kepsilon)
2985      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
2986      tf_index = math_ops.cast(tf_index, dtypes.int32)
2987
2988      # Now, we have the implicit threshold, so compute the sensitivity:
2989      return math_ops.divide(tp[tf_index],
2990                             tp[tf_index] + fn[tf_index] + kepsilon, name)
2991
2992    def sensitivity_across_replicas(_, values):
2993      return compute_sensitivity_at_specificity(
2994          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
2995
2996    sensitivity = _aggregate_across_replicas(
2997        metrics_collections, sensitivity_across_replicas, values)
2998
2999    update_op = compute_sensitivity_at_specificity(
3000        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3001        'update_op')
3002    if updates_collections:
3003      ops.add_to_collections(updates_collections, update_op)
3004
3005    return sensitivity, update_op
3006
3007
3008def _expand_and_tile(tensor, multiple, dim=0, name=None):
3009  """Slice `tensor` shape in 2, then tile along the sliced dimension.
3010
3011  A new dimension is inserted in shape of `tensor` before `dim`, then values are
3012  tiled `multiple` times along the new dimension.
3013
3014  Args:
3015    tensor: Input `Tensor` or `SparseTensor`.
3016    multiple: Integer, number of times to tile.
3017    dim: Integer, dimension along which to tile.
3018    name: Name of operation.
3019
3020  Returns:
3021    `Tensor` result of expanding and tiling `tensor`.
3022
3023  Raises:
3024    ValueError: if `multiple` is less than 1, or `dim` is not in
3025    `[-rank(tensor), rank(tensor)]`.
3026  """
3027  if multiple < 1:
3028    raise ValueError(f'Invalid argument multiple={multiple} for '
3029                     'expand_and_tile  call. `multiple` must be an integer > 0')
3030  with ops.name_scope(name, 'expand_and_tile',
3031                      (tensor, multiple, dim)) as scope:
3032    # Sparse.
3033    tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
3034    if isinstance(tensor, sparse_tensor.SparseTensor):
3035      if dim < 0:
3036        expand_dims = array_ops.reshape(
3037            array_ops.size(tensor.dense_shape) + dim, [1])
3038      else:
3039        expand_dims = [dim]
3040      expanded_shape = array_ops.concat(
3041          (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
3042           array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
3043          0,
3044          name='expanded_shape')
3045      expanded = sparse_ops.sparse_reshape(
3046          tensor, shape=expanded_shape, name='expand')
3047      if multiple == 1:
3048        return expanded
3049      return sparse_ops.sparse_concat(
3050          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
3051
3052    # Dense.
3053    expanded = array_ops.expand_dims(
3054        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
3055    if multiple == 1:
3056      return expanded
3057    ones = array_ops.ones_like(array_ops.shape(tensor))
3058    tile_multiples = array_ops.concat(
3059        (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
3060    return array_ops.tile(expanded, tile_multiples, name=scope)
3061
3062
3063def _num_relevant(labels, k):
3064  """Computes number of relevant values for each row in labels.
3065
3066  For labels with shape [D1, ... DN, num_labels], this is the minimum of
3067  `num_labels` and `k`.
3068
3069  Args:
3070    labels: `int64` `Tensor` or `SparseTensor` with shape
3071      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3072      target classes for the associated prediction. Commonly, N=1 and `labels`
3073      has shape [batch_size, num_labels].
3074    k: Integer, k for @k metric.
3075
3076  Returns:
3077    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
3078    relevant values for that row.
3079
3080  Raises:
3081    ValueError: if inputs have invalid dtypes or values.
3082  """
3083  if k < 1:
3084    raise ValueError(f'Invalid k={k}')
3085  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
3086    # For SparseTensor, calculate separate count for each row.
3087    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
3088    if isinstance(labels, sparse_tensor.SparseTensor):
3089      return math_ops.minimum(sets.set_size(labels), k, name=scope)
3090
3091    # The relevant values for each (d1, ... dN) is the minimum of k and the
3092    # number of labels along the last dimension that are non-negative.
3093    num_labels = math_ops.reduce_sum(
3094        array_ops.where_v2(math_ops.greater_equal(labels, 0),
3095                           array_ops.ones_like(labels),
3096                           array_ops.zeros_like(labels)),
3097        axis=-1)
3098    return math_ops.minimum(num_labels, k, name=scope)
3099
3100
3101def _sparse_average_precision_at_top_k(labels, predictions_idx):
3102  """Computes average precision@k of predictions with respect to sparse labels.
3103
3104  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
3105  for each row is:
3106
3107    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
3108
3109  A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
3110  `labels`, and the result `Tensors`. In the common case, this is [batch_size].
3111  Each row of the results contains the average precision for that row.
3112
3113  Args:
3114    labels: `int64` `Tensor` or `SparseTensor` with shape
3115      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3116      num_labels=1. N >= 1 and num_labels is the number of target classes for
3117      the associated prediction. Commonly, N=1 and `labels` has shape
3118      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3119      Values should be non-negative. Negative values are ignored.
3120    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3121      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3122      dimension must be set and contains the top `k` predicted class indices.
3123      [D1, ... DN] must match `labels`. Values should be in range
3124      [0, num_classes).
3125
3126  Returns:
3127    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3128    precision for that row.
3129
3130  Raises:
3131    ValueError: if the last dimension of predictions_idx is not set.
3132  """
3133  with ops.name_scope(None, 'average_precision',
3134                      (predictions_idx, labels)) as scope:
3135    predictions_idx = math_ops.cast(
3136        predictions_idx, dtypes.int64, name='predictions_idx')
3137    if predictions_idx.get_shape().ndims == 0:
3138      raise ValueError('The rank of `predictions_idx` must be at least 1.')
3139    k = predictions_idx.get_shape().as_list()[-1]
3140    if k is None:
3141      raise ValueError('The last dimension of predictions_idx must be set. '
3142                       'Currently, it is None.')
3143    labels = _maybe_expand_labels(labels, predictions_idx)
3144
3145    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3146    # prediction for each k, so we can calculate separate true positive values
3147    # for each k.
3148    predictions_idx_per_k = array_ops.expand_dims(
3149        predictions_idx, -1, name='predictions_idx_per_k')
3150
3151    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3152    labels_per_k = _expand_and_tile(
3153        labels, multiple=k, dim=-1, name='labels_per_k')
3154
3155    # The following tensors are all of shape [D1, ... DN, k], containing values
3156    # per row, per k value.
3157    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3158    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3159    #     the formula above.
3160    # `tp_per_k` (int32) - True positive counts.
3161    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3162    #     the precision denominator.
3163    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3164    #     term from the formula above.
3165    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3166    #     precisions at all k for which relevance indicator is true.
3167    relevant_per_k = _sparse_true_positive_at_k(
3168        labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3169    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3170    retrieved_per_k = math_ops.cumsum(
3171        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3172    precision_per_k = math_ops.divide(
3173        math_ops.cast(tp_per_k, dtypes.float64),
3174        math_ops.cast(retrieved_per_k, dtypes.float64),
3175        name='precision_per_k')
3176    relevant_precision_per_k = math_ops.multiply(
3177        precision_per_k,
3178        math_ops.cast(relevant_per_k, dtypes.float64),
3179        name='relevant_precision_per_k')
3180
3181    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3182    precision_sum = math_ops.reduce_sum(
3183        relevant_precision_per_k, axis=(-1,), name='precision_sum')
3184
3185    # Divide by number of relevant items to get average precision. These are
3186    # the "num_relevant_items" and "AveP" terms from the formula above.
3187    num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
3188    return math_ops.divide(precision_sum, num_relevant_items, name=scope)
3189
3190
3191def _streaming_sparse_average_precision_at_top_k(labels,
3192                                                 predictions_idx,
3193                                                 weights=None,
3194                                                 metrics_collections=None,
3195                                                 updates_collections=None,
3196                                                 name=None):
3197  """Computes average precision@k of predictions with respect to sparse labels.
3198
3199  `sparse_average_precision_at_top_k` creates two local variables,
3200  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3201  are used to compute the frequency. This frequency is ultimately returned as
3202  `average_precision_at_<k>`: an idempotent operation that simply divides
3203  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3204
3205  For estimation of the metric over a stream of data, the function creates an
3206  `update_op` operation that updates these variables and returns the
3207  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3208  the true positives and false positives weighted by `weights`. Then `update_op`
3209  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3210  values.
3211
3212  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3213
3214  Args:
3215    labels: `int64` `Tensor` or `SparseTensor` with shape
3216      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3217      num_labels=1. N >= 1 and num_labels is the number of target classes for
3218      the associated prediction. Commonly, N=1 and `labels` has shape
3219      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3220      Values should be non-negative. Negative values are ignored.
3221    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3222      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3223      dimension contains the top `k` predicted class indices. [D1, ... DN] must
3224      match `labels`. Values should be in range [0, num_classes).
3225    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3226      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3227      dimensions must be either `1`, or the same as the corresponding `labels`
3228      dimension).
3229    metrics_collections: An optional list of collections that values should
3230      be added to.
3231    updates_collections: An optional list of collections that updates should
3232      be added to.
3233    name: Name of new update operation, and namespace for other dependent ops.
3234
3235  Returns:
3236    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3237      precision values.
3238    update: `Operation` that increments variables appropriately, and whose
3239      value matches `metric`.
3240  """
3241  with ops.name_scope(name, 'average_precision_at_top_k',
3242                      (predictions_idx, labels, weights)) as scope:
3243    # Calculate per-example average precision, and apply weights.
3244    average_precision = _sparse_average_precision_at_top_k(
3245        predictions_idx=predictions_idx, labels=labels)
3246    if weights is not None:
3247      weights = weights_broadcast_ops.broadcast_weights(
3248          math_ops.cast(weights, dtypes.float64), average_precision)
3249      average_precision = math_ops.multiply(average_precision, weights)
3250
3251    # Create accumulation variables and update ops for max average precision and
3252    # total average precision.
3253    with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3254      # `max` is the max possible precision. Since max for any row is 1.0:
3255      # - For the unweighted case, this is just the number of rows.
3256      # - For the weighted case, it's the sum of the weights broadcast across
3257      #   `average_precision` rows.
3258      max_var = metric_variable([], dtypes.float64, name=max_scope)
3259      if weights is None:
3260        batch_max = math_ops.cast(
3261            array_ops.size(average_precision, name='batch_max'), dtypes.float64)
3262      else:
3263        batch_max = math_ops.reduce_sum(weights, name='batch_max')
3264      max_update = state_ops.assign_add(max_var, batch_max, name='update')
3265    with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3266      total_var = metric_variable([], dtypes.float64, name=total_scope)
3267      batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3268      total_update = state_ops.assign_add(total_var, batch_total, name='update')
3269
3270    # Divide total by max to get mean, for both vars and the update ops.
3271    def precision_across_replicas(_, total_var, max_var):
3272      return _safe_scalar_div(total_var, max_var, name='mean')
3273
3274    mean_average_precision = _aggregate_across_replicas(
3275        metrics_collections, precision_across_replicas, total_var, max_var)
3276
3277    update = _safe_scalar_div(total_update, max_update, name=scope)
3278    if updates_collections:
3279      ops.add_to_collections(updates_collections, update)
3280
3281    return mean_average_precision, update
3282
3283
3284def _clean_out_of_range_indices(labels, num_classes):
3285  """Replaces large out-of-range labels by small out-of-range labels.
3286
3287  Replaces any value in `labels` that is greater or equal to `num_classes` by
3288  -1. Do this conditionally for efficiency in case there are no such values.
3289
3290  Args:
3291    labels: `int64` `Tensor` or `SparseTensor`.
3292    num_classes: `int64` scalar `Tensor`.
3293  Returns:
3294    An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater
3295    or equal to num_classes replaced by -1.
3296  """
3297
3298  def _labels_is_sparse():
3299    """Returns true is `labels` is a sparse tensor."""
3300    return isinstance(labels, (sparse_tensor.SparseTensor,
3301                               sparse_tensor.SparseTensorValue))
3302
3303  def _clean_out_of_range(values):
3304    """Replaces by -1 any large out-of-range `values`."""
3305    return array_ops.where_v2(math_ops.greater_equal(values, num_classes),
3306                              -1 * array_ops.ones_like(values), values)
3307
3308  def _clean_labels_out_of_range():
3309    """Replaces by -1 ane large out-of-range values in `labels`."""
3310    if _labels_is_sparse():
3311      return type(labels)(indices=labels.indices,
3312                          values=_clean_out_of_range(labels.values),
3313                          dense_shape=labels.dense_shape)
3314    else:
3315      return _clean_out_of_range(labels)
3316
3317  max_labels = math_ops.reduce_max(
3318      labels.values if _labels_is_sparse() else labels)
3319  return control_flow_ops.cond(
3320      math_ops.greater_equal(max_labels, num_classes),
3321      _clean_labels_out_of_range,
3322      lambda: labels)
3323
3324
3325@tf_export(v1=['metrics.sparse_average_precision_at_k'])
3326@deprecated(None, 'Use average_precision_at_k instead')
3327def sparse_average_precision_at_k(labels,
3328                                  predictions,
3329                                  k,
3330                                  weights=None,
3331                                  metrics_collections=None,
3332                                  updates_collections=None,
3333                                  name=None):
3334  """Renamed to `average_precision_at_k`, please use that method instead."""
3335  return average_precision_at_k(
3336      labels=labels,
3337      predictions=predictions,
3338      k=k,
3339      weights=weights,
3340      metrics_collections=metrics_collections,
3341      updates_collections=updates_collections,
3342      name=name)
3343
3344
3345@tf_export(v1=['metrics.average_precision_at_k'])
3346def average_precision_at_k(labels,
3347                           predictions,
3348                           k,
3349                           weights=None,
3350                           metrics_collections=None,
3351                           updates_collections=None,
3352                           name=None):
3353  """Computes average precision@k of predictions with respect to sparse labels.
3354
3355  `average_precision_at_k` creates two local variables,
3356  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3357  are used to compute the frequency. This frequency is ultimately returned as
3358  `average_precision_at_<k>`: an idempotent operation that simply divides
3359  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3360
3361  For estimation of the metric over a stream of data, the function creates an
3362  `update_op` operation that updates these variables and returns the
3363  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3364  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3365  `labels` calculate the true positives and false positives weighted by
3366  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3367  `false_positive_at_<k>` using these values.
3368
3369  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3370
3371  Args:
3372    labels: `int64` `Tensor` or `SparseTensor` with shape
3373      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3374      num_labels=1. N >= 1 and num_labels is the number of target classes for
3375      the associated prediction. Commonly, N=1 and `labels` has shape
3376      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3377      should be in range [0, num_classes), where num_classes is the last
3378      dimension of `predictions`. Values outside this range are ignored.
3379    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3380      N >= 1. Commonly, N=1 and `predictions` has shape
3381      [batch size, num_classes]. The final dimension contains the logit values
3382      for each class. [D1, ... DN] must match `labels`.
3383    k: Integer, k for @k metric. This will calculate an average precision for
3384      range `[1,k]`, as documented above.
3385    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3386      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3387      dimensions must be either `1`, or the same as the corresponding `labels`
3388      dimension).
3389    metrics_collections: An optional list of collections that values should
3390      be added to.
3391    updates_collections: An optional list of collections that updates should
3392      be added to.
3393    name: Name of new update operation, and namespace for other dependent ops.
3394
3395  Returns:
3396    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3397      precision values.
3398    update: `Operation` that increments variables appropriately, and whose
3399      value matches `metric`.
3400
3401  Raises:
3402    ValueError: if k is invalid.
3403    RuntimeError: If eager execution is enabled.
3404  """
3405  if context.executing_eagerly():
3406    raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3407                       'supported when eager execution is enabled.')
3408
3409  if k < 1:
3410    raise ValueError(f'Invalid k={k}. `k` should be >= 1.')
3411  with ops.name_scope(name, _at_k_name('average_precision', k),
3412                      (predictions, labels, weights)) as scope:
3413    # Calculate top k indices to produce [D1, ... DN, k] tensor.
3414    _, predictions_idx = nn.top_k(predictions, k)
3415    # The documentation states that labels should be in [0, ..., num_classes),
3416    # but num_classes is lost when predictions_idx replaces predictions.
3417    # For conformity with the documentation, any label >= num_classes, which is
3418    # ignored, is replaced by -1.
3419    labels = _clean_out_of_range_indices(
3420        labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64))
3421    return _streaming_sparse_average_precision_at_top_k(
3422        labels=labels,
3423        predictions_idx=predictions_idx,
3424        weights=weights,
3425        metrics_collections=metrics_collections,
3426        updates_collections=updates_collections,
3427        name=scope)
3428
3429
3430def _sparse_false_positive_at_k(labels,
3431                                predictions_idx,
3432                                class_id=None,
3433                                weights=None):
3434  """Calculates false positives for precision@k.
3435
3436  If `class_id` is specified, calculate binary true positives for `class_id`
3437      only.
3438  If `class_id` is not specified, calculate metrics for `k` predicted vs
3439      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3440
3441  Args:
3442    labels: `int64` `Tensor` or `SparseTensor` with shape
3443      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3444      target classes for the associated prediction. Commonly, N=1 and `labels`
3445      has shape [batch_size, num_labels]. [D1, ... DN] must match
3446      `predictions_idx`.
3447    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3448      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3449      match `labels`.
3450    class_id: Class for which we want binary metrics.
3451    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3452      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3453      dimensions must be either `1`, or the same as the corresponding `labels`
3454      dimension).
3455
3456  Returns:
3457    A [D1, ... DN] `Tensor` of false positive counts.
3458  """
3459  with ops.name_scope(None, 'false_positives',
3460                      (predictions_idx, labels, weights)):
3461    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3462                                                     class_id)
3463    fp = sets.set_size(
3464        sets.set_difference(predictions_idx, labels, aminusb=True))
3465    fp = math_ops.cast(fp, dtypes.float64)
3466    if weights is not None:
3467      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
3468          weights, fp),)):
3469        weights = math_ops.cast(weights, dtypes.float64)
3470        fp = math_ops.multiply(fp, weights)
3471    return fp
3472
3473
3474def _streaming_sparse_false_positive_at_k(labels,
3475                                          predictions_idx,
3476                                          k=None,
3477                                          class_id=None,
3478                                          weights=None,
3479                                          name=None):
3480  """Calculates weighted per step false positives for precision@k.
3481
3482  If `class_id` is specified, calculate binary true positives for `class_id`
3483      only.
3484  If `class_id` is not specified, calculate metrics for `k` predicted vs
3485      `n` label classes, where `n` is the 2nd dimension of `labels`.
3486
3487  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3488
3489  Args:
3490    labels: `int64` `Tensor` or `SparseTensor` with shape
3491      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3492      target classes for the associated prediction. Commonly, N=1 and `labels`
3493      has shape [batch_size, num_labels]. [D1, ... DN] must match
3494      `predictions_idx`.
3495    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3496      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3497      match `labels`.
3498    k: Integer, k for @k metric. This is only used for default op name.
3499    class_id: Class for which we want binary metrics.
3500    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3501      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3502      dimensions must be either `1`, or the same as the corresponding `labels`
3503      dimension).
3504    name: Name of new variable, and namespace for other dependent ops.
3505
3506  Returns:
3507    A tuple of `Variable` and update `Operation`.
3508
3509  Raises:
3510    ValueError: If `weights` is not `None` and has an incompatible shape.
3511  """
3512  with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3513                      (predictions_idx, labels, weights)) as scope:
3514    fp = _sparse_false_positive_at_k(
3515        predictions_idx=predictions_idx,
3516        labels=labels,
3517        class_id=class_id,
3518        weights=weights)
3519    batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64)
3520
3521    var = metric_variable([], dtypes.float64, name=scope)
3522    return var, state_ops.assign_add(var, batch_total_fp, name='update')
3523
3524
3525@tf_export(v1=['metrics.precision_at_top_k'])
3526def precision_at_top_k(labels,
3527                       predictions_idx,
3528                       k=None,
3529                       class_id=None,
3530                       weights=None,
3531                       metrics_collections=None,
3532                       updates_collections=None,
3533                       name=None):
3534  """Computes precision@k of the predictions with respect to sparse labels.
3535
3536  Differs from `sparse_precision_at_k` in that predictions must be in the form
3537  of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3538  Refer to `sparse_precision_at_k` for more details.
3539
3540  Args:
3541    labels: `int64` `Tensor` or `SparseTensor` with shape
3542      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3543      num_labels=1. N >= 1 and num_labels is the number of target classes for
3544      the associated prediction. Commonly, N=1 and `labels` has shape
3545      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3546      should be in range [0, num_classes), where num_classes is the last
3547      dimension of `predictions`. Values outside this range are ignored.
3548    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3549      N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3550      The final dimension contains the top `k` predicted class indices.
3551      [D1, ... DN] must match `labels`.
3552    k: Integer, k for @k metric. Only used for the default op name.
3553    class_id: Integer class ID for which we want binary metrics. This should be
3554      in range [0, num_classes], where num_classes is the last dimension of
3555      `predictions`. If `class_id` is outside this range, the method returns
3556      NAN.
3557    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3558      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3559      dimensions must be either `1`, or the same as the corresponding `labels`
3560      dimension).
3561    metrics_collections: An optional list of collections that values should
3562      be added to.
3563    updates_collections: An optional list of collections that updates should
3564      be added to.
3565    name: Name of new update operation, and namespace for other dependent ops.
3566
3567  Returns:
3568    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3569      divided by the sum of `true_positives` and `false_positives`.
3570    update_op: `Operation` that increments `true_positives` and
3571      `false_positives` variables appropriately, and whose value matches
3572      `precision`.
3573
3574  Raises:
3575    ValueError: If `weights` is not `None` and its shape doesn't match
3576      `predictions`, or if either `metrics_collections` or `updates_collections`
3577      are not a list or tuple.
3578    RuntimeError: If eager execution is enabled.
3579  """
3580  if context.executing_eagerly():
3581    raise RuntimeError('tf.metrics.precision_at_top_k is not '
3582                       'supported when eager execution is enabled.')
3583
3584  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3585                      (predictions_idx, labels, weights)) as scope:
3586    labels = _maybe_expand_labels(labels, predictions_idx)
3587    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
3588    tp, tp_update = _streaming_sparse_true_positive_at_k(
3589        predictions_idx=top_k_idx,
3590        labels=labels,
3591        k=k,
3592        class_id=class_id,
3593        weights=weights)
3594    fp, fp_update = _streaming_sparse_false_positive_at_k(
3595        predictions_idx=top_k_idx,
3596        labels=labels,
3597        k=k,
3598        class_id=class_id,
3599        weights=weights)
3600
3601    def precision_across_replicas(_, tp, fp):
3602      return math_ops.divide(tp, math_ops.add(tp, fp), name=scope)
3603
3604    metric = _aggregate_across_replicas(
3605        metrics_collections, precision_across_replicas, tp, fp)
3606
3607    update = math_ops.divide(
3608        tp_update, math_ops.add(tp_update, fp_update), name='update')
3609    if updates_collections:
3610      ops.add_to_collections(updates_collections, update)
3611    return metric, update
3612
3613
3614@tf_export(v1=['metrics.sparse_precision_at_k'])
3615@deprecated(None, 'Use precision_at_k instead')
3616def sparse_precision_at_k(labels,
3617                          predictions,
3618                          k,
3619                          class_id=None,
3620                          weights=None,
3621                          metrics_collections=None,
3622                          updates_collections=None,
3623                          name=None):
3624  """Renamed to `precision_at_k`, please use that method instead."""
3625  return precision_at_k(
3626      labels=labels,
3627      predictions=predictions,
3628      k=k,
3629      class_id=class_id,
3630      weights=weights,
3631      metrics_collections=metrics_collections,
3632      updates_collections=updates_collections,
3633      name=name)
3634
3635
3636@tf_export(v1=['metrics.precision_at_k'])
3637def precision_at_k(labels,
3638                   predictions,
3639                   k,
3640                   class_id=None,
3641                   weights=None,
3642                   metrics_collections=None,
3643                   updates_collections=None,
3644                   name=None):
3645  """Computes precision@k of the predictions with respect to sparse labels.
3646
3647  If `class_id` is specified, we calculate precision by considering only the
3648      entries in the batch for which `class_id` is in the top-k highest
3649      `predictions`, and computing the fraction of them for which `class_id` is
3650      indeed a correct label.
3651  If `class_id` is not specified, we'll calculate precision as how often on
3652      average a class among the top-k classes with the highest predicted values
3653      of a batch entry is correct and can be found in the label for that entry.
3654
3655  `precision_at_k` creates two local variables,
3656  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3657  the precision@k frequency. This frequency is ultimately returned as
3658  `precision_at_<k>`: an idempotent operation that simply divides
3659  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3660  `false_positive_at_<k>`).
3661
3662  For estimation of the metric over a stream of data, the function creates an
3663  `update_op` operation that updates these variables and returns the
3664  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3665  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3666  `labels` calculate the true positives and false positives weighted by
3667  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3668  `false_positive_at_<k>` using these values.
3669
3670  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3671
3672  Args:
3673    labels: `int64` `Tensor` or `SparseTensor` with shape
3674      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3675      num_labels=1. N >= 1 and num_labels is the number of target classes for
3676      the associated prediction. Commonly, N=1 and `labels` has shape
3677      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3678      should be in range [0, num_classes), where num_classes is the last
3679      dimension of `predictions`. Values outside this range are ignored.
3680    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3681      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3682      The final dimension contains the logit values for each class. [D1, ... DN]
3683      must match `labels`.
3684    k: Integer, k for @k metric.
3685    class_id: Integer class ID for which we want binary metrics. This should be
3686      in range [0, num_classes], where num_classes is the last dimension of
3687      `predictions`. If `class_id` is outside this range, the method returns
3688      NAN.
3689    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3690      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3691      dimensions must be either `1`, or the same as the corresponding `labels`
3692      dimension).
3693    metrics_collections: An optional list of collections that values should
3694      be added to.
3695    updates_collections: An optional list of collections that updates should
3696      be added to.
3697    name: Name of new update operation, and namespace for other dependent ops.
3698
3699  Returns:
3700    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3701      divided by the sum of `true_positives` and `false_positives`.
3702    update_op: `Operation` that increments `true_positives` and
3703      `false_positives` variables appropriately, and whose value matches
3704      `precision`.
3705
3706  Raises:
3707    ValueError: If `weights` is not `None` and its shape doesn't match
3708      `predictions`, or if either `metrics_collections` or `updates_collections`
3709      are not a list or tuple.
3710    RuntimeError: If eager execution is enabled.
3711  """
3712  if context.executing_eagerly():
3713    raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3714                       'supported when eager execution is enabled.')
3715
3716  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3717                      (predictions, labels, weights)) as scope:
3718    _, top_k_idx = nn.top_k(predictions, k)
3719    return precision_at_top_k(
3720        labels=labels,
3721        predictions_idx=top_k_idx,
3722        k=k,
3723        class_id=class_id,
3724        weights=weights,
3725        metrics_collections=metrics_collections,
3726        updates_collections=updates_collections,
3727        name=scope)
3728
3729
3730@tf_export(v1=['metrics.specificity_at_sensitivity'])
3731def specificity_at_sensitivity(labels,
3732                               predictions,
3733                               sensitivity,
3734                               weights=None,
3735                               num_thresholds=200,
3736                               metrics_collections=None,
3737                               updates_collections=None,
3738                               name=None):
3739  """Computes the specificity at a given sensitivity.
3740
3741  The `specificity_at_sensitivity` function creates four local
3742  variables, `true_positives`, `true_negatives`, `false_positives` and
3743  `false_negatives` that are used to compute the specificity at the given
3744  sensitivity value. The threshold for the given sensitivity value is computed
3745  and used to evaluate the corresponding specificity.
3746
3747  For estimation of the metric over a stream of data, the function creates an
3748  `update_op` operation that updates these variables and returns the
3749  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3750  `false_positives` and `false_negatives` counts with the weight of each case
3751  found in the `predictions` and `labels`.
3752
3753  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3754
3755  For additional information about specificity and sensitivity, see the
3756  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3757
3758  Args:
3759    labels: The ground truth values, a `Tensor` whose dimensions must match
3760      `predictions`. Will be cast to `bool`.
3761    predictions: A floating point `Tensor` of arbitrary shape and whose values
3762      are in the range `[0, 1]`.
3763    sensitivity: A scalar value in range `[0, 1]`.
3764    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3765      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3766      be either `1`, or the same as the corresponding `labels` dimension).
3767    num_thresholds: The number of thresholds to use for matching the given
3768      sensitivity.
3769    metrics_collections: An optional list of collections that `specificity`
3770      should be added to.
3771    updates_collections: An optional list of collections that `update_op` should
3772      be added to.
3773    name: An optional variable_scope name.
3774
3775  Returns:
3776    specificity: A scalar `Tensor` representing the specificity at the given
3777      `sensitivity` value.
3778    update_op: An operation that increments the `true_positives`,
3779      `true_negatives`, `false_positives` and `false_negatives` variables
3780      appropriately and whose value matches `specificity`.
3781
3782  Raises:
3783    ValueError: If `predictions` and `labels` have mismatched shapes, if
3784      `weights` is not `None` and its shape doesn't match `predictions`, or if
3785      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3786      or `updates_collections` are not a list or tuple.
3787    RuntimeError: If eager execution is enabled.
3788  """
3789  if context.executing_eagerly():
3790    raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3791                       'supported when eager execution is enabled.')
3792
3793  if sensitivity < 0 or sensitivity > 1:
3794    raise ValueError('`sensitivity` must be in the range [0, 1]. Currently, '
3795                     f'`sensitivity` is {sensitivity}.')
3796
3797  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3798                                     (predictions, labels, weights)):
3799    kepsilon = 1e-7  # to account for floating point imprecisions
3800    thresholds = [
3801        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3802    ]
3803    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3804
3805    values, update_ops = _confusion_matrix_at_thresholds(
3806        labels, predictions, thresholds, weights)
3807
3808    def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3809      """Computes the specificity at the given sensitivity.
3810
3811      Args:
3812        tp: True positives.
3813        tn: True negatives.
3814        fp: False positives.
3815        fn: False negatives.
3816        name: The name of the operation.
3817
3818      Returns:
3819        The specificity using the aggregated values.
3820      """
3821      sensitivities = math_ops.divide(tp, tp + fn + kepsilon)
3822
3823      # We'll need to use this trick until tf.argmax allows us to specify
3824      # whether we should use the first or last index in case of ties.
3825      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3826      indices_at_minval = math_ops.equal(
3827          math_ops.abs(sensitivities - sensitivity), min_val)
3828      indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64)
3829      indices_at_minval = math_ops.cumsum(indices_at_minval)
3830      tf_index = math_ops.argmax(indices_at_minval, 0)
3831      tf_index = math_ops.cast(tf_index, dtypes.int32)
3832
3833      # Now, we have the implicit threshold, so compute the specificity:
3834      return math_ops.divide(tn[tf_index],
3835                             tn[tf_index] + fp[tf_index] + kepsilon, name)
3836
3837    def specificity_across_replicas(_, values):
3838      return compute_specificity_at_sensitivity(
3839          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3840
3841    specificity = _aggregate_across_replicas(
3842        metrics_collections, specificity_across_replicas, values)
3843
3844    update_op = compute_specificity_at_sensitivity(
3845        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3846        'update_op')
3847    if updates_collections:
3848      ops.add_to_collections(updates_collections, update_op)
3849
3850    return specificity, update_op
3851