1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections as collections_lib 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import confusion_matrix 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import metrics 36from tensorflow.python.ops import metrics_impl 37from tensorflow.python.ops import nn 38from tensorflow.python.ops import state_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.ops import weights_broadcast_ops 41from tensorflow.python.ops.distributions.normal import Normal 42from tensorflow.python.util.deprecation import deprecated 43 44# Epsilon constant used to represent extremely small quantity. 45_EPSILON = 1e-7 46 47 48@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the ' 49 'order of the labels and predictions arguments has been switched.') 50def streaming_true_positives(predictions, 51 labels, 52 weights=None, 53 metrics_collections=None, 54 updates_collections=None, 55 name=None): 56 """Sum the weights of true_positives. 57 58 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 59 60 Args: 61 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 62 be cast to `bool`. 63 labels: The ground truth values, a `Tensor` whose dimensions must match 64 `predictions`. Will be cast to `bool`. 65 weights: Optional `Tensor` whose rank is either 0, or the same rank as 66 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 67 must be either `1`, or the same as the corresponding `labels` 68 dimension). 69 metrics_collections: An optional list of collections that the metric 70 value variable should be added to. 71 updates_collections: An optional list of collections that the metric update 72 ops should be added to. 73 name: An optional variable_scope name. 74 75 Returns: 76 value_tensor: A `Tensor` representing the current value of the metric. 77 update_op: An operation that accumulates the error from a batch of data. 78 79 Raises: 80 ValueError: If `predictions` and `labels` have mismatched shapes, or if 81 `weights` is not `None` and its shape doesn't match `predictions`, or if 82 either `metrics_collections` or `updates_collections` are not a list or 83 tuple. 84 """ 85 return metrics.true_positives( 86 predictions=predictions, 87 labels=labels, 88 weights=weights, 89 metrics_collections=metrics_collections, 90 updates_collections=updates_collections, 91 name=name) 92 93 94@deprecated(None, 'Please switch to tf.metrics.true_negatives. Note that the ' 95 'order of the labels and predictions arguments has been switched.') 96def streaming_true_negatives(predictions, 97 labels, 98 weights=None, 99 metrics_collections=None, 100 updates_collections=None, 101 name=None): 102 """Sum the weights of true_negatives. 103 104 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 105 106 Args: 107 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 108 be cast to `bool`. 109 labels: The ground truth values, a `Tensor` whose dimensions must match 110 `predictions`. Will be cast to `bool`. 111 weights: Optional `Tensor` whose rank is either 0, or the same rank as 112 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 113 must be either `1`, or the same as the corresponding `labels` 114 dimension). 115 metrics_collections: An optional list of collections that the metric 116 value variable should be added to. 117 updates_collections: An optional list of collections that the metric update 118 ops should be added to. 119 name: An optional variable_scope name. 120 121 Returns: 122 value_tensor: A `Tensor` representing the current value of the metric. 123 update_op: An operation that accumulates the error from a batch of data. 124 125 Raises: 126 ValueError: If `predictions` and `labels` have mismatched shapes, or if 127 `weights` is not `None` and its shape doesn't match `predictions`, or if 128 either `metrics_collections` or `updates_collections` are not a list or 129 tuple. 130 """ 131 return metrics.true_negatives( 132 predictions=predictions, 133 labels=labels, 134 weights=weights, 135 metrics_collections=metrics_collections, 136 updates_collections=updates_collections, 137 name=name) 138 139 140@deprecated(None, 'Please switch to tf.metrics.false_positives. Note that the ' 141 'order of the labels and predictions arguments has been switched.') 142def streaming_false_positives(predictions, 143 labels, 144 weights=None, 145 metrics_collections=None, 146 updates_collections=None, 147 name=None): 148 """Sum the weights of false positives. 149 150 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 151 152 Args: 153 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 154 be cast to `bool`. 155 labels: The ground truth values, a `Tensor` whose dimensions must match 156 `predictions`. Will be cast to `bool`. 157 weights: Optional `Tensor` whose rank is either 0, or the same rank as 158 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 159 must be either `1`, or the same as the corresponding `labels` 160 dimension). 161 metrics_collections: An optional list of collections that the metric 162 value variable should be added to. 163 updates_collections: An optional list of collections that the metric update 164 ops should be added to. 165 name: An optional variable_scope name. 166 167 Returns: 168 value_tensor: A `Tensor` representing the current value of the metric. 169 update_op: An operation that accumulates the error from a batch of data. 170 171 Raises: 172 ValueError: If `predictions` and `labels` have mismatched shapes, or if 173 `weights` is not `None` and its shape doesn't match `predictions`, or if 174 either `metrics_collections` or `updates_collections` are not a list or 175 tuple. 176 """ 177 return metrics.false_positives( 178 predictions=predictions, 179 labels=labels, 180 weights=weights, 181 metrics_collections=metrics_collections, 182 updates_collections=updates_collections, 183 name=name) 184 185 186@deprecated(None, 'Please switch to tf.metrics.false_negatives. Note that the ' 187 'order of the labels and predictions arguments has been switched.') 188def streaming_false_negatives(predictions, 189 labels, 190 weights=None, 191 metrics_collections=None, 192 updates_collections=None, 193 name=None): 194 """Computes the total number of false negatives. 195 196 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 197 198 Args: 199 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 200 be cast to `bool`. 201 labels: The ground truth values, a `Tensor` whose dimensions must match 202 `predictions`. Will be cast to `bool`. 203 weights: Optional `Tensor` whose rank is either 0, or the same rank as 204 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 205 must be either `1`, or the same as the corresponding `labels` 206 dimension). 207 metrics_collections: An optional list of collections that the metric 208 value variable should be added to. 209 updates_collections: An optional list of collections that the metric update 210 ops should be added to. 211 name: An optional variable_scope name. 212 213 Returns: 214 value_tensor: A `Tensor` representing the current value of the metric. 215 update_op: An operation that accumulates the error from a batch of data. 216 217 Raises: 218 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 219 or if either `metrics_collections` or `updates_collections` are not a list 220 or tuple. 221 """ 222 return metrics.false_negatives( 223 predictions=predictions, 224 labels=labels, 225 weights=weights, 226 metrics_collections=metrics_collections, 227 updates_collections=updates_collections, 228 name=name) 229 230 231@deprecated(None, 'Please switch to tf.metrics.mean') 232def streaming_mean(values, 233 weights=None, 234 metrics_collections=None, 235 updates_collections=None, 236 name=None): 237 """Computes the (weighted) mean of the given values. 238 239 The `streaming_mean` function creates two local variables, `total` and `count` 240 that are used to compute the average of `values`. This average is ultimately 241 returned as `mean` which is an idempotent operation that simply divides 242 `total` by `count`. 243 244 For estimation of the metric over a stream of data, the function creates an 245 `update_op` operation that updates these variables and returns the `mean`. 246 `update_op` increments `total` with the reduced sum of the product of `values` 247 and `weights`, and it increments `count` with the reduced sum of `weights`. 248 249 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 250 251 Args: 252 values: A `Tensor` of arbitrary dimensions. 253 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 254 must be broadcastable to `values` (i.e., all dimensions must be either 255 `1`, or the same as the corresponding `values` dimension). 256 metrics_collections: An optional list of collections that `mean` 257 should be added to. 258 updates_collections: An optional list of collections that `update_op` 259 should be added to. 260 name: An optional variable_scope name. 261 262 Returns: 263 mean: A `Tensor` representing the current mean, the value of `total` divided 264 by `count`. 265 update_op: An operation that increments the `total` and `count` variables 266 appropriately and whose value matches `mean`. 267 268 Raises: 269 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 270 or if either `metrics_collections` or `updates_collections` are not a list 271 or tuple. 272 """ 273 return metrics.mean( 274 values=values, 275 weights=weights, 276 metrics_collections=metrics_collections, 277 updates_collections=updates_collections, 278 name=name) 279 280 281@deprecated(None, 'Please switch to tf.metrics.mean_tensor') 282def streaming_mean_tensor(values, 283 weights=None, 284 metrics_collections=None, 285 updates_collections=None, 286 name=None): 287 """Computes the element-wise (weighted) mean of the given tensors. 288 289 In contrast to the `streaming_mean` function which returns a scalar with the 290 mean, this function returns an average tensor with the same shape as the 291 input tensors. 292 293 The `streaming_mean_tensor` function creates two local variables, 294 `total_tensor` and `count_tensor` that are used to compute the average of 295 `values`. This average is ultimately returned as `mean` which is an idempotent 296 operation that simply divides `total` by `count`. 297 298 For estimation of the metric over a stream of data, the function creates an 299 `update_op` operation that updates these variables and returns the `mean`. 300 `update_op` increments `total` with the reduced sum of the product of `values` 301 and `weights`, and it increments `count` with the reduced sum of `weights`. 302 303 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 304 305 Args: 306 values: A `Tensor` of arbitrary dimensions. 307 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 308 must be broadcastable to `values` (i.e., all dimensions must be either 309 `1`, or the same as the corresponding `values` dimension). 310 metrics_collections: An optional list of collections that `mean` 311 should be added to. 312 updates_collections: An optional list of collections that `update_op` 313 should be added to. 314 name: An optional variable_scope name. 315 316 Returns: 317 mean: A float `Tensor` representing the current mean, the value of `total` 318 divided by `count`. 319 update_op: An operation that increments the `total` and `count` variables 320 appropriately and whose value matches `mean`. 321 322 Raises: 323 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 324 or if either `metrics_collections` or `updates_collections` are not a list 325 or tuple. 326 """ 327 return metrics.mean_tensor( 328 values=values, 329 weights=weights, 330 metrics_collections=metrics_collections, 331 updates_collections=updates_collections, 332 name=name) 333 334 335@deprecated(None, 'Please switch to tf.metrics.accuracy. Note that the order ' 336 'of the labels and predictions arguments has been switched.') 337def streaming_accuracy(predictions, 338 labels, 339 weights=None, 340 metrics_collections=None, 341 updates_collections=None, 342 name=None): 343 """Calculates how often `predictions` matches `labels`. 344 345 The `streaming_accuracy` function creates two local variables, `total` and 346 `count` that are used to compute the frequency with which `predictions` 347 matches `labels`. This frequency is ultimately returned as `accuracy`: an 348 idempotent operation that simply divides `total` by `count`. 349 350 For estimation of the metric over a stream of data, the function creates an 351 `update_op` operation that updates these variables and returns the `accuracy`. 352 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 353 where the corresponding elements of `predictions` and `labels` match and 0.0 354 otherwise. Then `update_op` increments `total` with the reduced sum of the 355 product of `weights` and `is_correct`, and it increments `count` with the 356 reduced sum of `weights`. 357 358 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 359 360 Args: 361 predictions: The predicted values, a `Tensor` of any shape. 362 labels: The ground truth values, a `Tensor` whose shape matches 363 `predictions`. 364 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 365 must be broadcastable to `labels` (i.e., all dimensions must be either 366 `1`, or the same as the corresponding `labels` dimension). 367 metrics_collections: An optional list of collections that `accuracy` should 368 be added to. 369 updates_collections: An optional list of collections that `update_op` should 370 be added to. 371 name: An optional variable_scope name. 372 373 Returns: 374 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 375 by `count`. 376 update_op: An operation that increments the `total` and `count` variables 377 appropriately and whose value matches `accuracy`. 378 379 Raises: 380 ValueError: If `predictions` and `labels` have mismatched shapes, or if 381 `weights` is not `None` and its shape doesn't match `predictions`, or if 382 either `metrics_collections` or `updates_collections` are not a list or 383 tuple. 384 """ 385 return metrics.accuracy( 386 predictions=predictions, 387 labels=labels, 388 weights=weights, 389 metrics_collections=metrics_collections, 390 updates_collections=updates_collections, 391 name=name) 392 393 394@deprecated(None, 'Please switch to tf.metrics.precision. Note that the order ' 395 'of the labels and predictions arguments has been switched.') 396def streaming_precision(predictions, 397 labels, 398 weights=None, 399 metrics_collections=None, 400 updates_collections=None, 401 name=None): 402 """Computes the precision of the predictions with respect to the labels. 403 404 The `streaming_precision` function creates two local variables, 405 `true_positives` and `false_positives`, that are used to compute the 406 precision. This value is ultimately returned as `precision`, an idempotent 407 operation that simply divides `true_positives` by the sum of `true_positives` 408 and `false_positives`. 409 410 For estimation of the metric over a stream of data, the function creates an 411 `update_op` operation that updates these variables and returns the 412 `precision`. `update_op` weights each prediction by the corresponding value in 413 `weights`. 414 415 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 416 417 Args: 418 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 419 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 420 match `predictions`. 421 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 422 must be broadcastable to `labels` (i.e., all dimensions must be either 423 `1`, or the same as the corresponding `labels` dimension). 424 metrics_collections: An optional list of collections that `precision` should 425 be added to. 426 updates_collections: An optional list of collections that `update_op` should 427 be added to. 428 name: An optional variable_scope name. 429 430 Returns: 431 precision: Scalar float `Tensor` with the value of `true_positives` 432 divided by the sum of `true_positives` and `false_positives`. 433 update_op: `Operation` that increments `true_positives` and 434 `false_positives` variables appropriately and whose value matches 435 `precision`. 436 437 Raises: 438 ValueError: If `predictions` and `labels` have mismatched shapes, or if 439 `weights` is not `None` and its shape doesn't match `predictions`, or if 440 either `metrics_collections` or `updates_collections` are not a list or 441 tuple. 442 """ 443 return metrics.precision( 444 predictions=predictions, 445 labels=labels, 446 weights=weights, 447 metrics_collections=metrics_collections, 448 updates_collections=updates_collections, 449 name=name) 450 451 452@deprecated(None, 'Please switch to tf.metrics.recall. Note that the order ' 453 'of the labels and predictions arguments has been switched.') 454def streaming_recall(predictions, 455 labels, 456 weights=None, 457 metrics_collections=None, 458 updates_collections=None, 459 name=None): 460 """Computes the recall of the predictions with respect to the labels. 461 462 The `streaming_recall` function creates two local variables, `true_positives` 463 and `false_negatives`, that are used to compute the recall. This value is 464 ultimately returned as `recall`, an idempotent operation that simply divides 465 `true_positives` by the sum of `true_positives` and `false_negatives`. 466 467 For estimation of the metric over a stream of data, the function creates an 468 `update_op` that updates these variables and returns the `recall`. `update_op` 469 weights each prediction by the corresponding value in `weights`. 470 471 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 472 473 Args: 474 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 475 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 476 match `predictions`. 477 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 478 must be broadcastable to `labels` (i.e., all dimensions must be either 479 `1`, or the same as the corresponding `labels` dimension). 480 metrics_collections: An optional list of collections that `recall` should 481 be added to. 482 updates_collections: An optional list of collections that `update_op` should 483 be added to. 484 name: An optional variable_scope name. 485 486 Returns: 487 recall: Scalar float `Tensor` with the value of `true_positives` divided 488 by the sum of `true_positives` and `false_negatives`. 489 update_op: `Operation` that increments `true_positives` and 490 `false_negatives` variables appropriately and whose value matches 491 `recall`. 492 493 Raises: 494 ValueError: If `predictions` and `labels` have mismatched shapes, or if 495 `weights` is not `None` and its shape doesn't match `predictions`, or if 496 either `metrics_collections` or `updates_collections` are not a list or 497 tuple. 498 """ 499 return metrics.recall( 500 predictions=predictions, 501 labels=labels, 502 weights=weights, 503 metrics_collections=metrics_collections, 504 updates_collections=updates_collections, 505 name=name) 506 507 508def streaming_false_positive_rate(predictions, 509 labels, 510 weights=None, 511 metrics_collections=None, 512 updates_collections=None, 513 name=None): 514 """Computes the false positive rate of predictions with respect to labels. 515 516 The `false_positive_rate` function creates two local variables, 517 `false_positives` and `true_negatives`, that are used to compute the 518 false positive rate. This value is ultimately returned as 519 `false_positive_rate`, an idempotent operation that simply divides 520 `false_positives` by the sum of `false_positives` and `true_negatives`. 521 522 For estimation of the metric over a stream of data, the function creates an 523 `update_op` operation that updates these variables and returns the 524 `false_positive_rate`. `update_op` weights each prediction by the 525 corresponding value in `weights`. 526 527 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 528 529 Args: 530 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 531 be cast to `bool`. 532 labels: The ground truth values, a `Tensor` whose dimensions must match 533 `predictions`. Will be cast to `bool`. 534 weights: Optional `Tensor` whose rank is either 0, or the same rank as 535 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 536 be either `1`, or the same as the corresponding `labels` dimension). 537 metrics_collections: An optional list of collections that 538 `false_positive_rate` should be added to. 539 updates_collections: An optional list of collections that `update_op` should 540 be added to. 541 name: An optional variable_scope name. 542 543 Returns: 544 false_positive_rate: Scalar float `Tensor` with the value of 545 `false_positives` divided by the sum of `false_positives` and 546 `true_negatives`. 547 update_op: `Operation` that increments `false_positives` and 548 `true_negatives` variables appropriately and whose value matches 549 `false_positive_rate`. 550 551 Raises: 552 ValueError: If `predictions` and `labels` have mismatched shapes, or if 553 `weights` is not `None` and its shape doesn't match `predictions`, or if 554 either `metrics_collections` or `updates_collections` are not a list or 555 tuple. 556 """ 557 with variable_scope.variable_scope(name, 'false_positive_rate', 558 (predictions, labels, weights)): 559 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 560 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 561 labels=math_ops.cast(labels, dtype=dtypes.bool), 562 weights=weights) 563 564 false_p, false_positives_update_op = metrics.false_positives( 565 labels=labels, 566 predictions=predictions, 567 weights=weights, 568 metrics_collections=None, 569 updates_collections=None, 570 name=None) 571 true_n, true_negatives_update_op = metrics.true_negatives( 572 labels=labels, 573 predictions=predictions, 574 weights=weights, 575 metrics_collections=None, 576 updates_collections=None, 577 name=None) 578 579 def compute_fpr(fp, tn, name): 580 return array_ops.where( 581 math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) 582 583 fpr = compute_fpr(false_p, true_n, 'value') 584 update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, 585 'update_op') 586 587 if metrics_collections: 588 ops.add_to_collections(metrics_collections, fpr) 589 590 if updates_collections: 591 ops.add_to_collections(updates_collections, update_op) 592 593 return fpr, update_op 594 595 596def streaming_false_negative_rate(predictions, 597 labels, 598 weights=None, 599 metrics_collections=None, 600 updates_collections=None, 601 name=None): 602 """Computes the false negative rate of predictions with respect to labels. 603 604 The `false_negative_rate` function creates two local variables, 605 `false_negatives` and `true_positives`, that are used to compute the 606 false positive rate. This value is ultimately returned as 607 `false_negative_rate`, an idempotent operation that simply divides 608 `false_negatives` by the sum of `false_negatives` and `true_positives`. 609 610 For estimation of the metric over a stream of data, the function creates an 611 `update_op` operation that updates these variables and returns the 612 `false_negative_rate`. `update_op` weights each prediction by the 613 corresponding value in `weights`. 614 615 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 616 617 Args: 618 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 619 be cast to `bool`. 620 labels: The ground truth values, a `Tensor` whose dimensions must match 621 `predictions`. Will be cast to `bool`. 622 weights: Optional `Tensor` whose rank is either 0, or the same rank as 623 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 624 be either `1`, or the same as the corresponding `labels` dimension). 625 metrics_collections: An optional list of collections that 626 `false_negative_rate` should be added to. 627 updates_collections: An optional list of collections that `update_op` should 628 be added to. 629 name: An optional variable_scope name. 630 631 Returns: 632 false_negative_rate: Scalar float `Tensor` with the value of 633 `false_negatives` divided by the sum of `false_negatives` and 634 `true_positives`. 635 update_op: `Operation` that increments `false_negatives` and 636 `true_positives` variables appropriately and whose value matches 637 `false_negative_rate`. 638 639 Raises: 640 ValueError: If `predictions` and `labels` have mismatched shapes, or if 641 `weights` is not `None` and its shape doesn't match `predictions`, or if 642 either `metrics_collections` or `updates_collections` are not a list or 643 tuple. 644 """ 645 with variable_scope.variable_scope(name, 'false_negative_rate', 646 (predictions, labels, weights)): 647 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 648 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 649 labels=math_ops.cast(labels, dtype=dtypes.bool), 650 weights=weights) 651 652 false_n, false_negatives_update_op = metrics.false_negatives( 653 labels, 654 predictions, 655 weights, 656 metrics_collections=None, 657 updates_collections=None, 658 name=None) 659 true_p, true_positives_update_op = metrics.true_positives( 660 labels, 661 predictions, 662 weights, 663 metrics_collections=None, 664 updates_collections=None, 665 name=None) 666 667 def compute_fnr(fn, tp, name): 668 return array_ops.where( 669 math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) 670 671 fnr = compute_fnr(false_n, true_p, 'value') 672 update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, 673 'update_op') 674 675 if metrics_collections: 676 ops.add_to_collections(metrics_collections, fnr) 677 678 if updates_collections: 679 ops.add_to_collections(updates_collections, update_op) 680 681 return fnr, update_op 682 683 684def _streaming_confusion_matrix_at_thresholds(predictions, 685 labels, 686 thresholds, 687 weights=None, 688 includes=None): 689 """Computes true_positives, false_negatives, true_negatives, false_positives. 690 691 This function creates up to four local variables, `true_positives`, 692 `true_negatives`, `false_positives` and `false_negatives`. 693 `true_positive[i]` is defined as the total weight of values in `predictions` 694 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 695 `false_negatives[i]` is defined as the total weight of values in `predictions` 696 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 697 `true_negatives[i]` is defined as the total weight of values in `predictions` 698 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 699 `false_positives[i]` is defined as the total weight of values in `predictions` 700 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 701 702 For estimation of these metrics over a stream of data, for each metric the 703 function respectively creates an `update_op` operation that updates the 704 variable and returns its value. 705 706 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 707 708 Args: 709 predictions: A floating point `Tensor` of arbitrary shape and whose values 710 are in the range `[0, 1]`. 711 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 712 to `bool`. 713 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 714 weights: Optional `Tensor` whose rank is either 0, or the same rank as 715 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 716 must be either `1`, or the same as the corresponding `labels` 717 dimension). 718 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 719 default to all four. 720 721 Returns: 722 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 723 `includes`. 724 update_ops: Dict of operations that increments the `values`. Keys are from 725 `includes`. 726 727 Raises: 728 ValueError: If `predictions` and `labels` have mismatched shapes, or if 729 `weights` is not `None` and its shape doesn't match `predictions`, or if 730 `includes` contains invalid keys. 731 """ 732 all_includes = ('tp', 'fn', 'tn', 'fp') 733 if includes is None: 734 includes = all_includes 735 else: 736 for include in includes: 737 if include not in all_includes: 738 raise ValueError('Invalid key: %s.' % include) 739 740 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 741 predictions, labels, weights) 742 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 743 744 num_thresholds = len(thresholds) 745 746 # Reshape predictions and labels. 747 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 748 labels_2d = array_ops.reshape( 749 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 750 751 # Use static shape if known. 752 num_predictions = predictions_2d.get_shape().as_list()[0] 753 754 # Otherwise use dynamic shape. 755 if num_predictions is None: 756 num_predictions = array_ops.shape(predictions_2d)[0] 757 thresh_tiled = array_ops.tile( 758 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 759 array_ops.stack([1, num_predictions])) 760 761 # Tile the predictions after thresholding them across different thresholds. 762 pred_is_pos = math_ops.greater( 763 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 764 thresh_tiled) 765 if ('fn' in includes) or ('tn' in includes): 766 pred_is_neg = math_ops.logical_not(pred_is_pos) 767 768 # Tile labels by number of thresholds 769 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 770 if ('fp' in includes) or ('tn' in includes): 771 label_is_neg = math_ops.logical_not(label_is_pos) 772 773 if weights is not None: 774 broadcast_weights = weights_broadcast_ops.broadcast_weights( 775 math_ops.cast(weights, dtypes.float32), predictions) 776 weights_tiled = array_ops.tile( 777 array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) 778 thresh_tiled.get_shape().assert_is_compatible_with( 779 weights_tiled.get_shape()) 780 else: 781 weights_tiled = None 782 783 values = {} 784 update_ops = {} 785 786 if 'tp' in includes: 787 true_positives = metrics_impl.metric_variable( 788 [num_thresholds], dtypes.float32, name='true_positives') 789 is_true_positive = math_ops.cast( 790 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) 791 if weights_tiled is not None: 792 is_true_positive *= weights_tiled 793 update_ops['tp'] = state_ops.assign_add(true_positives, 794 math_ops.reduce_sum( 795 is_true_positive, 1)) 796 values['tp'] = true_positives 797 798 if 'fn' in includes: 799 false_negatives = metrics_impl.metric_variable( 800 [num_thresholds], dtypes.float32, name='false_negatives') 801 is_false_negative = math_ops.cast( 802 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) 803 if weights_tiled is not None: 804 is_false_negative *= weights_tiled 805 update_ops['fn'] = state_ops.assign_add(false_negatives, 806 math_ops.reduce_sum( 807 is_false_negative, 1)) 808 values['fn'] = false_negatives 809 810 if 'tn' in includes: 811 true_negatives = metrics_impl.metric_variable( 812 [num_thresholds], dtypes.float32, name='true_negatives') 813 is_true_negative = math_ops.cast( 814 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) 815 if weights_tiled is not None: 816 is_true_negative *= weights_tiled 817 update_ops['tn'] = state_ops.assign_add(true_negatives, 818 math_ops.reduce_sum( 819 is_true_negative, 1)) 820 values['tn'] = true_negatives 821 822 if 'fp' in includes: 823 false_positives = metrics_impl.metric_variable( 824 [num_thresholds], dtypes.float32, name='false_positives') 825 is_false_positive = math_ops.cast( 826 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) 827 if weights_tiled is not None: 828 is_false_positive *= weights_tiled 829 update_ops['fp'] = state_ops.assign_add(false_positives, 830 math_ops.reduce_sum( 831 is_false_positive, 1)) 832 values['fp'] = false_positives 833 834 return values, update_ops 835 836 837def streaming_true_positives_at_thresholds(predictions, 838 labels, 839 thresholds, 840 weights=None): 841 values, update_ops = _streaming_confusion_matrix_at_thresholds( 842 predictions, labels, thresholds, weights=weights, includes=('tp',)) 843 return values['tp'], update_ops['tp'] 844 845 846def streaming_false_negatives_at_thresholds(predictions, 847 labels, 848 thresholds, 849 weights=None): 850 values, update_ops = _streaming_confusion_matrix_at_thresholds( 851 predictions, labels, thresholds, weights=weights, includes=('fn',)) 852 return values['fn'], update_ops['fn'] 853 854 855def streaming_false_positives_at_thresholds(predictions, 856 labels, 857 thresholds, 858 weights=None): 859 values, update_ops = _streaming_confusion_matrix_at_thresholds( 860 predictions, labels, thresholds, weights=weights, includes=('fp',)) 861 return values['fp'], update_ops['fp'] 862 863 864def streaming_true_negatives_at_thresholds(predictions, 865 labels, 866 thresholds, 867 weights=None): 868 values, update_ops = _streaming_confusion_matrix_at_thresholds( 869 predictions, labels, thresholds, weights=weights, includes=('tn',)) 870 return values['tn'], update_ops['tn'] 871 872 873def streaming_curve_points(labels=None, 874 predictions=None, 875 weights=None, 876 num_thresholds=200, 877 metrics_collections=None, 878 updates_collections=None, 879 curve='ROC', 880 name=None): 881 """Computes curve (ROC or PR) values for a prespecified number of points. 882 883 The `streaming_curve_points` function creates four local variables, 884 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 885 that are used to compute the curve values. To discretize the curve, a linearly 886 spaced set of thresholds is used to compute pairs of recall and precision 887 values. 888 889 For best results, `predictions` should be distributed approximately uniformly 890 in the range [0, 1] and not peaked around 0 or 1. 891 892 For estimation of the metric over a stream of data, the function creates an 893 `update_op` operation that updates these variables. 894 895 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 896 897 Args: 898 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 899 `bool`. 900 predictions: A floating point `Tensor` of arbitrary shape and whose values 901 are in the range `[0, 1]`. 902 weights: Optional `Tensor` whose rank is either 0, or the same rank as 903 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 904 be either `1`, or the same as the corresponding `labels` dimension). 905 num_thresholds: The number of thresholds to use when discretizing the roc 906 curve. 907 metrics_collections: An optional list of collections that `auc` should be 908 added to. 909 updates_collections: An optional list of collections that `update_op` should 910 be added to. 911 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 912 'PR' for the Precision-Recall-curve. 913 name: An optional variable_scope name. 914 915 Returns: 916 points: A `Tensor` with shape [num_thresholds, 2] that contains points of 917 the curve. 918 update_op: An operation that increments the `true_positives`, 919 `true_negatives`, `false_positives` and `false_negatives` variables. 920 921 Raises: 922 ValueError: If `predictions` and `labels` have mismatched shapes, or if 923 `weights` is not `None` and its shape doesn't match `predictions`, or if 924 either `metrics_collections` or `updates_collections` are not a list or 925 tuple. 926 927 TODO(chizeng): Consider rewriting this method to make use of logic within the 928 precision_recall_at_equal_thresholds method (to improve run time). 929 """ 930 with variable_scope.variable_scope(name, 'curve_points', 931 (labels, predictions, weights)): 932 if curve != 'ROC' and curve != 'PR': 933 raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) 934 kepsilon = _EPSILON # to account for floating point imprecisions 935 thresholds = [ 936 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 937 ] 938 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 939 940 values, update_ops = _streaming_confusion_matrix_at_thresholds( 941 labels=labels, 942 predictions=predictions, 943 thresholds=thresholds, 944 weights=weights) 945 946 # Add epsilons to avoid dividing by 0. 947 epsilon = 1.0e-6 948 949 def compute_points(tp, fn, tn, fp): 950 """Computes the roc-auc or pr-auc based on confusion counts.""" 951 rec = math_ops.div(tp + epsilon, tp + fn + epsilon) 952 if curve == 'ROC': 953 fp_rate = math_ops.div(fp, fp + tn + epsilon) 954 return fp_rate, rec 955 else: # curve == 'PR'. 956 prec = math_ops.div(tp + epsilon, tp + fp + epsilon) 957 return rec, prec 958 959 xs, ys = compute_points(values['tp'], values['fn'], values['tn'], 960 values['fp']) 961 points = array_ops.stack([xs, ys], axis=1) 962 update_op = control_flow_ops.group(*update_ops.values()) 963 964 if metrics_collections: 965 ops.add_to_collections(metrics_collections, points) 966 967 if updates_collections: 968 ops.add_to_collections(updates_collections, update_op) 969 970 return points, update_op 971 972 973@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of ' 974 'the labels and predictions arguments has been switched.') 975def streaming_auc(predictions, 976 labels, 977 weights=None, 978 num_thresholds=200, 979 metrics_collections=None, 980 updates_collections=None, 981 curve='ROC', 982 name=None): 983 """Computes the approximate AUC via a Riemann sum. 984 985 The `streaming_auc` function creates four local variables, `true_positives`, 986 `true_negatives`, `false_positives` and `false_negatives` that are used to 987 compute the AUC. To discretize the AUC curve, a linearly spaced set of 988 thresholds is used to compute pairs of recall and precision values. The area 989 under the ROC-curve is therefore computed using the height of the recall 990 values by the false positive rate, while the area under the PR-curve is the 991 computed using the height of the precision values by the recall. 992 993 This value is ultimately returned as `auc`, an idempotent operation that 994 computes the area under a discretized curve of precision versus recall values 995 (computed using the aforementioned variables). The `num_thresholds` variable 996 controls the degree of discretization with larger numbers of thresholds more 997 closely approximating the true AUC. The quality of the approximation may vary 998 dramatically depending on `num_thresholds`. 999 1000 For best results, `predictions` should be distributed approximately uniformly 1001 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 1002 approximation may be poor if this is not the case. 1003 1004 For estimation of the metric over a stream of data, the function creates an 1005 `update_op` operation that updates these variables and returns the `auc`. 1006 1007 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1008 1009 Args: 1010 predictions: A floating point `Tensor` of arbitrary shape and whose values 1011 are in the range `[0, 1]`. 1012 labels: A `bool` `Tensor` whose shape matches `predictions`. 1013 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1014 must be broadcastable to `labels` (i.e., all dimensions must be either 1015 `1`, or the same as the corresponding `labels` dimension). 1016 num_thresholds: The number of thresholds to use when discretizing the roc 1017 curve. 1018 metrics_collections: An optional list of collections that `auc` should be 1019 added to. 1020 updates_collections: An optional list of collections that `update_op` should 1021 be added to. 1022 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 1023 'PR' for the Precision-Recall-curve. 1024 name: An optional variable_scope name. 1025 1026 Returns: 1027 auc: A scalar `Tensor` representing the current area-under-curve. 1028 update_op: An operation that increments the `true_positives`, 1029 `true_negatives`, `false_positives` and `false_negatives` variables 1030 appropriately and whose value matches `auc`. 1031 1032 Raises: 1033 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1034 `weights` is not `None` and its shape doesn't match `predictions`, or if 1035 either `metrics_collections` or `updates_collections` are not a list or 1036 tuple. 1037 """ 1038 return metrics.auc( 1039 predictions=predictions, 1040 labels=labels, 1041 weights=weights, 1042 metrics_collections=metrics_collections, 1043 num_thresholds=num_thresholds, 1044 curve=curve, 1045 updates_collections=updates_collections, 1046 name=name) 1047 1048 1049def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None): 1050 """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. 1051 1052 Computes the area under the ROC or PR curve using each prediction as a 1053 threshold. This could be slow for large batches, but has the advantage of not 1054 having its results degrade depending on the distribution of predictions. 1055 1056 Args: 1057 labels: A `Tensor` of ground truth labels with the same shape as 1058 `predictions` with values of 0 or 1 and type `int64`. 1059 predictions: A 1-D `Tensor` of predictions whose values are `float64`. 1060 curve: The name of the curve to be computed, 'ROC' for the Receiving 1061 Operating Characteristic or 'PR' for the Precision-Recall curve. 1062 weights: A 1-D `Tensor` of weights whose values are `float64`. 1063 1064 Returns: 1065 A scalar `Tensor` containing the area-under-curve value for the input. 1066 """ 1067 # Compute the total weight and the total positive weight. 1068 size = array_ops.size(predictions) 1069 if weights is None: 1070 weights = array_ops.ones_like(labels, dtype=dtypes.float64) 1071 labels, predictions, weights = metrics_impl._remove_squeezable_dimensions( 1072 labels, predictions, weights) 1073 total_weight = math_ops.reduce_sum(weights) 1074 total_positive = math_ops.reduce_sum( 1075 array_ops.where( 1076 math_ops.greater(labels, 0), weights, 1077 array_ops.zeros_like(labels, dtype=dtypes.float64))) 1078 1079 def continue_computing_dynamic_auc(): 1080 """Continues dynamic auc computation, entered if labels are not all equal. 1081 1082 Returns: 1083 A scalar `Tensor` containing the area-under-curve value. 1084 """ 1085 # Sort the predictions descending, keeping the same order for the 1086 # corresponding labels and weights. 1087 ordered_predictions, indices = nn.top_k(predictions, k=size) 1088 ordered_labels = array_ops.gather(labels, indices) 1089 ordered_weights = array_ops.gather(weights, indices) 1090 1091 # Get the counts of the unique ordered predictions. 1092 _, _, counts = array_ops.unique_with_counts(ordered_predictions) 1093 1094 # Compute the indices of the split points between different predictions. 1095 splits = math_ops.cast( 1096 array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) 1097 1098 # Count the positives to the left of the split indices. 1099 true_positives = array_ops.gather( 1100 array_ops.pad( 1101 math_ops.cumsum( 1102 array_ops.where( 1103 math_ops.greater(ordered_labels, 0), ordered_weights, 1104 array_ops.zeros_like(ordered_labels, 1105 dtype=dtypes.float64))), 1106 paddings=[[1, 0]]), splits) 1107 if curve == 'ROC': 1108 # Compute the weight of the negatives to the left of every split point and 1109 # the total weight of the negatives number of negatives for computing the 1110 # FPR. 1111 false_positives = array_ops.gather( 1112 array_ops.pad( 1113 math_ops.cumsum( 1114 array_ops.where( 1115 math_ops.less(ordered_labels, 1), ordered_weights, 1116 array_ops.zeros_like( 1117 ordered_labels, dtype=dtypes.float64))), 1118 paddings=[[1, 0]]), splits) 1119 total_negative = total_weight - total_positive 1120 x_axis_values = math_ops.truediv(false_positives, total_negative) 1121 y_axis_values = math_ops.truediv(true_positives, total_positive) 1122 elif curve == 'PR': 1123 x_axis_values = math_ops.truediv(true_positives, total_positive) 1124 # For conformance, set precision to 1 when the number of positive 1125 # classifications is 0. 1126 positives = array_ops.gather( 1127 array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]), 1128 splits) 1129 y_axis_values = array_ops.where( 1130 math_ops.greater(splits, 0), 1131 math_ops.truediv(true_positives, positives), 1132 array_ops.ones_like(true_positives, dtype=dtypes.float64)) 1133 1134 # Calculate trapezoid areas. 1135 heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0 1136 widths = math_ops.abs( 1137 math_ops.subtract(x_axis_values[1:], x_axis_values[:-1])) 1138 return math_ops.reduce_sum(math_ops.multiply(heights, widths)) 1139 1140 # If all the labels are the same, AUC isn't well-defined (but raising an 1141 # exception seems excessive) so we return 0, otherwise we finish computing. 1142 return control_flow_ops.cond( 1143 math_ops.logical_or( 1144 math_ops.equal(total_positive, 0), math_ops.equal( 1145 total_positive, total_weight)), 1146 true_fn=lambda: array_ops.constant(0, dtypes.float64), 1147 false_fn=continue_computing_dynamic_auc) 1148 1149 1150def streaming_dynamic_auc(labels, 1151 predictions, 1152 curve='ROC', 1153 metrics_collections=(), 1154 updates_collections=(), 1155 name=None, 1156 weights=None): 1157 """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. 1158 1159 USAGE NOTE: this approach requires storing all of the predictions and labels 1160 for a single evaluation in memory, so it may not be usable when the evaluation 1161 batch size and/or the number of evaluation steps is very large. 1162 1163 Computes the area under the ROC or PR curve using each prediction as a 1164 threshold. This has the advantage of being resilient to the distribution of 1165 predictions by aggregating across batches, accumulating labels and predictions 1166 and performing the final calculation using all of the concatenated values. 1167 1168 Args: 1169 labels: A `Tensor` of ground truth labels with the same shape as `labels` 1170 and with values of 0 or 1 whose values are castable to `int64`. 1171 predictions: A `Tensor` of predictions whose values are castable to 1172 `float64`. Will be flattened into a 1-D `Tensor`. 1173 curve: The name of the curve for which to compute AUC, 'ROC' for the 1174 Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. 1175 metrics_collections: An optional iterable of collections that `auc` should 1176 be added to. 1177 updates_collections: An optional iterable of collections that `update_op` 1178 should be added to. 1179 name: An optional name for the variable_scope that contains the metric 1180 variables. 1181 weights: A 'Tensor' of non-negative weights whose values are castable to 1182 `float64`. Will be flattened into a 1-D `Tensor`. 1183 1184 Returns: 1185 auc: A scalar `Tensor` containing the current area-under-curve value. 1186 update_op: An operation that concatenates the input labels and predictions 1187 to the accumulated values. 1188 1189 Raises: 1190 ValueError: If `labels` and `predictions` have mismatched shapes or if 1191 `curve` isn't a recognized curve type. 1192 """ 1193 1194 if curve not in ['PR', 'ROC']: 1195 raise ValueError('curve must be either ROC or PR, %s unknown' % curve) 1196 1197 with variable_scope.variable_scope(name, default_name='dynamic_auc'): 1198 labels.get_shape().assert_is_compatible_with(predictions.get_shape()) 1199 predictions = array_ops.reshape( 1200 math_ops.cast(predictions, dtypes.float64), [-1]) 1201 labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) 1202 with ops.control_dependencies([ 1203 check_ops.assert_greater_equal( 1204 labels, 1205 array_ops.zeros_like(labels, dtypes.int64), 1206 message='labels must be 0 or 1, at least one is <0'), 1207 check_ops.assert_less_equal( 1208 labels, 1209 array_ops.ones_like(labels, dtypes.int64), 1210 message='labels must be 0 or 1, at least one is >1'), 1211 ]): 1212 preds_accum, update_preds = streaming_concat( 1213 predictions, name='concat_preds') 1214 labels_accum, update_labels = streaming_concat( 1215 labels, name='concat_labels') 1216 if weights is not None: 1217 weights = array_ops.reshape( 1218 math_ops.cast(weights, dtypes.float64), [-1]) 1219 weights_accum, update_weights = streaming_concat( 1220 weights, name='concat_weights') 1221 update_op = control_flow_ops.group(update_labels, update_preds, 1222 update_weights) 1223 else: 1224 weights_accum = None 1225 update_op = control_flow_ops.group(update_labels, update_preds) 1226 auc = _compute_dynamic_auc( 1227 labels_accum, preds_accum, curve=curve, weights=weights_accum) 1228 if updates_collections: 1229 ops.add_to_collections(updates_collections, update_op) 1230 if metrics_collections: 1231 ops.add_to_collections(metrics_collections, auc) 1232 return auc, update_op 1233 1234 1235def _compute_placement_auc(labels, predictions, weights, alpha, 1236 logit_transformation, is_valid): 1237 """Computes the AUC and asymptotic normally distributed confidence interval. 1238 1239 The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the 1240 concept of placement values for each labeled group, as presented by Delong and 1241 Delong (1988). The actual algorithm used is a more computationally efficient 1242 approach presented by Sun and Xu (2014). This could be slow for large batches, 1243 but has the advantage of not having its results degrade depending on the 1244 distribution of predictions. 1245 1246 Args: 1247 labels: A `Tensor` of ground truth labels with the same shape as 1248 `predictions` with values of 0 or 1 and type `int64`. 1249 predictions: A 1-D `Tensor` of predictions whose values are `float64`. 1250 weights: `Tensor` whose rank is either 0, or the same rank as `labels`. 1251 alpha: Confidence interval level desired. 1252 logit_transformation: A boolean value indicating whether the estimate should 1253 be logit transformed prior to calculating the confidence interval. Doing 1254 so enforces the restriction that the AUC should never be outside the 1255 interval [0,1]. 1256 is_valid: A bool tensor describing whether the input is valid. 1257 1258 Returns: 1259 A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence 1260 interval values. 1261 """ 1262 # Disable the invalid-name checker so that we can capitalize the name. 1263 # pylint: disable=invalid-name 1264 AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper']) 1265 # pylint: enable=invalid-name 1266 1267 # If all the labels are the same or if number of observations are too few, 1268 # AUC isn't well-defined 1269 size = array_ops.size(predictions, out_type=dtypes.int32) 1270 1271 # Count the total number of positive and negative labels in the input. 1272 total_0 = math_ops.reduce_sum( 1273 math_ops.cast(1 - labels, weights.dtype) * weights) 1274 total_1 = math_ops.reduce_sum( 1275 math_ops.cast(labels, weights.dtype) * weights) 1276 1277 # Sort the predictions ascending, as well as 1278 # (i) the corresponding labels and 1279 # (ii) the corresponding weights. 1280 ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True) 1281 ordered_predictions = array_ops.reverse( 1282 ordered_predictions, axis=array_ops.zeros(1, dtypes.int32)) 1283 indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32)) 1284 ordered_labels = array_ops.gather(labels, indices) 1285 ordered_weights = array_ops.gather(weights, indices) 1286 1287 # We now compute values required for computing placement values. 1288 1289 # We generate a list of indices (segmented_indices) of increasing order. An 1290 # index is assigned for each unique prediction float value. Prediction 1291 # values that are the same share the same index. 1292 _, segmented_indices = array_ops.unique(ordered_predictions) 1293 1294 # We create 2 tensors of weights. weights_for_true is non-zero for true 1295 # labels. weights_for_false is non-zero for false labels. 1296 float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32) 1297 float_labels_for_false = 1.0 - float_labels_for_true 1298 weights_for_true = ordered_weights * float_labels_for_true 1299 weights_for_false = ordered_weights * float_labels_for_false 1300 1301 # For each set of weights with the same segmented indices, we add up the 1302 # weight values. Note that for each label, we deliberately rely on weights 1303 # for the opposite label. 1304 weight_totals_for_true = math_ops.segment_sum(weights_for_false, 1305 segmented_indices) 1306 weight_totals_for_false = math_ops.segment_sum(weights_for_true, 1307 segmented_indices) 1308 1309 # These cumulative sums of weights importantly exclude the current weight 1310 # sums. 1311 cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true, 1312 exclusive=True) 1313 cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false, 1314 exclusive=True) 1315 1316 # Compute placement values using the formula. Values with the same segmented 1317 # indices and labels share the same placement values. 1318 placements_for_true = ( 1319 (cum_weight_totals_for_true + weight_totals_for_true / 2.0) / 1320 (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON)) 1321 placements_for_false = ( 1322 (cum_weight_totals_for_false + weight_totals_for_false / 2.0) / 1323 (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON)) 1324 1325 # We expand the tensors of placement values (for each label) so that their 1326 # shapes match that of predictions. 1327 placements_for_true = array_ops.gather(placements_for_true, segmented_indices) 1328 placements_for_false = array_ops.gather(placements_for_false, 1329 segmented_indices) 1330 1331 # Select placement values based on the label for each index. 1332 placement_values = ( 1333 placements_for_true * float_labels_for_true + 1334 placements_for_false * float_labels_for_false) 1335 1336 # Split placement values by labeled groups. 1337 placement_values_0 = placement_values * math_ops.cast( 1338 1 - ordered_labels, weights.dtype) 1339 weights_0 = ordered_weights * math_ops.cast( 1340 1 - ordered_labels, weights.dtype) 1341 placement_values_1 = placement_values * math_ops.cast( 1342 ordered_labels, weights.dtype) 1343 weights_1 = ordered_weights * math_ops.cast( 1344 ordered_labels, weights.dtype) 1345 1346 # Calculate AUC using placement values 1347 auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) / 1348 (total_0 + _EPSILON)) 1349 auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) / 1350 (total_1 + _EPSILON)) 1351 auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0) 1352 1353 # Calculate variance and standard error using the placement values. 1354 var_0 = ( 1355 math_ops.reduce_sum( 1356 weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / 1357 (total_0 - 1. + _EPSILON)) 1358 var_1 = ( 1359 math_ops.reduce_sum(weights_1 * math_ops.squared_difference( 1360 placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON)) 1361 auc_std_err = math_ops.sqrt( 1362 (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) 1363 1364 # Calculate asymptotic normal confidence intervals 1365 std_norm_dist = Normal(loc=0., scale=1.) 1366 z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0) 1367 if logit_transformation: 1368 estimate = math_ops.log(auc / (1. - auc + _EPSILON)) 1369 std_err = auc_std_err / (auc * (1. - auc + _EPSILON)) 1370 transformed_auc_lower = estimate + (z_value * std_err) 1371 transformed_auc_upper = estimate - (z_value * std_err) 1372 def inverse_logit_transformation(x): 1373 exp_negative = math_ops.exp(math_ops.negative(x)) 1374 return 1. / (1. + exp_negative + _EPSILON) 1375 1376 auc_lower = inverse_logit_transformation(transformed_auc_lower) 1377 auc_upper = inverse_logit_transformation(transformed_auc_upper) 1378 else: 1379 estimate = auc 1380 std_err = auc_std_err 1381 auc_lower = estimate + (z_value * std_err) 1382 auc_upper = estimate - (z_value * std_err) 1383 1384 ## If estimate is 1 or 0, no variance is present so CI = 1 1385 ## n.b. This can be misleading, since number obs can just be too low. 1386 lower = array_ops.where( 1387 math_ops.logical_or( 1388 math_ops.equal(auc, array_ops.ones_like(auc)), 1389 math_ops.equal(auc, array_ops.zeros_like(auc))), 1390 auc, auc_lower) 1391 upper = array_ops.where( 1392 math_ops.logical_or( 1393 math_ops.equal(auc, array_ops.ones_like(auc)), 1394 math_ops.equal(auc, array_ops.zeros_like(auc))), 1395 auc, auc_upper) 1396 1397 # If all the labels are the same, AUC isn't well-defined (but raising an 1398 # exception seems excessive) so we return 0, otherwise we finish computing. 1399 trivial_value = array_ops.constant(0.0) 1400 1401 return AucData(*control_flow_ops.cond( 1402 is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3)) 1403 1404 1405def auc_with_confidence_intervals(labels, 1406 predictions, 1407 weights=None, 1408 alpha=0.95, 1409 logit_transformation=True, 1410 metrics_collections=(), 1411 updates_collections=(), 1412 name=None): 1413 """Computes the AUC and asymptotic normally distributed confidence interval. 1414 1415 USAGE NOTE: this approach requires storing all of the predictions and labels 1416 for a single evaluation in memory, so it may not be usable when the evaluation 1417 batch size and/or the number of evaluation steps is very large. 1418 1419 Computes the area under the ROC curve and its confidence interval using 1420 placement values. This has the advantage of being resilient to the 1421 distribution of predictions by aggregating across batches, accumulating labels 1422 and predictions and performing the final calculation using all of the 1423 concatenated values. 1424 1425 Args: 1426 labels: A `Tensor` of ground truth labels with the same shape as `labels` 1427 and with values of 0 or 1 whose values are castable to `int64`. 1428 predictions: A `Tensor` of predictions whose values are castable to 1429 `float64`. Will be flattened into a 1-D `Tensor`. 1430 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1431 `labels`. 1432 alpha: Confidence interval level desired. 1433 logit_transformation: A boolean value indicating whether the estimate should 1434 be logit transformed prior to calculating the confidence interval. Doing 1435 so enforces the restriction that the AUC should never be outside the 1436 interval [0,1]. 1437 metrics_collections: An optional iterable of collections that `auc` should 1438 be added to. 1439 updates_collections: An optional iterable of collections that `update_op` 1440 should be added to. 1441 name: An optional name for the variable_scope that contains the metric 1442 variables. 1443 1444 Returns: 1445 auc: A 1-D `Tensor` containing the current area-under-curve, lower, and 1446 upper confidence interval values. 1447 update_op: An operation that concatenates the input labels and predictions 1448 to the accumulated values. 1449 1450 Raises: 1451 ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes 1452 or if `alpha` isn't in the range (0,1). 1453 """ 1454 if not (alpha > 0 and alpha < 1): 1455 raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha) 1456 1457 if weights is None: 1458 weights = array_ops.ones_like(predictions) 1459 1460 with variable_scope.variable_scope( 1461 name, 1462 default_name='auc_with_confidence_intervals', 1463 values=[labels, predictions, weights]): 1464 1465 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 1466 predictions=predictions, 1467 labels=labels, 1468 weights=weights) 1469 1470 total_weight = math_ops.reduce_sum(weights) 1471 1472 weights = array_ops.reshape(weights, [-1]) 1473 predictions = array_ops.reshape( 1474 math_ops.cast(predictions, dtypes.float64), [-1]) 1475 labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) 1476 1477 with ops.control_dependencies([ 1478 check_ops.assert_greater_equal( 1479 labels, 1480 array_ops.zeros_like(labels, dtypes.int64), 1481 message='labels must be 0 or 1, at least one is <0'), 1482 check_ops.assert_less_equal( 1483 labels, 1484 array_ops.ones_like(labels, dtypes.int64), 1485 message='labels must be 0 or 1, at least one is >1'), 1486 ]): 1487 preds_accum, update_preds = streaming_concat( 1488 predictions, name='concat_preds') 1489 labels_accum, update_labels = streaming_concat(labels, 1490 name='concat_labels') 1491 weights_accum, update_weights = streaming_concat( 1492 weights, name='concat_weights') 1493 update_op_for_valid_case = control_flow_ops.group( 1494 update_labels, update_preds, update_weights) 1495 1496 # Only perform updates if this case is valid. 1497 all_labels_positive_or_0 = math_ops.logical_and( 1498 math_ops.equal(math_ops.reduce_min(labels), 0), 1499 math_ops.equal(math_ops.reduce_max(labels), 1)) 1500 sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0) 1501 is_valid = math_ops.logical_and(all_labels_positive_or_0, 1502 sums_of_weights_at_least_1) 1503 1504 update_op = control_flow_ops.cond( 1505 sums_of_weights_at_least_1, 1506 lambda: update_op_for_valid_case, control_flow_ops.no_op) 1507 1508 auc = _compute_placement_auc( 1509 labels_accum, 1510 preds_accum, 1511 weights_accum, 1512 alpha=alpha, 1513 logit_transformation=logit_transformation, 1514 is_valid=is_valid) 1515 1516 if updates_collections: 1517 ops.add_to_collections(updates_collections, update_op) 1518 if metrics_collections: 1519 ops.add_to_collections(metrics_collections, auc) 1520 return auc, update_op 1521 1522 1523def precision_recall_at_equal_thresholds(labels, 1524 predictions, 1525 weights=None, 1526 num_thresholds=None, 1527 use_locking=None, 1528 name=None): 1529 """A helper method for creating metrics related to precision-recall curves. 1530 1531 These values are true positives, false negatives, true negatives, false 1532 positives, precision, and recall. This function returns a data structure that 1533 contains ops within it. 1534 1535 Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N) 1536 space and run time), this op exhibits O(T + N) space and run time, where T is 1537 the number of thresholds and N is the size of the predictions tensor. Hence, 1538 it may be advantageous to use this function when `predictions` is big. 1539 1540 For instance, prefer this method for per-pixel classification tasks, for which 1541 the predictions tensor may be very large. 1542 1543 Each number in `predictions`, a float in `[0, 1]`, is compared with its 1544 corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at 1545 each threshold. This is then multiplied with `weights` which can be used to 1546 reweight certain values, or more commonly used for masking values. 1547 1548 Args: 1549 labels: A bool `Tensor` whose shape matches `predictions`. 1550 predictions: A floating point `Tensor` of arbitrary shape and whose values 1551 are in the range `[0, 1]`. 1552 weights: Optional; If provided, a `Tensor` that has the same dtype as, 1553 and broadcastable to, `predictions`. This tensor is multiplied by counts. 1554 num_thresholds: Optional; Number of thresholds, evenly distributed in 1555 `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins 1556 is 1 less than `num_thresholds`. Using an even `num_thresholds` value 1557 instead of an odd one may yield unfriendly edges for bins. 1558 use_locking: Optional; If True, the op will be protected by a lock. 1559 Otherwise, the behavior is undefined, but may exhibit less contention. 1560 Defaults to True. 1561 name: Optional; variable_scope name. If not provided, the string 1562 'precision_recall_at_equal_threshold' is used. 1563 1564 Returns: 1565 result: A named tuple (See PrecisionRecallData within the implementation of 1566 this function) with properties that are variables of shape 1567 `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, 1568 precision, recall, thresholds. Types are same as that of predictions. 1569 update_op: An op that accumulates values. 1570 1571 Raises: 1572 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1573 `weights` is not `None` and its shape doesn't match `predictions`, or if 1574 `includes` contains invalid keys. 1575 """ 1576 # Disable the invalid-name checker so that we can capitalize the name. 1577 # pylint: disable=invalid-name 1578 PrecisionRecallData = collections_lib.namedtuple( 1579 'PrecisionRecallData', 1580 ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds']) 1581 # pylint: enable=invalid-name 1582 1583 if num_thresholds is None: 1584 num_thresholds = 201 1585 1586 if weights is None: 1587 weights = 1.0 1588 1589 if use_locking is None: 1590 use_locking = True 1591 1592 check_ops.assert_type(labels, dtypes.bool) 1593 1594 with variable_scope.variable_scope(name, 1595 'precision_recall_at_equal_thresholds', 1596 (labels, predictions, weights)): 1597 # Make sure that predictions are within [0.0, 1.0]. 1598 with ops.control_dependencies([ 1599 check_ops.assert_greater_equal( 1600 predictions, 1601 math_ops.cast(0.0, dtype=predictions.dtype), 1602 message='predictions must be in [0, 1]'), 1603 check_ops.assert_less_equal( 1604 predictions, 1605 math_ops.cast(1.0, dtype=predictions.dtype), 1606 message='predictions must be in [0, 1]') 1607 ]): 1608 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 1609 predictions=predictions, 1610 labels=labels, 1611 weights=weights) 1612 1613 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1614 1615 # It's important we aggregate using float64 since we're accumulating a lot 1616 # of 1.0's for the true/false labels, and accumulating to float32 will 1617 # be quite inaccurate even with just a modest amount of values (~20M). 1618 # We use float64 instead of integer primarily since GPU scatter kernel 1619 # only support floats. 1620 agg_dtype = dtypes.float64 1621 1622 f_labels = math_ops.cast(labels, agg_dtype) 1623 weights = math_ops.cast(weights, agg_dtype) 1624 true_labels = f_labels * weights 1625 false_labels = (1.0 - f_labels) * weights 1626 1627 # Flatten predictions and labels. 1628 predictions = array_ops.reshape(predictions, [-1]) 1629 true_labels = array_ops.reshape(true_labels, [-1]) 1630 false_labels = array_ops.reshape(false_labels, [-1]) 1631 1632 # To compute TP/FP/TN/FN, we are measuring a binary classifier 1633 # C(t) = (predictions >= t) 1634 # at each threshold 't'. So we have 1635 # TP(t) = sum( C(t) * true_labels ) 1636 # FP(t) = sum( C(t) * false_labels ) 1637 # 1638 # But, computing C(t) requires computation for each t. To make it fast, 1639 # observe that C(t) is a cumulative integral, and so if we have 1640 # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 1641 # where n = num_thresholds, and if we can compute the bucket function 1642 # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 1643 # then we get 1644 # C(t_i) = sum( B(j), j >= i ) 1645 # which is the reversed cumulative sum in tf.cumsum(). 1646 # 1647 # We can compute B(i) efficiently by taking advantage of the fact that 1648 # our thresholds are evenly distributed, in that 1649 # width = 1.0 / (num_thresholds - 1) 1650 # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 1651 # Given a prediction value p, we can map it to its bucket by 1652 # bucket_index(p) = floor( p * (num_thresholds - 1) ) 1653 # so we can use tf.scatter_add() to update the buckets in one pass. 1654 # 1655 # This implementation exhibits a run time and space complexity of O(T + N), 1656 # where T is the number of thresholds and N is the size of predictions. 1657 # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead 1658 # exhibit a complexity of O(T * N). 1659 1660 # Compute the bucket indices for each prediction value. 1661 bucket_indices = math_ops.cast( 1662 math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) 1663 1664 with ops.name_scope('variables'): 1665 tp_buckets_v = metrics_impl.metric_variable( 1666 [num_thresholds], agg_dtype, name='tp_buckets') 1667 fp_buckets_v = metrics_impl.metric_variable( 1668 [num_thresholds], agg_dtype, name='fp_buckets') 1669 1670 with ops.name_scope('update_op'): 1671 update_tp = state_ops.scatter_add( 1672 tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking) 1673 update_fp = state_ops.scatter_add( 1674 fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking) 1675 1676 # Set up the cumulative sums to compute the actual metrics. 1677 tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp') 1678 fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp') 1679 # fn = sum(true_labels) - tp 1680 # = sum(tp_buckets) - tp 1681 # = tp[0] - tp 1682 # Similarly, 1683 # tn = fp[0] - fp 1684 tn = fp[0] - fp 1685 fn = tp[0] - tp 1686 1687 # We use a minimum to prevent division by 0. 1688 epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype) 1689 precision = tp / math_ops.maximum(epsilon, tp + fp) 1690 recall = tp / math_ops.maximum(epsilon, tp + fn) 1691 1692 # Convert all tensors back to predictions' dtype (as per function contract). 1693 out_dtype = predictions.dtype 1694 _convert = lambda tensor: math_ops.cast(tensor, out_dtype) 1695 result = PrecisionRecallData( 1696 tp=_convert(tp), 1697 fp=_convert(fp), 1698 tn=_convert(tn), 1699 fn=_convert(fn), 1700 precision=_convert(precision), 1701 recall=_convert(recall), 1702 thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds))) 1703 update_op = control_flow_ops.group(update_tp, update_fp) 1704 return result, update_op 1705 1706 1707def streaming_specificity_at_sensitivity(predictions, 1708 labels, 1709 sensitivity, 1710 weights=None, 1711 num_thresholds=200, 1712 metrics_collections=None, 1713 updates_collections=None, 1714 name=None): 1715 """Computes the specificity at a given sensitivity. 1716 1717 The `streaming_specificity_at_sensitivity` function creates four local 1718 variables, `true_positives`, `true_negatives`, `false_positives` and 1719 `false_negatives` that are used to compute the specificity at the given 1720 sensitivity value. The threshold for the given sensitivity value is computed 1721 and used to evaluate the corresponding specificity. 1722 1723 For estimation of the metric over a stream of data, the function creates an 1724 `update_op` operation that updates these variables and returns the 1725 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 1726 `false_positives` and `false_negatives` counts with the weight of each case 1727 found in the `predictions` and `labels`. 1728 1729 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1730 1731 For additional information about specificity and sensitivity, see the 1732 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1733 1734 Args: 1735 predictions: A floating point `Tensor` of arbitrary shape and whose values 1736 are in the range `[0, 1]`. 1737 labels: A `bool` `Tensor` whose shape matches `predictions`. 1738 sensitivity: A scalar value in range `[0, 1]`. 1739 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1740 must be broadcastable to `labels` (i.e., all dimensions must be either 1741 `1`, or the same as the corresponding `labels` dimension). 1742 num_thresholds: The number of thresholds to use for matching the given 1743 sensitivity. 1744 metrics_collections: An optional list of collections that `specificity` 1745 should be added to. 1746 updates_collections: An optional list of collections that `update_op` should 1747 be added to. 1748 name: An optional variable_scope name. 1749 1750 Returns: 1751 specificity: A scalar `Tensor` representing the specificity at the given 1752 `specificity` value. 1753 update_op: An operation that increments the `true_positives`, 1754 `true_negatives`, `false_positives` and `false_negatives` variables 1755 appropriately and whose value matches `specificity`. 1756 1757 Raises: 1758 ValueError: If `predictions` and `labels` have mismatched shapes, if 1759 `weights` is not `None` and its shape doesn't match `predictions`, or if 1760 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 1761 or `updates_collections` are not a list or tuple. 1762 """ 1763 return metrics.specificity_at_sensitivity( 1764 sensitivity=sensitivity, 1765 num_thresholds=num_thresholds, 1766 predictions=predictions, 1767 labels=labels, 1768 weights=weights, 1769 metrics_collections=metrics_collections, 1770 updates_collections=updates_collections, 1771 name=name) 1772 1773 1774def streaming_sensitivity_at_specificity(predictions, 1775 labels, 1776 specificity, 1777 weights=None, 1778 num_thresholds=200, 1779 metrics_collections=None, 1780 updates_collections=None, 1781 name=None): 1782 """Computes the sensitivity at a given specificity. 1783 1784 The `streaming_sensitivity_at_specificity` function creates four local 1785 variables, `true_positives`, `true_negatives`, `false_positives` and 1786 `false_negatives` that are used to compute the sensitivity at the given 1787 specificity value. The threshold for the given specificity value is computed 1788 and used to evaluate the corresponding sensitivity. 1789 1790 For estimation of the metric over a stream of data, the function creates an 1791 `update_op` operation that updates these variables and returns the 1792 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 1793 `false_positives` and `false_negatives` counts with the weight of each case 1794 found in the `predictions` and `labels`. 1795 1796 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1797 1798 For additional information about specificity and sensitivity, see the 1799 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1800 1801 Args: 1802 predictions: A floating point `Tensor` of arbitrary shape and whose values 1803 are in the range `[0, 1]`. 1804 labels: A `bool` `Tensor` whose shape matches `predictions`. 1805 specificity: A scalar value in range `[0, 1]`. 1806 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1807 must be broadcastable to `labels` (i.e., all dimensions must be either 1808 `1`, or the same as the corresponding `labels` dimension). 1809 num_thresholds: The number of thresholds to use for matching the given 1810 specificity. 1811 metrics_collections: An optional list of collections that `sensitivity` 1812 should be added to. 1813 updates_collections: An optional list of collections that `update_op` should 1814 be added to. 1815 name: An optional variable_scope name. 1816 1817 Returns: 1818 sensitivity: A scalar `Tensor` representing the sensitivity at the given 1819 `specificity` value. 1820 update_op: An operation that increments the `true_positives`, 1821 `true_negatives`, `false_positives` and `false_negatives` variables 1822 appropriately and whose value matches `sensitivity`. 1823 1824 Raises: 1825 ValueError: If `predictions` and `labels` have mismatched shapes, if 1826 `weights` is not `None` and its shape doesn't match `predictions`, or if 1827 `specificity` is not between 0 and 1, or if either `metrics_collections` 1828 or `updates_collections` are not a list or tuple. 1829 """ 1830 return metrics.sensitivity_at_specificity( 1831 specificity=specificity, 1832 num_thresholds=num_thresholds, 1833 predictions=predictions, 1834 labels=labels, 1835 weights=weights, 1836 metrics_collections=metrics_collections, 1837 updates_collections=updates_collections, 1838 name=name) 1839 1840 1841@deprecated(None, 1842 'Please switch to tf.metrics.precision_at_thresholds. Note that ' 1843 'the order of the labels and predictions arguments are switched.') 1844def streaming_precision_at_thresholds(predictions, 1845 labels, 1846 thresholds, 1847 weights=None, 1848 metrics_collections=None, 1849 updates_collections=None, 1850 name=None): 1851 """Computes precision values for different `thresholds` on `predictions`. 1852 1853 The `streaming_precision_at_thresholds` function creates four local variables, 1854 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1855 for various values of thresholds. `precision[i]` is defined as the total 1856 weight of values in `predictions` above `thresholds[i]` whose corresponding 1857 entry in `labels` is `True`, divided by the total weight of values in 1858 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 1859 false_positives[i])`). 1860 1861 For estimation of the metric over a stream of data, the function creates an 1862 `update_op` operation that updates these variables and returns the 1863 `precision`. 1864 1865 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1866 1867 Args: 1868 predictions: A floating point `Tensor` of arbitrary shape and whose values 1869 are in the range `[0, 1]`. 1870 labels: A `bool` `Tensor` whose shape matches `predictions`. 1871 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1872 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1873 must be broadcastable to `labels` (i.e., all dimensions must be either 1874 `1`, or the same as the corresponding `labels` dimension). 1875 metrics_collections: An optional list of collections that `precision` should 1876 be added to. 1877 updates_collections: An optional list of collections that `update_op` should 1878 be added to. 1879 name: An optional variable_scope name. 1880 1881 Returns: 1882 precision: A float `Tensor` of shape `[len(thresholds)]`. 1883 update_op: An operation that increments the `true_positives`, 1884 `true_negatives`, `false_positives` and `false_negatives` variables that 1885 are used in the computation of `precision`. 1886 1887 Raises: 1888 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1889 `weights` is not `None` and its shape doesn't match `predictions`, or if 1890 either `metrics_collections` or `updates_collections` are not a list or 1891 tuple. 1892 """ 1893 return metrics.precision_at_thresholds( 1894 thresholds=thresholds, 1895 predictions=predictions, 1896 labels=labels, 1897 weights=weights, 1898 metrics_collections=metrics_collections, 1899 updates_collections=updates_collections, 1900 name=name) 1901 1902 1903@deprecated(None, 1904 'Please switch to tf.metrics.recall_at_thresholds. Note that the ' 1905 'order of the labels and predictions arguments has been switched.') 1906def streaming_recall_at_thresholds(predictions, 1907 labels, 1908 thresholds, 1909 weights=None, 1910 metrics_collections=None, 1911 updates_collections=None, 1912 name=None): 1913 """Computes various recall values for different `thresholds` on `predictions`. 1914 1915 The `streaming_recall_at_thresholds` function creates four local variables, 1916 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1917 for various values of thresholds. `recall[i]` is defined as the total weight 1918 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1919 `labels` is `True`, divided by the total weight of `True` values in `labels` 1920 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 1921 1922 For estimation of the metric over a stream of data, the function creates an 1923 `update_op` operation that updates these variables and returns the `recall`. 1924 1925 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1926 1927 Args: 1928 predictions: A floating point `Tensor` of arbitrary shape and whose values 1929 are in the range `[0, 1]`. 1930 labels: A `bool` `Tensor` whose shape matches `predictions`. 1931 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1932 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1933 must be broadcastable to `labels` (i.e., all dimensions must be either 1934 `1`, or the same as the corresponding `labels` dimension). 1935 metrics_collections: An optional list of collections that `recall` should be 1936 added to. 1937 updates_collections: An optional list of collections that `update_op` should 1938 be added to. 1939 name: An optional variable_scope name. 1940 1941 Returns: 1942 recall: A float `Tensor` of shape `[len(thresholds)]`. 1943 update_op: An operation that increments the `true_positives`, 1944 `true_negatives`, `false_positives` and `false_negatives` variables that 1945 are used in the computation of `recall`. 1946 1947 Raises: 1948 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1949 `weights` is not `None` and its shape doesn't match `predictions`, or if 1950 either `metrics_collections` or `updates_collections` are not a list or 1951 tuple. 1952 """ 1953 return metrics.recall_at_thresholds( 1954 thresholds=thresholds, 1955 predictions=predictions, 1956 labels=labels, 1957 weights=weights, 1958 metrics_collections=metrics_collections, 1959 updates_collections=updates_collections, 1960 name=name) 1961 1962 1963def streaming_false_positive_rate_at_thresholds(predictions, 1964 labels, 1965 thresholds, 1966 weights=None, 1967 metrics_collections=None, 1968 updates_collections=None, 1969 name=None): 1970 """Computes various fpr values for different `thresholds` on `predictions`. 1971 1972 The `streaming_false_positive_rate_at_thresholds` function creates two 1973 local variables, `false_positives`, `true_negatives`, for various values of 1974 thresholds. `false_positive_rate[i]` is defined as the total weight 1975 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1976 `labels` is `False`, divided by the total weight of `False` values in `labels` 1977 (`false_positives[i] / (false_positives[i] + true_negatives[i])`). 1978 1979 For estimation of the metric over a stream of data, the function creates an 1980 `update_op` operation that updates these variables and returns the 1981 `false_positive_rate`. 1982 1983 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1984 1985 Args: 1986 predictions: A floating point `Tensor` of arbitrary shape and whose values 1987 are in the range `[0, 1]`. 1988 labels: A `bool` `Tensor` whose shape matches `predictions`. 1989 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1990 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1991 must be broadcastable to `labels` (i.e., all dimensions must be either 1992 `1`, or the same as the corresponding `labels` dimension). 1993 metrics_collections: An optional list of collections that 1994 `false_positive_rate` should be added to. 1995 updates_collections: An optional list of collections that `update_op` should 1996 be added to. 1997 name: An optional variable_scope name. 1998 1999 Returns: 2000 false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`. 2001 update_op: An operation that increments the `false_positives` and 2002 `true_negatives` variables that are used in the computation of 2003 `false_positive_rate`. 2004 2005 Raises: 2006 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2007 `weights` is not `None` and its shape doesn't match `predictions`, or if 2008 either `metrics_collections` or `updates_collections` are not a list or 2009 tuple. 2010 """ 2011 with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', 2012 (predictions, labels, weights)): 2013 values, update_ops = _streaming_confusion_matrix_at_thresholds( 2014 predictions, labels, thresholds, weights, includes=('fp', 'tn')) 2015 2016 # Avoid division by zero. 2017 epsilon = _EPSILON 2018 2019 def compute_fpr(fp, tn, name): 2020 return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) 2021 2022 fpr = compute_fpr(values['fp'], values['tn'], 'value') 2023 update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') 2024 2025 if metrics_collections: 2026 ops.add_to_collections(metrics_collections, fpr) 2027 2028 if updates_collections: 2029 ops.add_to_collections(updates_collections, update_op) 2030 2031 return fpr, update_op 2032 2033 2034def streaming_false_negative_rate_at_thresholds(predictions, 2035 labels, 2036 thresholds, 2037 weights=None, 2038 metrics_collections=None, 2039 updates_collections=None, 2040 name=None): 2041 """Computes various fnr values for different `thresholds` on `predictions`. 2042 2043 The `streaming_false_negative_rate_at_thresholds` function creates two 2044 local variables, `false_negatives`, `true_positives`, for various values of 2045 thresholds. `false_negative_rate[i]` is defined as the total weight 2046 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2047 `labels` is `False`, divided by the total weight of `True` values in `labels` 2048 (`false_negatives[i] / (false_negatives[i] + true_positives[i])`). 2049 2050 For estimation of the metric over a stream of data, the function creates an 2051 `update_op` operation that updates these variables and returns the 2052 `false_positive_rate`. 2053 2054 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2055 2056 Args: 2057 predictions: A floating point `Tensor` of arbitrary shape and whose values 2058 are in the range `[0, 1]`. 2059 labels: A `bool` `Tensor` whose shape matches `predictions`. 2060 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2061 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 2062 must be broadcastable to `labels` (i.e., all dimensions must be either 2063 `1`, or the same as the corresponding `labels` dimension). 2064 metrics_collections: An optional list of collections that 2065 `false_negative_rate` should be added to. 2066 updates_collections: An optional list of collections that `update_op` should 2067 be added to. 2068 name: An optional variable_scope name. 2069 2070 Returns: 2071 false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`. 2072 update_op: An operation that increments the `false_negatives` and 2073 `true_positives` variables that are used in the computation of 2074 `false_negative_rate`. 2075 2076 Raises: 2077 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2078 `weights` is not `None` and its shape doesn't match `predictions`, or if 2079 either `metrics_collections` or `updates_collections` are not a list or 2080 tuple. 2081 """ 2082 with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', 2083 (predictions, labels, weights)): 2084 values, update_ops = _streaming_confusion_matrix_at_thresholds( 2085 predictions, labels, thresholds, weights, includes=('fn', 'tp')) 2086 2087 # Avoid division by zero. 2088 epsilon = _EPSILON 2089 2090 def compute_fnr(fn, tp, name): 2091 return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) 2092 2093 fnr = compute_fnr(values['fn'], values['tp'], 'value') 2094 update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') 2095 2096 if metrics_collections: 2097 ops.add_to_collections(metrics_collections, fnr) 2098 2099 if updates_collections: 2100 ops.add_to_collections(updates_collections, update_op) 2101 2102 return fnr, update_op 2103 2104 2105def _at_k_name(name, k=None, class_id=None): 2106 if k is not None: 2107 name = '%s_at_%d' % (name, k) 2108 else: 2109 name = '%s_at_k' % (name) 2110 if class_id is not None: 2111 name = '%s_class%d' % (name, class_id) 2112 return name 2113 2114 2115@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 2116 'and reshape labels from [batch_size] to [batch_size, 1].') 2117def streaming_recall_at_k(predictions, 2118 labels, 2119 k, 2120 weights=None, 2121 metrics_collections=None, 2122 updates_collections=None, 2123 name=None): 2124 """Computes the recall@k of the predictions with respect to dense labels. 2125 2126 The `streaming_recall_at_k` function creates two local variables, `total` and 2127 `count`, that are used to compute the recall@k frequency. This frequency is 2128 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 2129 divides `total` by `count`. 2130 2131 For estimation of the metric over a stream of data, the function creates an 2132 `update_op` operation that updates these variables and returns the 2133 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 2134 shape [batch_size] whose elements indicate whether or not the corresponding 2135 label is in the top `k` `predictions`. Then `update_op` increments `total` 2136 with the reduced sum of `weights` where `in_top_k` is `True`, and it 2137 increments `count` with the reduced sum of `weights`. 2138 2139 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2140 2141 Args: 2142 predictions: A float `Tensor` of dimension [batch_size, num_classes]. 2143 labels: A `Tensor` of dimension [batch_size] whose type is in `int32`, 2144 `int64`. 2145 k: The number of top elements to look at for computing recall. 2146 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 2147 must be broadcastable to `labels` (i.e., all dimensions must be either 2148 `1`, or the same as the corresponding `labels` dimension). 2149 metrics_collections: An optional list of collections that `recall_at_k` 2150 should be added to. 2151 updates_collections: An optional list of collections `update_op` should be 2152 added to. 2153 name: An optional variable_scope name. 2154 2155 Returns: 2156 recall_at_k: A `Tensor` representing the recall@k, the fraction of labels 2157 which fall into the top `k` predictions. 2158 update_op: An operation that increments the `total` and `count` variables 2159 appropriately and whose value matches `recall_at_k`. 2160 2161 Raises: 2162 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2163 `weights` is not `None` and its shape doesn't match `predictions`, or if 2164 either `metrics_collections` or `updates_collections` are not a list or 2165 tuple. 2166 """ 2167 in_top_k = math_ops.cast(nn.in_top_k(predictions, labels, k), dtypes.float32) 2168 return streaming_mean(in_top_k, weights, metrics_collections, 2169 updates_collections, name or _at_k_name('recall', k)) 2170 2171 2172# TODO(ptucker): Validate range of values in labels? 2173def streaming_sparse_recall_at_k(predictions, 2174 labels, 2175 k, 2176 class_id=None, 2177 weights=None, 2178 metrics_collections=None, 2179 updates_collections=None, 2180 name=None): 2181 """Computes recall@k of the predictions with respect to sparse labels. 2182 2183 If `class_id` is not specified, we'll calculate recall as the ratio of true 2184 positives (i.e., correct predictions, items in the top `k` highest 2185 `predictions` that are found in the corresponding row in `labels`) to 2186 actual positives (the full `labels` row). 2187 If `class_id` is specified, we calculate recall by considering only the rows 2188 in the batch for which `class_id` is in `labels`, and computing the 2189 fraction of them for which `class_id` is in the corresponding row in 2190 `labels`. 2191 2192 `streaming_sparse_recall_at_k` creates two local variables, 2193 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2194 the recall_at_k frequency. This frequency is ultimately returned as 2195 `recall_at_<k>`: an idempotent operation that simply divides 2196 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2197 `false_negative_at_<k>`). 2198 2199 For estimation of the metric over a stream of data, the function creates an 2200 `update_op` operation that updates these variables and returns the 2201 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2202 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2203 `labels` calculate the true positives and false negatives weighted by 2204 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2205 `false_negative_at_<k>` using these values. 2206 2207 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2208 2209 Args: 2210 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2211 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2212 The final dimension contains the logit values for each class. [D1, ... DN] 2213 must match `labels`. 2214 labels: `int64` `Tensor` or `SparseTensor` with shape 2215 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2216 target classes for the associated prediction. Commonly, N=1 and `labels` 2217 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 2218 Values should be in range [0, num_classes), where num_classes is the last 2219 dimension of `predictions`. Values outside this range always count 2220 towards `false_negative_at_<k>`. 2221 k: Integer, k for @k metric. 2222 class_id: Integer class ID for which we want binary metrics. This should be 2223 in range [0, num_classes), where num_classes is the last dimension of 2224 `predictions`. If class_id is outside this range, the method returns NAN. 2225 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2226 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2227 dimensions must be either `1`, or the same as the corresponding `labels` 2228 dimension). 2229 metrics_collections: An optional list of collections that values should 2230 be added to. 2231 updates_collections: An optional list of collections that updates should 2232 be added to. 2233 name: Name of new update operation, and namespace for other dependent ops. 2234 2235 Returns: 2236 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2237 by the sum of `true_positives` and `false_negatives`. 2238 update_op: `Operation` that increments `true_positives` and 2239 `false_negatives` variables appropriately, and whose value matches 2240 `recall`. 2241 2242 Raises: 2243 ValueError: If `weights` is not `None` and its shape doesn't match 2244 `predictions`, or if either `metrics_collections` or `updates_collections` 2245 are not a list or tuple. 2246 """ 2247 return metrics.recall_at_k( 2248 k=k, 2249 class_id=class_id, 2250 predictions=predictions, 2251 labels=labels, 2252 weights=weights, 2253 metrics_collections=metrics_collections, 2254 updates_collections=updates_collections, 2255 name=name) 2256 2257 2258# TODO(ptucker): Validate range of values in labels? 2259def streaming_sparse_precision_at_k(predictions, 2260 labels, 2261 k, 2262 class_id=None, 2263 weights=None, 2264 metrics_collections=None, 2265 updates_collections=None, 2266 name=None): 2267 """Computes precision@k of the predictions with respect to sparse labels. 2268 2269 If `class_id` is not specified, we calculate precision as the ratio of true 2270 positives (i.e., correct predictions, items in the top `k` highest 2271 `predictions` that are found in the corresponding row in `labels`) to 2272 positives (all top `k` `predictions`). 2273 If `class_id` is specified, we calculate precision by considering only the 2274 rows in the batch for which `class_id` is in the top `k` highest 2275 `predictions`, and computing the fraction of them for which `class_id` is 2276 in the corresponding row in `labels`. 2277 2278 We expect precision to decrease as `k` increases. 2279 2280 `streaming_sparse_precision_at_k` creates two local variables, 2281 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 2282 the precision@k frequency. This frequency is ultimately returned as 2283 `precision_at_<k>`: an idempotent operation that simply divides 2284 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2285 `false_positive_at_<k>`). 2286 2287 For estimation of the metric over a stream of data, the function creates an 2288 `update_op` operation that updates these variables and returns the 2289 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2290 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2291 `labels` calculate the true positives and false positives weighted by 2292 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2293 `false_positive_at_<k>` using these values. 2294 2295 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2296 2297 Args: 2298 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2299 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2300 The final dimension contains the logit values for each class. [D1, ... DN] 2301 must match `labels`. 2302 labels: `int64` `Tensor` or `SparseTensor` with shape 2303 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2304 target classes for the associated prediction. Commonly, N=1 and `labels` 2305 has shape [batch_size, num_labels]. [D1, ... DN] must match 2306 `predictions`. Values should be in range [0, num_classes), where 2307 num_classes is the last dimension of `predictions`. Values outside this 2308 range are ignored. 2309 k: Integer, k for @k metric. 2310 class_id: Integer class ID for which we want binary metrics. This should be 2311 in range [0, num_classes], where num_classes is the last dimension of 2312 `predictions`. If `class_id` is outside this range, the method returns 2313 NAN. 2314 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2315 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2316 dimensions must be either `1`, or the same as the corresponding `labels` 2317 dimension). 2318 metrics_collections: An optional list of collections that values should 2319 be added to. 2320 updates_collections: An optional list of collections that updates should 2321 be added to. 2322 name: Name of new update operation, and namespace for other dependent ops. 2323 2324 Returns: 2325 precision: Scalar `float64` `Tensor` with the value of `true_positives` 2326 divided by the sum of `true_positives` and `false_positives`. 2327 update_op: `Operation` that increments `true_positives` and 2328 `false_positives` variables appropriately, and whose value matches 2329 `precision`. 2330 2331 Raises: 2332 ValueError: If `weights` is not `None` and its shape doesn't match 2333 `predictions`, or if either `metrics_collections` or `updates_collections` 2334 are not a list or tuple. 2335 """ 2336 return metrics.precision_at_k( 2337 k=k, 2338 class_id=class_id, 2339 predictions=predictions, 2340 labels=labels, 2341 weights=weights, 2342 metrics_collections=metrics_collections, 2343 updates_collections=updates_collections, 2344 name=name) 2345 2346 2347# TODO(ptucker): Validate range of values in labels? 2348def streaming_sparse_precision_at_top_k(top_k_predictions, 2349 labels, 2350 class_id=None, 2351 weights=None, 2352 metrics_collections=None, 2353 updates_collections=None, 2354 name=None): 2355 """Computes precision@k of top-k predictions with respect to sparse labels. 2356 2357 If `class_id` is not specified, we calculate precision as the ratio of 2358 true positives (i.e., correct predictions, items in `top_k_predictions` 2359 that are found in the corresponding row in `labels`) to positives (all 2360 `top_k_predictions`). 2361 If `class_id` is specified, we calculate precision by considering only the 2362 rows in the batch for which `class_id` is in the top `k` highest 2363 `predictions`, and computing the fraction of them for which `class_id` is 2364 in the corresponding row in `labels`. 2365 2366 We expect precision to decrease as `k` increases. 2367 2368 `streaming_sparse_precision_at_top_k` creates two local variables, 2369 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 2370 the precision@k frequency. This frequency is ultimately returned as 2371 `precision_at_k`: an idempotent operation that simply divides 2372 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 2373 2374 For estimation of the metric over a stream of data, the function creates an 2375 `update_op` operation that updates these variables and returns the 2376 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 2377 and `labels` calculate the true positives and false positives weighted by 2378 `weights`. Then `update_op` increments `true_positive_at_k` and 2379 `false_positive_at_k` using these values. 2380 2381 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2382 2383 Args: 2384 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 2385 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 2386 The final dimension contains the indices of top-k labels. [D1, ... DN] 2387 must match `labels`. 2388 labels: `int64` `Tensor` or `SparseTensor` with shape 2389 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2390 target classes for the associated prediction. Commonly, N=1 and `labels` 2391 has shape [batch_size, num_labels]. [D1, ... DN] must match 2392 `top_k_predictions`. Values should be in range [0, num_classes), where 2393 num_classes is the last dimension of `predictions`. Values outside this 2394 range are ignored. 2395 class_id: Integer class ID for which we want binary metrics. This should be 2396 in range [0, num_classes), where num_classes is the last dimension of 2397 `predictions`. If `class_id` is outside this range, the method returns 2398 NAN. 2399 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2400 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2401 dimensions must be either `1`, or the same as the corresponding `labels` 2402 dimension). 2403 metrics_collections: An optional list of collections that values should 2404 be added to. 2405 updates_collections: An optional list of collections that updates should 2406 be added to. 2407 name: Name of new update operation, and namespace for other dependent ops. 2408 2409 Returns: 2410 precision: Scalar `float64` `Tensor` with the value of `true_positives` 2411 divided by the sum of `true_positives` and `false_positives`. 2412 update_op: `Operation` that increments `true_positives` and 2413 `false_positives` variables appropriately, and whose value matches 2414 `precision`. 2415 2416 Raises: 2417 ValueError: If `weights` is not `None` and its shape doesn't match 2418 `predictions`, or if either `metrics_collections` or `updates_collections` 2419 are not a list or tuple. 2420 ValueError: If `top_k_predictions` has rank < 2. 2421 """ 2422 default_name = _at_k_name('precision', class_id=class_id) 2423 with ops.name_scope(name, default_name, 2424 (top_k_predictions, labels, weights)) as name_scope: 2425 return metrics_impl.precision_at_top_k( 2426 labels=labels, 2427 predictions_idx=top_k_predictions, 2428 class_id=class_id, 2429 weights=weights, 2430 metrics_collections=metrics_collections, 2431 updates_collections=updates_collections, 2432 name=name_scope) 2433 2434 2435def sparse_recall_at_top_k(labels, 2436 top_k_predictions, 2437 class_id=None, 2438 weights=None, 2439 metrics_collections=None, 2440 updates_collections=None, 2441 name=None): 2442 """Computes recall@k of top-k predictions with respect to sparse labels. 2443 2444 If `class_id` is specified, we calculate recall by considering only the 2445 entries in the batch for which `class_id` is in the label, and computing 2446 the fraction of them for which `class_id` is in the top-k `predictions`. 2447 If `class_id` is not specified, we'll calculate recall as how often on 2448 average a class among the labels of a batch entry is in the top-k 2449 `predictions`. 2450 2451 `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>` 2452 and `false_negative_at_<k>`, that are used to compute the recall_at_k 2453 frequency. This frequency is ultimately returned as `recall_at_<k>`: an 2454 idempotent operation that simply divides `true_positive_at_<k>` by total 2455 (`true_positive_at_<k>` + `false_negative_at_<k>`). 2456 2457 For estimation of the metric over a stream of data, the function creates an 2458 `update_op` operation that updates these variables and returns the 2459 `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the 2460 true positives and false negatives weighted by `weights`. Then `update_op` 2461 increments `true_positive_at_<k>` and `false_negative_at_<k>` using these 2462 values. 2463 2464 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2465 2466 Args: 2467 labels: `int64` `Tensor` or `SparseTensor` with shape 2468 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2469 target classes for the associated prediction. Commonly, N=1 and `labels` 2470 has shape [batch_size, num_labels]. [D1, ... DN] must match 2471 `top_k_predictions`. Values should be in range [0, num_classes), where 2472 num_classes is the last dimension of `predictions`. Values outside this 2473 range always count towards `false_negative_at_<k>`. 2474 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 2475 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 2476 The final dimension contains the indices of top-k labels. [D1, ... DN] 2477 must match `labels`. 2478 class_id: Integer class ID for which we want binary metrics. This should be 2479 in range [0, num_classes), where num_classes is the last dimension of 2480 `predictions`. If class_id is outside this range, the method returns NAN. 2481 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2482 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2483 dimensions must be either `1`, or the same as the corresponding `labels` 2484 dimension). 2485 metrics_collections: An optional list of collections that values should 2486 be added to. 2487 updates_collections: An optional list of collections that updates should 2488 be added to. 2489 name: Name of new update operation, and namespace for other dependent ops. 2490 2491 Returns: 2492 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2493 by the sum of `true_positives` and `false_negatives`. 2494 update_op: `Operation` that increments `true_positives` and 2495 `false_negatives` variables appropriately, and whose value matches 2496 `recall`. 2497 2498 Raises: 2499 ValueError: If `weights` is not `None` and its shape doesn't match 2500 `predictions`, or if either `metrics_collections` or `updates_collections` 2501 are not a list or tuple. 2502 """ 2503 default_name = _at_k_name('recall', class_id=class_id) 2504 with ops.name_scope(name, default_name, 2505 (top_k_predictions, labels, weights)) as name_scope: 2506 return metrics_impl.recall_at_top_k( 2507 labels=labels, 2508 predictions_idx=top_k_predictions, 2509 class_id=class_id, 2510 weights=weights, 2511 metrics_collections=metrics_collections, 2512 updates_collections=updates_collections, 2513 name=name_scope) 2514 2515 2516def _compute_recall_at_precision(tp, fp, fn, precision, name, 2517 strict_mode=False): 2518 """Helper function to compute recall at a given `precision`. 2519 2520 Args: 2521 tp: The number of true positives. 2522 fp: The number of false positives. 2523 fn: The number of false negatives. 2524 precision: The precision for which the recall will be calculated. 2525 name: An optional variable_scope name. 2526 strict_mode: If true and there exists a threshold where the precision is 2527 no smaller than the target precision, return the corresponding recall at 2528 the threshold. Otherwise, return 0. If false, find the threshold where the 2529 precision is closest to the target precision and return the recall at the 2530 threshold. 2531 2532 Returns: 2533 The recall at a given `precision`. 2534 """ 2535 precisions = math_ops.div(tp, tp + fp + _EPSILON) 2536 if not strict_mode: 2537 tf_index = math_ops.argmin( 2538 math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) 2539 # Now, we have the implicit threshold, so compute the recall: 2540 return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, 2541 name) 2542 else: 2543 # We aim to find the threshold where the precision is minimum but no smaller 2544 # than the target precision. 2545 # The rationale: 2546 # 1. Compute the difference between precisions (by different thresholds) and 2547 # the target precision. 2548 # 2. Take the reciprocal of the values by the above step. The intention is 2549 # to make the positive values rank before negative values and also the 2550 # smaller positives rank before larger positives. 2551 tf_index = math_ops.argmax( 2552 math_ops.div(1.0, precisions - precision + _EPSILON), 2553 0, 2554 output_type=dtypes.int32) 2555 2556 def _return_good_recall(): 2557 return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, 2558 name) 2559 2560 return control_flow_ops.cond(precisions[tf_index] >= precision, 2561 _return_good_recall, lambda: .0) 2562 2563 2564def recall_at_precision(labels, 2565 predictions, 2566 precision, 2567 weights=None, 2568 num_thresholds=200, 2569 metrics_collections=None, 2570 updates_collections=None, 2571 name=None, 2572 strict_mode=False): 2573 """Computes `recall` at `precision`. 2574 2575 The `recall_at_precision` function creates four local variables, 2576 `tp` (true positives), `fp` (false positives) and `fn` (false negatives) 2577 that are used to compute the `recall` at the given `precision` value. The 2578 threshold for the given `precision` value is computed and used to evaluate the 2579 corresponding `recall`. 2580 2581 For estimation of the metric over a stream of data, the function creates an 2582 `update_op` operation that updates these variables and returns the 2583 `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the 2584 weight of each case found in the `predictions` and `labels`. 2585 2586 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2587 2588 Args: 2589 labels: The ground truth values, a `Tensor` whose dimensions must match 2590 `predictions`. Will be cast to `bool`. 2591 predictions: A floating point `Tensor` of arbitrary shape and whose values 2592 are in the range `[0, 1]`. 2593 precision: A scalar value in range `[0, 1]`. 2594 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2595 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2596 be either `1`, or the same as the corresponding `labels` dimension). 2597 num_thresholds: The number of thresholds to use for matching the given 2598 `precision`. 2599 metrics_collections: An optional list of collections that `recall` 2600 should be added to. 2601 updates_collections: An optional list of collections that `update_op` should 2602 be added to. 2603 name: An optional variable_scope name. 2604 strict_mode: If true and there exists a threshold where the precision is 2605 above the target precision, return the corresponding recall at the 2606 threshold. Otherwise, return 0. If false, find the threshold where the 2607 precision is closest to the target precision and return the recall at the 2608 threshold. 2609 2610 Returns: 2611 recall: A scalar `Tensor` representing the recall at the given 2612 `precision` value. 2613 update_op: An operation that increments the `tp`, `fp` and `fn` 2614 variables appropriately and whose value matches `recall`. 2615 2616 Raises: 2617 ValueError: If `predictions` and `labels` have mismatched shapes, if 2618 `weights` is not `None` and its shape doesn't match `predictions`, or if 2619 `precision` is not between 0 and 1, or if either `metrics_collections` 2620 or `updates_collections` are not a list or tuple. 2621 2622 """ 2623 if not 0 <= precision <= 1: 2624 raise ValueError('`precision` must be in the range [0, 1].') 2625 2626 with variable_scope.variable_scope(name, 'recall_at_precision', 2627 (predictions, labels, weights)): 2628 thresholds = [ 2629 i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1) 2630 ] 2631 thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON] 2632 2633 values, update_ops = _streaming_confusion_matrix_at_thresholds( 2634 predictions, labels, thresholds, weights) 2635 2636 recall = _compute_recall_at_precision(values['tp'], values['fp'], 2637 values['fn'], precision, 'value', 2638 strict_mode) 2639 update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], 2640 update_ops['fn'], precision, 2641 'update_op', strict_mode) 2642 2643 if metrics_collections: 2644 ops.add_to_collections(metrics_collections, recall) 2645 2646 if updates_collections: 2647 ops.add_to_collections(updates_collections, update_op) 2648 2649 return recall, update_op 2650 2651 2652def precision_at_recall(labels, 2653 predictions, 2654 target_recall, 2655 weights=None, 2656 num_thresholds=200, 2657 metrics_collections=None, 2658 updates_collections=None, 2659 name=None): 2660 """Computes the precision at a given recall. 2661 2662 This function creates variables to track the true positives, false positives, 2663 true negatives, and false negatives at a set of thresholds. Among those 2664 thresholds where recall is at least `target_recall`, precision is computed 2665 at the threshold where recall is closest to `target_recall`. 2666 2667 For estimation of the metric over a stream of data, the function creates an 2668 `update_op` operation that updates these variables and returns the 2669 precision at `target_recall`. `update_op` increments the counts of true 2670 positives, false positives, true negatives, and false negatives with the 2671 weight of each case found in the `predictions` and `labels`. 2672 2673 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2674 2675 For additional information about precision and recall, see 2676 http://en.wikipedia.org/wiki/Precision_and_recall 2677 2678 Args: 2679 labels: The ground truth values, a `Tensor` whose dimensions must match 2680 `predictions`. Will be cast to `bool`. 2681 predictions: A floating point `Tensor` of arbitrary shape and whose values 2682 are in the range `[0, 1]`. 2683 target_recall: A scalar value in range `[0, 1]`. 2684 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2685 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2686 be either `1`, or the same as the corresponding `labels` dimension). 2687 num_thresholds: The number of thresholds to use for matching the given 2688 recall. 2689 metrics_collections: An optional list of collections to which `precision` 2690 should be added. 2691 updates_collections: An optional list of collections to which `update_op` 2692 should be added. 2693 name: An optional variable_scope name. 2694 2695 Returns: 2696 precision: A scalar `Tensor` representing the precision at the given 2697 `target_recall` value. 2698 update_op: An operation that increments the variables for tracking the 2699 true positives, false positives, true negatives, and false negatives and 2700 whose value matches `precision`. 2701 2702 Raises: 2703 ValueError: If `predictions` and `labels` have mismatched shapes, if 2704 `weights` is not `None` and its shape doesn't match `predictions`, or if 2705 `target_recall` is not between 0 and 1, or if either `metrics_collections` 2706 or `updates_collections` are not a list or tuple. 2707 RuntimeError: If eager execution is enabled. 2708 """ 2709 if context.executing_eagerly(): 2710 raise RuntimeError('tf.metrics.precision_at_recall is not ' 2711 'supported when eager execution is enabled.') 2712 2713 if target_recall < 0 or target_recall > 1: 2714 raise ValueError('`target_recall` must be in the range [0, 1].') 2715 2716 with variable_scope.variable_scope(name, 'precision_at_recall', 2717 (predictions, labels, weights)): 2718 kepsilon = 1e-7 # Used to avoid division by zero. 2719 thresholds = [ 2720 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 2721 ] 2722 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 2723 2724 values, update_ops = _streaming_confusion_matrix_at_thresholds( 2725 predictions, labels, thresholds, weights) 2726 2727 def compute_precision_at_recall(tp, fp, fn, name): 2728 """Computes the precision at a given recall. 2729 2730 Args: 2731 tp: True positives. 2732 fp: False positives. 2733 fn: False negatives. 2734 name: A name for the operation. 2735 2736 Returns: 2737 The precision at the desired recall. 2738 """ 2739 recalls = math_ops.div(tp, tp + fn + kepsilon) 2740 2741 # Because recall is monotone decreasing as a function of the threshold, 2742 # the smallest recall exceeding target_recall occurs at the largest 2743 # threshold where recall >= target_recall. 2744 admissible_recalls = math_ops.cast( 2745 math_ops.greater_equal(recalls, target_recall), dtypes.int64) 2746 tf_index = math_ops.reduce_sum(admissible_recalls) - 1 2747 2748 # Now we have the threshold at which to compute precision: 2749 return math_ops.div(tp[tf_index] + kepsilon, 2750 tp[tf_index] + fp[tf_index] + kepsilon, 2751 name) 2752 2753 precision_value = compute_precision_at_recall( 2754 values['tp'], values['fp'], values['fn'], 'value') 2755 update_op = compute_precision_at_recall( 2756 update_ops['tp'], update_ops['fp'], update_ops['fn'], 'update_op') 2757 2758 if metrics_collections: 2759 ops.add_to_collections(metrics_collections, precision_value) 2760 2761 if updates_collections: 2762 ops.add_to_collections(updates_collections, update_op) 2763 2764 return precision_value, update_op 2765 2766 2767def streaming_sparse_average_precision_at_k(predictions, 2768 labels, 2769 k, 2770 weights=None, 2771 metrics_collections=None, 2772 updates_collections=None, 2773 name=None): 2774 """Computes average precision@k of predictions with respect to sparse labels. 2775 2776 See `sparse_average_precision_at_k` for details on formula. `weights` are 2777 applied to the result of `sparse_average_precision_at_k` 2778 2779 `streaming_sparse_average_precision_at_k` creates two local variables, 2780 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2781 are used to compute the frequency. This frequency is ultimately returned as 2782 `average_precision_at_<k>`: an idempotent operation that simply divides 2783 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2784 2785 For estimation of the metric over a stream of data, the function creates an 2786 `update_op` operation that updates these variables and returns the 2787 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2788 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2789 `labels` calculate the true positives and false positives weighted by 2790 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2791 `false_positive_at_<k>` using these values. 2792 2793 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2794 2795 Args: 2796 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2797 N >= 1. Commonly, N=1 and `predictions` has shape 2798 [batch size, num_classes]. The final dimension contains the logit values 2799 for each class. [D1, ... DN] must match `labels`. 2800 labels: `int64` `Tensor` or `SparseTensor` with shape 2801 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2802 target classes for the associated prediction. Commonly, N=1 and `labels` 2803 has shape [batch_size, num_labels]. [D1, ... DN] must match 2804 `predictions_`. Values should be in range [0, num_classes), where 2805 num_classes is the last dimension of `predictions`. Values outside this 2806 range are ignored. 2807 k: Integer, k for @k metric. This will calculate an average precision for 2808 range `[1,k]`, as documented above. 2809 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2810 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2811 dimensions must be either `1`, or the same as the corresponding `labels` 2812 dimension). 2813 metrics_collections: An optional list of collections that values should 2814 be added to. 2815 updates_collections: An optional list of collections that updates should 2816 be added to. 2817 name: Name of new update operation, and namespace for other dependent ops. 2818 2819 Returns: 2820 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2821 precision values. 2822 update: `Operation` that increments variables appropriately, and whose 2823 value matches `metric`. 2824 """ 2825 return metrics.average_precision_at_k( 2826 k=k, 2827 predictions=predictions, 2828 labels=labels, 2829 weights=weights, 2830 metrics_collections=metrics_collections, 2831 updates_collections=updates_collections, 2832 name=name) 2833 2834 2835def streaming_sparse_average_precision_at_top_k(top_k_predictions, 2836 labels, 2837 weights=None, 2838 metrics_collections=None, 2839 updates_collections=None, 2840 name=None): 2841 """Computes average precision@k of predictions with respect to sparse labels. 2842 2843 `streaming_sparse_average_precision_at_top_k` creates two local variables, 2844 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2845 are used to compute the frequency. This frequency is ultimately returned as 2846 `average_precision_at_<k>`: an idempotent operation that simply divides 2847 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2848 2849 For estimation of the metric over a stream of data, the function creates an 2850 `update_op` operation that updates these variables and returns the 2851 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 2852 the true positives and false positives weighted by `weights`. Then `update_op` 2853 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 2854 values. 2855 2856 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2857 2858 Args: 2859 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2860 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 2861 dimension must be set and contains the top `k` predicted class indices. 2862 [D1, ... DN] must match `labels`. Values should be in range 2863 [0, num_classes). 2864 labels: `int64` `Tensor` or `SparseTensor` with shape 2865 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2866 num_labels=1. N >= 1 and num_labels is the number of target classes for 2867 the associated prediction. Commonly, N=1 and `labels` has shape 2868 [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. 2869 Values should be in range [0, num_classes). 2870 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2871 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2872 dimensions must be either `1`, or the same as the corresponding `labels` 2873 dimension). 2874 metrics_collections: An optional list of collections that values should 2875 be added to. 2876 updates_collections: An optional list of collections that updates should 2877 be added to. 2878 name: Name of new update operation, and namespace for other dependent ops. 2879 2880 Returns: 2881 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2882 precision values. 2883 update: `Operation` that increments variables appropriately, and whose 2884 value matches `metric`. 2885 2886 Raises: 2887 ValueError: if the last dimension of top_k_predictions is not set. 2888 """ 2889 return metrics_impl._streaming_sparse_average_precision_at_top_k( # pylint: disable=protected-access 2890 predictions_idx=top_k_predictions, 2891 labels=labels, 2892 weights=weights, 2893 metrics_collections=metrics_collections, 2894 updates_collections=updates_collections, 2895 name=name) 2896 2897 2898@deprecated(None, 2899 'Please switch to tf.metrics.mean_absolute_error. Note that the ' 2900 'order of the labels and predictions arguments has been switched.') 2901def streaming_mean_absolute_error(predictions, 2902 labels, 2903 weights=None, 2904 metrics_collections=None, 2905 updates_collections=None, 2906 name=None): 2907 """Computes the mean absolute error between the labels and predictions. 2908 2909 The `streaming_mean_absolute_error` function creates two local variables, 2910 `total` and `count` that are used to compute the mean absolute error. This 2911 average is weighted by `weights`, and it is ultimately returned as 2912 `mean_absolute_error`: an idempotent operation that simply divides `total` by 2913 `count`. 2914 2915 For estimation of the metric over a stream of data, the function creates an 2916 `update_op` operation that updates these variables and returns the 2917 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 2918 absolute value of the differences between `predictions` and `labels`. Then 2919 `update_op` increments `total` with the reduced sum of the product of 2920 `weights` and `absolute_errors`, and it increments `count` with the reduced 2921 sum of `weights` 2922 2923 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2924 2925 Args: 2926 predictions: A `Tensor` of arbitrary shape. 2927 labels: A `Tensor` of the same shape as `predictions`. 2928 weights: Optional `Tensor` indicating the frequency with which an example is 2929 sampled. Rank must be 0, or the same rank as `labels`, and must be 2930 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2931 the same as the corresponding `labels` dimension). 2932 metrics_collections: An optional list of collections that 2933 `mean_absolute_error` should be added to. 2934 updates_collections: An optional list of collections that `update_op` should 2935 be added to. 2936 name: An optional variable_scope name. 2937 2938 Returns: 2939 mean_absolute_error: A `Tensor` representing the current mean, the value of 2940 `total` divided by `count`. 2941 update_op: An operation that increments the `total` and `count` variables 2942 appropriately and whose value matches `mean_absolute_error`. 2943 2944 Raises: 2945 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2946 `weights` is not `None` and its shape doesn't match `predictions`, or if 2947 either `metrics_collections` or `updates_collections` are not a list or 2948 tuple. 2949 """ 2950 return metrics.mean_absolute_error( 2951 predictions=predictions, 2952 labels=labels, 2953 weights=weights, 2954 metrics_collections=metrics_collections, 2955 updates_collections=updates_collections, 2956 name=name) 2957 2958 2959def streaming_mean_relative_error(predictions, 2960 labels, 2961 normalizer, 2962 weights=None, 2963 metrics_collections=None, 2964 updates_collections=None, 2965 name=None): 2966 """Computes the mean relative error by normalizing with the given values. 2967 2968 The `streaming_mean_relative_error` function creates two local variables, 2969 `total` and `count` that are used to compute the mean relative absolute error. 2970 This average is weighted by `weights`, and it is ultimately returned as 2971 `mean_relative_error`: an idempotent operation that simply divides `total` by 2972 `count`. 2973 2974 For estimation of the metric over a stream of data, the function creates an 2975 `update_op` operation that updates these variables and returns the 2976 `mean_reative_error`. Internally, a `relative_errors` operation divides the 2977 absolute value of the differences between `predictions` and `labels` by the 2978 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 2979 product of `weights` and `relative_errors`, and it increments `count` with the 2980 reduced sum of `weights`. 2981 2982 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2983 2984 Args: 2985 predictions: A `Tensor` of arbitrary shape. 2986 labels: A `Tensor` of the same shape as `predictions`. 2987 normalizer: A `Tensor` of the same shape as `predictions`. 2988 weights: Optional `Tensor` indicating the frequency with which an example is 2989 sampled. Rank must be 0, or the same rank as `labels`, and must be 2990 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2991 the same as the corresponding `labels` dimension). 2992 metrics_collections: An optional list of collections that 2993 `mean_relative_error` should be added to. 2994 updates_collections: An optional list of collections that `update_op` should 2995 be added to. 2996 name: An optional variable_scope name. 2997 2998 Returns: 2999 mean_relative_error: A `Tensor` representing the current mean, the value of 3000 `total` divided by `count`. 3001 update_op: An operation that increments the `total` and `count` variables 3002 appropriately and whose value matches `mean_relative_error`. 3003 3004 Raises: 3005 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3006 `weights` is not `None` and its shape doesn't match `predictions`, or if 3007 either `metrics_collections` or `updates_collections` are not a list or 3008 tuple. 3009 """ 3010 return metrics.mean_relative_error( 3011 normalizer=normalizer, 3012 predictions=predictions, 3013 labels=labels, 3014 weights=weights, 3015 metrics_collections=metrics_collections, 3016 updates_collections=updates_collections, 3017 name=name) 3018 3019@deprecated(None, 3020 'Please switch to tf.metrics.mean_squared_error. Note that the ' 3021 'order of the labels and predictions arguments has been switched.') 3022def streaming_mean_squared_error(predictions, 3023 labels, 3024 weights=None, 3025 metrics_collections=None, 3026 updates_collections=None, 3027 name=None): 3028 """Computes the mean squared error between the labels and predictions. 3029 3030 The `streaming_mean_squared_error` function creates two local variables, 3031 `total` and `count` that are used to compute the mean squared error. 3032 This average is weighted by `weights`, and it is ultimately returned as 3033 `mean_squared_error`: an idempotent operation that simply divides `total` by 3034 `count`. 3035 3036 For estimation of the metric over a stream of data, the function creates an 3037 `update_op` operation that updates these variables and returns the 3038 `mean_squared_error`. Internally, a `squared_error` operation computes the 3039 element-wise square of the difference between `predictions` and `labels`. Then 3040 `update_op` increments `total` with the reduced sum of the product of 3041 `weights` and `squared_error`, and it increments `count` with the reduced sum 3042 of `weights`. 3043 3044 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3045 3046 Args: 3047 predictions: A `Tensor` of arbitrary shape. 3048 labels: A `Tensor` of the same shape as `predictions`. 3049 weights: Optional `Tensor` indicating the frequency with which an example is 3050 sampled. Rank must be 0, or the same rank as `labels`, and must be 3051 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 3052 the same as the corresponding `labels` dimension). 3053 metrics_collections: An optional list of collections that 3054 `mean_squared_error` should be added to. 3055 updates_collections: An optional list of collections that `update_op` should 3056 be added to. 3057 name: An optional variable_scope name. 3058 3059 Returns: 3060 mean_squared_error: A `Tensor` representing the current mean, the value of 3061 `total` divided by `count`. 3062 update_op: An operation that increments the `total` and `count` variables 3063 appropriately and whose value matches `mean_squared_error`. 3064 3065 Raises: 3066 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3067 `weights` is not `None` and its shape doesn't match `predictions`, or if 3068 either `metrics_collections` or `updates_collections` are not a list or 3069 tuple. 3070 """ 3071 return metrics.mean_squared_error( 3072 predictions=predictions, 3073 labels=labels, 3074 weights=weights, 3075 metrics_collections=metrics_collections, 3076 updates_collections=updates_collections, 3077 name=name) 3078 3079@deprecated( 3080 None, 3081 'Please switch to tf.metrics.root_mean_squared_error. Note that the ' 3082 'order of the labels and predictions arguments has been switched.') 3083def streaming_root_mean_squared_error(predictions, 3084 labels, 3085 weights=None, 3086 metrics_collections=None, 3087 updates_collections=None, 3088 name=None): 3089 """Computes the root mean squared error between the labels and predictions. 3090 3091 The `streaming_root_mean_squared_error` function creates two local variables, 3092 `total` and `count` that are used to compute the root mean squared error. 3093 This average is weighted by `weights`, and it is ultimately returned as 3094 `root_mean_squared_error`: an idempotent operation that takes the square root 3095 of the division of `total` by `count`. 3096 3097 For estimation of the metric over a stream of data, the function creates an 3098 `update_op` operation that updates these variables and returns the 3099 `root_mean_squared_error`. Internally, a `squared_error` operation computes 3100 the element-wise square of the difference between `predictions` and `labels`. 3101 Then `update_op` increments `total` with the reduced sum of the product of 3102 `weights` and `squared_error`, and it increments `count` with the reduced sum 3103 of `weights`. 3104 3105 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3106 3107 Args: 3108 predictions: A `Tensor` of arbitrary shape. 3109 labels: A `Tensor` of the same shape as `predictions`. 3110 weights: Optional `Tensor` indicating the frequency with which an example is 3111 sampled. Rank must be 0, or the same rank as `labels`, and must be 3112 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 3113 the same as the corresponding `labels` dimension). 3114 metrics_collections: An optional list of collections that 3115 `root_mean_squared_error` should be added to. 3116 updates_collections: An optional list of collections that `update_op` should 3117 be added to. 3118 name: An optional variable_scope name. 3119 3120 Returns: 3121 root_mean_squared_error: A `Tensor` representing the current mean, the value 3122 of `total` divided by `count`. 3123 update_op: An operation that increments the `total` and `count` variables 3124 appropriately and whose value matches `root_mean_squared_error`. 3125 3126 Raises: 3127 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3128 `weights` is not `None` and its shape doesn't match `predictions`, or if 3129 either `metrics_collections` or `updates_collections` are not a list or 3130 tuple. 3131 """ 3132 return metrics.root_mean_squared_error( 3133 predictions=predictions, 3134 labels=labels, 3135 weights=weights, 3136 metrics_collections=metrics_collections, 3137 updates_collections=updates_collections, 3138 name=name) 3139 3140 3141def streaming_covariance(predictions, 3142 labels, 3143 weights=None, 3144 metrics_collections=None, 3145 updates_collections=None, 3146 name=None): 3147 """Computes the unbiased sample covariance between `predictions` and `labels`. 3148 3149 The `streaming_covariance` function creates four local variables, 3150 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 3151 compute the sample covariance between predictions and labels across multiple 3152 batches of data. The covariance is ultimately returned as an idempotent 3153 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 3154 in order to get an unbiased estimate. 3155 3156 The algorithm used for this online computation is described in 3157 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 3158 Specifically, the formula used to combine two sample comoments is 3159 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 3160 The comoment for a single batch of data is simply 3161 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 3162 3163 If `weights` is not None, then it is used to compute weighted comoments, 3164 means, and count. NOTE: these weights are treated as "frequency weights", as 3165 opposed to "reliability weights". See discussion of the difference on 3166 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 3167 3168 To facilitate the computation of covariance across multiple batches of data, 3169 the function creates an `update_op` operation, which updates underlying 3170 variables and returns the updated covariance. 3171 3172 Args: 3173 predictions: A `Tensor` of arbitrary size. 3174 labels: A `Tensor` of the same size as `predictions`. 3175 weights: Optional `Tensor` indicating the frequency with which an example is 3176 sampled. Rank must be 0, or the same rank as `labels`, and must be 3177 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 3178 the same as the corresponding `labels` dimension). 3179 metrics_collections: An optional list of collections that the metric 3180 value variable should be added to. 3181 updates_collections: An optional list of collections that the metric update 3182 ops should be added to. 3183 name: An optional variable_scope name. 3184 3185 Returns: 3186 covariance: A `Tensor` representing the current unbiased sample covariance, 3187 `comoment` / (`count` - 1). 3188 update_op: An operation that updates the local variables appropriately. 3189 3190 Raises: 3191 ValueError: If labels and predictions are of different sizes or if either 3192 `metrics_collections` or `updates_collections` are not a list or tuple. 3193 """ 3194 with variable_scope.variable_scope(name, 'covariance', 3195 (predictions, labels, weights)): 3196 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3197 predictions, labels, weights) 3198 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3199 count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') 3200 mean_prediction = metrics_impl.metric_variable( 3201 [], dtypes.float32, name='mean_prediction') 3202 mean_label = metrics_impl.metric_variable( 3203 [], dtypes.float32, name='mean_label') 3204 comoment = metrics_impl.metric_variable( # C_A in update equation 3205 [], dtypes.float32, name='comoment') 3206 3207 if weights is None: 3208 batch_count = math_ops.cast( 3209 array_ops.size(labels), dtypes.float32) # n_B in eqn 3210 weighted_predictions = predictions 3211 weighted_labels = labels 3212 else: 3213 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 3214 batch_count = math_ops.reduce_sum(weights) # n_B in eqn 3215 weighted_predictions = math_ops.multiply(predictions, weights) 3216 weighted_labels = math_ops.multiply(labels, weights) 3217 3218 update_count = state_ops.assign_add(count_, batch_count) # n_AB in eqn 3219 prev_count = update_count - batch_count # n_A in update equation 3220 3221 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 3222 # batch_mean_prediction is E[x_B] in the update equation 3223 batch_mean_prediction = math_ops.div_no_nan( 3224 math_ops.reduce_sum(weighted_predictions), batch_count) 3225 delta_mean_prediction = math_ops.div_no_nan( 3226 (batch_mean_prediction - mean_prediction) * batch_count, update_count) 3227 update_mean_prediction = state_ops.assign_add(mean_prediction, 3228 delta_mean_prediction) 3229 # prev_mean_prediction is E[x_A] in the update equation 3230 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 3231 3232 # batch_mean_label is E[y_B] in the update equation 3233 batch_mean_label = math_ops.div_no_nan( 3234 math_ops.reduce_sum(weighted_labels), batch_count) 3235 delta_mean_label = math_ops.div_no_nan( 3236 (batch_mean_label - mean_label) * batch_count, update_count) 3237 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 3238 # prev_mean_label is E[y_A] in the update equation 3239 prev_mean_label = update_mean_label - delta_mean_label 3240 3241 unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * 3242 (labels - batch_mean_label)) 3243 # batch_comoment is C_B in the update equation 3244 if weights is None: 3245 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 3246 else: 3247 batch_comoment = math_ops.reduce_sum( 3248 unweighted_batch_coresiduals * weights) 3249 3250 # View delta_comoment as = C_AB - C_A in the update equation above. 3251 # Since C_A is stored in a var, by how much do we need to increment that var 3252 # to make the var = C_AB? 3253 delta_comoment = ( 3254 batch_comoment + (prev_mean_prediction - batch_mean_prediction) * 3255 (prev_mean_label - batch_mean_label) * 3256 (prev_count * batch_count / update_count)) 3257 update_comoment = state_ops.assign_add(comoment, delta_comoment) 3258 3259 covariance = array_ops.where( 3260 math_ops.less_equal(count_, 1.), 3261 float('nan'), 3262 math_ops.truediv(comoment, count_ - 1), 3263 name='covariance') 3264 with ops.control_dependencies([update_comoment]): 3265 update_op = array_ops.where( 3266 math_ops.less_equal(count_, 1.), 3267 float('nan'), 3268 math_ops.truediv(comoment, count_ - 1), 3269 name='update_op') 3270 3271 if metrics_collections: 3272 ops.add_to_collections(metrics_collections, covariance) 3273 3274 if updates_collections: 3275 ops.add_to_collections(updates_collections, update_op) 3276 3277 return covariance, update_op 3278 3279 3280def streaming_pearson_correlation(predictions, 3281 labels, 3282 weights=None, 3283 metrics_collections=None, 3284 updates_collections=None, 3285 name=None): 3286 """Computes Pearson correlation coefficient between `predictions`, `labels`. 3287 3288 The `streaming_pearson_correlation` function delegates to 3289 `streaming_covariance` the tracking of three [co]variances: 3290 3291 - `streaming_covariance(predictions, labels)`, i.e. covariance 3292 - `streaming_covariance(predictions, predictions)`, i.e. variance 3293 - `streaming_covariance(labels, labels)`, i.e. variance 3294 3295 The product-moment correlation ultimately returned is an idempotent operation 3296 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 3297 facilitate correlation computation across multiple batches, the function 3298 groups the `update_op`s of the underlying streaming_covariance and returns an 3299 `update_op`. 3300 3301 If `weights` is not None, then it is used to compute a weighted correlation. 3302 NOTE: these weights are treated as "frequency weights", as opposed to 3303 "reliability weights". See discussion of the difference on 3304 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 3305 3306 Args: 3307 predictions: A `Tensor` of arbitrary size. 3308 labels: A `Tensor` of the same size as predictions. 3309 weights: Optional `Tensor` indicating the frequency with which an example is 3310 sampled. Rank must be 0, or the same rank as `labels`, and must be 3311 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 3312 the same as the corresponding `labels` dimension). 3313 metrics_collections: An optional list of collections that the metric 3314 value variable should be added to. 3315 updates_collections: An optional list of collections that the metric update 3316 ops should be added to. 3317 name: An optional variable_scope name. 3318 3319 Returns: 3320 pearson_r: A `Tensor` representing the current Pearson product-moment 3321 correlation coefficient, the value of 3322 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 3323 update_op: An operation that updates the underlying variables appropriately. 3324 3325 Raises: 3326 ValueError: If `labels` and `predictions` are of different sizes, or if 3327 `weights` is the wrong size, or if either `metrics_collections` or 3328 `updates_collections` are not a `list` or `tuple`. 3329 """ 3330 with variable_scope.variable_scope(name, 'pearson_r', 3331 (predictions, labels, weights)): 3332 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3333 predictions, labels, weights) 3334 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3335 # Broadcast weights here to avoid duplicate broadcasting in each call to 3336 # `streaming_covariance`. 3337 if weights is not None: 3338 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 3339 cov, update_cov = streaming_covariance( 3340 predictions, labels, weights=weights, name='covariance') 3341 var_predictions, update_var_predictions = streaming_covariance( 3342 predictions, predictions, weights=weights, name='variance_predictions') 3343 var_labels, update_var_labels = streaming_covariance( 3344 labels, labels, weights=weights, name='variance_labels') 3345 3346 pearson_r = math_ops.truediv( 3347 cov, 3348 math_ops.multiply( 3349 math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 3350 name='pearson_r') 3351 update_op = math_ops.truediv( 3352 update_cov, 3353 math_ops.multiply( 3354 math_ops.sqrt(update_var_predictions), 3355 math_ops.sqrt(update_var_labels)), 3356 name='update_op') 3357 3358 if metrics_collections: 3359 ops.add_to_collections(metrics_collections, pearson_r) 3360 3361 if updates_collections: 3362 ops.add_to_collections(updates_collections, update_op) 3363 3364 return pearson_r, update_op 3365 3366 3367# TODO(nsilberman): add a 'normalized' flag so that the user can request 3368# normalization if the inputs are not normalized. 3369def streaming_mean_cosine_distance(predictions, 3370 labels, 3371 dim, 3372 weights=None, 3373 metrics_collections=None, 3374 updates_collections=None, 3375 name=None): 3376 """Computes the cosine distance between the labels and predictions. 3377 3378 The `streaming_mean_cosine_distance` function creates two local variables, 3379 `total` and `count` that are used to compute the average cosine distance 3380 between `predictions` and `labels`. This average is weighted by `weights`, 3381 and it is ultimately returned as `mean_distance`, which is an idempotent 3382 operation that simply divides `total` by `count`. 3383 3384 For estimation of the metric over a stream of data, the function creates an 3385 `update_op` operation that updates these variables and returns the 3386 `mean_distance`. 3387 3388 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3389 3390 Args: 3391 predictions: A `Tensor` of the same shape as `labels`. 3392 labels: A `Tensor` of arbitrary shape. 3393 dim: The dimension along which the cosine distance is computed. 3394 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 3395 and whose dimension `dim` is 1. 3396 metrics_collections: An optional list of collections that the metric 3397 value variable should be added to. 3398 updates_collections: An optional list of collections that the metric update 3399 ops should be added to. 3400 name: An optional variable_scope name. 3401 3402 Returns: 3403 mean_distance: A `Tensor` representing the current mean, the value of 3404 `total` divided by `count`. 3405 update_op: An operation that increments the `total` and `count` variables 3406 appropriately. 3407 3408 Raises: 3409 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3410 `weights` is not `None` and its shape doesn't match `predictions`, or if 3411 either `metrics_collections` or `updates_collections` are not a list or 3412 tuple. 3413 """ 3414 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3415 predictions, labels, weights) 3416 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3417 radial_diffs = math_ops.multiply(predictions, labels) 3418 radial_diffs = math_ops.reduce_sum( 3419 radial_diffs, axis=[ 3420 dim, 3421 ], keepdims=True) 3422 mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, 3423 name or 'mean_cosine_distance') 3424 mean_distance = math_ops.subtract(1.0, mean_distance) 3425 update_op = math_ops.subtract(1.0, update_op) 3426 3427 if metrics_collections: 3428 ops.add_to_collections(metrics_collections, mean_distance) 3429 3430 if updates_collections: 3431 ops.add_to_collections(updates_collections, update_op) 3432 3433 return mean_distance, update_op 3434 3435 3436def streaming_percentage_less(values, 3437 threshold, 3438 weights=None, 3439 metrics_collections=None, 3440 updates_collections=None, 3441 name=None): 3442 """Computes the percentage of values less than the given threshold. 3443 3444 The `streaming_percentage_less` function creates two local variables, 3445 `total` and `count` that are used to compute the percentage of `values` that 3446 fall below `threshold`. This rate is weighted by `weights`, and it is 3447 ultimately returned as `percentage` which is an idempotent operation that 3448 simply divides `total` by `count`. 3449 3450 For estimation of the metric over a stream of data, the function creates an 3451 `update_op` operation that updates these variables and returns the 3452 `percentage`. 3453 3454 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3455 3456 Args: 3457 values: A numeric `Tensor` of arbitrary size. 3458 threshold: A scalar threshold. 3459 weights: An optional `Tensor` whose shape is broadcastable to `values`. 3460 metrics_collections: An optional list of collections that the metric 3461 value variable should be added to. 3462 updates_collections: An optional list of collections that the metric update 3463 ops should be added to. 3464 name: An optional variable_scope name. 3465 3466 Returns: 3467 percentage: A `Tensor` representing the current mean, the value of `total` 3468 divided by `count`. 3469 update_op: An operation that increments the `total` and `count` variables 3470 appropriately. 3471 3472 Raises: 3473 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 3474 or if either `metrics_collections` or `updates_collections` are not a list 3475 or tuple. 3476 """ 3477 return metrics.percentage_below( 3478 values=values, 3479 threshold=threshold, 3480 weights=weights, 3481 metrics_collections=metrics_collections, 3482 updates_collections=updates_collections, 3483 name=name) 3484 3485 3486def streaming_mean_iou(predictions, 3487 labels, 3488 num_classes, 3489 weights=None, 3490 metrics_collections=None, 3491 updates_collections=None, 3492 name=None): 3493 """Calculate per-step mean Intersection-Over-Union (mIOU). 3494 3495 Mean Intersection-Over-Union is a common evaluation metric for 3496 semantic image segmentation, which first computes the IOU for each 3497 semantic class and then computes the average over classes. 3498 IOU is defined as follows: 3499 IOU = true_positive / (true_positive + false_positive + false_negative). 3500 The predictions are accumulated in a confusion matrix, weighted by `weights`, 3501 and mIOU is then calculated from it. 3502 3503 For estimation of the metric over a stream of data, the function creates an 3504 `update_op` operation that updates these variables and returns the `mean_iou`. 3505 3506 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3507 3508 Args: 3509 predictions: A `Tensor` of prediction results for semantic labels, whose 3510 shape is [batch size] and type `int32` or `int64`. The tensor will be 3511 flattened, if its rank > 1. 3512 labels: A `Tensor` of ground truth labels with shape [batch size] and of 3513 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 3514 num_classes: The possible number of labels the prediction task can 3515 have. This value must be provided, since a confusion matrix of 3516 dimension = [num_classes, num_classes] will be allocated. 3517 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 3518 metrics_collections: An optional list of collections that `mean_iou` 3519 should be added to. 3520 updates_collections: An optional list of collections `update_op` should be 3521 added to. 3522 name: An optional variable_scope name. 3523 3524 Returns: 3525 mean_iou: A `Tensor` representing the mean intersection-over-union. 3526 update_op: An operation that increments the confusion matrix. 3527 3528 Raises: 3529 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3530 `weights` is not `None` and its shape doesn't match `predictions`, or if 3531 either `metrics_collections` or `updates_collections` are not a list or 3532 tuple. 3533 """ 3534 return metrics.mean_iou( 3535 num_classes=num_classes, 3536 predictions=predictions, 3537 labels=labels, 3538 weights=weights, 3539 metrics_collections=metrics_collections, 3540 updates_collections=updates_collections, 3541 name=name) 3542 3543 3544def _next_array_size(required_size, growth_factor=1.5): 3545 """Calculate the next size for reallocating a dynamic array. 3546 3547 Args: 3548 required_size: number or tf.Tensor specifying required array capacity. 3549 growth_factor: optional number or tf.Tensor specifying the growth factor 3550 between subsequent allocations. 3551 3552 Returns: 3553 tf.Tensor with dtype=int32 giving the next array size. 3554 """ 3555 exponent = math_ops.ceil( 3556 math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( 3557 math_ops.cast(growth_factor, dtypes.float32))) 3558 return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) 3559 3560 3561def streaming_concat(values, 3562 axis=0, 3563 max_size=None, 3564 metrics_collections=None, 3565 updates_collections=None, 3566 name=None): 3567 """Concatenate values along an axis across batches. 3568 3569 The function `streaming_concat` creates two local variables, `array` and 3570 `size`, that are used to store concatenated values. Internally, `array` is 3571 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 3572 that updates can be run in amortized constant time. 3573 3574 For estimation of the metric over a stream of data, the function creates an 3575 `update_op` operation that appends the values of a tensor and returns the 3576 length of the concatenated axis. 3577 3578 This op allows for evaluating metrics that cannot be updated incrementally 3579 using the same framework as other streaming metrics. 3580 3581 Args: 3582 values: `Tensor` to concatenate. Rank and the shape along all axes other 3583 than the axis to concatenate along must be statically known. 3584 axis: optional integer axis to concatenate along. 3585 max_size: optional integer maximum size of `value` along the given axis. 3586 Once the maximum size is reached, further updates are no-ops. By default, 3587 there is no maximum size: the array is resized as necessary. 3588 metrics_collections: An optional list of collections that `value` 3589 should be added to. 3590 updates_collections: An optional list of collections `update_op` should be 3591 added to. 3592 name: An optional variable_scope name. 3593 3594 Returns: 3595 value: A `Tensor` representing the concatenated values. 3596 update_op: An operation that concatenates the next values. 3597 3598 Raises: 3599 ValueError: if `values` does not have a statically known rank, `axis` is 3600 not in the valid range or the size of `values` is not statically known 3601 along any axis other than `axis`. 3602 """ 3603 with variable_scope.variable_scope(name, 'streaming_concat', (values,)): 3604 # pylint: disable=invalid-slice-index 3605 values_shape = values.get_shape() 3606 if values_shape.dims is None: 3607 raise ValueError('`values` must have known statically known rank') 3608 3609 ndim = len(values_shape) 3610 if axis < 0: 3611 axis += ndim 3612 if not 0 <= axis < ndim: 3613 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 3614 3615 fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] 3616 if any(value is None for value in fixed_shape): 3617 raise ValueError('all dimensions of `values` other than the dimension to ' 3618 'concatenate along must have statically known size') 3619 3620 # We move `axis` to the front of the internal array so assign ops can be 3621 # applied to contiguous slices 3622 init_size = 0 if max_size is None else max_size 3623 init_shape = [init_size] + fixed_shape 3624 array = metrics_impl.metric_variable( 3625 init_shape, values.dtype, validate_shape=False, name='array') 3626 size = metrics_impl.metric_variable([], dtypes.int32, name='size') 3627 3628 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 3629 valid_array = array[:size] 3630 valid_array.set_shape([None] + fixed_shape) 3631 value = array_ops.transpose(valid_array, perm, name='concat') 3632 3633 values_size = array_ops.shape(values)[axis] 3634 if max_size is None: 3635 batch_size = values_size 3636 else: 3637 batch_size = math_ops.minimum(values_size, max_size - size) 3638 3639 perm = [axis] + [n for n in range(ndim) if n != axis] 3640 batch_values = array_ops.transpose(values, perm)[:batch_size] 3641 3642 def reallocate(): 3643 next_size = _next_array_size(new_size) 3644 next_shape = array_ops.stack([next_size] + fixed_shape) 3645 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 3646 old_value = array.value() 3647 assign_op = state_ops.assign(array, new_value, validate_shape=False) 3648 with ops.control_dependencies([assign_op]): 3649 copy_op = array[:size].assign(old_value[:size]) 3650 # return value needs to be the same dtype as no_op() for cond 3651 with ops.control_dependencies([copy_op]): 3652 return control_flow_ops.no_op() 3653 3654 new_size = size + batch_size 3655 array_size = array_ops.shape_internal(array, optimize=False)[0] 3656 maybe_reallocate_op = control_flow_ops.cond( 3657 new_size > array_size, reallocate, control_flow_ops.no_op) 3658 with ops.control_dependencies([maybe_reallocate_op]): 3659 append_values_op = array[size:new_size].assign(batch_values) 3660 with ops.control_dependencies([append_values_op]): 3661 update_op = size.assign(new_size) 3662 3663 if metrics_collections: 3664 ops.add_to_collections(metrics_collections, value) 3665 3666 if updates_collections: 3667 ops.add_to_collections(updates_collections, update_op) 3668 3669 return value, update_op 3670 # pylint: enable=invalid-slice-index 3671 3672 3673def aggregate_metrics(*value_update_tuples): 3674 """Aggregates the metric value tensors and update ops into two lists. 3675 3676 Args: 3677 *value_update_tuples: a variable number of tuples, each of which contain the 3678 pair of (value_tensor, update_op) from a streaming metric. 3679 3680 Returns: 3681 A list of value `Tensor` objects and a list of update ops. 3682 3683 Raises: 3684 ValueError: if `value_update_tuples` is empty. 3685 """ 3686 if not value_update_tuples: 3687 raise ValueError('Expected at least one value_tensor/update_op pair') 3688 value_ops, update_ops = zip(*value_update_tuples) 3689 return list(value_ops), list(update_ops) 3690 3691 3692def aggregate_metric_map(names_to_tuples): 3693 """Aggregates the metric names to tuple dictionary. 3694 3695 This function is useful for pairing metric names with their associated value 3696 and update ops when the list of metrics is long. For example: 3697 3698 ```python 3699 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 3700 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 3701 predictions, labels, weights), 3702 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 3703 predictions, labels, labels, weights), 3704 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 3705 predictions, labels, weights), 3706 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 3707 predictions, labels, weights), 3708 }) 3709 ``` 3710 3711 Args: 3712 names_to_tuples: a map of metric names to tuples, each of which contain the 3713 pair of (value_tensor, update_op) from a streaming metric. 3714 3715 Returns: 3716 A dictionary from metric names to value ops and a dictionary from metric 3717 names to update ops. 3718 """ 3719 metric_names = names_to_tuples.keys() 3720 value_ops, update_ops = zip(*names_to_tuples.values()) 3721 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 3722 3723 3724def count(values, 3725 weights=None, 3726 metrics_collections=None, 3727 updates_collections=None, 3728 name=None): 3729 """Computes the number of examples, or sum of `weights`. 3730 3731 This metric keeps track of the denominator in `tf.metrics.mean`. 3732 When evaluating some metric (e.g. mean) on one or more subsets of the data, 3733 this auxiliary metric is useful for keeping track of how many examples there 3734 are in each subset. 3735 3736 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3737 3738 Args: 3739 values: A `Tensor` of arbitrary dimensions. Only it's shape is used. 3740 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3741 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 3742 must be either `1`, or the same as the corresponding `labels` 3743 dimension). 3744 metrics_collections: An optional list of collections that the metric 3745 value variable should be added to. 3746 updates_collections: An optional list of collections that the metric update 3747 ops should be added to. 3748 name: An optional variable_scope name. 3749 3750 Returns: 3751 count: A `Tensor` representing the current value of the metric. 3752 update_op: An operation that accumulates the metric from a batch of data. 3753 3754 Raises: 3755 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 3756 or if either `metrics_collections` or `updates_collections` are not a list 3757 or tuple. 3758 RuntimeError: If eager execution is enabled. 3759 """ 3760 if context.executing_eagerly(): 3761 raise RuntimeError('tf.contrib.metrics.count is not supported when eager ' 3762 'execution is enabled.') 3763 3764 with variable_scope.variable_scope(name, 'count', (values, weights)): 3765 3766 count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') 3767 3768 if weights is None: 3769 num_values = math_ops.cast(array_ops.size(values), dtypes.float32) 3770 else: 3771 values = math_ops.cast(values, dtypes.float32) 3772 values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3773 predictions=values, 3774 labels=None, 3775 weights=weights) 3776 weights = weights_broadcast_ops.broadcast_weights( 3777 math_ops.cast(weights, dtypes.float32), values) 3778 num_values = math_ops.reduce_sum(weights) 3779 3780 with ops.control_dependencies([values]): 3781 update_count_op = state_ops.assign_add(count_, num_values) 3782 3783 count_ = metrics_impl._aggregate_variable(count_, metrics_collections) # pylint: disable=protected-access 3784 3785 if updates_collections: 3786 ops.add_to_collections(updates_collections, update_count_op) 3787 3788 return count_, update_count_op 3789 3790 3791def cohen_kappa(labels, 3792 predictions_idx, 3793 num_classes, 3794 weights=None, 3795 metrics_collections=None, 3796 updates_collections=None, 3797 name=None): 3798 """Calculates Cohen's kappa. 3799 3800 [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic 3801 that measures inter-annotator agreement. 3802 3803 The `cohen_kappa` function calculates the confusion matrix, and creates three 3804 local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`, 3805 which refer to the diagonal part, rows and columns totals of the confusion 3806 matrix, respectively. This value is ultimately returned as `kappa`, an 3807 idempotent operation that is calculated by 3808 3809 pe = (pe_row * pe_col) / N 3810 k = (sum(po) - sum(pe)) / (N - sum(pe)) 3811 3812 For estimation of the metric over a stream of data, the function creates an 3813 `update_op` operation that updates these variables and returns the 3814 `kappa`. `update_op` weights each prediction by the corresponding value in 3815 `weights`. 3816 3817 Class labels are expected to start at 0. E.g., if `num_classes` 3818 was three, then the possible labels would be [0, 1, 2]. 3819 3820 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3821 3822 NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method 3823 doesn't support weighted matrix yet. 3824 3825 Args: 3826 labels: 1-D `Tensor` of real labels for the classification task. Must be 3827 one of the following types: int16, int32, int64. 3828 predictions_idx: 1-D `Tensor` of predicted class indices for a given 3829 classification. Must have the same type as `labels`. 3830 num_classes: The possible number of labels. 3831 weights: Optional `Tensor` whose shape matches `predictions`. 3832 metrics_collections: An optional list of collections that `kappa` should 3833 be added to. 3834 updates_collections: An optional list of collections that `update_op` should 3835 be added to. 3836 name: An optional variable_scope name. 3837 3838 Returns: 3839 kappa: Scalar float `Tensor` representing the current Cohen's kappa. 3840 update_op: `Operation` that increments `po`, `pe_row` and `pe_col` 3841 variables appropriately and whose value matches `kappa`. 3842 3843 Raises: 3844 ValueError: If `num_classes` is less than 2, or `predictions` and `labels` 3845 have mismatched shapes, or if `weights` is not `None` and its shape 3846 doesn't match `predictions`, or if either `metrics_collections` or 3847 `updates_collections` are not a list or tuple. 3848 RuntimeError: If eager execution is enabled. 3849 """ 3850 if context.executing_eagerly(): 3851 raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported ' 3852 'when eager execution is enabled.') 3853 if num_classes < 2: 3854 raise ValueError('`num_classes` must be >= 2.' 3855 'Found: {}'.format(num_classes)) 3856 with variable_scope.variable_scope(name, 'cohen_kappa', 3857 (labels, predictions_idx, weights)): 3858 # Convert 2-dim (num, 1) to 1-dim (num,) 3859 labels.get_shape().with_rank_at_most(2) 3860 if labels.get_shape().ndims == 2: 3861 labels = array_ops.squeeze(labels, axis=[-1]) 3862 predictions_idx, labels, weights = ( 3863 metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3864 predictions=predictions_idx, 3865 labels=labels, 3866 weights=weights)) 3867 predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape()) 3868 3869 stat_dtype = ( 3870 dtypes.int64 3871 if weights is None or weights.dtype.is_integer else dtypes.float32) 3872 po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po') 3873 pe_row = metrics_impl.metric_variable( 3874 (num_classes,), stat_dtype, name='pe_row') 3875 pe_col = metrics_impl.metric_variable( 3876 (num_classes,), stat_dtype, name='pe_col') 3877 3878 # Table of the counts of agreement: 3879 counts_in_table = confusion_matrix.confusion_matrix( 3880 labels, 3881 predictions_idx, 3882 num_classes=num_classes, 3883 weights=weights, 3884 dtype=stat_dtype, 3885 name='counts_in_table') 3886 3887 po_t = array_ops.diag_part(counts_in_table) 3888 pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0) 3889 pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1) 3890 update_po = state_ops.assign_add(po, po_t) 3891 update_pe_row = state_ops.assign_add(pe_row, pe_row_t) 3892 update_pe_col = state_ops.assign_add(pe_col, pe_col_t) 3893 3894 def _calculate_k(po, pe_row, pe_col, name): 3895 po_sum = math_ops.reduce_sum(po) 3896 total = math_ops.reduce_sum(pe_row) 3897 pe_sum = math_ops.reduce_sum( 3898 math_ops.div_no_nan( 3899 math_ops.cast(pe_row * pe_col, dtypes.float64), 3900 math_ops.cast(total, dtypes.float64))) 3901 po_sum, pe_sum, total = (math_ops.cast(po_sum, dtypes.float64), 3902 math_ops.cast(pe_sum, dtypes.float64), 3903 math_ops.cast(total, dtypes.float64)) 3904 # kappa = (po - pe) / (N - pe) 3905 k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access 3906 po_sum - pe_sum, 3907 total - pe_sum, 3908 name=name) 3909 return k 3910 3911 kappa = _calculate_k(po, pe_row, pe_col, name='value') 3912 update_op = _calculate_k( 3913 update_po, update_pe_row, update_pe_col, name='update_op') 3914 3915 if metrics_collections: 3916 ops.add_to_collections(metrics_collections, kappa) 3917 3918 if updates_collections: 3919 ops.add_to_collections(updates_collections, update_op) 3920 3921 return kappa, update_op 3922 3923 3924__all__ = [ 3925 'auc_with_confidence_intervals', 3926 'aggregate_metric_map', 3927 'aggregate_metrics', 3928 'cohen_kappa', 3929 'count', 3930 'precision_recall_at_equal_thresholds', 3931 'recall_at_precision', 3932 'sparse_recall_at_top_k', 3933 'streaming_accuracy', 3934 'streaming_auc', 3935 'streaming_curve_points', 3936 'streaming_dynamic_auc', 3937 'streaming_false_negative_rate', 3938 'streaming_false_negative_rate_at_thresholds', 3939 'streaming_false_negatives', 3940 'streaming_false_negatives_at_thresholds', 3941 'streaming_false_positive_rate', 3942 'streaming_false_positive_rate_at_thresholds', 3943 'streaming_false_positives', 3944 'streaming_false_positives_at_thresholds', 3945 'streaming_mean', 3946 'streaming_mean_absolute_error', 3947 'streaming_mean_cosine_distance', 3948 'streaming_mean_iou', 3949 'streaming_mean_relative_error', 3950 'streaming_mean_squared_error', 3951 'streaming_mean_tensor', 3952 'streaming_percentage_less', 3953 'streaming_precision', 3954 'streaming_precision_at_thresholds', 3955 'streaming_recall', 3956 'streaming_recall_at_k', 3957 'streaming_recall_at_thresholds', 3958 'streaming_root_mean_squared_error', 3959 'streaming_sensitivity_at_specificity', 3960 'streaming_sparse_average_precision_at_k', 3961 'streaming_sparse_average_precision_at_top_k', 3962 'streaming_sparse_precision_at_k', 3963 'streaming_sparse_precision_at_top_k', 3964 'streaming_sparse_recall_at_k', 3965 'streaming_specificity_at_sensitivity', 3966 'streaming_true_negatives', 3967 'streaming_true_negatives_at_thresholds', 3968 'streaming_true_positives', 3969 'streaming_true_positives_at_thresholds', 3970] 3971