• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections as collections_lib
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import confusion_matrix
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import metrics
36from tensorflow.python.ops import metrics_impl
37from tensorflow.python.ops import nn
38from tensorflow.python.ops import state_ops
39from tensorflow.python.ops import variable_scope
40from tensorflow.python.ops import weights_broadcast_ops
41from tensorflow.python.ops.distributions.normal import Normal
42from tensorflow.python.util.deprecation import deprecated
43
44# Epsilon constant used to represent extremely small quantity.
45_EPSILON = 1e-7
46
47
48@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
49            'order of the labels and predictions arguments has been switched.')
50def streaming_true_positives(predictions,
51                             labels,
52                             weights=None,
53                             metrics_collections=None,
54                             updates_collections=None,
55                             name=None):
56  """Sum the weights of true_positives.
57
58  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
59
60  Args:
61    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
62      be cast to `bool`.
63    labels: The ground truth values, a `Tensor` whose dimensions must match
64      `predictions`. Will be cast to `bool`.
65    weights: Optional `Tensor` whose rank is either 0, or the same rank as
66      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
67      must be either `1`, or the same as the corresponding `labels`
68      dimension).
69    metrics_collections: An optional list of collections that the metric
70      value variable should be added to.
71    updates_collections: An optional list of collections that the metric update
72      ops should be added to.
73    name: An optional variable_scope name.
74
75  Returns:
76    value_tensor: A `Tensor` representing the current value of the metric.
77    update_op: An operation that accumulates the error from a batch of data.
78
79  Raises:
80    ValueError: If `predictions` and `labels` have mismatched shapes, or if
81      `weights` is not `None` and its shape doesn't match `predictions`, or if
82      either `metrics_collections` or `updates_collections` are not a list or
83      tuple.
84  """
85  return metrics.true_positives(
86      predictions=predictions,
87      labels=labels,
88      weights=weights,
89      metrics_collections=metrics_collections,
90      updates_collections=updates_collections,
91      name=name)
92
93
94@deprecated(None, 'Please switch to tf.metrics.true_negatives. Note that the '
95            'order of the labels and predictions arguments has been switched.')
96def streaming_true_negatives(predictions,
97                             labels,
98                             weights=None,
99                             metrics_collections=None,
100                             updates_collections=None,
101                             name=None):
102  """Sum the weights of true_negatives.
103
104  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
105
106  Args:
107    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
108      be cast to `bool`.
109    labels: The ground truth values, a `Tensor` whose dimensions must match
110      `predictions`. Will be cast to `bool`.
111    weights: Optional `Tensor` whose rank is either 0, or the same rank as
112      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
113      must be either `1`, or the same as the corresponding `labels`
114      dimension).
115    metrics_collections: An optional list of collections that the metric
116      value variable should be added to.
117    updates_collections: An optional list of collections that the metric update
118      ops should be added to.
119    name: An optional variable_scope name.
120
121  Returns:
122    value_tensor: A `Tensor` representing the current value of the metric.
123    update_op: An operation that accumulates the error from a batch of data.
124
125  Raises:
126    ValueError: If `predictions` and `labels` have mismatched shapes, or if
127      `weights` is not `None` and its shape doesn't match `predictions`, or if
128      either `metrics_collections` or `updates_collections` are not a list or
129      tuple.
130  """
131  return metrics.true_negatives(
132      predictions=predictions,
133      labels=labels,
134      weights=weights,
135      metrics_collections=metrics_collections,
136      updates_collections=updates_collections,
137      name=name)
138
139
140@deprecated(None, 'Please switch to tf.metrics.false_positives. Note that the '
141            'order of the labels and predictions arguments has been switched.')
142def streaming_false_positives(predictions,
143                              labels,
144                              weights=None,
145                              metrics_collections=None,
146                              updates_collections=None,
147                              name=None):
148  """Sum the weights of false positives.
149
150  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
151
152  Args:
153    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
154      be cast to `bool`.
155    labels: The ground truth values, a `Tensor` whose dimensions must match
156      `predictions`. Will be cast to `bool`.
157    weights: Optional `Tensor` whose rank is either 0, or the same rank as
158      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
159      must be either `1`, or the same as the corresponding `labels`
160      dimension).
161    metrics_collections: An optional list of collections that the metric
162      value variable should be added to.
163    updates_collections: An optional list of collections that the metric update
164      ops should be added to.
165    name: An optional variable_scope name.
166
167  Returns:
168    value_tensor: A `Tensor` representing the current value of the metric.
169    update_op: An operation that accumulates the error from a batch of data.
170
171  Raises:
172    ValueError: If `predictions` and `labels` have mismatched shapes, or if
173      `weights` is not `None` and its shape doesn't match `predictions`, or if
174      either `metrics_collections` or `updates_collections` are not a list or
175      tuple.
176  """
177  return metrics.false_positives(
178      predictions=predictions,
179      labels=labels,
180      weights=weights,
181      metrics_collections=metrics_collections,
182      updates_collections=updates_collections,
183      name=name)
184
185
186@deprecated(None, 'Please switch to tf.metrics.false_negatives. Note that the '
187            'order of the labels and predictions arguments has been switched.')
188def streaming_false_negatives(predictions,
189                              labels,
190                              weights=None,
191                              metrics_collections=None,
192                              updates_collections=None,
193                              name=None):
194  """Computes the total number of false negatives.
195
196  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
197
198  Args:
199    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
200      be cast to `bool`.
201    labels: The ground truth values, a `Tensor` whose dimensions must match
202      `predictions`. Will be cast to `bool`.
203    weights: Optional `Tensor` whose rank is either 0, or the same rank as
204      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
205      must be either `1`, or the same as the corresponding `labels`
206      dimension).
207    metrics_collections: An optional list of collections that the metric
208      value variable should be added to.
209    updates_collections: An optional list of collections that the metric update
210      ops should be added to.
211    name: An optional variable_scope name.
212
213  Returns:
214    value_tensor: A `Tensor` representing the current value of the metric.
215    update_op: An operation that accumulates the error from a batch of data.
216
217  Raises:
218    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
219      or if either `metrics_collections` or `updates_collections` are not a list
220      or tuple.
221  """
222  return metrics.false_negatives(
223      predictions=predictions,
224      labels=labels,
225      weights=weights,
226      metrics_collections=metrics_collections,
227      updates_collections=updates_collections,
228      name=name)
229
230
231@deprecated(None, 'Please switch to tf.metrics.mean')
232def streaming_mean(values,
233                   weights=None,
234                   metrics_collections=None,
235                   updates_collections=None,
236                   name=None):
237  """Computes the (weighted) mean of the given values.
238
239  The `streaming_mean` function creates two local variables, `total` and `count`
240  that are used to compute the average of `values`. This average is ultimately
241  returned as `mean` which is an idempotent operation that simply divides
242  `total` by `count`.
243
244  For estimation of the metric over a stream of data, the function creates an
245  `update_op` operation that updates these variables and returns the `mean`.
246  `update_op` increments `total` with the reduced sum of the product of `values`
247  and `weights`, and it increments `count` with the reduced sum of `weights`.
248
249  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
250
251  Args:
252    values: A `Tensor` of arbitrary dimensions.
253    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
254      must be broadcastable to `values` (i.e., all dimensions must be either
255      `1`, or the same as the corresponding `values` dimension).
256    metrics_collections: An optional list of collections that `mean`
257      should be added to.
258    updates_collections: An optional list of collections that `update_op`
259      should be added to.
260    name: An optional variable_scope name.
261
262  Returns:
263    mean: A `Tensor` representing the current mean, the value of `total` divided
264      by `count`.
265    update_op: An operation that increments the `total` and `count` variables
266      appropriately and whose value matches `mean`.
267
268  Raises:
269    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
270      or if either `metrics_collections` or `updates_collections` are not a list
271      or tuple.
272  """
273  return metrics.mean(
274      values=values,
275      weights=weights,
276      metrics_collections=metrics_collections,
277      updates_collections=updates_collections,
278      name=name)
279
280
281@deprecated(None, 'Please switch to tf.metrics.mean_tensor')
282def streaming_mean_tensor(values,
283                          weights=None,
284                          metrics_collections=None,
285                          updates_collections=None,
286                          name=None):
287  """Computes the element-wise (weighted) mean of the given tensors.
288
289  In contrast to the `streaming_mean` function which returns a scalar with the
290  mean,  this function returns an average tensor with the same shape as the
291  input tensors.
292
293  The `streaming_mean_tensor` function creates two local variables,
294  `total_tensor` and `count_tensor` that are used to compute the average of
295  `values`. This average is ultimately returned as `mean` which is an idempotent
296  operation that simply divides `total` by `count`.
297
298  For estimation of the metric over a stream of data, the function creates an
299  `update_op` operation that updates these variables and returns the `mean`.
300  `update_op` increments `total` with the reduced sum of the product of `values`
301  and `weights`, and it increments `count` with the reduced sum of `weights`.
302
303  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
304
305  Args:
306    values: A `Tensor` of arbitrary dimensions.
307    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
308      must be broadcastable to `values` (i.e., all dimensions must be either
309      `1`, or the same as the corresponding `values` dimension).
310    metrics_collections: An optional list of collections that `mean`
311      should be added to.
312    updates_collections: An optional list of collections that `update_op`
313      should be added to.
314    name: An optional variable_scope name.
315
316  Returns:
317    mean: A float `Tensor` representing the current mean, the value of `total`
318      divided by `count`.
319    update_op: An operation that increments the `total` and `count` variables
320      appropriately and whose value matches `mean`.
321
322  Raises:
323    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
324      or if either `metrics_collections` or `updates_collections` are not a list
325      or tuple.
326  """
327  return metrics.mean_tensor(
328      values=values,
329      weights=weights,
330      metrics_collections=metrics_collections,
331      updates_collections=updates_collections,
332      name=name)
333
334
335@deprecated(None, 'Please switch to tf.metrics.accuracy. Note that the order '
336            'of the labels and predictions arguments has been switched.')
337def streaming_accuracy(predictions,
338                       labels,
339                       weights=None,
340                       metrics_collections=None,
341                       updates_collections=None,
342                       name=None):
343  """Calculates how often `predictions` matches `labels`.
344
345  The `streaming_accuracy` function creates two local variables, `total` and
346  `count` that are used to compute the frequency with which `predictions`
347  matches `labels`. This frequency is ultimately returned as `accuracy`: an
348  idempotent operation that simply divides `total` by `count`.
349
350  For estimation of the metric over a stream of data, the function creates an
351  `update_op` operation that updates these variables and returns the `accuracy`.
352  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
353  where the corresponding elements of `predictions` and `labels` match and 0.0
354  otherwise. Then `update_op` increments `total` with the reduced sum of the
355  product of `weights` and `is_correct`, and it increments `count` with the
356  reduced sum of `weights`.
357
358  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
359
360  Args:
361    predictions: The predicted values, a `Tensor` of any shape.
362    labels: The ground truth values, a `Tensor` whose shape matches
363      `predictions`.
364    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
365      must be broadcastable to `labels` (i.e., all dimensions must be either
366      `1`, or the same as the corresponding `labels` dimension).
367    metrics_collections: An optional list of collections that `accuracy` should
368      be added to.
369    updates_collections: An optional list of collections that `update_op` should
370      be added to.
371    name: An optional variable_scope name.
372
373  Returns:
374    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
375      by `count`.
376    update_op: An operation that increments the `total` and `count` variables
377      appropriately and whose value matches `accuracy`.
378
379  Raises:
380    ValueError: If `predictions` and `labels` have mismatched shapes, or if
381      `weights` is not `None` and its shape doesn't match `predictions`, or if
382      either `metrics_collections` or `updates_collections` are not a list or
383      tuple.
384  """
385  return metrics.accuracy(
386      predictions=predictions,
387      labels=labels,
388      weights=weights,
389      metrics_collections=metrics_collections,
390      updates_collections=updates_collections,
391      name=name)
392
393
394@deprecated(None, 'Please switch to tf.metrics.precision. Note that the order '
395            'of the labels and predictions arguments has been switched.')
396def streaming_precision(predictions,
397                        labels,
398                        weights=None,
399                        metrics_collections=None,
400                        updates_collections=None,
401                        name=None):
402  """Computes the precision of the predictions with respect to the labels.
403
404  The `streaming_precision` function creates two local variables,
405  `true_positives` and `false_positives`, that are used to compute the
406  precision. This value is ultimately returned as `precision`, an idempotent
407  operation that simply divides `true_positives` by the sum of `true_positives`
408  and `false_positives`.
409
410  For estimation of the metric over a stream of data, the function creates an
411  `update_op` operation that updates these variables and returns the
412  `precision`. `update_op` weights each prediction by the corresponding value in
413  `weights`.
414
415  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
416
417  Args:
418    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
419    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
420      match `predictions`.
421    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
422      must be broadcastable to `labels` (i.e., all dimensions must be either
423      `1`, or the same as the corresponding `labels` dimension).
424    metrics_collections: An optional list of collections that `precision` should
425      be added to.
426    updates_collections: An optional list of collections that `update_op` should
427      be added to.
428    name: An optional variable_scope name.
429
430  Returns:
431    precision: Scalar float `Tensor` with the value of `true_positives`
432      divided by the sum of `true_positives` and `false_positives`.
433    update_op: `Operation` that increments `true_positives` and
434      `false_positives` variables appropriately and whose value matches
435      `precision`.
436
437  Raises:
438    ValueError: If `predictions` and `labels` have mismatched shapes, or if
439      `weights` is not `None` and its shape doesn't match `predictions`, or if
440      either `metrics_collections` or `updates_collections` are not a list or
441      tuple.
442  """
443  return metrics.precision(
444      predictions=predictions,
445      labels=labels,
446      weights=weights,
447      metrics_collections=metrics_collections,
448      updates_collections=updates_collections,
449      name=name)
450
451
452@deprecated(None, 'Please switch to tf.metrics.recall. Note that the order '
453            'of the labels and predictions arguments has been switched.')
454def streaming_recall(predictions,
455                     labels,
456                     weights=None,
457                     metrics_collections=None,
458                     updates_collections=None,
459                     name=None):
460  """Computes the recall of the predictions with respect to the labels.
461
462  The `streaming_recall` function creates two local variables, `true_positives`
463  and `false_negatives`, that are used to compute the recall. This value is
464  ultimately returned as `recall`, an idempotent operation that simply divides
465  `true_positives` by the sum of `true_positives`  and `false_negatives`.
466
467  For estimation of the metric over a stream of data, the function creates an
468  `update_op` that updates these variables and returns the `recall`. `update_op`
469  weights each prediction by the corresponding value in `weights`.
470
471  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
472
473  Args:
474    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
475    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
476      match `predictions`.
477    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
478      must be broadcastable to `labels` (i.e., all dimensions must be either
479      `1`, or the same as the corresponding `labels` dimension).
480    metrics_collections: An optional list of collections that `recall` should
481      be added to.
482    updates_collections: An optional list of collections that `update_op` should
483      be added to.
484    name: An optional variable_scope name.
485
486  Returns:
487    recall: Scalar float `Tensor` with the value of `true_positives` divided
488      by the sum of `true_positives` and `false_negatives`.
489    update_op: `Operation` that increments `true_positives` and
490      `false_negatives` variables appropriately and whose value matches
491      `recall`.
492
493  Raises:
494    ValueError: If `predictions` and `labels` have mismatched shapes, or if
495      `weights` is not `None` and its shape doesn't match `predictions`, or if
496      either `metrics_collections` or `updates_collections` are not a list or
497      tuple.
498  """
499  return metrics.recall(
500      predictions=predictions,
501      labels=labels,
502      weights=weights,
503      metrics_collections=metrics_collections,
504      updates_collections=updates_collections,
505      name=name)
506
507
508def streaming_false_positive_rate(predictions,
509                                  labels,
510                                  weights=None,
511                                  metrics_collections=None,
512                                  updates_collections=None,
513                                  name=None):
514  """Computes the false positive rate of predictions with respect to labels.
515
516  The `false_positive_rate` function creates two local variables,
517  `false_positives` and `true_negatives`, that are used to compute the
518  false positive rate. This value is ultimately returned as
519  `false_positive_rate`, an idempotent operation that simply divides
520  `false_positives` by the sum of `false_positives` and `true_negatives`.
521
522  For estimation of the metric over a stream of data, the function creates an
523  `update_op` operation that updates these variables and returns the
524  `false_positive_rate`. `update_op` weights each prediction by the
525  corresponding value in `weights`.
526
527  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
528
529  Args:
530    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
531      be cast to `bool`.
532    labels: The ground truth values, a `Tensor` whose dimensions must match
533      `predictions`. Will be cast to `bool`.
534    weights: Optional `Tensor` whose rank is either 0, or the same rank as
535      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
536      be either `1`, or the same as the corresponding `labels` dimension).
537    metrics_collections: An optional list of collections that
538     `false_positive_rate` should be added to.
539    updates_collections: An optional list of collections that `update_op` should
540      be added to.
541    name: An optional variable_scope name.
542
543  Returns:
544    false_positive_rate: Scalar float `Tensor` with the value of
545      `false_positives` divided by the sum of `false_positives` and
546      `true_negatives`.
547    update_op: `Operation` that increments `false_positives` and
548      `true_negatives` variables appropriately and whose value matches
549      `false_positive_rate`.
550
551  Raises:
552    ValueError: If `predictions` and `labels` have mismatched shapes, or if
553      `weights` is not `None` and its shape doesn't match `predictions`, or if
554      either `metrics_collections` or `updates_collections` are not a list or
555      tuple.
556  """
557  with variable_scope.variable_scope(name, 'false_positive_rate',
558                                     (predictions, labels, weights)):
559    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
560        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
561        labels=math_ops.cast(labels, dtype=dtypes.bool),
562        weights=weights)
563
564    false_p, false_positives_update_op = metrics.false_positives(
565        labels=labels,
566        predictions=predictions,
567        weights=weights,
568        metrics_collections=None,
569        updates_collections=None,
570        name=None)
571    true_n, true_negatives_update_op = metrics.true_negatives(
572        labels=labels,
573        predictions=predictions,
574        weights=weights,
575        metrics_collections=None,
576        updates_collections=None,
577        name=None)
578
579    def compute_fpr(fp, tn, name):
580      return array_ops.where(
581          math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name)
582
583    fpr = compute_fpr(false_p, true_n, 'value')
584    update_op = compute_fpr(false_positives_update_op, true_negatives_update_op,
585                            'update_op')
586
587    if metrics_collections:
588      ops.add_to_collections(metrics_collections, fpr)
589
590    if updates_collections:
591      ops.add_to_collections(updates_collections, update_op)
592
593    return fpr, update_op
594
595
596def streaming_false_negative_rate(predictions,
597                                  labels,
598                                  weights=None,
599                                  metrics_collections=None,
600                                  updates_collections=None,
601                                  name=None):
602  """Computes the false negative rate of predictions with respect to labels.
603
604  The `false_negative_rate` function creates two local variables,
605  `false_negatives` and `true_positives`, that are used to compute the
606  false positive rate. This value is ultimately returned as
607  `false_negative_rate`, an idempotent operation that simply divides
608  `false_negatives` by the sum of `false_negatives` and `true_positives`.
609
610  For estimation of the metric over a stream of data, the function creates an
611  `update_op` operation that updates these variables and returns the
612  `false_negative_rate`. `update_op` weights each prediction by the
613  corresponding value in `weights`.
614
615  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
616
617  Args:
618    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
619      be cast to `bool`.
620    labels: The ground truth values, a `Tensor` whose dimensions must match
621      `predictions`. Will be cast to `bool`.
622    weights: Optional `Tensor` whose rank is either 0, or the same rank as
623      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
624      be either `1`, or the same as the corresponding `labels` dimension).
625    metrics_collections: An optional list of collections that
626      `false_negative_rate` should be added to.
627    updates_collections: An optional list of collections that `update_op` should
628      be added to.
629    name: An optional variable_scope name.
630
631  Returns:
632    false_negative_rate: Scalar float `Tensor` with the value of
633      `false_negatives` divided by the sum of `false_negatives` and
634      `true_positives`.
635    update_op: `Operation` that increments `false_negatives` and
636      `true_positives` variables appropriately and whose value matches
637      `false_negative_rate`.
638
639  Raises:
640    ValueError: If `predictions` and `labels` have mismatched shapes, or if
641      `weights` is not `None` and its shape doesn't match `predictions`, or if
642      either `metrics_collections` or `updates_collections` are not a list or
643      tuple.
644  """
645  with variable_scope.variable_scope(name, 'false_negative_rate',
646                                     (predictions, labels, weights)):
647    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
648        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
649        labels=math_ops.cast(labels, dtype=dtypes.bool),
650        weights=weights)
651
652    false_n, false_negatives_update_op = metrics.false_negatives(
653        labels,
654        predictions,
655        weights,
656        metrics_collections=None,
657        updates_collections=None,
658        name=None)
659    true_p, true_positives_update_op = metrics.true_positives(
660        labels,
661        predictions,
662        weights,
663        metrics_collections=None,
664        updates_collections=None,
665        name=None)
666
667    def compute_fnr(fn, tp, name):
668      return array_ops.where(
669          math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name)
670
671    fnr = compute_fnr(false_n, true_p, 'value')
672    update_op = compute_fnr(false_negatives_update_op, true_positives_update_op,
673                            'update_op')
674
675    if metrics_collections:
676      ops.add_to_collections(metrics_collections, fnr)
677
678    if updates_collections:
679      ops.add_to_collections(updates_collections, update_op)
680
681    return fnr, update_op
682
683
684def _streaming_confusion_matrix_at_thresholds(predictions,
685                                              labels,
686                                              thresholds,
687                                              weights=None,
688                                              includes=None):
689  """Computes true_positives, false_negatives, true_negatives, false_positives.
690
691  This function creates up to four local variables, `true_positives`,
692  `true_negatives`, `false_positives` and `false_negatives`.
693  `true_positive[i]` is defined as the total weight of values in `predictions`
694  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
695  `false_negatives[i]` is defined as the total weight of values in `predictions`
696  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
697  `true_negatives[i]` is defined as the total weight of values in `predictions`
698  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
699  `false_positives[i]` is defined as the total weight of values in `predictions`
700  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
701
702  For estimation of these metrics over a stream of data, for each metric the
703  function respectively creates an `update_op` operation that updates the
704  variable and returns its value.
705
706  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
707
708  Args:
709    predictions: A floating point `Tensor` of arbitrary shape and whose values
710      are in the range `[0, 1]`.
711    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
712      to `bool`.
713    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
714    weights: Optional `Tensor` whose rank is either 0, or the same rank as
715      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
716      must be either `1`, or the same as the corresponding `labels`
717      dimension).
718    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
719      default to all four.
720
721  Returns:
722    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
723        `includes`.
724    update_ops: Dict of operations that increments the `values`. Keys are from
725        `includes`.
726
727  Raises:
728    ValueError: If `predictions` and `labels` have mismatched shapes, or if
729      `weights` is not `None` and its shape doesn't match `predictions`, or if
730      `includes` contains invalid keys.
731  """
732  all_includes = ('tp', 'fn', 'tn', 'fp')
733  if includes is None:
734    includes = all_includes
735  else:
736    for include in includes:
737      if include not in all_includes:
738        raise ValueError('Invalid key: %s.' % include)
739
740  predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
741      predictions, labels, weights)
742  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
743
744  num_thresholds = len(thresholds)
745
746  # Reshape predictions and labels.
747  predictions_2d = array_ops.reshape(predictions, [-1, 1])
748  labels_2d = array_ops.reshape(
749      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
750
751  # Use static shape if known.
752  num_predictions = predictions_2d.get_shape().as_list()[0]
753
754  # Otherwise use dynamic shape.
755  if num_predictions is None:
756    num_predictions = array_ops.shape(predictions_2d)[0]
757  thresh_tiled = array_ops.tile(
758      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
759      array_ops.stack([1, num_predictions]))
760
761  # Tile the predictions after thresholding them across different thresholds.
762  pred_is_pos = math_ops.greater(
763      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
764      thresh_tiled)
765  if ('fn' in includes) or ('tn' in includes):
766    pred_is_neg = math_ops.logical_not(pred_is_pos)
767
768  # Tile labels by number of thresholds
769  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
770  if ('fp' in includes) or ('tn' in includes):
771    label_is_neg = math_ops.logical_not(label_is_pos)
772
773  if weights is not None:
774    broadcast_weights = weights_broadcast_ops.broadcast_weights(
775        math_ops.cast(weights, dtypes.float32), predictions)
776    weights_tiled = array_ops.tile(
777        array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1])
778    thresh_tiled.get_shape().assert_is_compatible_with(
779        weights_tiled.get_shape())
780  else:
781    weights_tiled = None
782
783  values = {}
784  update_ops = {}
785
786  if 'tp' in includes:
787    true_positives = metrics_impl.metric_variable(
788        [num_thresholds], dtypes.float32, name='true_positives')
789    is_true_positive = math_ops.cast(
790        math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
791    if weights_tiled is not None:
792      is_true_positive *= weights_tiled
793    update_ops['tp'] = state_ops.assign_add(true_positives,
794                                            math_ops.reduce_sum(
795                                                is_true_positive, 1))
796    values['tp'] = true_positives
797
798  if 'fn' in includes:
799    false_negatives = metrics_impl.metric_variable(
800        [num_thresholds], dtypes.float32, name='false_negatives')
801    is_false_negative = math_ops.cast(
802        math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
803    if weights_tiled is not None:
804      is_false_negative *= weights_tiled
805    update_ops['fn'] = state_ops.assign_add(false_negatives,
806                                            math_ops.reduce_sum(
807                                                is_false_negative, 1))
808    values['fn'] = false_negatives
809
810  if 'tn' in includes:
811    true_negatives = metrics_impl.metric_variable(
812        [num_thresholds], dtypes.float32, name='true_negatives')
813    is_true_negative = math_ops.cast(
814        math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
815    if weights_tiled is not None:
816      is_true_negative *= weights_tiled
817    update_ops['tn'] = state_ops.assign_add(true_negatives,
818                                            math_ops.reduce_sum(
819                                                is_true_negative, 1))
820    values['tn'] = true_negatives
821
822  if 'fp' in includes:
823    false_positives = metrics_impl.metric_variable(
824        [num_thresholds], dtypes.float32, name='false_positives')
825    is_false_positive = math_ops.cast(
826        math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
827    if weights_tiled is not None:
828      is_false_positive *= weights_tiled
829    update_ops['fp'] = state_ops.assign_add(false_positives,
830                                            math_ops.reduce_sum(
831                                                is_false_positive, 1))
832    values['fp'] = false_positives
833
834  return values, update_ops
835
836
837def streaming_true_positives_at_thresholds(predictions,
838                                           labels,
839                                           thresholds,
840                                           weights=None):
841  values, update_ops = _streaming_confusion_matrix_at_thresholds(
842      predictions, labels, thresholds, weights=weights, includes=('tp',))
843  return values['tp'], update_ops['tp']
844
845
846def streaming_false_negatives_at_thresholds(predictions,
847                                            labels,
848                                            thresholds,
849                                            weights=None):
850  values, update_ops = _streaming_confusion_matrix_at_thresholds(
851      predictions, labels, thresholds, weights=weights, includes=('fn',))
852  return values['fn'], update_ops['fn']
853
854
855def streaming_false_positives_at_thresholds(predictions,
856                                            labels,
857                                            thresholds,
858                                            weights=None):
859  values, update_ops = _streaming_confusion_matrix_at_thresholds(
860      predictions, labels, thresholds, weights=weights, includes=('fp',))
861  return values['fp'], update_ops['fp']
862
863
864def streaming_true_negatives_at_thresholds(predictions,
865                                           labels,
866                                           thresholds,
867                                           weights=None):
868  values, update_ops = _streaming_confusion_matrix_at_thresholds(
869      predictions, labels, thresholds, weights=weights, includes=('tn',))
870  return values['tn'], update_ops['tn']
871
872
873def streaming_curve_points(labels=None,
874                           predictions=None,
875                           weights=None,
876                           num_thresholds=200,
877                           metrics_collections=None,
878                           updates_collections=None,
879                           curve='ROC',
880                           name=None):
881  """Computes curve (ROC or PR) values for a prespecified number of points.
882
883  The `streaming_curve_points` function creates four local variables,
884  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
885  that are used to compute the curve values. To discretize the curve, a linearly
886  spaced set of thresholds is used to compute pairs of recall and precision
887  values.
888
889  For best results, `predictions` should be distributed approximately uniformly
890  in the range [0, 1] and not peaked around 0 or 1.
891
892  For estimation of the metric over a stream of data, the function creates an
893  `update_op` operation that updates these variables.
894
895  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
896
897  Args:
898    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
899      `bool`.
900    predictions: A floating point `Tensor` of arbitrary shape and whose values
901      are in the range `[0, 1]`.
902    weights: Optional `Tensor` whose rank is either 0, or the same rank as
903      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
904      be either `1`, or the same as the corresponding `labels` dimension).
905    num_thresholds: The number of thresholds to use when discretizing the roc
906      curve.
907    metrics_collections: An optional list of collections that `auc` should be
908      added to.
909    updates_collections: An optional list of collections that `update_op` should
910      be added to.
911    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
912      'PR' for the Precision-Recall-curve.
913    name: An optional variable_scope name.
914
915  Returns:
916    points: A `Tensor` with shape [num_thresholds, 2] that contains points of
917      the curve.
918    update_op: An operation that increments the `true_positives`,
919      `true_negatives`, `false_positives` and `false_negatives` variables.
920
921  Raises:
922    ValueError: If `predictions` and `labels` have mismatched shapes, or if
923      `weights` is not `None` and its shape doesn't match `predictions`, or if
924      either `metrics_collections` or `updates_collections` are not a list or
925      tuple.
926
927  TODO(chizeng): Consider rewriting this method to make use of logic within the
928  precision_recall_at_equal_thresholds method (to improve run time).
929  """
930  with variable_scope.variable_scope(name, 'curve_points',
931                                     (labels, predictions, weights)):
932    if curve != 'ROC' and curve != 'PR':
933      raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
934    kepsilon = _EPSILON  # to account for floating point imprecisions
935    thresholds = [
936        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
937    ]
938    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
939
940    values, update_ops = _streaming_confusion_matrix_at_thresholds(
941        labels=labels,
942        predictions=predictions,
943        thresholds=thresholds,
944        weights=weights)
945
946    # Add epsilons to avoid dividing by 0.
947    epsilon = 1.0e-6
948
949    def compute_points(tp, fn, tn, fp):
950      """Computes the roc-auc or pr-auc based on confusion counts."""
951      rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
952      if curve == 'ROC':
953        fp_rate = math_ops.div(fp, fp + tn + epsilon)
954        return fp_rate, rec
955      else:  # curve == 'PR'.
956        prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
957        return rec, prec
958
959    xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
960                            values['fp'])
961    points = array_ops.stack([xs, ys], axis=1)
962    update_op = control_flow_ops.group(*update_ops.values())
963
964    if metrics_collections:
965      ops.add_to_collections(metrics_collections, points)
966
967    if updates_collections:
968      ops.add_to_collections(updates_collections, update_op)
969
970    return points, update_op
971
972
973@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of '
974            'the labels and predictions arguments has been switched.')
975def streaming_auc(predictions,
976                  labels,
977                  weights=None,
978                  num_thresholds=200,
979                  metrics_collections=None,
980                  updates_collections=None,
981                  curve='ROC',
982                  name=None):
983  """Computes the approximate AUC via a Riemann sum.
984
985  The `streaming_auc` function creates four local variables, `true_positives`,
986  `true_negatives`, `false_positives` and `false_negatives` that are used to
987  compute the AUC. To discretize the AUC curve, a linearly spaced set of
988  thresholds is used to compute pairs of recall and precision values. The area
989  under the ROC-curve is therefore computed using the height of the recall
990  values by the false positive rate, while the area under the PR-curve is the
991  computed using the height of the precision values by the recall.
992
993  This value is ultimately returned as `auc`, an idempotent operation that
994  computes the area under a discretized curve of precision versus recall values
995  (computed using the aforementioned variables). The `num_thresholds` variable
996  controls the degree of discretization with larger numbers of thresholds more
997  closely approximating the true AUC. The quality of the approximation may vary
998  dramatically depending on `num_thresholds`.
999
1000  For best results, `predictions` should be distributed approximately uniformly
1001  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
1002  approximation may be poor if this is not the case.
1003
1004  For estimation of the metric over a stream of data, the function creates an
1005  `update_op` operation that updates these variables and returns the `auc`.
1006
1007  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1008
1009  Args:
1010    predictions: A floating point `Tensor` of arbitrary shape and whose values
1011      are in the range `[0, 1]`.
1012    labels: A `bool` `Tensor` whose shape matches `predictions`.
1013    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1014      must be broadcastable to `labels` (i.e., all dimensions must be either
1015      `1`, or the same as the corresponding `labels` dimension).
1016    num_thresholds: The number of thresholds to use when discretizing the roc
1017      curve.
1018    metrics_collections: An optional list of collections that `auc` should be
1019      added to.
1020    updates_collections: An optional list of collections that `update_op` should
1021      be added to.
1022    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
1023    'PR' for the Precision-Recall-curve.
1024    name: An optional variable_scope name.
1025
1026  Returns:
1027    auc: A scalar `Tensor` representing the current area-under-curve.
1028    update_op: An operation that increments the `true_positives`,
1029      `true_negatives`, `false_positives` and `false_negatives` variables
1030      appropriately and whose value matches `auc`.
1031
1032  Raises:
1033    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1034      `weights` is not `None` and its shape doesn't match `predictions`, or if
1035      either `metrics_collections` or `updates_collections` are not a list or
1036      tuple.
1037  """
1038  return metrics.auc(
1039      predictions=predictions,
1040      labels=labels,
1041      weights=weights,
1042      metrics_collections=metrics_collections,
1043      num_thresholds=num_thresholds,
1044      curve=curve,
1045      updates_collections=updates_collections,
1046      name=name)
1047
1048
1049def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None):
1050  """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
1051
1052  Computes the area under the ROC or PR curve using each prediction as a
1053  threshold. This could be slow for large batches, but has the advantage of not
1054  having its results degrade depending on the distribution of predictions.
1055
1056  Args:
1057    labels: A `Tensor` of ground truth labels with the same shape as
1058      `predictions` with values of 0 or 1 and type `int64`.
1059    predictions: A 1-D `Tensor` of predictions whose values are `float64`.
1060    curve: The name of the curve to be computed, 'ROC' for the Receiving
1061      Operating Characteristic or 'PR' for the Precision-Recall curve.
1062    weights: A 1-D `Tensor` of weights whose values are `float64`.
1063
1064  Returns:
1065    A scalar `Tensor` containing the area-under-curve value for the input.
1066  """
1067  # Compute the total weight and the total positive weight.
1068  size = array_ops.size(predictions)
1069  if weights is None:
1070    weights = array_ops.ones_like(labels, dtype=dtypes.float64)
1071  labels, predictions, weights = metrics_impl._remove_squeezable_dimensions(
1072      labels, predictions, weights)
1073  total_weight = math_ops.reduce_sum(weights)
1074  total_positive = math_ops.reduce_sum(
1075      array_ops.where(
1076          math_ops.greater(labels, 0), weights,
1077          array_ops.zeros_like(labels, dtype=dtypes.float64)))
1078
1079  def continue_computing_dynamic_auc():
1080    """Continues dynamic auc computation, entered if labels are not all equal.
1081
1082    Returns:
1083      A scalar `Tensor` containing the area-under-curve value.
1084    """
1085    # Sort the predictions descending, keeping the same order for the
1086    # corresponding labels and weights.
1087    ordered_predictions, indices = nn.top_k(predictions, k=size)
1088    ordered_labels = array_ops.gather(labels, indices)
1089    ordered_weights = array_ops.gather(weights, indices)
1090
1091    # Get the counts of the unique ordered predictions.
1092    _, _, counts = array_ops.unique_with_counts(ordered_predictions)
1093
1094    # Compute the indices of the split points between different predictions.
1095    splits = math_ops.cast(
1096        array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
1097
1098    # Count the positives to the left of the split indices.
1099    true_positives = array_ops.gather(
1100        array_ops.pad(
1101            math_ops.cumsum(
1102                array_ops.where(
1103                    math_ops.greater(ordered_labels, 0), ordered_weights,
1104                    array_ops.zeros_like(ordered_labels,
1105                                         dtype=dtypes.float64))),
1106            paddings=[[1, 0]]), splits)
1107    if curve == 'ROC':
1108      # Compute the weight of the negatives to the left of every split point and
1109      # the total weight of the negatives number of negatives for computing the
1110      # FPR.
1111      false_positives = array_ops.gather(
1112          array_ops.pad(
1113              math_ops.cumsum(
1114                  array_ops.where(
1115                      math_ops.less(ordered_labels, 1), ordered_weights,
1116                      array_ops.zeros_like(
1117                          ordered_labels, dtype=dtypes.float64))),
1118              paddings=[[1, 0]]), splits)
1119      total_negative = total_weight - total_positive
1120      x_axis_values = math_ops.truediv(false_positives, total_negative)
1121      y_axis_values = math_ops.truediv(true_positives, total_positive)
1122    elif curve == 'PR':
1123      x_axis_values = math_ops.truediv(true_positives, total_positive)
1124      # For conformance, set precision to 1 when the number of positive
1125      # classifications is 0.
1126      positives = array_ops.gather(
1127          array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]),
1128          splits)
1129      y_axis_values = array_ops.where(
1130          math_ops.greater(splits, 0),
1131          math_ops.truediv(true_positives, positives),
1132          array_ops.ones_like(true_positives, dtype=dtypes.float64))
1133
1134    # Calculate trapezoid areas.
1135    heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
1136    widths = math_ops.abs(
1137        math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
1138    return math_ops.reduce_sum(math_ops.multiply(heights, widths))
1139
1140  # If all the labels are the same, AUC isn't well-defined (but raising an
1141  # exception seems excessive) so we return 0, otherwise we finish computing.
1142  return control_flow_ops.cond(
1143      math_ops.logical_or(
1144          math_ops.equal(total_positive, 0), math_ops.equal(
1145              total_positive, total_weight)),
1146      true_fn=lambda: array_ops.constant(0, dtypes.float64),
1147      false_fn=continue_computing_dynamic_auc)
1148
1149
1150def streaming_dynamic_auc(labels,
1151                          predictions,
1152                          curve='ROC',
1153                          metrics_collections=(),
1154                          updates_collections=(),
1155                          name=None,
1156                          weights=None):
1157  """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
1158
1159  USAGE NOTE: this approach requires storing all of the predictions and labels
1160  for a single evaluation in memory, so it may not be usable when the evaluation
1161  batch size and/or the number of evaluation steps is very large.
1162
1163  Computes the area under the ROC or PR curve using each prediction as a
1164  threshold. This has the advantage of being resilient to the distribution of
1165  predictions by aggregating across batches, accumulating labels and predictions
1166  and performing the final calculation using all of the concatenated values.
1167
1168  Args:
1169    labels: A `Tensor` of ground truth labels with the same shape as `labels`
1170      and with values of 0 or 1 whose values are castable to `int64`.
1171    predictions: A `Tensor` of predictions whose values are castable to
1172      `float64`. Will be flattened into a 1-D `Tensor`.
1173    curve: The name of the curve for which to compute AUC, 'ROC' for the
1174      Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
1175    metrics_collections: An optional iterable of collections that `auc` should
1176      be added to.
1177    updates_collections: An optional iterable of collections that `update_op`
1178      should be added to.
1179    name: An optional name for the variable_scope that contains the metric
1180      variables.
1181    weights: A 'Tensor' of non-negative weights whose values are castable to
1182      `float64`. Will be flattened into a 1-D `Tensor`.
1183
1184  Returns:
1185    auc: A scalar `Tensor` containing the current area-under-curve value.
1186    update_op: An operation that concatenates the input labels and predictions
1187      to the accumulated values.
1188
1189  Raises:
1190    ValueError: If `labels` and `predictions` have mismatched shapes or if
1191      `curve` isn't a recognized curve type.
1192  """
1193
1194  if curve not in ['PR', 'ROC']:
1195    raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
1196
1197  with variable_scope.variable_scope(name, default_name='dynamic_auc'):
1198    labels.get_shape().assert_is_compatible_with(predictions.get_shape())
1199    predictions = array_ops.reshape(
1200        math_ops.cast(predictions, dtypes.float64), [-1])
1201    labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
1202    with ops.control_dependencies([
1203        check_ops.assert_greater_equal(
1204            labels,
1205            array_ops.zeros_like(labels, dtypes.int64),
1206            message='labels must be 0 or 1, at least one is <0'),
1207        check_ops.assert_less_equal(
1208            labels,
1209            array_ops.ones_like(labels, dtypes.int64),
1210            message='labels must be 0 or 1, at least one is >1'),
1211    ]):
1212      preds_accum, update_preds = streaming_concat(
1213          predictions, name='concat_preds')
1214      labels_accum, update_labels = streaming_concat(
1215          labels, name='concat_labels')
1216      if weights is not None:
1217        weights = array_ops.reshape(
1218            math_ops.cast(weights, dtypes.float64), [-1])
1219        weights_accum, update_weights = streaming_concat(
1220            weights, name='concat_weights')
1221        update_op = control_flow_ops.group(update_labels, update_preds,
1222                                           update_weights)
1223      else:
1224        weights_accum = None
1225        update_op = control_flow_ops.group(update_labels, update_preds)
1226      auc = _compute_dynamic_auc(
1227          labels_accum, preds_accum, curve=curve, weights=weights_accum)
1228      if updates_collections:
1229        ops.add_to_collections(updates_collections, update_op)
1230      if metrics_collections:
1231        ops.add_to_collections(metrics_collections, auc)
1232      return auc, update_op
1233
1234
1235def _compute_placement_auc(labels, predictions, weights, alpha,
1236                           logit_transformation, is_valid):
1237  """Computes the AUC and asymptotic normally distributed confidence interval.
1238
1239  The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the
1240  concept of placement values for each labeled group, as presented by Delong and
1241  Delong (1988). The actual algorithm used is a more computationally efficient
1242  approach presented by Sun and Xu (2014). This could be slow for large batches,
1243  but has the advantage of not having its results degrade depending on the
1244  distribution of predictions.
1245
1246  Args:
1247    labels: A `Tensor` of ground truth labels with the same shape as
1248      `predictions` with values of 0 or 1 and type `int64`.
1249    predictions: A 1-D `Tensor` of predictions whose values are `float64`.
1250    weights: `Tensor` whose rank is either 0, or the same rank as `labels`.
1251    alpha: Confidence interval level desired.
1252    logit_transformation: A boolean value indicating whether the estimate should
1253      be logit transformed prior to calculating the confidence interval. Doing
1254      so enforces the restriction that the AUC should never be outside the
1255      interval [0,1].
1256    is_valid: A bool tensor describing whether the input is valid.
1257
1258  Returns:
1259    A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence
1260    interval values.
1261  """
1262  # Disable the invalid-name checker so that we can capitalize the name.
1263  # pylint: disable=invalid-name
1264  AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper'])
1265  # pylint: enable=invalid-name
1266
1267  # If all the labels are the same or if number of observations are too few,
1268  # AUC isn't well-defined
1269  size = array_ops.size(predictions, out_type=dtypes.int32)
1270
1271  # Count the total number of positive and negative labels in the input.
1272  total_0 = math_ops.reduce_sum(
1273      math_ops.cast(1 - labels, weights.dtype) * weights)
1274  total_1 = math_ops.reduce_sum(
1275      math_ops.cast(labels, weights.dtype) * weights)
1276
1277  # Sort the predictions ascending, as well as
1278  # (i) the corresponding labels and
1279  # (ii) the corresponding weights.
1280  ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True)
1281  ordered_predictions = array_ops.reverse(
1282      ordered_predictions, axis=array_ops.zeros(1, dtypes.int32))
1283  indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32))
1284  ordered_labels = array_ops.gather(labels, indices)
1285  ordered_weights = array_ops.gather(weights, indices)
1286
1287  # We now compute values required for computing placement values.
1288
1289  # We generate a list of indices (segmented_indices) of increasing order. An
1290  # index is assigned for each unique prediction float value. Prediction
1291  # values that are the same share the same index.
1292  _, segmented_indices = array_ops.unique(ordered_predictions)
1293
1294  # We create 2 tensors of weights. weights_for_true is non-zero for true
1295  # labels. weights_for_false is non-zero for false labels.
1296  float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32)
1297  float_labels_for_false = 1.0 - float_labels_for_true
1298  weights_for_true = ordered_weights * float_labels_for_true
1299  weights_for_false = ordered_weights * float_labels_for_false
1300
1301  # For each set of weights with the same segmented indices, we add up the
1302  # weight values. Note that for each label, we deliberately rely on weights
1303  # for the opposite label.
1304  weight_totals_for_true = math_ops.segment_sum(weights_for_false,
1305                                                segmented_indices)
1306  weight_totals_for_false = math_ops.segment_sum(weights_for_true,
1307                                                 segmented_indices)
1308
1309  # These cumulative sums of weights importantly exclude the current weight
1310  # sums.
1311  cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true,
1312                                               exclusive=True)
1313  cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false,
1314                                                exclusive=True)
1315
1316  # Compute placement values using the formula. Values with the same segmented
1317  # indices and labels share the same placement values.
1318  placements_for_true = (
1319      (cum_weight_totals_for_true + weight_totals_for_true / 2.0) /
1320      (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON))
1321  placements_for_false = (
1322      (cum_weight_totals_for_false + weight_totals_for_false / 2.0) /
1323      (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON))
1324
1325  # We expand the tensors of placement values (for each label) so that their
1326  # shapes match that of predictions.
1327  placements_for_true = array_ops.gather(placements_for_true, segmented_indices)
1328  placements_for_false = array_ops.gather(placements_for_false,
1329                                          segmented_indices)
1330
1331  # Select placement values based on the label for each index.
1332  placement_values = (
1333      placements_for_true * float_labels_for_true +
1334      placements_for_false * float_labels_for_false)
1335
1336  # Split placement values by labeled groups.
1337  placement_values_0 = placement_values * math_ops.cast(
1338      1 - ordered_labels, weights.dtype)
1339  weights_0 = ordered_weights * math_ops.cast(
1340      1 - ordered_labels, weights.dtype)
1341  placement_values_1 = placement_values * math_ops.cast(
1342      ordered_labels, weights.dtype)
1343  weights_1 = ordered_weights * math_ops.cast(
1344      ordered_labels, weights.dtype)
1345
1346  # Calculate AUC using placement values
1347  auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) /
1348           (total_0 + _EPSILON))
1349  auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) /
1350           (total_1 + _EPSILON))
1351  auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0)
1352
1353  # Calculate variance and standard error using the placement values.
1354  var_0 = (
1355      math_ops.reduce_sum(
1356          weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) /
1357      (total_0 - 1. + _EPSILON))
1358  var_1 = (
1359      math_ops.reduce_sum(weights_1 * math_ops.squared_difference(
1360          placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON))
1361  auc_std_err = math_ops.sqrt(
1362      (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON)))
1363
1364  # Calculate asymptotic normal confidence intervals
1365  std_norm_dist = Normal(loc=0., scale=1.)
1366  z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0)
1367  if logit_transformation:
1368    estimate = math_ops.log(auc / (1. - auc + _EPSILON))
1369    std_err = auc_std_err / (auc * (1. - auc + _EPSILON))
1370    transformed_auc_lower = estimate + (z_value * std_err)
1371    transformed_auc_upper = estimate - (z_value * std_err)
1372    def inverse_logit_transformation(x):
1373      exp_negative = math_ops.exp(math_ops.negative(x))
1374      return 1. / (1. + exp_negative + _EPSILON)
1375
1376    auc_lower = inverse_logit_transformation(transformed_auc_lower)
1377    auc_upper = inverse_logit_transformation(transformed_auc_upper)
1378  else:
1379    estimate = auc
1380    std_err = auc_std_err
1381    auc_lower = estimate + (z_value * std_err)
1382    auc_upper = estimate - (z_value * std_err)
1383
1384  ## If estimate is 1 or 0, no variance is present so CI = 1
1385  ## n.b. This can be misleading, since number obs can just be too low.
1386  lower = array_ops.where(
1387      math_ops.logical_or(
1388          math_ops.equal(auc, array_ops.ones_like(auc)),
1389          math_ops.equal(auc, array_ops.zeros_like(auc))),
1390      auc, auc_lower)
1391  upper = array_ops.where(
1392      math_ops.logical_or(
1393          math_ops.equal(auc, array_ops.ones_like(auc)),
1394          math_ops.equal(auc, array_ops.zeros_like(auc))),
1395      auc, auc_upper)
1396
1397  # If all the labels are the same, AUC isn't well-defined (but raising an
1398  # exception seems excessive) so we return 0, otherwise we finish computing.
1399  trivial_value = array_ops.constant(0.0)
1400
1401  return AucData(*control_flow_ops.cond(
1402      is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3))
1403
1404
1405def auc_with_confidence_intervals(labels,
1406                                  predictions,
1407                                  weights=None,
1408                                  alpha=0.95,
1409                                  logit_transformation=True,
1410                                  metrics_collections=(),
1411                                  updates_collections=(),
1412                                  name=None):
1413  """Computes the AUC and asymptotic normally distributed confidence interval.
1414
1415  USAGE NOTE: this approach requires storing all of the predictions and labels
1416  for a single evaluation in memory, so it may not be usable when the evaluation
1417  batch size and/or the number of evaluation steps is very large.
1418
1419  Computes the area under the ROC curve and its confidence interval using
1420  placement values. This has the advantage of being resilient to the
1421  distribution of predictions by aggregating across batches, accumulating labels
1422  and predictions and performing the final calculation using all of the
1423  concatenated values.
1424
1425  Args:
1426    labels: A `Tensor` of ground truth labels with the same shape as `labels`
1427      and with values of 0 or 1 whose values are castable to `int64`.
1428    predictions: A `Tensor` of predictions whose values are castable to
1429      `float64`. Will be flattened into a 1-D `Tensor`.
1430    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1431      `labels`.
1432    alpha: Confidence interval level desired.
1433    logit_transformation: A boolean value indicating whether the estimate should
1434      be logit transformed prior to calculating the confidence interval. Doing
1435      so enforces the restriction that the AUC should never be outside the
1436      interval [0,1].
1437    metrics_collections: An optional iterable of collections that `auc` should
1438      be added to.
1439    updates_collections: An optional iterable of collections that `update_op`
1440      should be added to.
1441    name: An optional name for the variable_scope that contains the metric
1442      variables.
1443
1444  Returns:
1445    auc: A 1-D `Tensor` containing the current area-under-curve, lower, and
1446      upper confidence interval values.
1447    update_op: An operation that concatenates the input labels and predictions
1448      to the accumulated values.
1449
1450  Raises:
1451    ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes
1452    or if `alpha` isn't in the range (0,1).
1453  """
1454  if not (alpha > 0 and alpha < 1):
1455    raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha)
1456
1457  if weights is None:
1458    weights = array_ops.ones_like(predictions)
1459
1460  with variable_scope.variable_scope(
1461      name,
1462      default_name='auc_with_confidence_intervals',
1463      values=[labels, predictions, weights]):
1464
1465    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
1466        predictions=predictions,
1467        labels=labels,
1468        weights=weights)
1469
1470    total_weight = math_ops.reduce_sum(weights)
1471
1472    weights = array_ops.reshape(weights, [-1])
1473    predictions = array_ops.reshape(
1474        math_ops.cast(predictions, dtypes.float64), [-1])
1475    labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
1476
1477    with ops.control_dependencies([
1478        check_ops.assert_greater_equal(
1479            labels,
1480            array_ops.zeros_like(labels, dtypes.int64),
1481            message='labels must be 0 or 1, at least one is <0'),
1482        check_ops.assert_less_equal(
1483            labels,
1484            array_ops.ones_like(labels, dtypes.int64),
1485            message='labels must be 0 or 1, at least one is >1'),
1486    ]):
1487      preds_accum, update_preds = streaming_concat(
1488          predictions, name='concat_preds')
1489      labels_accum, update_labels = streaming_concat(labels,
1490                                                     name='concat_labels')
1491      weights_accum, update_weights = streaming_concat(
1492          weights, name='concat_weights')
1493      update_op_for_valid_case = control_flow_ops.group(
1494          update_labels, update_preds, update_weights)
1495
1496      # Only perform updates if this case is valid.
1497      all_labels_positive_or_0 = math_ops.logical_and(
1498          math_ops.equal(math_ops.reduce_min(labels), 0),
1499          math_ops.equal(math_ops.reduce_max(labels), 1))
1500      sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0)
1501      is_valid = math_ops.logical_and(all_labels_positive_or_0,
1502                                      sums_of_weights_at_least_1)
1503
1504      update_op = control_flow_ops.cond(
1505          sums_of_weights_at_least_1,
1506          lambda: update_op_for_valid_case, control_flow_ops.no_op)
1507
1508      auc = _compute_placement_auc(
1509          labels_accum,
1510          preds_accum,
1511          weights_accum,
1512          alpha=alpha,
1513          logit_transformation=logit_transformation,
1514          is_valid=is_valid)
1515
1516      if updates_collections:
1517        ops.add_to_collections(updates_collections, update_op)
1518      if metrics_collections:
1519        ops.add_to_collections(metrics_collections, auc)
1520      return auc, update_op
1521
1522
1523def precision_recall_at_equal_thresholds(labels,
1524                                         predictions,
1525                                         weights=None,
1526                                         num_thresholds=None,
1527                                         use_locking=None,
1528                                         name=None):
1529  """A helper method for creating metrics related to precision-recall curves.
1530
1531  These values are true positives, false negatives, true negatives, false
1532  positives, precision, and recall. This function returns a data structure that
1533  contains ops within it.
1534
1535  Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N)
1536  space and run time), this op exhibits O(T + N) space and run time, where T is
1537  the number of thresholds and N is the size of the predictions tensor. Hence,
1538  it may be advantageous to use this function when `predictions` is big.
1539
1540  For instance, prefer this method for per-pixel classification tasks, for which
1541  the predictions tensor may be very large.
1542
1543  Each number in `predictions`, a float in `[0, 1]`, is compared with its
1544  corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at
1545  each threshold. This is then multiplied with `weights` which can be used to
1546  reweight certain values, or more commonly used for masking values.
1547
1548  Args:
1549    labels: A bool `Tensor` whose shape matches `predictions`.
1550    predictions: A floating point `Tensor` of arbitrary shape and whose values
1551      are in the range `[0, 1]`.
1552    weights: Optional; If provided, a `Tensor` that has the same dtype as,
1553      and broadcastable to, `predictions`. This tensor is multiplied by counts.
1554    num_thresholds: Optional; Number of thresholds, evenly distributed in
1555      `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins
1556      is 1 less than `num_thresholds`. Using an even `num_thresholds` value
1557      instead of an odd one may yield unfriendly edges for bins.
1558    use_locking: Optional; If True, the op will be protected by a lock.
1559      Otherwise, the behavior is undefined, but may exhibit less contention.
1560      Defaults to True.
1561    name: Optional; variable_scope name. If not provided, the string
1562      'precision_recall_at_equal_threshold' is used.
1563
1564  Returns:
1565    result: A named tuple (See PrecisionRecallData within the implementation of
1566      this function) with properties that are variables of shape
1567      `[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
1568      precision, recall, thresholds. Types are same as that of predictions.
1569    update_op: An op that accumulates values.
1570
1571  Raises:
1572    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1573      `weights` is not `None` and its shape doesn't match `predictions`, or if
1574      `includes` contains invalid keys.
1575  """
1576  # Disable the invalid-name checker so that we can capitalize the name.
1577  # pylint: disable=invalid-name
1578  PrecisionRecallData = collections_lib.namedtuple(
1579      'PrecisionRecallData',
1580      ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds'])
1581  # pylint: enable=invalid-name
1582
1583  if num_thresholds is None:
1584    num_thresholds = 201
1585
1586  if weights is None:
1587    weights = 1.0
1588
1589  if use_locking is None:
1590    use_locking = True
1591
1592  check_ops.assert_type(labels, dtypes.bool)
1593
1594  with variable_scope.variable_scope(name,
1595                                     'precision_recall_at_equal_thresholds',
1596                                     (labels, predictions, weights)):
1597    # Make sure that predictions are within [0.0, 1.0].
1598    with ops.control_dependencies([
1599        check_ops.assert_greater_equal(
1600            predictions,
1601            math_ops.cast(0.0, dtype=predictions.dtype),
1602            message='predictions must be in [0, 1]'),
1603        check_ops.assert_less_equal(
1604            predictions,
1605            math_ops.cast(1.0, dtype=predictions.dtype),
1606            message='predictions must be in [0, 1]')
1607    ]):
1608      predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
1609          predictions=predictions,
1610          labels=labels,
1611          weights=weights)
1612
1613    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1614
1615    # It's important we aggregate using float64 since we're accumulating a lot
1616    # of 1.0's for the true/false labels, and accumulating to float32 will
1617    # be quite inaccurate even with just a modest amount of values (~20M).
1618    # We use float64 instead of integer primarily since GPU scatter kernel
1619    # only support floats.
1620    agg_dtype = dtypes.float64
1621
1622    f_labels = math_ops.cast(labels, agg_dtype)
1623    weights = math_ops.cast(weights, agg_dtype)
1624    true_labels = f_labels  * weights
1625    false_labels = (1.0 - f_labels) * weights
1626
1627    # Flatten predictions and labels.
1628    predictions = array_ops.reshape(predictions, [-1])
1629    true_labels = array_ops.reshape(true_labels, [-1])
1630    false_labels = array_ops.reshape(false_labels, [-1])
1631
1632    # To compute TP/FP/TN/FN, we are measuring a binary classifier
1633    #   C(t) = (predictions >= t)
1634    # at each threshold 't'. So we have
1635    #   TP(t) = sum( C(t) * true_labels )
1636    #   FP(t) = sum( C(t) * false_labels )
1637    #
1638    # But, computing C(t) requires computation for each t. To make it fast,
1639    # observe that C(t) is a cumulative integral, and so if we have
1640    #   thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
1641    # where n = num_thresholds, and if we can compute the bucket function
1642    #   B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
1643    # then we get
1644    #   C(t_i) = sum( B(j), j >= i )
1645    # which is the reversed cumulative sum in tf.cumsum().
1646    #
1647    # We can compute B(i) efficiently by taking advantage of the fact that
1648    # our thresholds are evenly distributed, in that
1649    #   width = 1.0 / (num_thresholds - 1)
1650    #   thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
1651    # Given a prediction value p, we can map it to its bucket by
1652    #   bucket_index(p) = floor( p * (num_thresholds - 1) )
1653    # so we can use tf.scatter_add() to update the buckets in one pass.
1654    #
1655    # This implementation exhibits a run time and space complexity of O(T + N),
1656    # where T is the number of thresholds and N is the size of predictions.
1657    # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead
1658    # exhibit a complexity of O(T * N).
1659
1660    # Compute the bucket indices for each prediction value.
1661    bucket_indices = math_ops.cast(
1662        math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32)
1663
1664    with ops.name_scope('variables'):
1665      tp_buckets_v = metrics_impl.metric_variable(
1666          [num_thresholds], agg_dtype, name='tp_buckets')
1667      fp_buckets_v = metrics_impl.metric_variable(
1668          [num_thresholds], agg_dtype, name='fp_buckets')
1669
1670    with ops.name_scope('update_op'):
1671      update_tp = state_ops.scatter_add(
1672          tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking)
1673      update_fp = state_ops.scatter_add(
1674          fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking)
1675
1676    # Set up the cumulative sums to compute the actual metrics.
1677    tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp')
1678    fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp')
1679    # fn = sum(true_labels) - tp
1680    #    = sum(tp_buckets) - tp
1681    #    = tp[0] - tp
1682    # Similarly,
1683    # tn = fp[0] - fp
1684    tn = fp[0] - fp
1685    fn = tp[0] - tp
1686
1687    # We use a minimum to prevent division by 0.
1688    epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype)
1689    precision = tp / math_ops.maximum(epsilon, tp + fp)
1690    recall = tp / math_ops.maximum(epsilon, tp + fn)
1691
1692    # Convert all tensors back to predictions' dtype (as per function contract).
1693    out_dtype = predictions.dtype
1694    _convert = lambda tensor: math_ops.cast(tensor, out_dtype)
1695    result = PrecisionRecallData(
1696        tp=_convert(tp),
1697        fp=_convert(fp),
1698        tn=_convert(tn),
1699        fn=_convert(fn),
1700        precision=_convert(precision),
1701        recall=_convert(recall),
1702        thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds)))
1703    update_op = control_flow_ops.group(update_tp, update_fp)
1704    return result, update_op
1705
1706
1707def streaming_specificity_at_sensitivity(predictions,
1708                                         labels,
1709                                         sensitivity,
1710                                         weights=None,
1711                                         num_thresholds=200,
1712                                         metrics_collections=None,
1713                                         updates_collections=None,
1714                                         name=None):
1715  """Computes the specificity at a given sensitivity.
1716
1717  The `streaming_specificity_at_sensitivity` function creates four local
1718  variables, `true_positives`, `true_negatives`, `false_positives` and
1719  `false_negatives` that are used to compute the specificity at the given
1720  sensitivity value. The threshold for the given sensitivity value is computed
1721  and used to evaluate the corresponding specificity.
1722
1723  For estimation of the metric over a stream of data, the function creates an
1724  `update_op` operation that updates these variables and returns the
1725  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
1726  `false_positives` and `false_negatives` counts with the weight of each case
1727  found in the `predictions` and `labels`.
1728
1729  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1730
1731  For additional information about specificity and sensitivity, see the
1732  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1733
1734  Args:
1735    predictions: A floating point `Tensor` of arbitrary shape and whose values
1736      are in the range `[0, 1]`.
1737    labels: A `bool` `Tensor` whose shape matches `predictions`.
1738    sensitivity: A scalar value in range `[0, 1]`.
1739    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1740      must be broadcastable to `labels` (i.e., all dimensions must be either
1741      `1`, or the same as the corresponding `labels` dimension).
1742    num_thresholds: The number of thresholds to use for matching the given
1743      sensitivity.
1744    metrics_collections: An optional list of collections that `specificity`
1745      should be added to.
1746    updates_collections: An optional list of collections that `update_op` should
1747      be added to.
1748    name: An optional variable_scope name.
1749
1750  Returns:
1751    specificity: A scalar `Tensor` representing the specificity at the given
1752      `specificity` value.
1753    update_op: An operation that increments the `true_positives`,
1754      `true_negatives`, `false_positives` and `false_negatives` variables
1755      appropriately and whose value matches `specificity`.
1756
1757  Raises:
1758    ValueError: If `predictions` and `labels` have mismatched shapes, if
1759      `weights` is not `None` and its shape doesn't match `predictions`, or if
1760      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
1761      or `updates_collections` are not a list or tuple.
1762  """
1763  return metrics.specificity_at_sensitivity(
1764      sensitivity=sensitivity,
1765      num_thresholds=num_thresholds,
1766      predictions=predictions,
1767      labels=labels,
1768      weights=weights,
1769      metrics_collections=metrics_collections,
1770      updates_collections=updates_collections,
1771      name=name)
1772
1773
1774def streaming_sensitivity_at_specificity(predictions,
1775                                         labels,
1776                                         specificity,
1777                                         weights=None,
1778                                         num_thresholds=200,
1779                                         metrics_collections=None,
1780                                         updates_collections=None,
1781                                         name=None):
1782  """Computes the sensitivity at a given specificity.
1783
1784  The `streaming_sensitivity_at_specificity` function creates four local
1785  variables, `true_positives`, `true_negatives`, `false_positives` and
1786  `false_negatives` that are used to compute the sensitivity at the given
1787  specificity value. The threshold for the given specificity value is computed
1788  and used to evaluate the corresponding sensitivity.
1789
1790  For estimation of the metric over a stream of data, the function creates an
1791  `update_op` operation that updates these variables and returns the
1792  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
1793  `false_positives` and `false_negatives` counts with the weight of each case
1794  found in the `predictions` and `labels`.
1795
1796  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1797
1798  For additional information about specificity and sensitivity, see the
1799  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1800
1801  Args:
1802    predictions: A floating point `Tensor` of arbitrary shape and whose values
1803      are in the range `[0, 1]`.
1804    labels: A `bool` `Tensor` whose shape matches `predictions`.
1805    specificity: A scalar value in range `[0, 1]`.
1806    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1807      must be broadcastable to `labels` (i.e., all dimensions must be either
1808      `1`, or the same as the corresponding `labels` dimension).
1809    num_thresholds: The number of thresholds to use for matching the given
1810      specificity.
1811    metrics_collections: An optional list of collections that `sensitivity`
1812      should be added to.
1813    updates_collections: An optional list of collections that `update_op` should
1814      be added to.
1815    name: An optional variable_scope name.
1816
1817  Returns:
1818    sensitivity: A scalar `Tensor` representing the sensitivity at the given
1819      `specificity` value.
1820    update_op: An operation that increments the `true_positives`,
1821      `true_negatives`, `false_positives` and `false_negatives` variables
1822      appropriately and whose value matches `sensitivity`.
1823
1824  Raises:
1825    ValueError: If `predictions` and `labels` have mismatched shapes, if
1826      `weights` is not `None` and its shape doesn't match `predictions`, or if
1827      `specificity` is not between 0 and 1, or if either `metrics_collections`
1828      or `updates_collections` are not a list or tuple.
1829  """
1830  return metrics.sensitivity_at_specificity(
1831      specificity=specificity,
1832      num_thresholds=num_thresholds,
1833      predictions=predictions,
1834      labels=labels,
1835      weights=weights,
1836      metrics_collections=metrics_collections,
1837      updates_collections=updates_collections,
1838      name=name)
1839
1840
1841@deprecated(None,
1842            'Please switch to tf.metrics.precision_at_thresholds. Note that '
1843            'the order of the labels and predictions arguments are switched.')
1844def streaming_precision_at_thresholds(predictions,
1845                                      labels,
1846                                      thresholds,
1847                                      weights=None,
1848                                      metrics_collections=None,
1849                                      updates_collections=None,
1850                                      name=None):
1851  """Computes precision values for different `thresholds` on `predictions`.
1852
1853  The `streaming_precision_at_thresholds` function creates four local variables,
1854  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1855  for various values of thresholds. `precision[i]` is defined as the total
1856  weight of values in `predictions` above `thresholds[i]` whose corresponding
1857  entry in `labels` is `True`, divided by the total weight of values in
1858  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
1859  false_positives[i])`).
1860
1861  For estimation of the metric over a stream of data, the function creates an
1862  `update_op` operation that updates these variables and returns the
1863  `precision`.
1864
1865  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1866
1867  Args:
1868    predictions: A floating point `Tensor` of arbitrary shape and whose values
1869      are in the range `[0, 1]`.
1870    labels: A `bool` `Tensor` whose shape matches `predictions`.
1871    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1872    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1873      must be broadcastable to `labels` (i.e., all dimensions must be either
1874      `1`, or the same as the corresponding `labels` dimension).
1875    metrics_collections: An optional list of collections that `precision` should
1876      be added to.
1877    updates_collections: An optional list of collections that `update_op` should
1878      be added to.
1879    name: An optional variable_scope name.
1880
1881  Returns:
1882    precision: A float `Tensor` of shape `[len(thresholds)]`.
1883    update_op: An operation that increments the `true_positives`,
1884      `true_negatives`, `false_positives` and `false_negatives` variables that
1885      are used in the computation of `precision`.
1886
1887  Raises:
1888    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1889      `weights` is not `None` and its shape doesn't match `predictions`, or if
1890      either `metrics_collections` or `updates_collections` are not a list or
1891      tuple.
1892  """
1893  return metrics.precision_at_thresholds(
1894      thresholds=thresholds,
1895      predictions=predictions,
1896      labels=labels,
1897      weights=weights,
1898      metrics_collections=metrics_collections,
1899      updates_collections=updates_collections,
1900      name=name)
1901
1902
1903@deprecated(None,
1904            'Please switch to tf.metrics.recall_at_thresholds. Note that the '
1905            'order of the labels and predictions arguments has been switched.')
1906def streaming_recall_at_thresholds(predictions,
1907                                   labels,
1908                                   thresholds,
1909                                   weights=None,
1910                                   metrics_collections=None,
1911                                   updates_collections=None,
1912                                   name=None):
1913  """Computes various recall values for different `thresholds` on `predictions`.
1914
1915  The `streaming_recall_at_thresholds` function creates four local variables,
1916  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1917  for various values of thresholds. `recall[i]` is defined as the total weight
1918  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1919  `labels` is `True`, divided by the total weight of `True` values in `labels`
1920  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
1921
1922  For estimation of the metric over a stream of data, the function creates an
1923  `update_op` operation that updates these variables and returns the `recall`.
1924
1925  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1926
1927  Args:
1928    predictions: A floating point `Tensor` of arbitrary shape and whose values
1929      are in the range `[0, 1]`.
1930    labels: A `bool` `Tensor` whose shape matches `predictions`.
1931    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1932    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1933      must be broadcastable to `labels` (i.e., all dimensions must be either
1934      `1`, or the same as the corresponding `labels` dimension).
1935    metrics_collections: An optional list of collections that `recall` should be
1936      added to.
1937    updates_collections: An optional list of collections that `update_op` should
1938      be added to.
1939    name: An optional variable_scope name.
1940
1941  Returns:
1942    recall: A float `Tensor` of shape `[len(thresholds)]`.
1943    update_op: An operation that increments the `true_positives`,
1944      `true_negatives`, `false_positives` and `false_negatives` variables that
1945      are used in the computation of `recall`.
1946
1947  Raises:
1948    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1949      `weights` is not `None` and its shape doesn't match `predictions`, or if
1950      either `metrics_collections` or `updates_collections` are not a list or
1951      tuple.
1952  """
1953  return metrics.recall_at_thresholds(
1954      thresholds=thresholds,
1955      predictions=predictions,
1956      labels=labels,
1957      weights=weights,
1958      metrics_collections=metrics_collections,
1959      updates_collections=updates_collections,
1960      name=name)
1961
1962
1963def streaming_false_positive_rate_at_thresholds(predictions,
1964                                                labels,
1965                                                thresholds,
1966                                                weights=None,
1967                                                metrics_collections=None,
1968                                                updates_collections=None,
1969                                                name=None):
1970  """Computes various fpr values for different `thresholds` on `predictions`.
1971
1972  The `streaming_false_positive_rate_at_thresholds` function creates two
1973  local variables, `false_positives`, `true_negatives`, for various values of
1974  thresholds. `false_positive_rate[i]` is defined as the total weight
1975  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1976  `labels` is `False`, divided by the total weight of `False` values in `labels`
1977  (`false_positives[i] / (false_positives[i] + true_negatives[i])`).
1978
1979  For estimation of the metric over a stream of data, the function creates an
1980  `update_op` operation that updates these variables and returns the
1981  `false_positive_rate`.
1982
1983  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1984
1985  Args:
1986    predictions: A floating point `Tensor` of arbitrary shape and whose values
1987      are in the range `[0, 1]`.
1988    labels: A `bool` `Tensor` whose shape matches `predictions`.
1989    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1990    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1991      must be broadcastable to `labels` (i.e., all dimensions must be either
1992      `1`, or the same as the corresponding `labels` dimension).
1993    metrics_collections: An optional list of collections that
1994      `false_positive_rate` should be added to.
1995    updates_collections: An optional list of collections that `update_op` should
1996      be added to.
1997    name: An optional variable_scope name.
1998
1999  Returns:
2000    false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`.
2001    update_op: An operation that increments the `false_positives` and
2002      `true_negatives` variables that are used in the computation of
2003      `false_positive_rate`.
2004
2005  Raises:
2006    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2007      `weights` is not `None` and its shape doesn't match `predictions`, or if
2008      either `metrics_collections` or `updates_collections` are not a list or
2009      tuple.
2010  """
2011  with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds',
2012                                     (predictions, labels, weights)):
2013    values, update_ops = _streaming_confusion_matrix_at_thresholds(
2014        predictions, labels, thresholds, weights, includes=('fp', 'tn'))
2015
2016    # Avoid division by zero.
2017    epsilon = _EPSILON
2018
2019    def compute_fpr(fp, tn, name):
2020      return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name)
2021
2022    fpr = compute_fpr(values['fp'], values['tn'], 'value')
2023    update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op')
2024
2025    if metrics_collections:
2026      ops.add_to_collections(metrics_collections, fpr)
2027
2028    if updates_collections:
2029      ops.add_to_collections(updates_collections, update_op)
2030
2031    return fpr, update_op
2032
2033
2034def streaming_false_negative_rate_at_thresholds(predictions,
2035                                                labels,
2036                                                thresholds,
2037                                                weights=None,
2038                                                metrics_collections=None,
2039                                                updates_collections=None,
2040                                                name=None):
2041  """Computes various fnr values for different `thresholds` on `predictions`.
2042
2043  The `streaming_false_negative_rate_at_thresholds` function creates two
2044  local variables, `false_negatives`, `true_positives`, for various values of
2045  thresholds. `false_negative_rate[i]` is defined as the total weight
2046  of values in `predictions` above `thresholds[i]` whose corresponding entry in
2047  `labels` is `False`, divided by the total weight of `True` values in `labels`
2048  (`false_negatives[i] / (false_negatives[i] + true_positives[i])`).
2049
2050  For estimation of the metric over a stream of data, the function creates an
2051  `update_op` operation that updates these variables and returns the
2052  `false_positive_rate`.
2053
2054  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2055
2056  Args:
2057    predictions: A floating point `Tensor` of arbitrary shape and whose values
2058      are in the range `[0, 1]`.
2059    labels: A `bool` `Tensor` whose shape matches `predictions`.
2060    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2061    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
2062      must be broadcastable to `labels` (i.e., all dimensions must be either
2063      `1`, or the same as the corresponding `labels` dimension).
2064    metrics_collections: An optional list of collections that
2065      `false_negative_rate` should be added to.
2066    updates_collections: An optional list of collections that `update_op` should
2067      be added to.
2068    name: An optional variable_scope name.
2069
2070  Returns:
2071    false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`.
2072    update_op: An operation that increments the `false_negatives` and
2073      `true_positives` variables that are used in the computation of
2074      `false_negative_rate`.
2075
2076  Raises:
2077    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2078      `weights` is not `None` and its shape doesn't match `predictions`, or if
2079      either `metrics_collections` or `updates_collections` are not a list or
2080      tuple.
2081  """
2082  with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds',
2083                                     (predictions, labels, weights)):
2084    values, update_ops = _streaming_confusion_matrix_at_thresholds(
2085        predictions, labels, thresholds, weights, includes=('fn', 'tp'))
2086
2087    # Avoid division by zero.
2088    epsilon = _EPSILON
2089
2090    def compute_fnr(fn, tp, name):
2091      return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name)
2092
2093    fnr = compute_fnr(values['fn'], values['tp'], 'value')
2094    update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op')
2095
2096    if metrics_collections:
2097      ops.add_to_collections(metrics_collections, fnr)
2098
2099    if updates_collections:
2100      ops.add_to_collections(updates_collections, update_op)
2101
2102    return fnr, update_op
2103
2104
2105def _at_k_name(name, k=None, class_id=None):
2106  if k is not None:
2107    name = '%s_at_%d' % (name, k)
2108  else:
2109    name = '%s_at_k' % (name)
2110  if class_id is not None:
2111    name = '%s_class%d' % (name, class_id)
2112  return name
2113
2114
2115@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
2116            'and reshape labels from [batch_size] to [batch_size, 1].')
2117def streaming_recall_at_k(predictions,
2118                          labels,
2119                          k,
2120                          weights=None,
2121                          metrics_collections=None,
2122                          updates_collections=None,
2123                          name=None):
2124  """Computes the recall@k of the predictions with respect to dense labels.
2125
2126  The `streaming_recall_at_k` function creates two local variables, `total` and
2127  `count`, that are used to compute the recall@k frequency. This frequency is
2128  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
2129  divides `total` by `count`.
2130
2131  For estimation of the metric over a stream of data, the function creates an
2132  `update_op` operation that updates these variables and returns the
2133  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
2134  shape [batch_size] whose elements indicate whether or not the corresponding
2135  label is in the top `k` `predictions`. Then `update_op` increments `total`
2136  with the reduced sum of `weights` where `in_top_k` is `True`, and it
2137  increments `count` with the reduced sum of `weights`.
2138
2139  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2140
2141  Args:
2142    predictions: A float `Tensor` of dimension [batch_size, num_classes].
2143    labels: A `Tensor` of dimension [batch_size] whose type is in `int32`,
2144      `int64`.
2145    k: The number of top elements to look at for computing recall.
2146    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
2147      must be broadcastable to `labels` (i.e., all dimensions must be either
2148      `1`, or the same as the corresponding `labels` dimension).
2149    metrics_collections: An optional list of collections that `recall_at_k`
2150      should be added to.
2151    updates_collections: An optional list of collections `update_op` should be
2152      added to.
2153    name: An optional variable_scope name.
2154
2155  Returns:
2156    recall_at_k: A `Tensor` representing the recall@k, the fraction of labels
2157      which fall into the top `k` predictions.
2158    update_op: An operation that increments the `total` and `count` variables
2159      appropriately and whose value matches `recall_at_k`.
2160
2161  Raises:
2162    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2163      `weights` is not `None` and its shape doesn't match `predictions`, or if
2164      either `metrics_collections` or `updates_collections` are not a list or
2165      tuple.
2166  """
2167  in_top_k = math_ops.cast(nn.in_top_k(predictions, labels, k), dtypes.float32)
2168  return streaming_mean(in_top_k, weights, metrics_collections,
2169                        updates_collections, name or _at_k_name('recall', k))
2170
2171
2172# TODO(ptucker): Validate range of values in labels?
2173def streaming_sparse_recall_at_k(predictions,
2174                                 labels,
2175                                 k,
2176                                 class_id=None,
2177                                 weights=None,
2178                                 metrics_collections=None,
2179                                 updates_collections=None,
2180                                 name=None):
2181  """Computes recall@k of the predictions with respect to sparse labels.
2182
2183  If `class_id` is not specified, we'll calculate recall as the ratio of true
2184      positives (i.e., correct predictions, items in the top `k` highest
2185      `predictions` that are found in the corresponding row in `labels`) to
2186      actual positives (the full `labels` row).
2187  If `class_id` is specified, we calculate recall by considering only the rows
2188      in the batch for which `class_id` is in `labels`, and computing the
2189      fraction of them for which `class_id` is in the corresponding row in
2190      `labels`.
2191
2192  `streaming_sparse_recall_at_k` creates two local variables,
2193  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2194  the recall_at_k frequency. This frequency is ultimately returned as
2195  `recall_at_<k>`: an idempotent operation that simply divides
2196  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2197  `false_negative_at_<k>`).
2198
2199  For estimation of the metric over a stream of data, the function creates an
2200  `update_op` operation that updates these variables and returns the
2201  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2202  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2203  `labels` calculate the true positives and false negatives weighted by
2204  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2205  `false_negative_at_<k>` using these values.
2206
2207  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2208
2209  Args:
2210    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2211      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2212      The final dimension contains the logit values for each class. [D1, ... DN]
2213      must match `labels`.
2214    labels: `int64` `Tensor` or `SparseTensor` with shape
2215      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2216      target classes for the associated prediction. Commonly, N=1 and `labels`
2217      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
2218      Values should be in range [0, num_classes), where num_classes is the last
2219      dimension of `predictions`. Values outside this range always count
2220      towards `false_negative_at_<k>`.
2221    k: Integer, k for @k metric.
2222    class_id: Integer class ID for which we want binary metrics. This should be
2223      in range [0, num_classes), where num_classes is the last dimension of
2224      `predictions`. If class_id is outside this range, the method returns NAN.
2225    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2226      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2227      dimensions must be either `1`, or the same as the corresponding `labels`
2228      dimension).
2229    metrics_collections: An optional list of collections that values should
2230      be added to.
2231    updates_collections: An optional list of collections that updates should
2232      be added to.
2233    name: Name of new update operation, and namespace for other dependent ops.
2234
2235  Returns:
2236    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2237      by the sum of `true_positives` and `false_negatives`.
2238    update_op: `Operation` that increments `true_positives` and
2239      `false_negatives` variables appropriately, and whose value matches
2240      `recall`.
2241
2242  Raises:
2243    ValueError: If `weights` is not `None` and its shape doesn't match
2244    `predictions`, or if either `metrics_collections` or `updates_collections`
2245    are not a list or tuple.
2246  """
2247  return metrics.recall_at_k(
2248      k=k,
2249      class_id=class_id,
2250      predictions=predictions,
2251      labels=labels,
2252      weights=weights,
2253      metrics_collections=metrics_collections,
2254      updates_collections=updates_collections,
2255      name=name)
2256
2257
2258# TODO(ptucker): Validate range of values in labels?
2259def streaming_sparse_precision_at_k(predictions,
2260                                    labels,
2261                                    k,
2262                                    class_id=None,
2263                                    weights=None,
2264                                    metrics_collections=None,
2265                                    updates_collections=None,
2266                                    name=None):
2267  """Computes precision@k of the predictions with respect to sparse labels.
2268
2269  If `class_id` is not specified, we calculate precision as the ratio of true
2270      positives (i.e., correct predictions, items in the top `k` highest
2271      `predictions` that are found in the corresponding row in `labels`) to
2272      positives (all top `k` `predictions`).
2273  If `class_id` is specified, we calculate precision by considering only the
2274      rows in the batch for which `class_id` is in the top `k` highest
2275      `predictions`, and computing the fraction of them for which `class_id` is
2276      in the corresponding row in `labels`.
2277
2278  We expect precision to decrease as `k` increases.
2279
2280  `streaming_sparse_precision_at_k` creates two local variables,
2281  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
2282  the precision@k frequency. This frequency is ultimately returned as
2283  `precision_at_<k>`: an idempotent operation that simply divides
2284  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2285  `false_positive_at_<k>`).
2286
2287  For estimation of the metric over a stream of data, the function creates an
2288  `update_op` operation that updates these variables and returns the
2289  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2290  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2291  `labels` calculate the true positives and false positives weighted by
2292  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2293  `false_positive_at_<k>` using these values.
2294
2295  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2296
2297  Args:
2298    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2299      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2300      The final dimension contains the logit values for each class. [D1, ... DN]
2301      must match `labels`.
2302    labels: `int64` `Tensor` or `SparseTensor` with shape
2303      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2304      target classes for the associated prediction. Commonly, N=1 and `labels`
2305      has shape [batch_size, num_labels]. [D1, ... DN] must match
2306      `predictions`. Values should be in range [0, num_classes), where
2307      num_classes is the last dimension of `predictions`. Values outside this
2308      range are ignored.
2309    k: Integer, k for @k metric.
2310    class_id: Integer class ID for which we want binary metrics. This should be
2311      in range [0, num_classes], where num_classes is the last dimension of
2312      `predictions`. If `class_id` is outside this range, the method returns
2313      NAN.
2314    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2315      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2316      dimensions must be either `1`, or the same as the corresponding `labels`
2317      dimension).
2318    metrics_collections: An optional list of collections that values should
2319      be added to.
2320    updates_collections: An optional list of collections that updates should
2321      be added to.
2322    name: Name of new update operation, and namespace for other dependent ops.
2323
2324  Returns:
2325    precision: Scalar `float64` `Tensor` with the value of `true_positives`
2326      divided by the sum of `true_positives` and `false_positives`.
2327    update_op: `Operation` that increments `true_positives` and
2328      `false_positives` variables appropriately, and whose value matches
2329      `precision`.
2330
2331  Raises:
2332    ValueError: If `weights` is not `None` and its shape doesn't match
2333      `predictions`, or if either `metrics_collections` or `updates_collections`
2334      are not a list or tuple.
2335  """
2336  return metrics.precision_at_k(
2337      k=k,
2338      class_id=class_id,
2339      predictions=predictions,
2340      labels=labels,
2341      weights=weights,
2342      metrics_collections=metrics_collections,
2343      updates_collections=updates_collections,
2344      name=name)
2345
2346
2347# TODO(ptucker): Validate range of values in labels?
2348def streaming_sparse_precision_at_top_k(top_k_predictions,
2349                                        labels,
2350                                        class_id=None,
2351                                        weights=None,
2352                                        metrics_collections=None,
2353                                        updates_collections=None,
2354                                        name=None):
2355  """Computes precision@k of top-k predictions with respect to sparse labels.
2356
2357  If `class_id` is not specified, we calculate precision as the ratio of
2358      true positives (i.e., correct predictions, items in `top_k_predictions`
2359      that are found in the corresponding row in `labels`) to positives (all
2360      `top_k_predictions`).
2361  If `class_id` is specified, we calculate precision by considering only the
2362      rows in the batch for which `class_id` is in the top `k` highest
2363      `predictions`, and computing the fraction of them for which `class_id` is
2364      in the corresponding row in `labels`.
2365
2366  We expect precision to decrease as `k` increases.
2367
2368  `streaming_sparse_precision_at_top_k` creates two local variables,
2369  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
2370  the precision@k frequency. This frequency is ultimately returned as
2371  `precision_at_k`: an idempotent operation that simply divides
2372  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
2373
2374  For estimation of the metric over a stream of data, the function creates an
2375  `update_op` operation that updates these variables and returns the
2376  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
2377  and `labels` calculate the true positives and false positives weighted by
2378  `weights`. Then `update_op` increments `true_positive_at_k` and
2379  `false_positive_at_k` using these values.
2380
2381  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2382
2383  Args:
2384    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
2385      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
2386      The final dimension contains the indices of top-k labels. [D1, ... DN]
2387      must match `labels`.
2388    labels: `int64` `Tensor` or `SparseTensor` with shape
2389      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2390      target classes for the associated prediction. Commonly, N=1 and `labels`
2391      has shape [batch_size, num_labels]. [D1, ... DN] must match
2392      `top_k_predictions`. Values should be in range [0, num_classes), where
2393      num_classes is the last dimension of `predictions`. Values outside this
2394      range are ignored.
2395    class_id: Integer class ID for which we want binary metrics. This should be
2396      in range [0, num_classes), where num_classes is the last dimension of
2397      `predictions`. If `class_id` is outside this range, the method returns
2398      NAN.
2399    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2400      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2401      dimensions must be either `1`, or the same as the corresponding `labels`
2402      dimension).
2403    metrics_collections: An optional list of collections that values should
2404      be added to.
2405    updates_collections: An optional list of collections that updates should
2406      be added to.
2407    name: Name of new update operation, and namespace for other dependent ops.
2408
2409  Returns:
2410    precision: Scalar `float64` `Tensor` with the value of `true_positives`
2411      divided by the sum of `true_positives` and `false_positives`.
2412    update_op: `Operation` that increments `true_positives` and
2413      `false_positives` variables appropriately, and whose value matches
2414      `precision`.
2415
2416  Raises:
2417    ValueError: If `weights` is not `None` and its shape doesn't match
2418      `predictions`, or if either `metrics_collections` or `updates_collections`
2419      are not a list or tuple.
2420    ValueError: If `top_k_predictions` has rank < 2.
2421  """
2422  default_name = _at_k_name('precision', class_id=class_id)
2423  with ops.name_scope(name, default_name,
2424                      (top_k_predictions, labels, weights)) as name_scope:
2425    return metrics_impl.precision_at_top_k(
2426        labels=labels,
2427        predictions_idx=top_k_predictions,
2428        class_id=class_id,
2429        weights=weights,
2430        metrics_collections=metrics_collections,
2431        updates_collections=updates_collections,
2432        name=name_scope)
2433
2434
2435def sparse_recall_at_top_k(labels,
2436                           top_k_predictions,
2437                           class_id=None,
2438                           weights=None,
2439                           metrics_collections=None,
2440                           updates_collections=None,
2441                           name=None):
2442  """Computes recall@k of top-k predictions with respect to sparse labels.
2443
2444  If `class_id` is specified, we calculate recall by considering only the
2445      entries in the batch for which `class_id` is in the label, and computing
2446      the fraction of them for which `class_id` is in the top-k `predictions`.
2447  If `class_id` is not specified, we'll calculate recall as how often on
2448      average a class among the labels of a batch entry is in the top-k
2449      `predictions`.
2450
2451  `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
2452  and `false_negative_at_<k>`, that are used to compute the recall_at_k
2453  frequency. This frequency is ultimately returned as `recall_at_<k>`: an
2454  idempotent operation that simply divides `true_positive_at_<k>` by total
2455  (`true_positive_at_<k>` + `false_negative_at_<k>`).
2456
2457  For estimation of the metric over a stream of data, the function creates an
2458  `update_op` operation that updates these variables and returns the
2459  `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
2460  true positives and false negatives weighted by `weights`. Then `update_op`
2461  increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
2462  values.
2463
2464  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2465
2466  Args:
2467    labels: `int64` `Tensor` or `SparseTensor` with shape
2468      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2469      target classes for the associated prediction. Commonly, N=1 and `labels`
2470      has shape [batch_size, num_labels]. [D1, ... DN] must match
2471      `top_k_predictions`. Values should be in range [0, num_classes), where
2472      num_classes is the last dimension of `predictions`. Values outside this
2473      range always count towards `false_negative_at_<k>`.
2474    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
2475      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
2476      The final dimension contains the indices of top-k labels. [D1, ... DN]
2477      must match `labels`.
2478    class_id: Integer class ID for which we want binary metrics. This should be
2479      in range [0, num_classes), where num_classes is the last dimension of
2480      `predictions`. If class_id is outside this range, the method returns NAN.
2481    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2482      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2483      dimensions must be either `1`, or the same as the corresponding `labels`
2484      dimension).
2485    metrics_collections: An optional list of collections that values should
2486      be added to.
2487    updates_collections: An optional list of collections that updates should
2488      be added to.
2489    name: Name of new update operation, and namespace for other dependent ops.
2490
2491  Returns:
2492    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2493      by the sum of `true_positives` and `false_negatives`.
2494    update_op: `Operation` that increments `true_positives` and
2495      `false_negatives` variables appropriately, and whose value matches
2496      `recall`.
2497
2498  Raises:
2499    ValueError: If `weights` is not `None` and its shape doesn't match
2500    `predictions`, or if either `metrics_collections` or `updates_collections`
2501    are not a list or tuple.
2502  """
2503  default_name = _at_k_name('recall', class_id=class_id)
2504  with ops.name_scope(name, default_name,
2505                      (top_k_predictions, labels, weights)) as name_scope:
2506    return metrics_impl.recall_at_top_k(
2507        labels=labels,
2508        predictions_idx=top_k_predictions,
2509        class_id=class_id,
2510        weights=weights,
2511        metrics_collections=metrics_collections,
2512        updates_collections=updates_collections,
2513        name=name_scope)
2514
2515
2516def _compute_recall_at_precision(tp, fp, fn, precision, name,
2517                                 strict_mode=False):
2518  """Helper function to compute recall at a given `precision`.
2519
2520  Args:
2521    tp: The number of true positives.
2522    fp: The number of false positives.
2523    fn: The number of false negatives.
2524    precision: The precision for which the recall will be calculated.
2525    name: An optional variable_scope name.
2526    strict_mode: If true and there exists a threshold where the precision is
2527      no smaller than the target precision, return the corresponding recall at
2528      the threshold. Otherwise, return 0. If false, find the threshold where the
2529      precision is closest to the target precision and return the recall at the
2530      threshold.
2531
2532  Returns:
2533    The recall at a given `precision`.
2534  """
2535  precisions = math_ops.div(tp, tp + fp + _EPSILON)
2536  if not strict_mode:
2537    tf_index = math_ops.argmin(
2538        math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
2539    # Now, we have the implicit threshold, so compute the recall:
2540    return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
2541                        name)
2542  else:
2543    # We aim to find the threshold where the precision is minimum but no smaller
2544    # than the target precision.
2545    # The rationale:
2546    # 1. Compute the difference between precisions (by different thresholds) and
2547    #   the target precision.
2548    # 2. Take the reciprocal of the values by the above step. The intention is
2549    #   to make the positive values rank before negative values and also the
2550    #   smaller positives rank before larger positives.
2551    tf_index = math_ops.argmax(
2552        math_ops.div(1.0, precisions - precision + _EPSILON),
2553        0,
2554        output_type=dtypes.int32)
2555
2556    def _return_good_recall():
2557      return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
2558                          name)
2559
2560    return control_flow_ops.cond(precisions[tf_index] >= precision,
2561                                 _return_good_recall, lambda: .0)
2562
2563
2564def recall_at_precision(labels,
2565                        predictions,
2566                        precision,
2567                        weights=None,
2568                        num_thresholds=200,
2569                        metrics_collections=None,
2570                        updates_collections=None,
2571                        name=None,
2572                        strict_mode=False):
2573  """Computes `recall` at `precision`.
2574
2575  The `recall_at_precision` function creates four local variables,
2576  `tp` (true positives), `fp` (false positives) and `fn` (false negatives)
2577  that are used to compute the `recall` at the given `precision` value. The
2578  threshold for the given `precision` value is computed and used to evaluate the
2579  corresponding `recall`.
2580
2581  For estimation of the metric over a stream of data, the function creates an
2582  `update_op` operation that updates these variables and returns the
2583  `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the
2584  weight of each case found in the `predictions` and `labels`.
2585
2586  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2587
2588  Args:
2589    labels: The ground truth values, a `Tensor` whose dimensions must match
2590      `predictions`. Will be cast to `bool`.
2591    predictions: A floating point `Tensor` of arbitrary shape and whose values
2592      are in the range `[0, 1]`.
2593    precision: A scalar value in range `[0, 1]`.
2594    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2595      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2596      be either `1`, or the same as the corresponding `labels` dimension).
2597    num_thresholds: The number of thresholds to use for matching the given
2598      `precision`.
2599    metrics_collections: An optional list of collections that `recall`
2600      should be added to.
2601    updates_collections: An optional list of collections that `update_op` should
2602      be added to.
2603    name: An optional variable_scope name.
2604    strict_mode: If true and there exists a threshold where the precision is
2605      above the target precision, return the corresponding recall at the
2606      threshold. Otherwise, return 0. If false, find the threshold where the
2607      precision is closest to the target precision and return the recall at the
2608      threshold.
2609
2610  Returns:
2611    recall: A scalar `Tensor` representing the recall at the given
2612      `precision` value.
2613    update_op: An operation that increments the `tp`, `fp` and `fn`
2614      variables appropriately and whose value matches `recall`.
2615
2616  Raises:
2617    ValueError: If `predictions` and `labels` have mismatched shapes, if
2618      `weights` is not `None` and its shape doesn't match `predictions`, or if
2619      `precision` is not between 0 and 1, or if either `metrics_collections`
2620      or `updates_collections` are not a list or tuple.
2621
2622  """
2623  if not 0 <= precision <= 1:
2624    raise ValueError('`precision` must be in the range [0, 1].')
2625
2626  with variable_scope.variable_scope(name, 'recall_at_precision',
2627                                     (predictions, labels, weights)):
2628    thresholds = [
2629        i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1)
2630    ]
2631    thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON]
2632
2633    values, update_ops = _streaming_confusion_matrix_at_thresholds(
2634        predictions, labels, thresholds, weights)
2635
2636    recall = _compute_recall_at_precision(values['tp'], values['fp'],
2637                                          values['fn'], precision, 'value',
2638                                          strict_mode)
2639    update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
2640                                             update_ops['fn'], precision,
2641                                             'update_op', strict_mode)
2642
2643    if metrics_collections:
2644      ops.add_to_collections(metrics_collections, recall)
2645
2646    if updates_collections:
2647      ops.add_to_collections(updates_collections, update_op)
2648
2649    return recall, update_op
2650
2651
2652def precision_at_recall(labels,
2653                        predictions,
2654                        target_recall,
2655                        weights=None,
2656                        num_thresholds=200,
2657                        metrics_collections=None,
2658                        updates_collections=None,
2659                        name=None):
2660  """Computes the precision at a given recall.
2661
2662  This function creates variables to track the true positives, false positives,
2663  true negatives, and false negatives at a set of thresholds. Among those
2664  thresholds where recall is at least `target_recall`, precision is computed
2665  at the threshold where recall is closest to `target_recall`.
2666
2667  For estimation of the metric over a stream of data, the function creates an
2668  `update_op` operation that updates these variables and returns the
2669  precision at `target_recall`. `update_op` increments the counts of true
2670  positives, false positives, true negatives, and false negatives with the
2671  weight of each case found in the `predictions` and `labels`.
2672
2673  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2674
2675  For additional information about precision and recall, see
2676  http://en.wikipedia.org/wiki/Precision_and_recall
2677
2678  Args:
2679    labels: The ground truth values, a `Tensor` whose dimensions must match
2680      `predictions`. Will be cast to `bool`.
2681    predictions: A floating point `Tensor` of arbitrary shape and whose values
2682      are in the range `[0, 1]`.
2683    target_recall: A scalar value in range `[0, 1]`.
2684    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2685      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2686      be either `1`, or the same as the corresponding `labels` dimension).
2687    num_thresholds: The number of thresholds to use for matching the given
2688      recall.
2689    metrics_collections: An optional list of collections to which `precision`
2690      should be added.
2691    updates_collections: An optional list of collections to which `update_op`
2692      should be added.
2693    name: An optional variable_scope name.
2694
2695  Returns:
2696    precision: A scalar `Tensor` representing the precision at the given
2697      `target_recall` value.
2698    update_op: An operation that increments the variables for tracking the
2699      true positives, false positives, true negatives, and false negatives and
2700      whose value matches `precision`.
2701
2702  Raises:
2703    ValueError: If `predictions` and `labels` have mismatched shapes, if
2704      `weights` is not `None` and its shape doesn't match `predictions`, or if
2705      `target_recall` is not between 0 and 1, or if either `metrics_collections`
2706      or `updates_collections` are not a list or tuple.
2707    RuntimeError: If eager execution is enabled.
2708  """
2709  if context.executing_eagerly():
2710    raise RuntimeError('tf.metrics.precision_at_recall is not '
2711                       'supported when eager execution is enabled.')
2712
2713  if target_recall < 0 or target_recall > 1:
2714    raise ValueError('`target_recall` must be in the range [0, 1].')
2715
2716  with variable_scope.variable_scope(name, 'precision_at_recall',
2717                                     (predictions, labels, weights)):
2718    kepsilon = 1e-7  # Used to avoid division by zero.
2719    thresholds = [
2720        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
2721    ]
2722    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
2723
2724    values, update_ops = _streaming_confusion_matrix_at_thresholds(
2725        predictions, labels, thresholds, weights)
2726
2727    def compute_precision_at_recall(tp, fp, fn, name):
2728      """Computes the precision at a given recall.
2729
2730      Args:
2731        tp: True positives.
2732        fp: False positives.
2733        fn: False negatives.
2734        name: A name for the operation.
2735
2736      Returns:
2737        The precision at the desired recall.
2738      """
2739      recalls = math_ops.div(tp, tp + fn + kepsilon)
2740
2741      # Because recall is monotone decreasing as a function of the threshold,
2742      # the smallest recall exceeding target_recall occurs at the largest
2743      # threshold where recall >= target_recall.
2744      admissible_recalls = math_ops.cast(
2745          math_ops.greater_equal(recalls, target_recall), dtypes.int64)
2746      tf_index = math_ops.reduce_sum(admissible_recalls) - 1
2747
2748      # Now we have the threshold at which to compute precision:
2749      return math_ops.div(tp[tf_index] + kepsilon,
2750                          tp[tf_index] + fp[tf_index] + kepsilon,
2751                          name)
2752
2753    precision_value = compute_precision_at_recall(
2754        values['tp'], values['fp'], values['fn'], 'value')
2755    update_op = compute_precision_at_recall(
2756        update_ops['tp'], update_ops['fp'], update_ops['fn'], 'update_op')
2757
2758    if metrics_collections:
2759      ops.add_to_collections(metrics_collections, precision_value)
2760
2761    if updates_collections:
2762      ops.add_to_collections(updates_collections, update_op)
2763
2764    return precision_value, update_op
2765
2766
2767def streaming_sparse_average_precision_at_k(predictions,
2768                                            labels,
2769                                            k,
2770                                            weights=None,
2771                                            metrics_collections=None,
2772                                            updates_collections=None,
2773                                            name=None):
2774  """Computes average precision@k of predictions with respect to sparse labels.
2775
2776  See `sparse_average_precision_at_k` for details on formula. `weights` are
2777  applied to the result of `sparse_average_precision_at_k`
2778
2779  `streaming_sparse_average_precision_at_k` creates two local variables,
2780  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
2781  are used to compute the frequency. This frequency is ultimately returned as
2782  `average_precision_at_<k>`: an idempotent operation that simply divides
2783  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
2784
2785  For estimation of the metric over a stream of data, the function creates an
2786  `update_op` operation that updates these variables and returns the
2787  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2788  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2789  `labels` calculate the true positives and false positives weighted by
2790  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2791  `false_positive_at_<k>` using these values.
2792
2793  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2794
2795  Args:
2796    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2797      N >= 1. Commonly, N=1 and `predictions` has shape
2798      [batch size, num_classes]. The final dimension contains the logit values
2799      for each class. [D1, ... DN] must match `labels`.
2800    labels: `int64` `Tensor` or `SparseTensor` with shape
2801      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2802      target classes for the associated prediction. Commonly, N=1 and `labels`
2803      has shape [batch_size, num_labels]. [D1, ... DN] must match
2804      `predictions_`. Values should be in range [0, num_classes), where
2805      num_classes is the last dimension of `predictions`. Values outside this
2806      range are ignored.
2807    k: Integer, k for @k metric. This will calculate an average precision for
2808      range `[1,k]`, as documented above.
2809    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2810      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2811      dimensions must be either `1`, or the same as the corresponding `labels`
2812      dimension).
2813    metrics_collections: An optional list of collections that values should
2814      be added to.
2815    updates_collections: An optional list of collections that updates should
2816      be added to.
2817    name: Name of new update operation, and namespace for other dependent ops.
2818
2819  Returns:
2820    mean_average_precision: Scalar `float64` `Tensor` with the mean average
2821      precision values.
2822    update: `Operation` that increments variables appropriately, and whose
2823      value matches `metric`.
2824  """
2825  return metrics.average_precision_at_k(
2826      k=k,
2827      predictions=predictions,
2828      labels=labels,
2829      weights=weights,
2830      metrics_collections=metrics_collections,
2831      updates_collections=updates_collections,
2832      name=name)
2833
2834
2835def streaming_sparse_average_precision_at_top_k(top_k_predictions,
2836                                                labels,
2837                                                weights=None,
2838                                                metrics_collections=None,
2839                                                updates_collections=None,
2840                                                name=None):
2841  """Computes average precision@k of predictions with respect to sparse labels.
2842
2843  `streaming_sparse_average_precision_at_top_k` creates two local variables,
2844  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
2845  are used to compute the frequency. This frequency is ultimately returned as
2846  `average_precision_at_<k>`: an idempotent operation that simply divides
2847  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
2848
2849  For estimation of the metric over a stream of data, the function creates an
2850  `update_op` operation that updates these variables and returns the
2851  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
2852  the true positives and false positives weighted by `weights`. Then `update_op`
2853  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
2854  values.
2855
2856  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2857
2858  Args:
2859    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2860      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
2861      dimension must be set and contains the top `k` predicted class indices.
2862      [D1, ... DN] must match `labels`. Values should be in range
2863      [0, num_classes).
2864    labels: `int64` `Tensor` or `SparseTensor` with shape
2865      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2866      num_labels=1. N >= 1 and num_labels is the number of target classes for
2867      the associated prediction. Commonly, N=1 and `labels` has shape
2868      [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`.
2869      Values should be in range [0, num_classes).
2870    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2871      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2872      dimensions must be either `1`, or the same as the corresponding `labels`
2873      dimension).
2874    metrics_collections: An optional list of collections that values should
2875      be added to.
2876    updates_collections: An optional list of collections that updates should
2877      be added to.
2878    name: Name of new update operation, and namespace for other dependent ops.
2879
2880  Returns:
2881    mean_average_precision: Scalar `float64` `Tensor` with the mean average
2882      precision values.
2883    update: `Operation` that increments variables appropriately, and whose
2884      value matches `metric`.
2885
2886  Raises:
2887    ValueError: if the last dimension of top_k_predictions is not set.
2888  """
2889  return metrics_impl._streaming_sparse_average_precision_at_top_k(  # pylint: disable=protected-access
2890      predictions_idx=top_k_predictions,
2891      labels=labels,
2892      weights=weights,
2893      metrics_collections=metrics_collections,
2894      updates_collections=updates_collections,
2895      name=name)
2896
2897
2898@deprecated(None,
2899            'Please switch to tf.metrics.mean_absolute_error. Note that the '
2900            'order of the labels and predictions arguments has been switched.')
2901def streaming_mean_absolute_error(predictions,
2902                                  labels,
2903                                  weights=None,
2904                                  metrics_collections=None,
2905                                  updates_collections=None,
2906                                  name=None):
2907  """Computes the mean absolute error between the labels and predictions.
2908
2909  The `streaming_mean_absolute_error` function creates two local variables,
2910  `total` and `count` that are used to compute the mean absolute error. This
2911  average is weighted by `weights`, and it is ultimately returned as
2912  `mean_absolute_error`: an idempotent operation that simply divides `total` by
2913  `count`.
2914
2915  For estimation of the metric over a stream of data, the function creates an
2916  `update_op` operation that updates these variables and returns the
2917  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
2918  absolute value of the differences between `predictions` and `labels`. Then
2919  `update_op` increments `total` with the reduced sum of the product of
2920  `weights` and `absolute_errors`, and it increments `count` with the reduced
2921  sum of `weights`
2922
2923  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2924
2925  Args:
2926    predictions: A `Tensor` of arbitrary shape.
2927    labels: A `Tensor` of the same shape as `predictions`.
2928    weights: Optional `Tensor` indicating the frequency with which an example is
2929      sampled. Rank must be 0, or the same rank as `labels`, and must be
2930      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2931      the same as the corresponding `labels` dimension).
2932    metrics_collections: An optional list of collections that
2933      `mean_absolute_error` should be added to.
2934    updates_collections: An optional list of collections that `update_op` should
2935      be added to.
2936    name: An optional variable_scope name.
2937
2938  Returns:
2939    mean_absolute_error: A `Tensor` representing the current mean, the value of
2940      `total` divided by `count`.
2941    update_op: An operation that increments the `total` and `count` variables
2942      appropriately and whose value matches `mean_absolute_error`.
2943
2944  Raises:
2945    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2946      `weights` is not `None` and its shape doesn't match `predictions`, or if
2947      either `metrics_collections` or `updates_collections` are not a list or
2948      tuple.
2949  """
2950  return metrics.mean_absolute_error(
2951      predictions=predictions,
2952      labels=labels,
2953      weights=weights,
2954      metrics_collections=metrics_collections,
2955      updates_collections=updates_collections,
2956      name=name)
2957
2958
2959def streaming_mean_relative_error(predictions,
2960                                  labels,
2961                                  normalizer,
2962                                  weights=None,
2963                                  metrics_collections=None,
2964                                  updates_collections=None,
2965                                  name=None):
2966  """Computes the mean relative error by normalizing with the given values.
2967
2968  The `streaming_mean_relative_error` function creates two local variables,
2969  `total` and `count` that are used to compute the mean relative absolute error.
2970  This average is weighted by `weights`, and it is ultimately returned as
2971  `mean_relative_error`: an idempotent operation that simply divides `total` by
2972  `count`.
2973
2974  For estimation of the metric over a stream of data, the function creates an
2975  `update_op` operation that updates these variables and returns the
2976  `mean_reative_error`. Internally, a `relative_errors` operation divides the
2977  absolute value of the differences between `predictions` and `labels` by the
2978  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
2979  product of `weights` and `relative_errors`, and it increments `count` with the
2980  reduced sum of `weights`.
2981
2982  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2983
2984  Args:
2985    predictions: A `Tensor` of arbitrary shape.
2986    labels: A `Tensor` of the same shape as `predictions`.
2987    normalizer: A `Tensor` of the same shape as `predictions`.
2988    weights: Optional `Tensor` indicating the frequency with which an example is
2989      sampled. Rank must be 0, or the same rank as `labels`, and must be
2990      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2991      the same as the corresponding `labels` dimension).
2992    metrics_collections: An optional list of collections that
2993      `mean_relative_error` should be added to.
2994    updates_collections: An optional list of collections that `update_op` should
2995      be added to.
2996    name: An optional variable_scope name.
2997
2998  Returns:
2999    mean_relative_error: A `Tensor` representing the current mean, the value of
3000      `total` divided by `count`.
3001    update_op: An operation that increments the `total` and `count` variables
3002      appropriately and whose value matches `mean_relative_error`.
3003
3004  Raises:
3005    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3006      `weights` is not `None` and its shape doesn't match `predictions`, or if
3007      either `metrics_collections` or `updates_collections` are not a list or
3008      tuple.
3009  """
3010  return metrics.mean_relative_error(
3011      normalizer=normalizer,
3012      predictions=predictions,
3013      labels=labels,
3014      weights=weights,
3015      metrics_collections=metrics_collections,
3016      updates_collections=updates_collections,
3017      name=name)
3018
3019@deprecated(None,
3020            'Please switch to tf.metrics.mean_squared_error. Note that the '
3021            'order of the labels and predictions arguments has been switched.')
3022def streaming_mean_squared_error(predictions,
3023                                 labels,
3024                                 weights=None,
3025                                 metrics_collections=None,
3026                                 updates_collections=None,
3027                                 name=None):
3028  """Computes the mean squared error between the labels and predictions.
3029
3030  The `streaming_mean_squared_error` function creates two local variables,
3031  `total` and `count` that are used to compute the mean squared error.
3032  This average is weighted by `weights`, and it is ultimately returned as
3033  `mean_squared_error`: an idempotent operation that simply divides `total` by
3034  `count`.
3035
3036  For estimation of the metric over a stream of data, the function creates an
3037  `update_op` operation that updates these variables and returns the
3038  `mean_squared_error`. Internally, a `squared_error` operation computes the
3039  element-wise square of the difference between `predictions` and `labels`. Then
3040  `update_op` increments `total` with the reduced sum of the product of
3041  `weights` and `squared_error`, and it increments `count` with the reduced sum
3042  of `weights`.
3043
3044  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3045
3046  Args:
3047    predictions: A `Tensor` of arbitrary shape.
3048    labels: A `Tensor` of the same shape as `predictions`.
3049    weights: Optional `Tensor` indicating the frequency with which an example is
3050      sampled. Rank must be 0, or the same rank as `labels`, and must be
3051      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
3052      the same as the corresponding `labels` dimension).
3053    metrics_collections: An optional list of collections that
3054      `mean_squared_error` should be added to.
3055    updates_collections: An optional list of collections that `update_op` should
3056      be added to.
3057    name: An optional variable_scope name.
3058
3059  Returns:
3060    mean_squared_error: A `Tensor` representing the current mean, the value of
3061      `total` divided by `count`.
3062    update_op: An operation that increments the `total` and `count` variables
3063      appropriately and whose value matches `mean_squared_error`.
3064
3065  Raises:
3066    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3067      `weights` is not `None` and its shape doesn't match `predictions`, or if
3068      either `metrics_collections` or `updates_collections` are not a list or
3069      tuple.
3070  """
3071  return metrics.mean_squared_error(
3072      predictions=predictions,
3073      labels=labels,
3074      weights=weights,
3075      metrics_collections=metrics_collections,
3076      updates_collections=updates_collections,
3077      name=name)
3078
3079@deprecated(
3080    None,
3081    'Please switch to tf.metrics.root_mean_squared_error. Note that the '
3082    'order of the labels and predictions arguments has been switched.')
3083def streaming_root_mean_squared_error(predictions,
3084                                      labels,
3085                                      weights=None,
3086                                      metrics_collections=None,
3087                                      updates_collections=None,
3088                                      name=None):
3089  """Computes the root mean squared error between the labels and predictions.
3090
3091  The `streaming_root_mean_squared_error` function creates two local variables,
3092  `total` and `count` that are used to compute the root mean squared error.
3093  This average is weighted by `weights`, and it is ultimately returned as
3094  `root_mean_squared_error`: an idempotent operation that takes the square root
3095  of the division of `total` by `count`.
3096
3097  For estimation of the metric over a stream of data, the function creates an
3098  `update_op` operation that updates these variables and returns the
3099  `root_mean_squared_error`. Internally, a `squared_error` operation computes
3100  the element-wise square of the difference between `predictions` and `labels`.
3101  Then `update_op` increments `total` with the reduced sum of the product of
3102  `weights` and `squared_error`, and it increments `count` with the reduced sum
3103  of `weights`.
3104
3105  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3106
3107  Args:
3108    predictions: A `Tensor` of arbitrary shape.
3109    labels: A `Tensor` of the same shape as `predictions`.
3110    weights: Optional `Tensor` indicating the frequency with which an example is
3111      sampled. Rank must be 0, or the same rank as `labels`, and must be
3112      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
3113      the same as the corresponding `labels` dimension).
3114    metrics_collections: An optional list of collections that
3115      `root_mean_squared_error` should be added to.
3116    updates_collections: An optional list of collections that `update_op` should
3117      be added to.
3118    name: An optional variable_scope name.
3119
3120  Returns:
3121    root_mean_squared_error: A `Tensor` representing the current mean, the value
3122      of `total` divided by `count`.
3123    update_op: An operation that increments the `total` and `count` variables
3124      appropriately and whose value matches `root_mean_squared_error`.
3125
3126  Raises:
3127    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3128      `weights` is not `None` and its shape doesn't match `predictions`, or if
3129      either `metrics_collections` or `updates_collections` are not a list or
3130      tuple.
3131  """
3132  return metrics.root_mean_squared_error(
3133      predictions=predictions,
3134      labels=labels,
3135      weights=weights,
3136      metrics_collections=metrics_collections,
3137      updates_collections=updates_collections,
3138      name=name)
3139
3140
3141def streaming_covariance(predictions,
3142                         labels,
3143                         weights=None,
3144                         metrics_collections=None,
3145                         updates_collections=None,
3146                         name=None):
3147  """Computes the unbiased sample covariance between `predictions` and `labels`.
3148
3149  The `streaming_covariance` function creates four local variables,
3150  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
3151  compute the sample covariance between predictions and labels across multiple
3152  batches of data. The covariance is ultimately returned as an idempotent
3153  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
3154  in order to get an unbiased estimate.
3155
3156  The algorithm used for this online computation is described in
3157  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
3158  Specifically, the formula used to combine two sample comoments is
3159  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
3160  The comoment for a single batch of data is simply
3161  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
3162
3163  If `weights` is not None, then it is used to compute weighted comoments,
3164  means, and count. NOTE: these weights are treated as "frequency weights", as
3165  opposed to "reliability weights". See discussion of the difference on
3166  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
3167
3168  To facilitate the computation of covariance across multiple batches of data,
3169  the function creates an `update_op` operation, which updates underlying
3170  variables and returns the updated covariance.
3171
3172  Args:
3173    predictions: A `Tensor` of arbitrary size.
3174    labels: A `Tensor` of the same size as `predictions`.
3175    weights: Optional `Tensor` indicating the frequency with which an example is
3176      sampled. Rank must be 0, or the same rank as `labels`, and must be
3177      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
3178      the same as the corresponding `labels` dimension).
3179    metrics_collections: An optional list of collections that the metric
3180      value variable should be added to.
3181    updates_collections: An optional list of collections that the metric update
3182      ops should be added to.
3183    name: An optional variable_scope name.
3184
3185  Returns:
3186    covariance: A `Tensor` representing the current unbiased sample covariance,
3187      `comoment` / (`count` - 1).
3188    update_op: An operation that updates the local variables appropriately.
3189
3190  Raises:
3191    ValueError: If labels and predictions are of different sizes or if either
3192      `metrics_collections` or `updates_collections` are not a list or tuple.
3193  """
3194  with variable_scope.variable_scope(name, 'covariance',
3195                                     (predictions, labels, weights)):
3196    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3197        predictions, labels, weights)
3198    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3199    count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
3200    mean_prediction = metrics_impl.metric_variable(
3201        [], dtypes.float32, name='mean_prediction')
3202    mean_label = metrics_impl.metric_variable(
3203        [], dtypes.float32, name='mean_label')
3204    comoment = metrics_impl.metric_variable(  # C_A in update equation
3205        [], dtypes.float32, name='comoment')
3206
3207    if weights is None:
3208      batch_count = math_ops.cast(
3209          array_ops.size(labels), dtypes.float32)  # n_B in eqn
3210      weighted_predictions = predictions
3211      weighted_labels = labels
3212    else:
3213      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
3214      batch_count = math_ops.reduce_sum(weights)  # n_B in eqn
3215      weighted_predictions = math_ops.multiply(predictions, weights)
3216      weighted_labels = math_ops.multiply(labels, weights)
3217
3218    update_count = state_ops.assign_add(count_, batch_count)  # n_AB in eqn
3219    prev_count = update_count - batch_count  # n_A in update equation
3220
3221    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
3222    # batch_mean_prediction is E[x_B] in the update equation
3223    batch_mean_prediction = math_ops.div_no_nan(
3224        math_ops.reduce_sum(weighted_predictions), batch_count)
3225    delta_mean_prediction = math_ops.div_no_nan(
3226        (batch_mean_prediction - mean_prediction) * batch_count, update_count)
3227    update_mean_prediction = state_ops.assign_add(mean_prediction,
3228                                                  delta_mean_prediction)
3229    # prev_mean_prediction is E[x_A] in the update equation
3230    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
3231
3232    # batch_mean_label is E[y_B] in the update equation
3233    batch_mean_label = math_ops.div_no_nan(
3234        math_ops.reduce_sum(weighted_labels), batch_count)
3235    delta_mean_label = math_ops.div_no_nan(
3236        (batch_mean_label - mean_label) * batch_count, update_count)
3237    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
3238    # prev_mean_label is E[y_A] in the update equation
3239    prev_mean_label = update_mean_label - delta_mean_label
3240
3241    unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) *
3242                                    (labels - batch_mean_label))
3243    # batch_comoment is C_B in the update equation
3244    if weights is None:
3245      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
3246    else:
3247      batch_comoment = math_ops.reduce_sum(
3248          unweighted_batch_coresiduals * weights)
3249
3250    # View delta_comoment as = C_AB - C_A in the update equation above.
3251    # Since C_A is stored in a var, by how much do we need to increment that var
3252    # to make the var = C_AB?
3253    delta_comoment = (
3254        batch_comoment + (prev_mean_prediction - batch_mean_prediction) *
3255        (prev_mean_label - batch_mean_label) *
3256        (prev_count * batch_count / update_count))
3257    update_comoment = state_ops.assign_add(comoment, delta_comoment)
3258
3259    covariance = array_ops.where(
3260        math_ops.less_equal(count_, 1.),
3261        float('nan'),
3262        math_ops.truediv(comoment, count_ - 1),
3263        name='covariance')
3264    with ops.control_dependencies([update_comoment]):
3265      update_op = array_ops.where(
3266          math_ops.less_equal(count_, 1.),
3267          float('nan'),
3268          math_ops.truediv(comoment, count_ - 1),
3269          name='update_op')
3270
3271  if metrics_collections:
3272    ops.add_to_collections(metrics_collections, covariance)
3273
3274  if updates_collections:
3275    ops.add_to_collections(updates_collections, update_op)
3276
3277  return covariance, update_op
3278
3279
3280def streaming_pearson_correlation(predictions,
3281                                  labels,
3282                                  weights=None,
3283                                  metrics_collections=None,
3284                                  updates_collections=None,
3285                                  name=None):
3286  """Computes Pearson correlation coefficient between `predictions`, `labels`.
3287
3288  The `streaming_pearson_correlation` function delegates to
3289  `streaming_covariance` the tracking of three [co]variances:
3290
3291  - `streaming_covariance(predictions, labels)`, i.e. covariance
3292  - `streaming_covariance(predictions, predictions)`, i.e. variance
3293  - `streaming_covariance(labels, labels)`, i.e. variance
3294
3295  The product-moment correlation ultimately returned is an idempotent operation
3296  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
3297  facilitate correlation computation across multiple batches, the function
3298  groups the `update_op`s of the underlying streaming_covariance and returns an
3299  `update_op`.
3300
3301  If `weights` is not None, then it is used to compute a weighted correlation.
3302  NOTE: these weights are treated as "frequency weights", as opposed to
3303  "reliability weights". See discussion of the difference on
3304  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
3305
3306  Args:
3307    predictions: A `Tensor` of arbitrary size.
3308    labels: A `Tensor` of the same size as predictions.
3309    weights: Optional `Tensor` indicating the frequency with which an example is
3310      sampled. Rank must be 0, or the same rank as `labels`, and must be
3311      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
3312      the same as the corresponding `labels` dimension).
3313    metrics_collections: An optional list of collections that the metric
3314      value variable should be added to.
3315    updates_collections: An optional list of collections that the metric update
3316      ops should be added to.
3317    name: An optional variable_scope name.
3318
3319  Returns:
3320    pearson_r: A `Tensor` representing the current Pearson product-moment
3321      correlation coefficient, the value of
3322      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
3323    update_op: An operation that updates the underlying variables appropriately.
3324
3325  Raises:
3326    ValueError: If `labels` and `predictions` are of different sizes, or if
3327      `weights` is the wrong size, or if either `metrics_collections` or
3328      `updates_collections` are not a `list` or `tuple`.
3329  """
3330  with variable_scope.variable_scope(name, 'pearson_r',
3331                                     (predictions, labels, weights)):
3332    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3333        predictions, labels, weights)
3334    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3335    # Broadcast weights here to avoid duplicate broadcasting in each call to
3336    # `streaming_covariance`.
3337    if weights is not None:
3338      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
3339    cov, update_cov = streaming_covariance(
3340        predictions, labels, weights=weights, name='covariance')
3341    var_predictions, update_var_predictions = streaming_covariance(
3342        predictions, predictions, weights=weights, name='variance_predictions')
3343    var_labels, update_var_labels = streaming_covariance(
3344        labels, labels, weights=weights, name='variance_labels')
3345
3346    pearson_r = math_ops.truediv(
3347        cov,
3348        math_ops.multiply(
3349            math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)),
3350        name='pearson_r')
3351    update_op = math_ops.truediv(
3352        update_cov,
3353        math_ops.multiply(
3354            math_ops.sqrt(update_var_predictions),
3355            math_ops.sqrt(update_var_labels)),
3356        name='update_op')
3357
3358  if metrics_collections:
3359    ops.add_to_collections(metrics_collections, pearson_r)
3360
3361  if updates_collections:
3362    ops.add_to_collections(updates_collections, update_op)
3363
3364  return pearson_r, update_op
3365
3366
3367# TODO(nsilberman): add a 'normalized' flag so that the user can request
3368# normalization if the inputs are not normalized.
3369def streaming_mean_cosine_distance(predictions,
3370                                   labels,
3371                                   dim,
3372                                   weights=None,
3373                                   metrics_collections=None,
3374                                   updates_collections=None,
3375                                   name=None):
3376  """Computes the cosine distance between the labels and predictions.
3377
3378  The `streaming_mean_cosine_distance` function creates two local variables,
3379  `total` and `count` that are used to compute the average cosine distance
3380  between `predictions` and `labels`. This average is weighted by `weights`,
3381  and it is ultimately returned as `mean_distance`, which is an idempotent
3382  operation that simply divides `total` by `count`.
3383
3384  For estimation of the metric over a stream of data, the function creates an
3385  `update_op` operation that updates these variables and returns the
3386  `mean_distance`.
3387
3388  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3389
3390  Args:
3391    predictions: A `Tensor` of the same shape as `labels`.
3392    labels: A `Tensor` of arbitrary shape.
3393    dim: The dimension along which the cosine distance is computed.
3394    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
3395      and whose dimension `dim` is 1.
3396    metrics_collections: An optional list of collections that the metric
3397      value variable should be added to.
3398    updates_collections: An optional list of collections that the metric update
3399      ops should be added to.
3400    name: An optional variable_scope name.
3401
3402  Returns:
3403    mean_distance: A `Tensor` representing the current mean, the value of
3404      `total` divided by `count`.
3405    update_op: An operation that increments the `total` and `count` variables
3406      appropriately.
3407
3408  Raises:
3409    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3410      `weights` is not `None` and its shape doesn't match `predictions`, or if
3411      either `metrics_collections` or `updates_collections` are not a list or
3412      tuple.
3413  """
3414  predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3415      predictions, labels, weights)
3416  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3417  radial_diffs = math_ops.multiply(predictions, labels)
3418  radial_diffs = math_ops.reduce_sum(
3419      radial_diffs, axis=[
3420          dim,
3421      ], keepdims=True)
3422  mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None,
3423                                            name or 'mean_cosine_distance')
3424  mean_distance = math_ops.subtract(1.0, mean_distance)
3425  update_op = math_ops.subtract(1.0, update_op)
3426
3427  if metrics_collections:
3428    ops.add_to_collections(metrics_collections, mean_distance)
3429
3430  if updates_collections:
3431    ops.add_to_collections(updates_collections, update_op)
3432
3433  return mean_distance, update_op
3434
3435
3436def streaming_percentage_less(values,
3437                              threshold,
3438                              weights=None,
3439                              metrics_collections=None,
3440                              updates_collections=None,
3441                              name=None):
3442  """Computes the percentage of values less than the given threshold.
3443
3444  The `streaming_percentage_less` function creates two local variables,
3445  `total` and `count` that are used to compute the percentage of `values` that
3446  fall below `threshold`. This rate is weighted by `weights`, and it is
3447  ultimately returned as `percentage` which is an idempotent operation that
3448  simply divides `total` by `count`.
3449
3450  For estimation of the metric over a stream of data, the function creates an
3451  `update_op` operation that updates these variables and returns the
3452  `percentage`.
3453
3454  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3455
3456  Args:
3457    values: A numeric `Tensor` of arbitrary size.
3458    threshold: A scalar threshold.
3459    weights: An optional `Tensor` whose shape is broadcastable to `values`.
3460    metrics_collections: An optional list of collections that the metric
3461      value variable should be added to.
3462    updates_collections: An optional list of collections that the metric update
3463      ops should be added to.
3464    name: An optional variable_scope name.
3465
3466  Returns:
3467    percentage: A `Tensor` representing the current mean, the value of `total`
3468      divided by `count`.
3469    update_op: An operation that increments the `total` and `count` variables
3470      appropriately.
3471
3472  Raises:
3473    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
3474      or if either `metrics_collections` or `updates_collections` are not a list
3475      or tuple.
3476  """
3477  return metrics.percentage_below(
3478      values=values,
3479      threshold=threshold,
3480      weights=weights,
3481      metrics_collections=metrics_collections,
3482      updates_collections=updates_collections,
3483      name=name)
3484
3485
3486def streaming_mean_iou(predictions,
3487                       labels,
3488                       num_classes,
3489                       weights=None,
3490                       metrics_collections=None,
3491                       updates_collections=None,
3492                       name=None):
3493  """Calculate per-step mean Intersection-Over-Union (mIOU).
3494
3495  Mean Intersection-Over-Union is a common evaluation metric for
3496  semantic image segmentation, which first computes the IOU for each
3497  semantic class and then computes the average over classes.
3498  IOU is defined as follows:
3499    IOU = true_positive / (true_positive + false_positive + false_negative).
3500  The predictions are accumulated in a confusion matrix, weighted by `weights`,
3501  and mIOU is then calculated from it.
3502
3503  For estimation of the metric over a stream of data, the function creates an
3504  `update_op` operation that updates these variables and returns the `mean_iou`.
3505
3506  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3507
3508  Args:
3509    predictions: A `Tensor` of prediction results for semantic labels, whose
3510      shape is [batch size] and type `int32` or `int64`. The tensor will be
3511      flattened, if its rank > 1.
3512    labels: A `Tensor` of ground truth labels with shape [batch size] and of
3513      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
3514    num_classes: The possible number of labels the prediction task can
3515      have. This value must be provided, since a confusion matrix of
3516      dimension = [num_classes, num_classes] will be allocated.
3517    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
3518    metrics_collections: An optional list of collections that `mean_iou`
3519      should be added to.
3520    updates_collections: An optional list of collections `update_op` should be
3521      added to.
3522    name: An optional variable_scope name.
3523
3524  Returns:
3525    mean_iou: A `Tensor` representing the mean intersection-over-union.
3526    update_op: An operation that increments the confusion matrix.
3527
3528  Raises:
3529    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3530      `weights` is not `None` and its shape doesn't match `predictions`, or if
3531      either `metrics_collections` or `updates_collections` are not a list or
3532      tuple.
3533  """
3534  return metrics.mean_iou(
3535      num_classes=num_classes,
3536      predictions=predictions,
3537      labels=labels,
3538      weights=weights,
3539      metrics_collections=metrics_collections,
3540      updates_collections=updates_collections,
3541      name=name)
3542
3543
3544def _next_array_size(required_size, growth_factor=1.5):
3545  """Calculate the next size for reallocating a dynamic array.
3546
3547  Args:
3548    required_size: number or tf.Tensor specifying required array capacity.
3549    growth_factor: optional number or tf.Tensor specifying the growth factor
3550      between subsequent allocations.
3551
3552  Returns:
3553    tf.Tensor with dtype=int32 giving the next array size.
3554  """
3555  exponent = math_ops.ceil(
3556      math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log(
3557          math_ops.cast(growth_factor, dtypes.float32)))
3558  return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32)
3559
3560
3561def streaming_concat(values,
3562                     axis=0,
3563                     max_size=None,
3564                     metrics_collections=None,
3565                     updates_collections=None,
3566                     name=None):
3567  """Concatenate values along an axis across batches.
3568
3569  The function `streaming_concat` creates two local variables, `array` and
3570  `size`, that are used to store concatenated values. Internally, `array` is
3571  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
3572  that updates can be run in amortized constant time.
3573
3574  For estimation of the metric over a stream of data, the function creates an
3575  `update_op` operation that appends the values of a tensor and returns the
3576  length of the concatenated axis.
3577
3578  This op allows for evaluating metrics that cannot be updated incrementally
3579  using the same framework as other streaming metrics.
3580
3581  Args:
3582    values: `Tensor` to concatenate. Rank and the shape along all axes other
3583      than the axis to concatenate along must be statically known.
3584    axis: optional integer axis to concatenate along.
3585    max_size: optional integer maximum size of `value` along the given axis.
3586      Once the maximum size is reached, further updates are no-ops. By default,
3587      there is no maximum size: the array is resized as necessary.
3588    metrics_collections: An optional list of collections that `value`
3589      should be added to.
3590    updates_collections: An optional list of collections `update_op` should be
3591      added to.
3592    name: An optional variable_scope name.
3593
3594  Returns:
3595    value: A `Tensor` representing the concatenated values.
3596    update_op: An operation that concatenates the next values.
3597
3598  Raises:
3599    ValueError: if `values` does not have a statically known rank, `axis` is
3600      not in the valid range or the size of `values` is not statically known
3601      along any axis other than `axis`.
3602  """
3603  with variable_scope.variable_scope(name, 'streaming_concat', (values,)):
3604    # pylint: disable=invalid-slice-index
3605    values_shape = values.get_shape()
3606    if values_shape.dims is None:
3607      raise ValueError('`values` must have known statically known rank')
3608
3609    ndim = len(values_shape)
3610    if axis < 0:
3611      axis += ndim
3612    if not 0 <= axis < ndim:
3613      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
3614
3615    fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis]
3616    if any(value is None for value in fixed_shape):
3617      raise ValueError('all dimensions of `values` other than the dimension to '
3618                       'concatenate along must have statically known size')
3619
3620    # We move `axis` to the front of the internal array so assign ops can be
3621    # applied to contiguous slices
3622    init_size = 0 if max_size is None else max_size
3623    init_shape = [init_size] + fixed_shape
3624    array = metrics_impl.metric_variable(
3625        init_shape, values.dtype, validate_shape=False, name='array')
3626    size = metrics_impl.metric_variable([], dtypes.int32, name='size')
3627
3628    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
3629    valid_array = array[:size]
3630    valid_array.set_shape([None] + fixed_shape)
3631    value = array_ops.transpose(valid_array, perm, name='concat')
3632
3633    values_size = array_ops.shape(values)[axis]
3634    if max_size is None:
3635      batch_size = values_size
3636    else:
3637      batch_size = math_ops.minimum(values_size, max_size - size)
3638
3639    perm = [axis] + [n for n in range(ndim) if n != axis]
3640    batch_values = array_ops.transpose(values, perm)[:batch_size]
3641
3642    def reallocate():
3643      next_size = _next_array_size(new_size)
3644      next_shape = array_ops.stack([next_size] + fixed_shape)
3645      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
3646      old_value = array.value()
3647      assign_op = state_ops.assign(array, new_value, validate_shape=False)
3648      with ops.control_dependencies([assign_op]):
3649        copy_op = array[:size].assign(old_value[:size])
3650      # return value needs to be the same dtype as no_op() for cond
3651      with ops.control_dependencies([copy_op]):
3652        return control_flow_ops.no_op()
3653
3654    new_size = size + batch_size
3655    array_size = array_ops.shape_internal(array, optimize=False)[0]
3656    maybe_reallocate_op = control_flow_ops.cond(
3657        new_size > array_size, reallocate, control_flow_ops.no_op)
3658    with ops.control_dependencies([maybe_reallocate_op]):
3659      append_values_op = array[size:new_size].assign(batch_values)
3660    with ops.control_dependencies([append_values_op]):
3661      update_op = size.assign(new_size)
3662
3663    if metrics_collections:
3664      ops.add_to_collections(metrics_collections, value)
3665
3666    if updates_collections:
3667      ops.add_to_collections(updates_collections, update_op)
3668
3669    return value, update_op
3670    # pylint: enable=invalid-slice-index
3671
3672
3673def aggregate_metrics(*value_update_tuples):
3674  """Aggregates the metric value tensors and update ops into two lists.
3675
3676  Args:
3677    *value_update_tuples: a variable number of tuples, each of which contain the
3678      pair of (value_tensor, update_op) from a streaming metric.
3679
3680  Returns:
3681    A list of value `Tensor` objects and a list of update ops.
3682
3683  Raises:
3684    ValueError: if `value_update_tuples` is empty.
3685  """
3686  if not value_update_tuples:
3687    raise ValueError('Expected at least one value_tensor/update_op pair')
3688  value_ops, update_ops = zip(*value_update_tuples)
3689  return list(value_ops), list(update_ops)
3690
3691
3692def aggregate_metric_map(names_to_tuples):
3693  """Aggregates the metric names to tuple dictionary.
3694
3695  This function is useful for pairing metric names with their associated value
3696  and update ops when the list of metrics is long. For example:
3697
3698  ```python
3699    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
3700        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
3701            predictions, labels, weights),
3702        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
3703            predictions, labels, labels, weights),
3704        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
3705            predictions, labels, weights),
3706        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
3707            predictions, labels, weights),
3708    })
3709  ```
3710
3711  Args:
3712    names_to_tuples: a map of metric names to tuples, each of which contain the
3713      pair of (value_tensor, update_op) from a streaming metric.
3714
3715  Returns:
3716    A dictionary from metric names to value ops and a dictionary from metric
3717    names to update ops.
3718  """
3719  metric_names = names_to_tuples.keys()
3720  value_ops, update_ops = zip(*names_to_tuples.values())
3721  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
3722
3723
3724def count(values,
3725          weights=None,
3726          metrics_collections=None,
3727          updates_collections=None,
3728          name=None):
3729  """Computes the number of examples, or sum of `weights`.
3730
3731  This metric keeps track of the denominator in `tf.metrics.mean`.
3732  When evaluating some metric (e.g. mean) on one or more subsets of the data,
3733  this auxiliary metric is useful for keeping track of how many examples there
3734  are in each subset.
3735
3736  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3737
3738  Args:
3739    values: A `Tensor` of arbitrary dimensions. Only it's shape is used.
3740    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3741      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
3742      must be either `1`, or the same as the corresponding `labels`
3743      dimension).
3744    metrics_collections: An optional list of collections that the metric
3745      value variable should be added to.
3746    updates_collections: An optional list of collections that the metric update
3747      ops should be added to.
3748    name: An optional variable_scope name.
3749
3750  Returns:
3751    count: A `Tensor` representing the current value of the metric.
3752    update_op: An operation that accumulates the metric from a batch of data.
3753
3754  Raises:
3755    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
3756      or if either `metrics_collections` or `updates_collections` are not a list
3757      or tuple.
3758    RuntimeError: If eager execution is enabled.
3759  """
3760  if context.executing_eagerly():
3761    raise RuntimeError('tf.contrib.metrics.count is not supported when eager '
3762                       'execution is enabled.')
3763
3764  with variable_scope.variable_scope(name, 'count', (values, weights)):
3765
3766    count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
3767
3768    if weights is None:
3769      num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
3770    else:
3771      values = math_ops.cast(values, dtypes.float32)
3772      values, _, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3773          predictions=values,
3774          labels=None,
3775          weights=weights)
3776      weights = weights_broadcast_ops.broadcast_weights(
3777          math_ops.cast(weights, dtypes.float32), values)
3778      num_values = math_ops.reduce_sum(weights)
3779
3780    with ops.control_dependencies([values]):
3781      update_count_op = state_ops.assign_add(count_, num_values)
3782
3783    count_ = metrics_impl._aggregate_variable(count_, metrics_collections)  # pylint: disable=protected-access
3784
3785    if updates_collections:
3786      ops.add_to_collections(updates_collections, update_count_op)
3787
3788    return count_, update_count_op
3789
3790
3791def cohen_kappa(labels,
3792                predictions_idx,
3793                num_classes,
3794                weights=None,
3795                metrics_collections=None,
3796                updates_collections=None,
3797                name=None):
3798  """Calculates Cohen's kappa.
3799
3800  [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic
3801  that measures inter-annotator agreement.
3802
3803  The `cohen_kappa` function calculates the confusion matrix, and creates three
3804  local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`,
3805  which refer to the diagonal part, rows and columns totals of the confusion
3806  matrix, respectively. This value is ultimately returned as `kappa`, an
3807  idempotent operation that is calculated by
3808
3809      pe = (pe_row * pe_col) / N
3810      k = (sum(po) - sum(pe)) / (N - sum(pe))
3811
3812  For estimation of the metric over a stream of data, the function creates an
3813  `update_op` operation that updates these variables and returns the
3814  `kappa`. `update_op` weights each prediction by the corresponding value in
3815  `weights`.
3816
3817  Class labels are expected to start at 0. E.g., if `num_classes`
3818  was three, then the possible labels would be [0, 1, 2].
3819
3820  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3821
3822  NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method
3823  doesn't support weighted matrix yet.
3824
3825  Args:
3826    labels: 1-D `Tensor` of real labels for the classification task. Must be
3827      one of the following types: int16, int32, int64.
3828    predictions_idx: 1-D `Tensor` of predicted class indices for a given
3829      classification. Must have the same type as `labels`.
3830    num_classes: The possible number of labels.
3831    weights: Optional `Tensor` whose shape matches `predictions`.
3832    metrics_collections: An optional list of collections that `kappa` should
3833      be added to.
3834    updates_collections: An optional list of collections that `update_op` should
3835      be added to.
3836    name: An optional variable_scope name.
3837
3838  Returns:
3839    kappa: Scalar float `Tensor` representing the current Cohen's kappa.
3840    update_op: `Operation` that increments `po`, `pe_row` and `pe_col`
3841      variables appropriately and whose value matches `kappa`.
3842
3843  Raises:
3844    ValueError: If `num_classes` is less than 2, or `predictions` and `labels`
3845      have mismatched shapes, or if `weights` is not `None` and its shape
3846      doesn't match `predictions`, or if either `metrics_collections` or
3847      `updates_collections` are not a list or tuple.
3848    RuntimeError: If eager execution is enabled.
3849  """
3850  if context.executing_eagerly():
3851    raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported '
3852                       'when eager execution is enabled.')
3853  if num_classes < 2:
3854    raise ValueError('`num_classes` must be >= 2.'
3855                     'Found: {}'.format(num_classes))
3856  with variable_scope.variable_scope(name, 'cohen_kappa',
3857                                     (labels, predictions_idx, weights)):
3858    # Convert 2-dim (num, 1) to 1-dim (num,)
3859    labels.get_shape().with_rank_at_most(2)
3860    if labels.get_shape().ndims == 2:
3861      labels = array_ops.squeeze(labels, axis=[-1])
3862    predictions_idx, labels, weights = (
3863        metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3864            predictions=predictions_idx,
3865            labels=labels,
3866            weights=weights))
3867    predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape())
3868
3869    stat_dtype = (
3870        dtypes.int64
3871        if weights is None or weights.dtype.is_integer else dtypes.float32)
3872    po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po')
3873    pe_row = metrics_impl.metric_variable(
3874        (num_classes,), stat_dtype, name='pe_row')
3875    pe_col = metrics_impl.metric_variable(
3876        (num_classes,), stat_dtype, name='pe_col')
3877
3878    # Table of the counts of agreement:
3879    counts_in_table = confusion_matrix.confusion_matrix(
3880        labels,
3881        predictions_idx,
3882        num_classes=num_classes,
3883        weights=weights,
3884        dtype=stat_dtype,
3885        name='counts_in_table')
3886
3887    po_t = array_ops.diag_part(counts_in_table)
3888    pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0)
3889    pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1)
3890    update_po = state_ops.assign_add(po, po_t)
3891    update_pe_row = state_ops.assign_add(pe_row, pe_row_t)
3892    update_pe_col = state_ops.assign_add(pe_col, pe_col_t)
3893
3894    def _calculate_k(po, pe_row, pe_col, name):
3895      po_sum = math_ops.reduce_sum(po)
3896      total = math_ops.reduce_sum(pe_row)
3897      pe_sum = math_ops.reduce_sum(
3898          math_ops.div_no_nan(
3899              math_ops.cast(pe_row * pe_col, dtypes.float64),
3900              math_ops.cast(total, dtypes.float64)))
3901      po_sum, pe_sum, total = (math_ops.cast(po_sum, dtypes.float64),
3902                               math_ops.cast(pe_sum, dtypes.float64),
3903                               math_ops.cast(total, dtypes.float64))
3904      # kappa = (po - pe) / (N - pe)
3905      k = metrics_impl._safe_scalar_div(  # pylint: disable=protected-access
3906          po_sum - pe_sum,
3907          total - pe_sum,
3908          name=name)
3909      return k
3910
3911    kappa = _calculate_k(po, pe_row, pe_col, name='value')
3912    update_op = _calculate_k(
3913        update_po, update_pe_row, update_pe_col, name='update_op')
3914
3915    if metrics_collections:
3916      ops.add_to_collections(metrics_collections, kappa)
3917
3918    if updates_collections:
3919      ops.add_to_collections(updates_collections, update_op)
3920
3921    return kappa, update_op
3922
3923
3924__all__ = [
3925    'auc_with_confidence_intervals',
3926    'aggregate_metric_map',
3927    'aggregate_metrics',
3928    'cohen_kappa',
3929    'count',
3930    'precision_recall_at_equal_thresholds',
3931    'recall_at_precision',
3932    'sparse_recall_at_top_k',
3933    'streaming_accuracy',
3934    'streaming_auc',
3935    'streaming_curve_points',
3936    'streaming_dynamic_auc',
3937    'streaming_false_negative_rate',
3938    'streaming_false_negative_rate_at_thresholds',
3939    'streaming_false_negatives',
3940    'streaming_false_negatives_at_thresholds',
3941    'streaming_false_positive_rate',
3942    'streaming_false_positive_rate_at_thresholds',
3943    'streaming_false_positives',
3944    'streaming_false_positives_at_thresholds',
3945    'streaming_mean',
3946    'streaming_mean_absolute_error',
3947    'streaming_mean_cosine_distance',
3948    'streaming_mean_iou',
3949    'streaming_mean_relative_error',
3950    'streaming_mean_squared_error',
3951    'streaming_mean_tensor',
3952    'streaming_percentage_less',
3953    'streaming_precision',
3954    'streaming_precision_at_thresholds',
3955    'streaming_recall',
3956    'streaming_recall_at_k',
3957    'streaming_recall_at_thresholds',
3958    'streaming_root_mean_squared_error',
3959    'streaming_sensitivity_at_specificity',
3960    'streaming_sparse_average_precision_at_k',
3961    'streaming_sparse_average_precision_at_top_k',
3962    'streaming_sparse_precision_at_k',
3963    'streaming_sparse_precision_at_top_k',
3964    'streaming_sparse_recall_at_k',
3965    'streaming_specificity_at_sensitivity',
3966    'streaming_true_negatives',
3967    'streaming_true_negatives_at_thresholds',
3968    'streaming_true_positives',
3969    'streaming_true_positives_at_thresholds',
3970]
3971