1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# Licensed under the Apache License, Version 2.0 (the "License"); 3# you may not use this file except in compliance with the License. 4# You may obtain a copy of the License at 5# 6# http://www.apache.org/licenses/LICENSE-2.0 7# 8# Unless required by applicable law or agreed to in writing, software 9# distributed under the License is distributed on an "AS IS" BASIS, 10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11# See the License for the specific language governing permissions and 12# limitations under the License. 13# ============================================================================== 14"""Implementation of tf.metrics module.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import distribution_strategy_context 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import confusion_matrix 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn 31from tensorflow.python.ops import sets 32from tensorflow.python.ops import sparse_ops 33from tensorflow.python.ops import state_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.ops import weights_broadcast_ops 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.util.deprecation import deprecated 38from tensorflow.python.util.tf_export import tf_export 39 40 41def metric_variable(shape, dtype, validate_shape=True, name=None): 42 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. 43 44 If running in a `DistributionStrategy` context, the variable will be 45 "sync on read". This means: 46 47 * The returned object will be a container with separate variables 48 per replica of the model. 49 50 * When writing to the variable, e.g. using `assign_add` in a metric 51 update, the update will be applied to the variable local to the 52 replica. 53 54 * To get a metric's result value, we need to sum the variable values 55 across the replicas before computing the final answer. Furthermore, 56 the final answer should be computed once instead of in every 57 replica. Both of these are accomplished by running the computation 58 of the final result value inside 59 `distribution_strategy_context.get_replica_context().merge_call(fn)`. 60 Inside the `merge_call()`, ops are only added to the graph once 61 and access to a sync on read variable in a computation returns 62 the sum across all replicas. 63 64 Args: 65 shape: Shape of the created variable. 66 dtype: Type of the created variable. 67 validate_shape: (Optional) Whether shape validation is enabled for 68 the created variable. 69 name: (Optional) String name of the created variable. 70 71 Returns: 72 A (non-trainable) variable initialized to zero, or if inside a 73 `DistributionStrategy` scope a sync on read variable container. 74 """ 75 # Note that synchronization "ON_READ" implies trainable=False. 76 return variable_scope.variable( 77 lambda: array_ops.zeros(shape, dtype), 78 trainable=False, 79 collections=[ 80 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 81 ], 82 validate_shape=validate_shape, 83 synchronization=variable_scope.VariableSynchronization.ON_READ, 84 aggregation=variable_scope.VariableAggregation.SUM, 85 name=name) 86 87 88def _remove_squeezable_dimensions(predictions, labels, weights): 89 """Squeeze or expand last dim if needed. 90 91 Squeezes last dim of `predictions` or `labels` if their rank differs by 1 92 (using confusion_matrix.remove_squeezable_dimensions). 93 Squeezes or expands last dim of `weights` if its rank differs by 1 from the 94 new rank of `predictions`. 95 96 If `weights` is scalar, it is kept scalar. 97 98 This will use static shape if available. Otherwise, it will add graph 99 operations, which could result in a performance hit. 100 101 Args: 102 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 103 labels: Optional label `Tensor` whose dimensions match `predictions`. 104 weights: Optional weight scalar or `Tensor` whose dimensions match 105 `predictions`. 106 107 Returns: 108 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has 109 the last dimension squeezed, `weights` could be extended by one dimension. 110 """ 111 predictions = ops.convert_to_tensor(predictions) 112 if labels is not None: 113 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 114 labels, predictions) 115 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 116 117 if weights is None: 118 return predictions, labels, None 119 120 weights = ops.convert_to_tensor(weights) 121 weights_shape = weights.get_shape() 122 weights_rank = weights_shape.ndims 123 if weights_rank == 0: 124 return predictions, labels, weights 125 126 predictions_shape = predictions.get_shape() 127 predictions_rank = predictions_shape.ndims 128 if (predictions_rank is not None) and (weights_rank is not None): 129 # Use static rank. 130 if weights_rank - predictions_rank == 1: 131 weights = array_ops.squeeze(weights, [-1]) 132 elif predictions_rank - weights_rank == 1: 133 weights = array_ops.expand_dims(weights, [-1]) 134 else: 135 # Use dynamic rank. 136 weights_rank_tensor = array_ops.rank(weights) 137 rank_diff = weights_rank_tensor - array_ops.rank(predictions) 138 139 def _maybe_expand_weights(): 140 return control_flow_ops.cond( 141 math_ops.equal(rank_diff, -1), 142 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) 143 144 # Don't attempt squeeze if it will fail based on static check. 145 if ((weights_rank is not None) and 146 (not weights_shape.dims[-1].is_compatible_with(1))): 147 maybe_squeeze_weights = lambda: weights 148 else: 149 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) 150 151 def _maybe_adjust_weights(): 152 return control_flow_ops.cond( 153 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 154 _maybe_expand_weights) 155 156 # If weights are scalar, do nothing. Otherwise, try to add or remove a 157 # dimension to match predictions. 158 weights = control_flow_ops.cond( 159 math_ops.equal(weights_rank_tensor, 0), lambda: weights, 160 _maybe_adjust_weights) 161 return predictions, labels, weights 162 163 164def _maybe_expand_labels(labels, predictions): 165 """If necessary, expand `labels` along last dimension to match `predictions`. 166 167 Args: 168 labels: `Tensor` or `SparseTensor` with shape 169 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies 170 num_labels=1, in which case the result is an expanded `labels` with shape 171 [D1, ... DN, 1]. 172 predictions: `Tensor` with shape [D1, ... DN, num_classes]. 173 174 Returns: 175 `labels` with the same rank as `predictions`. 176 177 Raises: 178 ValueError: if `labels` has invalid shape. 179 """ 180 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope: 181 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 182 183 # If sparse, expand sparse shape. 184 if isinstance(labels, sparse_tensor.SparseTensor): 185 return control_flow_ops.cond( 186 math_ops.equal( 187 array_ops.rank(predictions), 188 array_ops.size(labels.dense_shape) + 1), 189 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda 190 labels, 191 shape=array_ops.concat((labels.dense_shape, (1,)), 0), 192 name=scope), 193 lambda: labels) 194 195 # Otherwise, try to use static shape. 196 labels_rank = labels.get_shape().ndims 197 if labels_rank is not None: 198 predictions_rank = predictions.get_shape().ndims 199 if predictions_rank is not None: 200 if predictions_rank == labels_rank: 201 return labels 202 if predictions_rank == labels_rank + 1: 203 return array_ops.expand_dims(labels, -1, name=scope) 204 raise ValueError( 205 f'Unexpected labels shape {labels.get_shape()} for predictions ' 206 f'shape {predictions.get_shape()}. Predictions rank should be the ' 207 'same rank as labels rank or labels rank plus one .') 208 209 # Otherwise, use dynamic shape. 210 return control_flow_ops.cond( 211 math_ops.equal(array_ops.rank(predictions), 212 array_ops.rank(labels) + 1), 213 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) 214 215 216def _safe_scalar_div(numerator, denominator, name): 217 """Divides two values, returning 0 if the denominator is 0. 218 219 Args: 220 numerator: A scalar `float64` `Tensor`. 221 denominator: A scalar `float64` `Tensor`. 222 name: Name for the returned op. 223 224 Returns: 225 0 if `denominator` == 0, else `numerator` / `denominator` 226 """ 227 numerator.get_shape().with_rank_at_most(1) 228 denominator.get_shape().with_rank_at_most(1) 229 return math_ops.div_no_nan(numerator, denominator, name=name) 230 231 232def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 233 """Calculate a streaming confusion matrix. 234 235 Calculates a confusion matrix. For estimation over a stream of data, 236 the function creates an `update_op` operation. 237 238 Args: 239 labels: A `Tensor` of ground truth labels with shape [batch size] and of 240 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 241 predictions: A `Tensor` of prediction results for semantic labels, whose 242 shape is [batch size] and type `int32` or `int64`. The tensor will be 243 flattened if its rank > 1. 244 num_classes: The possible number of labels the prediction task can 245 have. This value must be provided, since a confusion matrix of 246 dimension = [num_classes, num_classes] will be allocated. 247 weights: Optional `Tensor` whose rank is either 0, or the same rank as 248 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 249 be either `1`, or the same as the corresponding `labels` dimension). 250 251 Returns: 252 total_cm: A `Tensor` representing the confusion matrix. 253 update_op: An operation that increments the confusion matrix. 254 """ 255 # Local variable to accumulate the predictions in the confusion matrix. 256 total_cm = metric_variable( 257 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 258 259 # Cast the type to int64 required by confusion_matrix_ops. 260 predictions = math_ops.cast(predictions, dtypes.int64) 261 labels = math_ops.cast(labels, dtypes.int64) 262 num_classes = math_ops.cast(num_classes, dtypes.int64) 263 264 # Flatten the input if its rank > 1. 265 if predictions.get_shape().ndims > 1: 266 predictions = array_ops.reshape(predictions, [-1]) 267 268 if labels.get_shape().ndims > 1: 269 labels = array_ops.reshape(labels, [-1]) 270 271 if (weights is not None) and (weights.get_shape().ndims > 1): 272 weights = array_ops.reshape(weights, [-1]) 273 274 # Accumulate the prediction to current confusion matrix. 275 current_cm = confusion_matrix.confusion_matrix( 276 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 277 update_op = state_ops.assign_add(total_cm, current_cm) 278 return total_cm, update_op 279 280 281def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args): 282 """Aggregate metric value across replicas.""" 283 def fn(distribution, *a): 284 """Call `metric_value_fn` in the correct control flow context.""" 285 if hasattr(distribution.extended, '_outer_control_flow_context'): 286 # If there was an outer context captured before this method was called, 287 # then we enter that context to create the metric value op. If the 288 # captured context is `None`, ops.control_dependencies(None) gives the 289 # desired behavior. Else we use `Enter` and `Exit` to enter and exit the 290 # captured context. 291 # This special handling is needed because sometimes the metric is created 292 # inside a while_loop (and perhaps a TPU rewrite context). But we don't 293 # want the value op to be evaluated every step or on the TPU. So we 294 # create it outside so that it can be evaluated at the end on the host, 295 # once the update ops have been evaluated. 296 297 # pylint: disable=protected-access 298 if distribution.extended._outer_control_flow_context is None: 299 with ops.control_dependencies(None): 300 metric_value = metric_value_fn(distribution, *a) 301 else: 302 distribution.extended._outer_control_flow_context.Enter() 303 metric_value = metric_value_fn(distribution, *a) 304 distribution.extended._outer_control_flow_context.Exit() 305 # pylint: enable=protected-access 306 else: 307 metric_value = metric_value_fn(distribution, *a) 308 if metrics_collections: 309 ops.add_to_collections(metrics_collections, metric_value) 310 return metric_value 311 312 return distribution_strategy_context.get_replica_context().merge_call( 313 fn, args=args) 314 315 316@tf_export(v1=['metrics.mean']) 317def mean(values, 318 weights=None, 319 metrics_collections=None, 320 updates_collections=None, 321 name=None): 322 """Computes the (weighted) mean of the given values. 323 324 The `mean` function creates two local variables, `total` and `count` 325 that are used to compute the average of `values`. This average is ultimately 326 returned as `mean` which is an idempotent operation that simply divides 327 `total` by `count`. 328 329 For estimation of the metric over a stream of data, the function creates an 330 `update_op` operation that updates these variables and returns the `mean`. 331 `update_op` increments `total` with the reduced sum of the product of `values` 332 and `weights`, and it increments `count` with the reduced sum of `weights`. 333 334 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 335 336 Args: 337 values: A `Tensor` of arbitrary dimensions. 338 weights: Optional `Tensor` whose rank is either 0, or the same rank as 339 `values`, and must be broadcastable to `values` (i.e., all dimensions must 340 be either `1`, or the same as the corresponding `values` dimension). 341 metrics_collections: An optional list of collections that `mean` 342 should be added to. 343 updates_collections: An optional list of collections that `update_op` 344 should be added to. 345 name: An optional variable_scope name. 346 347 Returns: 348 mean: A `Tensor` representing the current mean, the value of `total` divided 349 by `count`. 350 update_op: An operation that increments the `total` and `count` variables 351 appropriately and whose value matches `mean_value`. 352 353 Raises: 354 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 355 or if either `metrics_collections` or `updates_collections` are not a list 356 or tuple. 357 RuntimeError: If eager execution is enabled. 358 """ 359 if context.executing_eagerly(): 360 raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 361 'is enabled.') 362 363 with variable_scope.variable_scope(name, 'mean', (values, weights)): 364 values = math_ops.cast(values, dtypes.float32) 365 366 total = metric_variable([], dtypes.float32, name='total') 367 count = metric_variable([], dtypes.float32, name='count') 368 369 if weights is None: 370 num_values = math_ops.cast(array_ops.size(values), dtypes.float32) 371 else: 372 values, _, weights = _remove_squeezable_dimensions( 373 predictions=values, labels=None, weights=weights) 374 weights = weights_broadcast_ops.broadcast_weights( 375 math_ops.cast(weights, dtypes.float32), values) 376 values = math_ops.multiply(values, weights) 377 num_values = math_ops.reduce_sum(weights) 378 379 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 380 with ops.control_dependencies([values]): 381 update_count_op = state_ops.assign_add(count, num_values) 382 383 def compute_mean(_, t, c): 384 return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value') 385 386 mean_t = _aggregate_across_replicas( 387 metrics_collections, compute_mean, total, count) 388 update_op = math_ops.div_no_nan( 389 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 390 391 if updates_collections: 392 ops.add_to_collections(updates_collections, update_op) 393 394 return mean_t, update_op 395 396 397@tf_export(v1=['metrics.accuracy']) 398def accuracy(labels, 399 predictions, 400 weights=None, 401 metrics_collections=None, 402 updates_collections=None, 403 name=None): 404 """Calculates how often `predictions` matches `labels`. 405 406 The `accuracy` function creates two local variables, `total` and 407 `count` that are used to compute the frequency with which `predictions` 408 matches `labels`. This frequency is ultimately returned as `accuracy`: an 409 idempotent operation that simply divides `total` by `count`. 410 411 For estimation of the metric over a stream of data, the function creates an 412 `update_op` operation that updates these variables and returns the `accuracy`. 413 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 414 where the corresponding elements of `predictions` and `labels` match and 0.0 415 otherwise. Then `update_op` increments `total` with the reduced sum of the 416 product of `weights` and `is_correct`, and it increments `count` with the 417 reduced sum of `weights`. 418 419 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 420 421 Args: 422 labels: The ground truth values, a `Tensor` whose shape matches 423 `predictions`. 424 predictions: The predicted values, a `Tensor` of any shape. 425 weights: Optional `Tensor` whose rank is either 0, or the same rank as 426 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 427 be either `1`, or the same as the corresponding `labels` dimension). 428 metrics_collections: An optional list of collections that `accuracy` should 429 be added to. 430 updates_collections: An optional list of collections that `update_op` should 431 be added to. 432 name: An optional variable_scope name. 433 434 Returns: 435 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 436 by `count`. 437 update_op: An operation that increments the `total` and `count` variables 438 appropriately and whose value matches `accuracy`. 439 440 Raises: 441 ValueError: If `predictions` and `labels` have mismatched shapes, or if 442 `weights` is not `None` and its shape doesn't match `predictions`, or if 443 either `metrics_collections` or `updates_collections` are not a list or 444 tuple. 445 RuntimeError: If eager execution is enabled. 446 447 @compatibility(TF2) 448 `tf.compat.v1.metrics.accuracy` is not compatible with eager 449 execution or `tf.function`. 450 Please use `tf.keras.metrics.Accuracy` instead for TF2 migration. After 451 instantiating a `tf.keras.metrics.Accuracy` object, you can first call the 452 `update_state()` method to record the prediction/labels, and then call the 453 `result()` method to get the accuracy eagerly. You can also attach it to a 454 Keras model when calling the `compile` method. Please refer to [this 455 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses) 456 for more details. 457 458 #### Structural Mapping to Native TF2 459 460 Before: 461 462 ```python 463 accuracy, update_op = tf.compat.v1.metrics.accuracy( 464 labels=labels, 465 predictions=predictions, 466 weights=weights, 467 metrics_collections=metrics_collections, 468 update_collections=update_collections, 469 name=name) 470 ``` 471 472 After: 473 474 ```python 475 m = tf.keras.metrics.Accuracy( 476 name=name, 477 dtype=None) 478 479 m.update_state( 480 y_true=labels, 481 y_pred=predictions, 482 sample_weight=weights) 483 484 accuracy = m.result() 485 ``` 486 487 #### How to Map Arguments 488 489 | TF1 Arg Name | TF2 Arg Name | Note | 490 | :-------------------- | :-------------- | :------------------------- | 491 | `label` | `y_true` | In `update_state()` method | 492 | `predictions` | `y_true` | In `update_state()` method | 493 | `weights` | `sample_weight` | In `update_state()` method | 494 | `metrics_collections` | Not supported | Metrics should be tracked | 495 : : : explicitly or with Keras : 496 : : : APIs, for example, : 497 : : : [add_metric][add_metric], : 498 : : : instead of via collections : 499 | `updates_collections` | Not supported | - | 500 | `name` | `name` | In constructor | 501 502 [add_metric]:https//www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric 503 504 505 #### Before & After Usage Example 506 507 Before: 508 509 >>> g = tf.Graph() 510 >>> with g.as_default(): 511 ... logits = [1, 2, 3] 512 ... labels = [0, 2, 3] 513 ... acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels) 514 ... global_init = tf.compat.v1.global_variables_initializer() 515 ... local_init = tf.compat.v1.local_variables_initializer() 516 >>> sess = tf.compat.v1.Session(graph=g) 517 >>> sess.run([global_init, local_init]) 518 >>> print(sess.run([acc, acc_op])) 519 [0.0, 0.66667] 520 521 522 After: 523 524 >>> m = tf.keras.metrics.Accuracy() 525 >>> m.update_state([1, 2, 3], [0, 2, 3]) 526 >>> m.result().numpy() 527 0.66667 528 529 ```python 530 # Used within Keras model 531 model.compile(optimizer='sgd', 532 loss='mse', 533 metrics=[tf.keras.metrics.Accuracy()]) 534 ``` 535 536 @end_compatibility 537 """ 538 if context.executing_eagerly(): 539 raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 540 'execution is enabled.') 541 542 predictions, labels, weights = _remove_squeezable_dimensions( 543 predictions=predictions, labels=labels, weights=weights) 544 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 545 if labels.dtype != predictions.dtype: 546 predictions = math_ops.cast(predictions, labels.dtype) 547 is_correct = math_ops.cast( 548 math_ops.equal(predictions, labels), dtypes.float32) 549 return mean(is_correct, weights, metrics_collections, updates_collections, 550 name or 'accuracy') 551 552 553def _confusion_matrix_at_thresholds(labels, 554 predictions, 555 thresholds, 556 weights=None, 557 includes=None): 558 """Computes true_positives, false_negatives, true_negatives, false_positives. 559 560 This function creates up to four local variables, `true_positives`, 561 `true_negatives`, `false_positives` and `false_negatives`. 562 `true_positive[i]` is defined as the total weight of values in `predictions` 563 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 564 `false_negatives[i]` is defined as the total weight of values in `predictions` 565 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 566 `true_negatives[i]` is defined as the total weight of values in `predictions` 567 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 568 `false_positives[i]` is defined as the total weight of values in `predictions` 569 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 570 571 For estimation of these metrics over a stream of data, for each metric the 572 function respectively creates an `update_op` operation that updates the 573 variable and returns its value. 574 575 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 576 577 Args: 578 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 579 `bool`. 580 predictions: A floating point `Tensor` of arbitrary shape and whose values 581 are in the range `[0, 1]`. 582 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 583 weights: Optional `Tensor` whose rank is either 0, or the same rank as 584 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 585 be either `1`, or the same as the corresponding `labels` dimension). 586 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 587 default to all four. 588 589 Returns: 590 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 591 `includes`. 592 update_ops: Dict of operations that increments the `values`. Keys are from 593 `includes`. 594 595 Raises: 596 ValueError: If `predictions` and `labels` have mismatched shapes, or if 597 `weights` is not `None` and its shape doesn't match `predictions`, or if 598 `includes` contains invalid keys. 599 """ 600 all_includes = ('tp', 'fn', 'tn', 'fp') 601 if includes is None: 602 includes = all_includes 603 else: 604 for include in includes: 605 if include not in all_includes: 606 raise ValueError(f'Invalid key: {include}') 607 608 with ops.control_dependencies([ 609 check_ops.assert_greater_equal( 610 predictions, 611 math_ops.cast(0.0, dtype=predictions.dtype), 612 message='predictions must be in [0, 1]'), 613 check_ops.assert_less_equal( 614 predictions, 615 math_ops.cast(1.0, dtype=predictions.dtype), 616 message='predictions must be in [0, 1]') 617 ]): 618 predictions, labels, weights = _remove_squeezable_dimensions( 619 predictions=math_ops.cast(predictions, dtypes.float32), 620 labels=math_ops.cast(labels, dtype=dtypes.bool), 621 weights=weights) 622 623 num_thresholds = len(thresholds) 624 625 # Reshape predictions and labels. 626 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 627 labels_2d = array_ops.reshape( 628 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 629 630 # Use static shape if known. 631 num_predictions = predictions_2d.get_shape().as_list()[0] 632 633 # Otherwise use dynamic shape. 634 if num_predictions is None: 635 num_predictions = array_ops.shape(predictions_2d)[0] 636 thresh_tiled = array_ops.tile( 637 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 638 array_ops.stack([1, num_predictions])) 639 640 # Tile the predictions after thresholding them across different thresholds. 641 pred_is_pos = math_ops.greater( 642 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 643 thresh_tiled) 644 if ('fn' in includes) or ('tn' in includes): 645 pred_is_neg = math_ops.logical_not(pred_is_pos) 646 647 # Tile labels by number of thresholds 648 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 649 if ('fp' in includes) or ('tn' in includes): 650 label_is_neg = math_ops.logical_not(label_is_pos) 651 652 if weights is not None: 653 weights = weights_broadcast_ops.broadcast_weights( 654 math_ops.cast(weights, dtypes.float32), predictions) 655 weights_tiled = array_ops.tile( 656 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 657 thresh_tiled.get_shape().assert_is_compatible_with( 658 weights_tiled.get_shape()) 659 else: 660 weights_tiled = None 661 662 values = {} 663 update_ops = {} 664 665 if 'tp' in includes: 666 true_p = metric_variable( 667 [num_thresholds], dtypes.float32, name='true_positives') 668 is_true_positive = math_ops.cast( 669 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) 670 if weights_tiled is not None: 671 is_true_positive *= weights_tiled 672 update_ops['tp'] = state_ops.assign_add(true_p, 673 math_ops.reduce_sum( 674 is_true_positive, 1)) 675 values['tp'] = true_p 676 677 if 'fn' in includes: 678 false_n = metric_variable( 679 [num_thresholds], dtypes.float32, name='false_negatives') 680 is_false_negative = math_ops.cast( 681 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) 682 if weights_tiled is not None: 683 is_false_negative *= weights_tiled 684 update_ops['fn'] = state_ops.assign_add(false_n, 685 math_ops.reduce_sum( 686 is_false_negative, 1)) 687 values['fn'] = false_n 688 689 if 'tn' in includes: 690 true_n = metric_variable( 691 [num_thresholds], dtypes.float32, name='true_negatives') 692 is_true_negative = math_ops.cast( 693 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) 694 if weights_tiled is not None: 695 is_true_negative *= weights_tiled 696 update_ops['tn'] = state_ops.assign_add(true_n, 697 math_ops.reduce_sum( 698 is_true_negative, 1)) 699 values['tn'] = true_n 700 701 if 'fp' in includes: 702 false_p = metric_variable( 703 [num_thresholds], dtypes.float32, name='false_positives') 704 is_false_positive = math_ops.cast( 705 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) 706 if weights_tiled is not None: 707 is_false_positive *= weights_tiled 708 update_ops['fp'] = state_ops.assign_add(false_p, 709 math_ops.reduce_sum( 710 is_false_positive, 1)) 711 values['fp'] = false_p 712 713 return values, update_ops 714 715 716def _aggregate_variable(v, collections): 717 f = lambda distribution, value: distribution.extended.read_var(value) 718 return _aggregate_across_replicas(collections, f, v) 719 720 721@tf_export(v1=['metrics.auc']) 722@deprecated(None, 723 'The value of AUC returned by this may race with the update so ' 724 'this is deprecated. Please use tf.keras.metrics.AUC instead.') 725def auc(labels, 726 predictions, 727 weights=None, 728 num_thresholds=200, 729 metrics_collections=None, 730 updates_collections=None, 731 curve='ROC', 732 name=None, 733 summation_method='trapezoidal', 734 thresholds=None): 735 """Computes the approximate AUC via a Riemann sum. 736 737 The `auc` function creates four local variables, `true_positives`, 738 `true_negatives`, `false_positives` and `false_negatives` that are used to 739 compute the AUC. To discretize the AUC curve, a linearly spaced set of 740 thresholds is used to compute pairs of recall and precision values. The area 741 under the ROC-curve is therefore computed using the height of the recall 742 values by the false positive rate, while the area under the PR-curve is the 743 computed using the height of the precision values by the recall. 744 745 This value is ultimately returned as `auc`, an idempotent operation that 746 computes the area under a discretized curve of precision versus recall values 747 (computed using the aforementioned variables). The `num_thresholds` variable 748 controls the degree of discretization with larger numbers of thresholds more 749 closely approximating the true AUC. The quality of the approximation may vary 750 dramatically depending on `num_thresholds`. 751 752 For best results, `predictions` should be distributed approximately uniformly 753 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 754 approximation may be poor if this is not the case. Setting `summation_method` 755 to 'minoring' or 'majoring' can help quantify the error in the approximation 756 by providing lower or upper bound estimate of the AUC. The `thresholds` 757 parameter can be used to manually specify thresholds which split the 758 predictions more evenly. 759 760 For estimation of the metric over a stream of data, the function creates an 761 `update_op` operation that updates these variables and returns the `auc`. 762 763 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 764 765 Args: 766 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 767 `bool`. 768 predictions: A floating point `Tensor` of arbitrary shape and whose values 769 are in the range `[0, 1]`. 770 weights: Optional `Tensor` whose rank is either 0, or the same rank as 771 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 772 be either `1`, or the same as the corresponding `labels` dimension). 773 num_thresholds: The number of thresholds to use when discretizing the roc 774 curve. 775 metrics_collections: An optional list of collections that `auc` should be 776 added to. 777 updates_collections: An optional list of collections that `update_op` should 778 be added to. 779 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 780 'PR' for the Precision-Recall-curve. 781 name: An optional variable_scope name. 782 summation_method: Specifies the Riemann summation method used 783 (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that 784 applies the trapezoidal rule; 'careful_interpolation', a variant of it 785 differing only by a more correct interpolation scheme for PR-AUC - 786 interpolating (true/false) positives but not the ratio that is precision; 787 'minoring' that applies left summation for increasing intervals and right 788 summation for decreasing intervals; 'majoring' that does the opposite. 789 Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' 790 (to be deprecated soon) as it applies the same method for ROC, and a 791 better one (see Davis & Goadrich 2006 for details) for the PR curve. 792 thresholds: An optional list of floating point values to use as the 793 thresholds for discretizing the curve. If set, the `num_thresholds` 794 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 795 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be 796 automatically included with these to correctly handle predictions equal to 797 exactly 0 or 1. 798 799 Returns: 800 auc: A scalar `Tensor` representing the current area-under-curve. 801 update_op: An operation that increments the `true_positives`, 802 `true_negatives`, `false_positives` and `false_negatives` variables 803 appropriately and whose value matches `auc`. 804 805 Raises: 806 ValueError: If `predictions` and `labels` have mismatched shapes, or if 807 `weights` is not `None` and its shape doesn't match `predictions`, or if 808 either `metrics_collections` or `updates_collections` are not a list or 809 tuple. 810 RuntimeError: If eager execution is enabled. 811 """ 812 if context.executing_eagerly(): 813 raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 814 'is enabled.') 815 816 with variable_scope.variable_scope(name, 'auc', 817 (labels, predictions, weights)): 818 if curve != 'ROC' and curve != 'PR': 819 raise ValueError(f'Curve must be either ROC or PR. Curve {curve} is ' 820 'unknown.') 821 822 kepsilon = 1e-7 # To account for floating point imprecisions. 823 if thresholds is not None: 824 # If specified, use the supplied thresholds. 825 thresholds = sorted(thresholds) 826 num_thresholds = len(thresholds) + 2 827 else: 828 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 829 # (0, 1). 830 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 831 for i in range(num_thresholds - 2)] 832 833 # Add an endpoint "threshold" below zero and above one for either threshold 834 # method. 835 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 836 837 values, update_ops = _confusion_matrix_at_thresholds( 838 labels, predictions, thresholds, weights) 839 840 # Add epsilons to avoid dividing by 0. 841 epsilon = 1.0e-6 842 843 def interpolate_pr_auc(tp, fp, fn): 844 """Interpolation formula inspired by section 4 of (Davis et al., 2006). 845 846 Note here we derive & use a closed formula not present in the paper 847 - as follows: 848 Modeling all of TP (true positive weight), 849 FP (false positive weight) and their sum P = TP + FP (positive weight) 850 as varying linearly within each interval [A, B] between successive 851 thresholds, we get 852 Precision = (TP_A + slope * (P - P_A)) / P 853 with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). 854 The area within the interval is thus (slope / total_pos_weight) times 855 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 856 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 857 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 858 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 859 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 860 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 861 where dTP == TP_B - TP_A. 862 Note that when P_A == 0 the above calculation simplifies into 863 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 864 which is really equivalent to imputing constant precision throughout the 865 first bucket having >0 true positives. 866 867 Args: 868 tp: true positive counts 869 fp: false positive counts 870 fn: false negative counts 871 872 Returns: 873 pr_auc: an approximation of the area under the P-R curve. 874 875 References: 876 The Relationship Between Precision-Recall and ROC Curves: 877 [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874) 878 ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf)) 879 """ 880 dtp = tp[:num_thresholds - 1] - tp[1:] 881 p = tp + fp 882 prec_slope = math_ops.div_no_nan( 883 dtp, 884 math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0), 885 name='prec_slope') 886 intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) 887 safe_p_ratio = array_ops.where( 888 math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), 889 math_ops.div_no_nan( 890 p[:num_thresholds - 1], 891 math_ops.maximum(p[1:], 0), 892 name='recall_relative_ratio'), array_ops.ones_like(p[1:])) 893 return math_ops.reduce_sum( 894 math_ops.div_no_nan( 895 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 896 math_ops.maximum(tp[1:] + fn[1:], 0), 897 name='pr_auc_increment'), 898 name='interpolate_pr_auc') 899 900 def compute_auc(tp, fn, tn, fp, name): 901 """Computes the roc-auc or pr-auc based on confusion counts.""" 902 if curve == 'PR': 903 if summation_method == 'trapezoidal': 904 logging.warning( 905 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' 906 'please switch to "careful_interpolation" instead.') 907 elif summation_method == 'careful_interpolation': 908 # This one is a bit tricky and is handled separately. 909 return interpolate_pr_auc(tp, fp, fn) 910 rec = math_ops.divide(tp + epsilon, tp + fn + epsilon) 911 if curve == 'ROC': 912 fp_rate = math_ops.divide(fp, fp + tn + epsilon) 913 x = fp_rate 914 y = rec 915 else: # curve == 'PR'. 916 prec = math_ops.divide(tp + epsilon, tp + fp + epsilon) 917 x = rec 918 y = prec 919 if summation_method in ('trapezoidal', 'careful_interpolation'): 920 # Note that the case ('PR', 'careful_interpolation') has been handled 921 # above. 922 return math_ops.reduce_sum( 923 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 924 (y[:num_thresholds - 1] + y[1:]) / 2.), 925 name=name) 926 elif summation_method == 'minoring': 927 return math_ops.reduce_sum( 928 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 929 math_ops.minimum(y[:num_thresholds - 1], y[1:])), 930 name=name) 931 elif summation_method == 'majoring': 932 return math_ops.reduce_sum( 933 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 934 math_ops.maximum(y[:num_thresholds - 1], y[1:])), 935 name=name) 936 else: 937 raise ValueError(f'Invalid summation_method: {summation_method} ' 938 'summation_method should be \'trapezoidal\', ' 939 '\'careful_interpolation\', \'minoring\', or ' 940 '\'majoring\'.') 941 942 # sum up the areas of all the trapeziums 943 def compute_auc_value(_, values): 944 return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'], 945 'value') 946 947 auc_value = _aggregate_across_replicas( 948 metrics_collections, compute_auc_value, values) 949 update_op = compute_auc(update_ops['tp'], update_ops['fn'], 950 update_ops['tn'], update_ops['fp'], 'update_op') 951 952 if updates_collections: 953 ops.add_to_collections(updates_collections, update_op) 954 955 return auc_value, update_op 956 957 958@tf_export(v1=['metrics.mean_absolute_error']) 959def mean_absolute_error(labels, 960 predictions, 961 weights=None, 962 metrics_collections=None, 963 updates_collections=None, 964 name=None): 965 """Computes the mean absolute error between the labels and predictions. 966 967 The `mean_absolute_error` function creates two local variables, 968 `total` and `count` that are used to compute the mean absolute error. This 969 average is weighted by `weights`, and it is ultimately returned as 970 `mean_absolute_error`: an idempotent operation that simply divides `total` by 971 `count`. 972 973 For estimation of the metric over a stream of data, the function creates an 974 `update_op` operation that updates these variables and returns the 975 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 976 absolute value of the differences between `predictions` and `labels`. Then 977 `update_op` increments `total` with the reduced sum of the product of 978 `weights` and `absolute_errors`, and it increments `count` with the reduced 979 sum of `weights` 980 981 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 982 983 Args: 984 labels: A `Tensor` of the same shape as `predictions`. 985 predictions: A `Tensor` of arbitrary shape. 986 weights: Optional `Tensor` whose rank is either 0, or the same rank as 987 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 988 be either `1`, or the same as the corresponding `labels` dimension). 989 metrics_collections: An optional list of collections that 990 `mean_absolute_error` should be added to. 991 updates_collections: An optional list of collections that `update_op` should 992 be added to. 993 name: An optional variable_scope name. 994 995 Returns: 996 mean_absolute_error: A `Tensor` representing the current mean, the value of 997 `total` divided by `count`. 998 update_op: An operation that increments the `total` and `count` variables 999 appropriately and whose value matches `mean_absolute_error`. 1000 1001 Raises: 1002 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1003 `weights` is not `None` and its shape doesn't match `predictions`, or if 1004 either `metrics_collections` or `updates_collections` are not a list or 1005 tuple. 1006 RuntimeError: If eager execution is enabled. 1007 """ 1008 if context.executing_eagerly(): 1009 raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 1010 'when eager execution is enabled.') 1011 1012 predictions, labels, weights = _remove_squeezable_dimensions( 1013 predictions=predictions, labels=labels, weights=weights) 1014 absolute_errors = math_ops.abs(predictions - labels) 1015 return mean(absolute_errors, weights, metrics_collections, 1016 updates_collections, name or 'mean_absolute_error') 1017 1018 1019@tf_export(v1=['metrics.mean_cosine_distance']) 1020def mean_cosine_distance(labels, 1021 predictions, 1022 dim, 1023 weights=None, 1024 metrics_collections=None, 1025 updates_collections=None, 1026 name=None): 1027 """Computes the cosine distance between the labels and predictions. 1028 1029 The `mean_cosine_distance` function creates two local variables, 1030 `total` and `count` that are used to compute the average cosine distance 1031 between `predictions` and `labels`. This average is weighted by `weights`, 1032 and it is ultimately returned as `mean_distance`, which is an idempotent 1033 operation that simply divides `total` by `count`. 1034 1035 For estimation of the metric over a stream of data, the function creates an 1036 `update_op` operation that updates these variables and returns the 1037 `mean_distance`. 1038 1039 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1040 1041 Args: 1042 labels: A `Tensor` of arbitrary shape. 1043 predictions: A `Tensor` of the same shape as `labels`. 1044 dim: The dimension along which the cosine distance is computed. 1045 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1046 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1047 be either `1`, or the same as the corresponding `labels` dimension). Also, 1048 dimension `dim` must be `1`. 1049 metrics_collections: An optional list of collections that the metric 1050 value variable should be added to. 1051 updates_collections: An optional list of collections that the metric update 1052 ops should be added to. 1053 name: An optional variable_scope name. 1054 1055 Returns: 1056 mean_distance: A `Tensor` representing the current mean, the value of 1057 `total` divided by `count`. 1058 update_op: An operation that increments the `total` and `count` variables 1059 appropriately. 1060 1061 Raises: 1062 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1063 `weights` is not `None` and its shape doesn't match `predictions`, or if 1064 either `metrics_collections` or `updates_collections` are not a list or 1065 tuple. 1066 RuntimeError: If eager execution is enabled. 1067 """ 1068 if context.executing_eagerly(): 1069 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 1070 'eager execution is enabled.') 1071 1072 predictions, labels, weights = _remove_squeezable_dimensions( 1073 predictions=predictions, labels=labels, weights=weights) 1074 radial_diffs = math_ops.multiply(predictions, labels) 1075 radial_diffs = math_ops.reduce_sum( 1076 radial_diffs, axis=[ 1077 dim, 1078 ], keepdims=True) 1079 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or 1080 'mean_cosine_distance') 1081 mean_distance = math_ops.subtract(1.0, mean_distance) 1082 update_op = math_ops.subtract(1.0, update_op) 1083 1084 if metrics_collections: 1085 ops.add_to_collections(metrics_collections, mean_distance) 1086 1087 if updates_collections: 1088 ops.add_to_collections(updates_collections, update_op) 1089 1090 return mean_distance, update_op 1091 1092 1093@tf_export(v1=['metrics.mean_per_class_accuracy']) 1094def mean_per_class_accuracy(labels, 1095 predictions, 1096 num_classes, 1097 weights=None, 1098 metrics_collections=None, 1099 updates_collections=None, 1100 name=None): 1101 """Calculates the mean of the per-class accuracies. 1102 1103 Calculates the accuracy for each class, then takes the mean of that. 1104 1105 For estimation of the metric over a stream of data, the function creates an 1106 `update_op` operation that updates the accuracy of each class and returns 1107 them. 1108 1109 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1110 1111 Args: 1112 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1113 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1114 predictions: A `Tensor` of prediction results for semantic labels, whose 1115 shape is [batch size] and type `int32` or `int64`. The tensor will be 1116 flattened if its rank > 1. 1117 num_classes: The possible number of labels the prediction task can 1118 have. This value must be provided, since two variables with shape = 1119 [num_classes] will be allocated. 1120 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1121 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1122 be either `1`, or the same as the corresponding `labels` dimension). 1123 metrics_collections: An optional list of collections that 1124 `mean_per_class_accuracy' 1125 should be added to. 1126 updates_collections: An optional list of collections `update_op` should be 1127 added to. 1128 name: An optional variable_scope name. 1129 1130 Returns: 1131 mean_accuracy: A `Tensor` representing the mean per class accuracy. 1132 update_op: An operation that updates the accuracy tensor. 1133 1134 Raises: 1135 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1136 `weights` is not `None` and its shape doesn't match `predictions`, or if 1137 either `metrics_collections` or `updates_collections` are not a list or 1138 tuple. 1139 RuntimeError: If eager execution is enabled. 1140 """ 1141 if context.executing_eagerly(): 1142 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 1143 'when eager execution is enabled.') 1144 1145 with variable_scope.variable_scope(name, 'mean_accuracy', 1146 (predictions, labels, weights)): 1147 labels = math_ops.cast(labels, dtypes.int64) 1148 1149 # Flatten the input if its rank > 1. 1150 if labels.get_shape().ndims > 1: 1151 labels = array_ops.reshape(labels, [-1]) 1152 1153 if predictions.get_shape().ndims > 1: 1154 predictions = array_ops.reshape(predictions, [-1]) 1155 1156 # Check if shape is compatible. 1157 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1158 1159 total = metric_variable([num_classes], dtypes.float32, name='total') 1160 count = metric_variable([num_classes], dtypes.float32, name='count') 1161 1162 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) 1163 1164 if labels.dtype != predictions.dtype: 1165 predictions = math_ops.cast(predictions, labels.dtype) 1166 is_correct = math_ops.cast( 1167 math_ops.equal(predictions, labels), dtypes.float32) 1168 1169 if weights is not None: 1170 if weights.get_shape().ndims > 1: 1171 weights = array_ops.reshape(weights, [-1]) 1172 weights = math_ops.cast(weights, dtypes.float32) 1173 1174 is_correct *= weights 1175 ones *= weights 1176 1177 update_total_op = state_ops.scatter_add(total, labels, ones) 1178 update_count_op = state_ops.scatter_add(count, labels, is_correct) 1179 1180 def compute_mean_accuracy(_, count, total): 1181 per_class_accuracy = math_ops.div_no_nan( 1182 count, math_ops.maximum(total, 0), name=None) 1183 mean_accuracy_v = math_ops.reduce_mean( 1184 per_class_accuracy, name='mean_accuracy') 1185 return mean_accuracy_v 1186 1187 mean_accuracy_v = _aggregate_across_replicas( 1188 metrics_collections, compute_mean_accuracy, count, total) 1189 1190 update_op = math_ops.div_no_nan( 1191 update_count_op, math_ops.maximum(update_total_op, 0), name='update_op') 1192 if updates_collections: 1193 ops.add_to_collections(updates_collections, update_op) 1194 1195 return mean_accuracy_v, update_op 1196 1197 1198@tf_export(v1=['metrics.mean_iou']) 1199def mean_iou(labels, 1200 predictions, 1201 num_classes, 1202 weights=None, 1203 metrics_collections=None, 1204 updates_collections=None, 1205 name=None): 1206 """Calculate per-step mean Intersection-Over-Union (mIOU). 1207 1208 Mean Intersection-Over-Union is a common evaluation metric for 1209 semantic image segmentation, which first computes the IOU for each 1210 semantic class and then computes the average over classes. 1211 IOU is defined as follows: 1212 IOU = true_positive / (true_positive + false_positive + false_negative). 1213 The predictions are accumulated in a confusion matrix, weighted by `weights`, 1214 and mIOU is then calculated from it. 1215 1216 For estimation of the metric over a stream of data, the function creates an 1217 `update_op` operation that updates these variables and returns the `mean_iou`. 1218 1219 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1220 1221 Args: 1222 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1223 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1224 predictions: A `Tensor` of prediction results for semantic labels, whose 1225 shape is [batch size] and type `int32` or `int64`. The tensor will be 1226 flattened if its rank > 1. 1227 num_classes: The possible number of labels the prediction task can 1228 have. This value must be provided, since a confusion matrix of 1229 dimension = [num_classes, num_classes] will be allocated. 1230 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1231 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1232 be either `1`, or the same as the corresponding `labels` dimension). 1233 metrics_collections: An optional list of collections that `mean_iou` 1234 should be added to. 1235 updates_collections: An optional list of collections `update_op` should be 1236 added to. 1237 name: An optional variable_scope name. 1238 1239 Returns: 1240 mean_iou: A `Tensor` representing the mean intersection-over-union. 1241 update_op: An operation that increments the confusion matrix. 1242 1243 Raises: 1244 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1245 `weights` is not `None` and its shape doesn't match `predictions`, or if 1246 either `metrics_collections` or `updates_collections` are not a list or 1247 tuple. 1248 RuntimeError: If eager execution is enabled. 1249 """ 1250 if context.executing_eagerly(): 1251 raise RuntimeError('tf.metrics.mean_iou is not supported when ' 1252 'eager execution is enabled.') 1253 1254 with variable_scope.variable_scope(name, 'mean_iou', 1255 (predictions, labels, weights)): 1256 # Check if shape is compatible. 1257 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1258 1259 total_cm, update_op = _streaming_confusion_matrix(labels, predictions, 1260 num_classes, weights) 1261 1262 def compute_mean_iou(_, total_cm): 1263 """Compute the mean intersection-over-union via the confusion matrix.""" 1264 sum_over_row = math_ops.cast( 1265 math_ops.reduce_sum(total_cm, 0), dtypes.float32) 1266 sum_over_col = math_ops.cast( 1267 math_ops.reduce_sum(total_cm, 1), dtypes.float32) 1268 cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32) 1269 denominator = sum_over_row + sum_over_col - cm_diag 1270 1271 # The mean is only computed over classes that appear in the 1272 # label or prediction tensor. If the denominator is 0, we need to 1273 # ignore the class. 1274 num_valid_entries = math_ops.reduce_sum( 1275 math_ops.cast( 1276 math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) 1277 1278 # If the value of the denominator is 0, set it to 1 to avoid 1279 # zero division. 1280 denominator = array_ops.where( 1281 math_ops.greater(denominator, 0), denominator, 1282 array_ops.ones_like(denominator)) 1283 iou = math_ops.divide(cm_diag, denominator) 1284 1285 # If the number of valid entries is 0 (no classes) we return 0. 1286 result = array_ops.where( 1287 math_ops.greater(num_valid_entries, 0), 1288 math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0) 1289 return result 1290 1291 # TODO(priyag): Use outside_compilation if in TPU context. 1292 mean_iou_v = _aggregate_across_replicas( 1293 metrics_collections, compute_mean_iou, total_cm) 1294 1295 if updates_collections: 1296 ops.add_to_collections(updates_collections, update_op) 1297 1298 return mean_iou_v, update_op 1299 1300 1301@tf_export(v1=['metrics.mean_relative_error']) 1302def mean_relative_error(labels, 1303 predictions, 1304 normalizer, 1305 weights=None, 1306 metrics_collections=None, 1307 updates_collections=None, 1308 name=None): 1309 """Computes the mean relative error by normalizing with the given values. 1310 1311 The `mean_relative_error` function creates two local variables, 1312 `total` and `count` that are used to compute the mean relative absolute error. 1313 This average is weighted by `weights`, and it is ultimately returned as 1314 `mean_relative_error`: an idempotent operation that simply divides `total` by 1315 `count`. 1316 1317 For estimation of the metric over a stream of data, the function creates an 1318 `update_op` operation that updates these variables and returns the 1319 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1320 absolute value of the differences between `predictions` and `labels` by the 1321 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1322 product of `weights` and `relative_errors`, and it increments `count` with the 1323 reduced sum of `weights`. 1324 1325 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1326 1327 Args: 1328 labels: A `Tensor` of the same shape as `predictions`. 1329 predictions: A `Tensor` of arbitrary shape. 1330 normalizer: A `Tensor` of the same shape as `predictions`. 1331 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1332 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1333 be either `1`, or the same as the corresponding `labels` dimension). 1334 metrics_collections: An optional list of collections that 1335 `mean_relative_error` should be added to. 1336 updates_collections: An optional list of collections that `update_op` should 1337 be added to. 1338 name: An optional variable_scope name. 1339 1340 Returns: 1341 mean_relative_error: A `Tensor` representing the current mean, the value of 1342 `total` divided by `count`. 1343 update_op: An operation that increments the `total` and `count` variables 1344 appropriately and whose value matches `mean_relative_error`. 1345 1346 Raises: 1347 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1348 `weights` is not `None` and its shape doesn't match `predictions`, or if 1349 either `metrics_collections` or `updates_collections` are not a list or 1350 tuple. 1351 RuntimeError: If eager execution is enabled. 1352 """ 1353 if context.executing_eagerly(): 1354 raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 1355 'eager execution is enabled.') 1356 1357 predictions, labels, weights = _remove_squeezable_dimensions( 1358 predictions=predictions, labels=labels, weights=weights) 1359 1360 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions( 1361 predictions, normalizer) 1362 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1363 relative_errors = array_ops.where( 1364 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), 1365 math_ops.divide(math_ops.abs(labels - predictions), normalizer)) 1366 return mean(relative_errors, weights, metrics_collections, 1367 updates_collections, name or 'mean_relative_error') 1368 1369 1370@tf_export(v1=['metrics.mean_squared_error']) 1371def mean_squared_error(labels, 1372 predictions, 1373 weights=None, 1374 metrics_collections=None, 1375 updates_collections=None, 1376 name=None): 1377 """Computes the mean squared error between the labels and predictions. 1378 1379 The `mean_squared_error` function creates two local variables, 1380 `total` and `count` that are used to compute the mean squared error. 1381 This average is weighted by `weights`, and it is ultimately returned as 1382 `mean_squared_error`: an idempotent operation that simply divides `total` by 1383 `count`. 1384 1385 For estimation of the metric over a stream of data, the function creates an 1386 `update_op` operation that updates these variables and returns the 1387 `mean_squared_error`. Internally, a `squared_error` operation computes the 1388 element-wise square of the difference between `predictions` and `labels`. Then 1389 `update_op` increments `total` with the reduced sum of the product of 1390 `weights` and `squared_error`, and it increments `count` with the reduced sum 1391 of `weights`. 1392 1393 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1394 1395 Args: 1396 labels: A `Tensor` of the same shape as `predictions`. 1397 predictions: A `Tensor` of arbitrary shape. 1398 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1399 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1400 be either `1`, or the same as the corresponding `labels` dimension). 1401 metrics_collections: An optional list of collections that 1402 `mean_squared_error` should be added to. 1403 updates_collections: An optional list of collections that `update_op` should 1404 be added to. 1405 name: An optional variable_scope name. 1406 1407 Returns: 1408 mean_squared_error: A `Tensor` representing the current mean, the value of 1409 `total` divided by `count`. 1410 update_op: An operation that increments the `total` and `count` variables 1411 appropriately and whose value matches `mean_squared_error`. 1412 1413 Raises: 1414 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1415 `weights` is not `None` and its shape doesn't match `predictions`, or if 1416 either `metrics_collections` or `updates_collections` are not a list or 1417 tuple. 1418 RuntimeError: If eager execution is enabled. 1419 """ 1420 if context.executing_eagerly(): 1421 raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 1422 'eager execution is enabled.') 1423 1424 predictions, labels, weights = _remove_squeezable_dimensions( 1425 predictions=predictions, labels=labels, weights=weights) 1426 squared_error = math_ops.squared_difference(labels, predictions) 1427 return mean(squared_error, weights, metrics_collections, updates_collections, 1428 name or 'mean_squared_error') 1429 1430 1431@tf_export(v1=['metrics.mean_tensor']) 1432def mean_tensor(values, 1433 weights=None, 1434 metrics_collections=None, 1435 updates_collections=None, 1436 name=None): 1437 """Computes the element-wise (weighted) mean of the given tensors. 1438 1439 In contrast to the `mean` function which returns a scalar with the 1440 mean, this function returns an average tensor with the same shape as the 1441 input tensors. 1442 1443 The `mean_tensor` function creates two local variables, 1444 `total_tensor` and `count_tensor` that are used to compute the average of 1445 `values`. This average is ultimately returned as `mean` which is an idempotent 1446 operation that simply divides `total` by `count`. 1447 1448 For estimation of the metric over a stream of data, the function creates an 1449 `update_op` operation that updates these variables and returns the `mean`. 1450 `update_op` increments `total` with the reduced sum of the product of `values` 1451 and `weights`, and it increments `count` with the reduced sum of `weights`. 1452 1453 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1454 1455 Args: 1456 values: A `Tensor` of arbitrary dimensions. 1457 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1458 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1459 be either `1`, or the same as the corresponding `values` dimension). 1460 metrics_collections: An optional list of collections that `mean` 1461 should be added to. 1462 updates_collections: An optional list of collections that `update_op` 1463 should be added to. 1464 name: An optional variable_scope name. 1465 1466 Returns: 1467 mean: A float `Tensor` representing the current mean, the value of `total` 1468 divided by `count`. 1469 update_op: An operation that increments the `total` and `count` variables 1470 appropriately and whose value matches `mean_value`. 1471 1472 Raises: 1473 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1474 or if either `metrics_collections` or `updates_collections` are not a list 1475 or tuple. 1476 RuntimeError: If eager execution is enabled. 1477 """ 1478 if context.executing_eagerly(): 1479 raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 1480 'eager execution is enabled.') 1481 1482 with variable_scope.variable_scope(name, 'mean', (values, weights)): 1483 values = math_ops.cast(values, dtypes.float32) 1484 total = metric_variable( 1485 values.get_shape(), dtypes.float32, name='total_tensor') 1486 count = metric_variable( 1487 values.get_shape(), dtypes.float32, name='count_tensor') 1488 1489 num_values = array_ops.ones_like(values) 1490 if weights is not None: 1491 values, _, weights = _remove_squeezable_dimensions( 1492 predictions=values, labels=None, weights=weights) 1493 weights = weights_broadcast_ops.broadcast_weights( 1494 math_ops.cast(weights, dtypes.float32), values) 1495 values = math_ops.multiply(values, weights) 1496 num_values = math_ops.multiply(num_values, weights) 1497 1498 update_total_op = state_ops.assign_add(total, values) 1499 with ops.control_dependencies([values]): 1500 update_count_op = state_ops.assign_add(count, num_values) 1501 1502 compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda 1503 t, math_ops.maximum(c, 0), name='value') 1504 1505 mean_t = _aggregate_across_replicas( 1506 metrics_collections, compute_mean, total, count) 1507 1508 update_op = math_ops.div_no_nan( 1509 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 1510 if updates_collections: 1511 ops.add_to_collections(updates_collections, update_op) 1512 1513 return mean_t, update_op 1514 1515 1516@tf_export(v1=['metrics.percentage_below']) 1517def percentage_below(values, 1518 threshold, 1519 weights=None, 1520 metrics_collections=None, 1521 updates_collections=None, 1522 name=None): 1523 """Computes the percentage of values less than the given threshold. 1524 1525 The `percentage_below` function creates two local variables, 1526 `total` and `count` that are used to compute the percentage of `values` that 1527 fall below `threshold`. This rate is weighted by `weights`, and it is 1528 ultimately returned as `percentage` which is an idempotent operation that 1529 simply divides `total` by `count`. 1530 1531 For estimation of the metric over a stream of data, the function creates an 1532 `update_op` operation that updates these variables and returns the 1533 `percentage`. 1534 1535 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1536 1537 Args: 1538 values: A numeric `Tensor` of arbitrary size. 1539 threshold: A scalar threshold. 1540 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1541 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1542 be either `1`, or the same as the corresponding `values` dimension). 1543 metrics_collections: An optional list of collections that the metric 1544 value variable should be added to. 1545 updates_collections: An optional list of collections that the metric update 1546 ops should be added to. 1547 name: An optional variable_scope name. 1548 1549 Returns: 1550 percentage: A `Tensor` representing the current mean, the value of `total` 1551 divided by `count`. 1552 update_op: An operation that increments the `total` and `count` variables 1553 appropriately. 1554 1555 Raises: 1556 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1557 or if either `metrics_collections` or `updates_collections` are not a list 1558 or tuple. 1559 RuntimeError: If eager execution is enabled. 1560 """ 1561 if context.executing_eagerly(): 1562 raise RuntimeError('tf.metrics.percentage_below is not supported when ' 1563 'eager execution is enabled.') 1564 1565 is_below_threshold = math_ops.cast( 1566 math_ops.less(values, threshold), dtypes.float32) 1567 return mean(is_below_threshold, weights, metrics_collections, 1568 updates_collections, name or 'percentage_below_threshold') 1569 1570 1571def _count_condition(values, 1572 weights=None, 1573 metrics_collections=None, 1574 updates_collections=None): 1575 """Sums the weights of cases where the given values are True. 1576 1577 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1578 1579 Args: 1580 values: A `bool` `Tensor` of arbitrary size. 1581 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1582 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1583 be either `1`, or the same as the corresponding `values` dimension). 1584 metrics_collections: An optional list of collections that the metric 1585 value variable should be added to. 1586 updates_collections: An optional list of collections that the metric update 1587 ops should be added to. 1588 1589 Returns: 1590 value_tensor: A `Tensor` representing the current value of the metric. 1591 update_op: An operation that accumulates the error from a batch of data. 1592 1593 Raises: 1594 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1595 or if either `metrics_collections` or `updates_collections` are not a list 1596 or tuple. 1597 """ 1598 check_ops.assert_type(values, dtypes.bool) 1599 count = metric_variable([], dtypes.float32, name='count') 1600 1601 values = math_ops.cast(values, dtypes.float32) 1602 if weights is not None: 1603 with ops.control_dependencies((check_ops.assert_rank_in( 1604 weights, (0, array_ops.rank(values))),)): 1605 weights = math_ops.cast(weights, dtypes.float32) 1606 values = math_ops.multiply(values, weights) 1607 1608 value_tensor = _aggregate_variable(count, metrics_collections) 1609 1610 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 1611 if updates_collections: 1612 ops.add_to_collections(updates_collections, update_op) 1613 1614 return value_tensor, update_op 1615 1616 1617@tf_export(v1=['metrics.false_negatives']) 1618def false_negatives(labels, 1619 predictions, 1620 weights=None, 1621 metrics_collections=None, 1622 updates_collections=None, 1623 name=None): 1624 """Computes the total number of false negatives. 1625 1626 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1627 1628 Args: 1629 labels: The ground truth values, a `Tensor` whose dimensions must match 1630 `predictions`. Will be cast to `bool`. 1631 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1632 be cast to `bool`. 1633 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1634 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1635 be either `1`, or the same as the corresponding `labels` dimension). 1636 metrics_collections: An optional list of collections that the metric 1637 value variable should be added to. 1638 updates_collections: An optional list of collections that the metric update 1639 ops should be added to. 1640 name: An optional variable_scope name. 1641 1642 Returns: 1643 value_tensor: A `Tensor` representing the current value of the metric. 1644 update_op: An operation that accumulates the error from a batch of data. 1645 1646 Raises: 1647 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1648 or if either `metrics_collections` or `updates_collections` are not a list 1649 or tuple. 1650 RuntimeError: If eager execution is enabled. 1651 """ 1652 if context.executing_eagerly(): 1653 raise RuntimeError('tf.metrics.false_negatives is not supported when ' 1654 'eager execution is enabled.') 1655 1656 with variable_scope.variable_scope(name, 'false_negatives', 1657 (predictions, labels, weights)): 1658 1659 predictions, labels, weights = _remove_squeezable_dimensions( 1660 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1661 labels=math_ops.cast(labels, dtype=dtypes.bool), 1662 weights=weights) 1663 is_false_negative = math_ops.logical_and( 1664 math_ops.equal(labels, True), math_ops.equal(predictions, False)) 1665 return _count_condition(is_false_negative, weights, metrics_collections, 1666 updates_collections) 1667 1668 1669@tf_export(v1=['metrics.false_negatives_at_thresholds']) 1670def false_negatives_at_thresholds(labels, 1671 predictions, 1672 thresholds, 1673 weights=None, 1674 metrics_collections=None, 1675 updates_collections=None, 1676 name=None): 1677 """Computes false negatives at provided threshold values. 1678 1679 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1680 1681 Args: 1682 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1683 `bool`. 1684 predictions: A floating point `Tensor` of arbitrary shape and whose values 1685 are in the range `[0, 1]`. 1686 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1687 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1688 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1689 be either `1`, or the same as the corresponding `labels` dimension). 1690 metrics_collections: An optional list of collections that `false_negatives` 1691 should be added to. 1692 updates_collections: An optional list of collections that `update_op` should 1693 be added to. 1694 name: An optional variable_scope name. 1695 1696 Returns: 1697 false_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1698 update_op: An operation that updates the `false_negatives` variable and 1699 returns its current value. 1700 1701 Raises: 1702 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1703 `weights` is not `None` and its shape doesn't match `predictions`, or if 1704 either `metrics_collections` or `updates_collections` are not a list or 1705 tuple. 1706 RuntimeError: If eager execution is enabled. 1707 """ 1708 if context.executing_eagerly(): 1709 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 1710 'supported when eager execution is enabled.') 1711 1712 with variable_scope.variable_scope(name, 'false_negatives', 1713 (predictions, labels, weights)): 1714 values, update_ops = _confusion_matrix_at_thresholds( 1715 labels, predictions, thresholds, weights=weights, includes=('fn',)) 1716 1717 fn_value = _aggregate_variable(values['fn'], metrics_collections) 1718 1719 if updates_collections: 1720 ops.add_to_collections(updates_collections, update_ops['fn']) 1721 1722 return fn_value, update_ops['fn'] 1723 1724 1725@tf_export(v1=['metrics.false_positives']) 1726def false_positives(labels, 1727 predictions, 1728 weights=None, 1729 metrics_collections=None, 1730 updates_collections=None, 1731 name=None): 1732 """Sum the weights of false positives. 1733 1734 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1735 1736 Args: 1737 labels: The ground truth values, a `Tensor` whose dimensions must match 1738 `predictions`. Will be cast to `bool`. 1739 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1740 be cast to `bool`. 1741 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1742 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1743 be either `1`, or the same as the corresponding `labels` dimension). 1744 metrics_collections: An optional list of collections that the metric 1745 value variable should be added to. 1746 updates_collections: An optional list of collections that the metric update 1747 ops should be added to. 1748 name: An optional variable_scope name. 1749 1750 Returns: 1751 value_tensor: A `Tensor` representing the current value of the metric. 1752 update_op: An operation that accumulates the error from a batch of data. 1753 1754 Raises: 1755 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1756 `weights` is not `None` and its shape doesn't match `predictions`, or if 1757 either `metrics_collections` or `updates_collections` are not a list or 1758 tuple. 1759 RuntimeError: If eager execution is enabled. 1760 """ 1761 if context.executing_eagerly(): 1762 raise RuntimeError('tf.metrics.false_positives is not supported when ' 1763 'eager execution is enabled.') 1764 1765 with variable_scope.variable_scope(name, 'false_positives', 1766 (predictions, labels, weights)): 1767 1768 predictions, labels, weights = _remove_squeezable_dimensions( 1769 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1770 labels=math_ops.cast(labels, dtype=dtypes.bool), 1771 weights=weights) 1772 is_false_positive = math_ops.logical_and( 1773 math_ops.equal(labels, False), math_ops.equal(predictions, True)) 1774 return _count_condition(is_false_positive, weights, metrics_collections, 1775 updates_collections) 1776 1777 1778@tf_export(v1=['metrics.false_positives_at_thresholds']) 1779def false_positives_at_thresholds(labels, 1780 predictions, 1781 thresholds, 1782 weights=None, 1783 metrics_collections=None, 1784 updates_collections=None, 1785 name=None): 1786 """Computes false positives at provided threshold values. 1787 1788 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1789 1790 Args: 1791 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1792 `bool`. 1793 predictions: A floating point `Tensor` of arbitrary shape and whose values 1794 are in the range `[0, 1]`. 1795 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1796 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1797 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1798 be either `1`, or the same as the corresponding `labels` dimension). 1799 metrics_collections: An optional list of collections that `false_positives` 1800 should be added to. 1801 updates_collections: An optional list of collections that `update_op` should 1802 be added to. 1803 name: An optional variable_scope name. 1804 1805 Returns: 1806 false_positives: A float `Tensor` of shape `[len(thresholds)]`. 1807 update_op: An operation that updates the `false_positives` variable and 1808 returns its current value. 1809 1810 Raises: 1811 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1812 `weights` is not `None` and its shape doesn't match `predictions`, or if 1813 either `metrics_collections` or `updates_collections` are not a list or 1814 tuple. 1815 RuntimeError: If eager execution is enabled. 1816 """ 1817 if context.executing_eagerly(): 1818 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 1819 'supported when eager execution is enabled.') 1820 1821 with variable_scope.variable_scope(name, 'false_positives', 1822 (predictions, labels, weights)): 1823 values, update_ops = _confusion_matrix_at_thresholds( 1824 labels, predictions, thresholds, weights=weights, includes=('fp',)) 1825 1826 fp_value = _aggregate_variable(values['fp'], metrics_collections) 1827 1828 if updates_collections: 1829 ops.add_to_collections(updates_collections, update_ops['fp']) 1830 1831 return fp_value, update_ops['fp'] 1832 1833 1834@tf_export(v1=['metrics.true_negatives']) 1835def true_negatives(labels, 1836 predictions, 1837 weights=None, 1838 metrics_collections=None, 1839 updates_collections=None, 1840 name=None): 1841 """Sum the weights of true_negatives. 1842 1843 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1844 1845 Args: 1846 labels: The ground truth values, a `Tensor` whose dimensions must match 1847 `predictions`. Will be cast to `bool`. 1848 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1849 be cast to `bool`. 1850 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1851 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1852 be either `1`, or the same as the corresponding `labels` dimension). 1853 metrics_collections: An optional list of collections that the metric 1854 value variable should be added to. 1855 updates_collections: An optional list of collections that the metric update 1856 ops should be added to. 1857 name: An optional variable_scope name. 1858 1859 Returns: 1860 value_tensor: A `Tensor` representing the current value of the metric. 1861 update_op: An operation that accumulates the error from a batch of data. 1862 1863 Raises: 1864 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1865 `weights` is not `None` and its shape doesn't match `predictions`, or if 1866 either `metrics_collections` or `updates_collections` are not a list or 1867 tuple. 1868 RuntimeError: If eager execution is enabled. 1869 """ 1870 if context.executing_eagerly(): 1871 raise RuntimeError('tf.metrics.true_negatives is not ' 1872 'supported when eager execution is enabled.') 1873 1874 with variable_scope.variable_scope(name, 'true_negatives', 1875 (predictions, labels, weights)): 1876 1877 predictions, labels, weights = _remove_squeezable_dimensions( 1878 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1879 labels=math_ops.cast(labels, dtype=dtypes.bool), 1880 weights=weights) 1881 is_true_negative = math_ops.logical_and( 1882 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 1883 return _count_condition(is_true_negative, weights, metrics_collections, 1884 updates_collections) 1885 1886 1887@tf_export(v1=['metrics.true_negatives_at_thresholds']) 1888def true_negatives_at_thresholds(labels, 1889 predictions, 1890 thresholds, 1891 weights=None, 1892 metrics_collections=None, 1893 updates_collections=None, 1894 name=None): 1895 """Computes true negatives at provided threshold values. 1896 1897 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1898 1899 Args: 1900 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1901 `bool`. 1902 predictions: A floating point `Tensor` of arbitrary shape and whose values 1903 are in the range `[0, 1]`. 1904 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1905 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1906 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1907 be either `1`, or the same as the corresponding `labels` dimension). 1908 metrics_collections: An optional list of collections that `true_negatives` 1909 should be added to. 1910 updates_collections: An optional list of collections that `update_op` should 1911 be added to. 1912 name: An optional variable_scope name. 1913 1914 Returns: 1915 true_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1916 update_op: An operation that updates the `true_negatives` variable and 1917 returns its current value. 1918 1919 Raises: 1920 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1921 `weights` is not `None` and its shape doesn't match `predictions`, or if 1922 either `metrics_collections` or `updates_collections` are not a list or 1923 tuple. 1924 RuntimeError: If eager execution is enabled. 1925 """ 1926 if context.executing_eagerly(): 1927 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 1928 'supported when eager execution is enabled.') 1929 1930 with variable_scope.variable_scope(name, 'true_negatives', 1931 (predictions, labels, weights)): 1932 values, update_ops = _confusion_matrix_at_thresholds( 1933 labels, predictions, thresholds, weights=weights, includes=('tn',)) 1934 1935 tn_value = _aggregate_variable(values['tn'], metrics_collections) 1936 1937 if updates_collections: 1938 ops.add_to_collections(updates_collections, update_ops['tn']) 1939 1940 return tn_value, update_ops['tn'] 1941 1942 1943@tf_export(v1=['metrics.true_positives']) 1944def true_positives(labels, 1945 predictions, 1946 weights=None, 1947 metrics_collections=None, 1948 updates_collections=None, 1949 name=None): 1950 """Sum the weights of true_positives. 1951 1952 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1953 1954 Args: 1955 labels: The ground truth values, a `Tensor` whose dimensions must match 1956 `predictions`. Will be cast to `bool`. 1957 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1958 be cast to `bool`. 1959 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1960 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1961 be either `1`, or the same as the corresponding `labels` dimension). 1962 metrics_collections: An optional list of collections that the metric 1963 value variable should be added to. 1964 updates_collections: An optional list of collections that the metric update 1965 ops should be added to. 1966 name: An optional variable_scope name. 1967 1968 Returns: 1969 value_tensor: A `Tensor` representing the current value of the metric. 1970 update_op: An operation that accumulates the error from a batch of data. 1971 1972 Raises: 1973 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1974 `weights` is not `None` and its shape doesn't match `predictions`, or if 1975 either `metrics_collections` or `updates_collections` are not a list or 1976 tuple. 1977 RuntimeError: If eager execution is enabled. 1978 """ 1979 if context.executing_eagerly(): 1980 raise RuntimeError('tf.metrics.true_positives is not ' 1981 'supported when eager execution is enabled.') 1982 1983 with variable_scope.variable_scope(name, 'true_positives', 1984 (predictions, labels, weights)): 1985 1986 predictions, labels, weights = _remove_squeezable_dimensions( 1987 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1988 labels=math_ops.cast(labels, dtype=dtypes.bool), 1989 weights=weights) 1990 is_true_positive = math_ops.logical_and( 1991 math_ops.equal(labels, True), math_ops.equal(predictions, True)) 1992 return _count_condition(is_true_positive, weights, metrics_collections, 1993 updates_collections) 1994 1995 1996@tf_export(v1=['metrics.true_positives_at_thresholds']) 1997def true_positives_at_thresholds(labels, 1998 predictions, 1999 thresholds, 2000 weights=None, 2001 metrics_collections=None, 2002 updates_collections=None, 2003 name=None): 2004 """Computes true positives at provided threshold values. 2005 2006 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2007 2008 Args: 2009 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 2010 `bool`. 2011 predictions: A floating point `Tensor` of arbitrary shape and whose values 2012 are in the range `[0, 1]`. 2013 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2014 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2015 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2016 be either `1`, or the same as the corresponding `labels` dimension). 2017 metrics_collections: An optional list of collections that `true_positives` 2018 should be added to. 2019 updates_collections: An optional list of collections that `update_op` should 2020 be added to. 2021 name: An optional variable_scope name. 2022 2023 Returns: 2024 true_positives: A float `Tensor` of shape `[len(thresholds)]`. 2025 update_op: An operation that updates the `true_positives` variable and 2026 returns its current value. 2027 2028 Raises: 2029 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2030 `weights` is not `None` and its shape doesn't match `predictions`, or if 2031 either `metrics_collections` or `updates_collections` are not a list or 2032 tuple. 2033 RuntimeError: If eager execution is enabled. 2034 """ 2035 if context.executing_eagerly(): 2036 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 2037 'supported when eager execution is enabled.') 2038 2039 with variable_scope.variable_scope(name, 'true_positives', 2040 (predictions, labels, weights)): 2041 values, update_ops = _confusion_matrix_at_thresholds( 2042 labels, predictions, thresholds, weights=weights, includes=('tp',)) 2043 2044 tp_value = _aggregate_variable(values['tp'], metrics_collections) 2045 2046 if updates_collections: 2047 ops.add_to_collections(updates_collections, update_ops['tp']) 2048 2049 return tp_value, update_ops['tp'] 2050 2051 2052@tf_export(v1=['metrics.precision']) 2053def precision(labels, 2054 predictions, 2055 weights=None, 2056 metrics_collections=None, 2057 updates_collections=None, 2058 name=None): 2059 """Computes the precision of the predictions with respect to the labels. 2060 2061 The `precision` function creates two local variables, 2062 `true_positives` and `false_positives`, that are used to compute the 2063 precision. This value is ultimately returned as `precision`, an idempotent 2064 operation that simply divides `true_positives` by the sum of `true_positives` 2065 and `false_positives`. 2066 2067 For estimation of the metric over a stream of data, the function creates an 2068 `update_op` operation that updates these variables and returns the 2069 `precision`. `update_op` weights each prediction by the corresponding value in 2070 `weights`. 2071 2072 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2073 2074 Args: 2075 labels: The ground truth values, a `Tensor` whose dimensions must match 2076 `predictions`. Will be cast to `bool`. 2077 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2078 be cast to `bool`. 2079 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2080 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2081 be either `1`, or the same as the corresponding `labels` dimension). 2082 metrics_collections: An optional list of collections that `precision` should 2083 be added to. 2084 updates_collections: An optional list of collections that `update_op` should 2085 be added to. 2086 name: An optional variable_scope name. 2087 2088 Returns: 2089 precision: Scalar float `Tensor` with the value of `true_positives` 2090 divided by the sum of `true_positives` and `false_positives`. 2091 update_op: `Operation` that increments `true_positives` and 2092 `false_positives` variables appropriately and whose value matches 2093 `precision`. 2094 2095 Raises: 2096 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2097 `weights` is not `None` and its shape doesn't match `predictions`, or if 2098 either `metrics_collections` or `updates_collections` are not a list or 2099 tuple. 2100 RuntimeError: If eager execution is enabled. 2101 """ 2102 if context.executing_eagerly(): 2103 raise RuntimeError('tf.metrics.precision is not ' 2104 'supported when eager execution is enabled.') 2105 2106 with variable_scope.variable_scope(name, 'precision', 2107 (predictions, labels, weights)): 2108 2109 predictions, labels, weights = _remove_squeezable_dimensions( 2110 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2111 labels=math_ops.cast(labels, dtype=dtypes.bool), 2112 weights=weights) 2113 2114 true_p, true_positives_update_op = true_positives( 2115 labels, 2116 predictions, 2117 weights, 2118 metrics_collections=None, 2119 updates_collections=None, 2120 name=None) 2121 false_p, false_positives_update_op = false_positives( 2122 labels, 2123 predictions, 2124 weights, 2125 metrics_collections=None, 2126 updates_collections=None, 2127 name=None) 2128 2129 def compute_precision(tp, fp, name): 2130 return array_ops.where( 2131 math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name) 2132 2133 def once_across_replicas(_, true_p, false_p): 2134 return compute_precision(true_p, false_p, 'value') 2135 2136 p = _aggregate_across_replicas(metrics_collections, once_across_replicas, 2137 true_p, false_p) 2138 2139 update_op = compute_precision(true_positives_update_op, 2140 false_positives_update_op, 'update_op') 2141 if updates_collections: 2142 ops.add_to_collections(updates_collections, update_op) 2143 2144 return p, update_op 2145 2146 2147@tf_export(v1=['metrics.precision_at_thresholds']) 2148def precision_at_thresholds(labels, 2149 predictions, 2150 thresholds, 2151 weights=None, 2152 metrics_collections=None, 2153 updates_collections=None, 2154 name=None): 2155 """Computes precision values for different `thresholds` on `predictions`. 2156 2157 The `precision_at_thresholds` function creates four local variables, 2158 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2159 for various values of thresholds. `precision[i]` is defined as the total 2160 weight of values in `predictions` above `thresholds[i]` whose corresponding 2161 entry in `labels` is `True`, divided by the total weight of values in 2162 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 2163 false_positives[i])`). 2164 2165 For estimation of the metric over a stream of data, the function creates an 2166 `update_op` operation that updates these variables and returns the 2167 `precision`. 2168 2169 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2170 2171 Args: 2172 labels: The ground truth values, a `Tensor` whose dimensions must match 2173 `predictions`. Will be cast to `bool`. 2174 predictions: A floating point `Tensor` of arbitrary shape and whose values 2175 are in the range `[0, 1]`. 2176 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2177 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2178 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2179 be either `1`, or the same as the corresponding `labels` dimension). 2180 metrics_collections: An optional list of collections that `auc` should be 2181 added to. 2182 updates_collections: An optional list of collections that `update_op` should 2183 be added to. 2184 name: An optional variable_scope name. 2185 2186 Returns: 2187 precision: A float `Tensor` of shape `[len(thresholds)]`. 2188 update_op: An operation that increments the `true_positives`, 2189 `true_negatives`, `false_positives` and `false_negatives` variables that 2190 are used in the computation of `precision`. 2191 2192 Raises: 2193 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2194 `weights` is not `None` and its shape doesn't match `predictions`, or if 2195 either `metrics_collections` or `updates_collections` are not a list or 2196 tuple. 2197 RuntimeError: If eager execution is enabled. 2198 """ 2199 if context.executing_eagerly(): 2200 raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 2201 'supported when eager execution is enabled.') 2202 2203 with variable_scope.variable_scope(name, 'precision_at_thresholds', 2204 (predictions, labels, weights)): 2205 values, update_ops = _confusion_matrix_at_thresholds( 2206 labels, predictions, thresholds, weights, includes=('tp', 'fp')) 2207 2208 # Avoid division by zero. 2209 epsilon = 1e-7 2210 2211 def compute_precision(tp, fp, name): 2212 return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name) 2213 2214 def precision_across_replicas(_, values): 2215 return compute_precision(values['tp'], values['fp'], 'value') 2216 2217 prec = _aggregate_across_replicas( 2218 metrics_collections, precision_across_replicas, values) 2219 2220 update_op = compute_precision(update_ops['tp'], update_ops['fp'], 2221 'update_op') 2222 if updates_collections: 2223 ops.add_to_collections(updates_collections, update_op) 2224 2225 return prec, update_op 2226 2227 2228@tf_export(v1=['metrics.recall']) 2229def recall(labels, 2230 predictions, 2231 weights=None, 2232 metrics_collections=None, 2233 updates_collections=None, 2234 name=None): 2235 """Computes the recall of the predictions with respect to the labels. 2236 2237 The `recall` function creates two local variables, `true_positives` 2238 and `false_negatives`, that are used to compute the recall. This value is 2239 ultimately returned as `recall`, an idempotent operation that simply divides 2240 `true_positives` by the sum of `true_positives` and `false_negatives`. 2241 2242 For estimation of the metric over a stream of data, the function creates an 2243 `update_op` that updates these variables and returns the `recall`. `update_op` 2244 weights each prediction by the corresponding value in `weights`. 2245 2246 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2247 2248 Args: 2249 labels: The ground truth values, a `Tensor` whose dimensions must match 2250 `predictions`. Will be cast to `bool`. 2251 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2252 be cast to `bool`. 2253 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2254 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2255 be either `1`, or the same as the corresponding `labels` dimension). 2256 metrics_collections: An optional list of collections that `recall` should 2257 be added to. 2258 updates_collections: An optional list of collections that `update_op` should 2259 be added to. 2260 name: An optional variable_scope name. 2261 2262 Returns: 2263 recall: Scalar float `Tensor` with the value of `true_positives` divided 2264 by the sum of `true_positives` and `false_negatives`. 2265 update_op: `Operation` that increments `true_positives` and 2266 `false_negatives` variables appropriately and whose value matches 2267 `recall`. 2268 2269 Raises: 2270 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2271 `weights` is not `None` and its shape doesn't match `predictions`, or if 2272 either `metrics_collections` or `updates_collections` are not a list or 2273 tuple. 2274 RuntimeError: If eager execution is enabled. 2275 """ 2276 if context.executing_eagerly(): 2277 raise RuntimeError('tf.metrics.recall is not supported is not ' 2278 'supported when eager execution is enabled.') 2279 2280 with variable_scope.variable_scope(name, 'recall', 2281 (predictions, labels, weights)): 2282 predictions, labels, weights = _remove_squeezable_dimensions( 2283 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2284 labels=math_ops.cast(labels, dtype=dtypes.bool), 2285 weights=weights) 2286 2287 true_p, true_positives_update_op = true_positives( 2288 labels, 2289 predictions, 2290 weights, 2291 metrics_collections=None, 2292 updates_collections=None, 2293 name=None) 2294 false_n, false_negatives_update_op = false_negatives( 2295 labels, 2296 predictions, 2297 weights, 2298 metrics_collections=None, 2299 updates_collections=None, 2300 name=None) 2301 2302 def compute_recall(true_p, false_n, name): 2303 return array_ops.where( 2304 math_ops.greater(true_p + false_n, 0), 2305 math_ops.divide(true_p, true_p + false_n), 0, name) 2306 2307 def once_across_replicas(_, true_p, false_n): 2308 return compute_recall(true_p, false_n, 'value') 2309 2310 rec = _aggregate_across_replicas( 2311 metrics_collections, once_across_replicas, true_p, false_n) 2312 2313 update_op = compute_recall(true_positives_update_op, 2314 false_negatives_update_op, 'update_op') 2315 if updates_collections: 2316 ops.add_to_collections(updates_collections, update_op) 2317 2318 return rec, update_op 2319 2320 2321def _at_k_name(name, k=None, class_id=None): 2322 if k is not None: 2323 name = '%s_at_%d' % (name, k) 2324 else: 2325 name = '%s_at_k' % (name) 2326 if class_id is not None: 2327 name = '%s_class%d' % (name, class_id) 2328 return name 2329 2330 2331def _select_class_id(ids, selected_id): 2332 """Filter all but `selected_id` out of `ids`. 2333 2334 Args: 2335 ids: `int64` `Tensor` or `SparseTensor` of IDs. 2336 selected_id: Int id to select. 2337 2338 Returns: 2339 `SparseTensor` of same dimensions as `ids`. This contains only the entries 2340 equal to `selected_id`. 2341 """ 2342 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) 2343 if isinstance(ids, sparse_tensor.SparseTensor): 2344 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, 2345 selected_id)) 2346 2347 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 2348 # tf.equal and tf.reduce_any? 2349 2350 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 2351 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 2352 ids_last_dim = array_ops.size(ids_shape) - 1 2353 filled_selected_id_shape = math_ops.reduced_shape(ids_shape, 2354 array_ops.reshape( 2355 ids_last_dim, [1])) 2356 2357 # Intersect `ids` with the selected ID. 2358 filled_selected_id = array_ops.fill(filled_selected_id_shape, 2359 math_ops.cast(selected_id, dtypes.int64)) 2360 result = sets.set_intersection(filled_selected_id, ids) 2361 return sparse_tensor.SparseTensor( 2362 indices=result.indices, values=result.values, dense_shape=ids_shape) 2363 2364 2365def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 2366 """If class ID is specified, filter all other classes. 2367 2368 Args: 2369 labels: `int64` `Tensor` or `SparseTensor` with shape 2370 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2371 target classes for the associated prediction. Commonly, N=1 and `labels` 2372 has shape [batch_size, num_labels]. [D1, ... DN] must match 2373 `predictions_idx`. 2374 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 2375 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 2376 [batch size, k]. 2377 selected_id: Int id to select. 2378 2379 Returns: 2380 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 2381 """ 2382 if selected_id is None: 2383 return labels, predictions_idx 2384 return (_select_class_id(labels, selected_id), 2385 _select_class_id(predictions_idx, selected_id)) 2386 2387 2388def _sparse_true_positive_at_k(labels, 2389 predictions_idx, 2390 class_id=None, 2391 weights=None, 2392 name=None): 2393 """Calculates true positives for recall@k and precision@k. 2394 2395 If `class_id` is specified, calculate binary true positives for `class_id` 2396 only. 2397 If `class_id` is not specified, calculate metrics for `k` predicted vs 2398 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2399 2400 Args: 2401 labels: `int64` `Tensor` or `SparseTensor` with shape 2402 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2403 target classes for the associated prediction. Commonly, N=1 and `labels` 2404 has shape [batch_size, num_labels]. [D1, ... DN] must match 2405 `predictions_idx`. 2406 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2407 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2408 match `labels`. 2409 class_id: Class for which we want binary metrics. 2410 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2411 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2412 dimensions must be either `1`, or the same as the corresponding `labels` 2413 dimension). 2414 name: Name of operation. 2415 2416 Returns: 2417 A [D1, ... DN] `Tensor` of true positive counts. 2418 """ 2419 with ops.name_scope(name, 'true_positives', 2420 (predictions_idx, labels, weights)): 2421 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2422 class_id) 2423 tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) 2424 tp = math_ops.cast(tp, dtypes.float64) 2425 if weights is not None: 2426 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2427 weights, tp),)): 2428 weights = math_ops.cast(weights, dtypes.float64) 2429 tp = math_ops.multiply(tp, weights) 2430 return tp 2431 2432 2433def _streaming_sparse_true_positive_at_k(labels, 2434 predictions_idx, 2435 k=None, 2436 class_id=None, 2437 weights=None, 2438 name=None): 2439 """Calculates weighted per step true positives for recall@k and precision@k. 2440 2441 If `class_id` is specified, calculate binary true positives for `class_id` 2442 only. 2443 If `class_id` is not specified, calculate metrics for `k` predicted vs 2444 `n` label classes, where `n` is the 2nd dimension of `labels`. 2445 2446 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2447 2448 Args: 2449 labels: `int64` `Tensor` or `SparseTensor` with shape 2450 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2451 target classes for the associated prediction. Commonly, N=1 and `labels` 2452 has shape [batch_size, num_labels]. [D1, ... DN] must match 2453 `predictions_idx`. 2454 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2455 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2456 match `labels`. 2457 k: Integer, k for @k metric. This is only used for default op name. 2458 class_id: Class for which we want binary metrics. 2459 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2460 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2461 dimensions must be either `1`, or the same as the corresponding `labels` 2462 dimension). 2463 name: Name of new variable, and namespace for other dependent ops. 2464 2465 Returns: 2466 A tuple of `Variable` and update `Operation`. 2467 2468 Raises: 2469 ValueError: If `weights` is not `None` and has an incompatible shape. 2470 """ 2471 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), 2472 (predictions_idx, labels, weights)) as scope: 2473 tp = _sparse_true_positive_at_k( 2474 predictions_idx=predictions_idx, 2475 labels=labels, 2476 class_id=class_id, 2477 weights=weights) 2478 batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64) 2479 2480 var = metric_variable([], dtypes.float64, name=scope) 2481 return var, state_ops.assign_add(var, batch_total_tp, name='update') 2482 2483 2484def _sparse_false_negative_at_k(labels, 2485 predictions_idx, 2486 class_id=None, 2487 weights=None): 2488 """Calculates false negatives for recall@k. 2489 2490 If `class_id` is specified, calculate binary true positives for `class_id` 2491 only. 2492 If `class_id` is not specified, calculate metrics for `k` predicted vs 2493 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2494 2495 Args: 2496 labels: `int64` `Tensor` or `SparseTensor` with shape 2497 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2498 target classes for the associated prediction. Commonly, N=1 and `labels` 2499 has shape [batch_size, num_labels]. [D1, ... DN] must match 2500 `predictions_idx`. 2501 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2502 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2503 match `labels`. 2504 class_id: Class for which we want binary metrics. 2505 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2506 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2507 dimensions must be either `1`, or the same as the corresponding `labels` 2508 dimension). 2509 2510 Returns: 2511 A [D1, ... DN] `Tensor` of false negative counts. 2512 """ 2513 with ops.name_scope(None, 'false_negatives', 2514 (predictions_idx, labels, weights)): 2515 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2516 class_id) 2517 fn = sets.set_size( 2518 sets.set_difference(predictions_idx, labels, aminusb=False)) 2519 fn = math_ops.cast(fn, dtypes.float64) 2520 if weights is not None: 2521 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2522 weights, fn),)): 2523 weights = math_ops.cast(weights, dtypes.float64) 2524 fn = math_ops.multiply(fn, weights) 2525 return fn 2526 2527 2528def _streaming_sparse_false_negative_at_k(labels, 2529 predictions_idx, 2530 k, 2531 class_id=None, 2532 weights=None, 2533 name=None): 2534 """Calculates weighted per step false negatives for recall@k. 2535 2536 If `class_id` is specified, calculate binary true positives for `class_id` 2537 only. 2538 If `class_id` is not specified, calculate metrics for `k` predicted vs 2539 `n` label classes, where `n` is the 2nd dimension of `labels`. 2540 2541 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2542 2543 Args: 2544 labels: `int64` `Tensor` or `SparseTensor` with shape 2545 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2546 target classes for the associated prediction. Commonly, N=1 and `labels` 2547 has shape [batch_size, num_labels]. [D1, ... DN] must match 2548 `predictions_idx`. 2549 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2550 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2551 match `labels`. 2552 k: Integer, k for @k metric. This is only used for default op name. 2553 class_id: Class for which we want binary metrics. 2554 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2555 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2556 dimensions must be either `1`, or the same as the corresponding `labels` 2557 dimension). 2558 name: Name of new variable, and namespace for other dependent ops. 2559 2560 Returns: 2561 A tuple of `Variable` and update `Operation`. 2562 2563 Raises: 2564 ValueError: If `weights` is not `None` and has an incompatible shape. 2565 """ 2566 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), 2567 (predictions_idx, labels, weights)) as scope: 2568 fn = _sparse_false_negative_at_k( 2569 predictions_idx=predictions_idx, 2570 labels=labels, 2571 class_id=class_id, 2572 weights=weights) 2573 batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64) 2574 2575 var = metric_variable([], dtypes.float64, name=scope) 2576 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2577 2578 2579@tf_export(v1=['metrics.recall_at_k']) 2580def recall_at_k(labels, 2581 predictions, 2582 k, 2583 class_id=None, 2584 weights=None, 2585 metrics_collections=None, 2586 updates_collections=None, 2587 name=None): 2588 """Computes recall@k of the predictions with respect to sparse labels. 2589 2590 If `class_id` is specified, we calculate recall by considering only the 2591 entries in the batch for which `class_id` is in the label, and computing 2592 the fraction of them for which `class_id` is in the top-k `predictions`. 2593 If `class_id` is not specified, we'll calculate recall as how often on 2594 average a class among the labels of a batch entry is in the top-k 2595 `predictions`. 2596 2597 `sparse_recall_at_k` creates two local variables, 2598 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2599 the recall_at_k frequency. This frequency is ultimately returned as 2600 `recall_at_<k>`: an idempotent operation that simply divides 2601 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2602 `false_negative_at_<k>`). 2603 2604 For estimation of the metric over a stream of data, the function creates an 2605 `update_op` operation that updates these variables and returns the 2606 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2607 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2608 `labels` calculate the true positives and false negatives weighted by 2609 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2610 `false_negative_at_<k>` using these values. 2611 2612 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2613 2614 Args: 2615 labels: `int64` `Tensor` or `SparseTensor` with shape 2616 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2617 num_labels=1. N >= 1 and num_labels is the number of target classes for 2618 the associated prediction. Commonly, N=1 and `labels` has shape 2619 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2620 should be in range [0, num_classes), where num_classes is the last 2621 dimension of `predictions`. Values outside this range always count 2622 towards `false_negative_at_<k>`. 2623 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2624 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2625 The final dimension contains the logit values for each class. [D1, ... DN] 2626 must match `labels`. 2627 k: Integer, k for @k metric. 2628 class_id: Integer class ID for which we want binary metrics. This should be 2629 in range [0, num_classes), where num_classes is the last dimension of 2630 `predictions`. If class_id is outside this range, the method returns NAN. 2631 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2632 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2633 dimensions must be either `1`, or the same as the corresponding `labels` 2634 dimension). 2635 metrics_collections: An optional list of collections that values should 2636 be added to. 2637 updates_collections: An optional list of collections that updates should 2638 be added to. 2639 name: Name of new update operation, and namespace for other dependent ops. 2640 2641 Returns: 2642 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2643 by the sum of `true_positives` and `false_negatives`. 2644 update_op: `Operation` that increments `true_positives` and 2645 `false_negatives` variables appropriately, and whose value matches 2646 `recall`. 2647 2648 Raises: 2649 ValueError: If `weights` is not `None` and its shape doesn't match 2650 `predictions`, or if either `metrics_collections` or `updates_collections` 2651 are not a list or tuple. 2652 RuntimeError: If eager execution is enabled. 2653 """ 2654 if context.executing_eagerly(): 2655 raise RuntimeError('tf.metrics.recall_at_k is not ' 2656 'supported when eager execution is enabled.') 2657 2658 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2659 (predictions, labels, weights)) as scope: 2660 _, top_k_idx = nn.top_k(predictions, k) 2661 return recall_at_top_k( 2662 labels=labels, 2663 predictions_idx=top_k_idx, 2664 k=k, 2665 class_id=class_id, 2666 weights=weights, 2667 metrics_collections=metrics_collections, 2668 updates_collections=updates_collections, 2669 name=scope) 2670 2671 2672@tf_export(v1=['metrics.recall_at_top_k']) 2673def recall_at_top_k(labels, 2674 predictions_idx, 2675 k=None, 2676 class_id=None, 2677 weights=None, 2678 metrics_collections=None, 2679 updates_collections=None, 2680 name=None): 2681 """Computes recall@k of top-k predictions with respect to sparse labels. 2682 2683 Differs from `recall_at_k` in that predictions must be in the form of top `k` 2684 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` 2685 for more details. 2686 2687 Args: 2688 labels: `int64` `Tensor` or `SparseTensor` with shape 2689 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2690 num_labels=1. N >= 1 and num_labels is the number of target classes for 2691 the associated prediction. Commonly, N=1 and `labels` has shape 2692 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2693 should be in range [0, num_classes), where num_classes is the last 2694 dimension of `predictions`. Values outside this range always count 2695 towards `false_negative_at_<k>`. 2696 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2697 Commonly, N=1 and predictions has shape [batch size, k]. The final 2698 dimension contains the top `k` predicted class indices. [D1, ... DN] must 2699 match `labels`. 2700 k: Integer, k for @k metric. Only used for the default op name. 2701 class_id: Integer class ID for which we want binary metrics. This should be 2702 in range [0, num_classes), where num_classes is the last dimension of 2703 `predictions`. If class_id is outside this range, the method returns NAN. 2704 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2705 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2706 dimensions must be either `1`, or the same as the corresponding `labels` 2707 dimension). 2708 metrics_collections: An optional list of collections that values should 2709 be added to. 2710 updates_collections: An optional list of collections that updates should 2711 be added to. 2712 name: Name of new update operation, and namespace for other dependent ops. 2713 2714 Returns: 2715 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2716 by the sum of `true_positives` and `false_negatives`. 2717 update_op: `Operation` that increments `true_positives` and 2718 `false_negatives` variables appropriately, and whose value matches 2719 `recall`. 2720 2721 Raises: 2722 ValueError: If `weights` is not `None` and its shape doesn't match 2723 `predictions`, or if either `metrics_collections` or `updates_collections` 2724 are not a list or tuple. 2725 """ 2726 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2727 (predictions_idx, labels, weights)) as scope: 2728 labels = _maybe_expand_labels(labels, predictions_idx) 2729 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 2730 tp, tp_update = _streaming_sparse_true_positive_at_k( 2731 predictions_idx=top_k_idx, 2732 labels=labels, 2733 k=k, 2734 class_id=class_id, 2735 weights=weights) 2736 fn, fn_update = _streaming_sparse_false_negative_at_k( 2737 predictions_idx=top_k_idx, 2738 labels=labels, 2739 k=k, 2740 class_id=class_id, 2741 weights=weights) 2742 2743 def compute_recall(_, tp, fn): 2744 return math_ops.divide(tp, math_ops.add(tp, fn), name=scope) 2745 2746 metric = _aggregate_across_replicas( 2747 metrics_collections, compute_recall, tp, fn) 2748 2749 update = math_ops.divide( 2750 tp_update, math_ops.add(tp_update, fn_update), name='update') 2751 if updates_collections: 2752 ops.add_to_collections(updates_collections, update) 2753 return metric, update 2754 2755 2756@tf_export(v1=['metrics.recall_at_thresholds']) 2757def recall_at_thresholds(labels, 2758 predictions, 2759 thresholds, 2760 weights=None, 2761 metrics_collections=None, 2762 updates_collections=None, 2763 name=None): 2764 """Computes various recall values for different `thresholds` on `predictions`. 2765 2766 The `recall_at_thresholds` function creates four local variables, 2767 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2768 for various values of thresholds. `recall[i]` is defined as the total weight 2769 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2770 `labels` is `True`, divided by the total weight of `True` values in `labels` 2771 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 2772 2773 For estimation of the metric over a stream of data, the function creates an 2774 `update_op` operation that updates these variables and returns the `recall`. 2775 2776 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2777 2778 Args: 2779 labels: The ground truth values, a `Tensor` whose dimensions must match 2780 `predictions`. Will be cast to `bool`. 2781 predictions: A floating point `Tensor` of arbitrary shape and whose values 2782 are in the range `[0, 1]`. 2783 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2784 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2785 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2786 be either `1`, or the same as the corresponding `labels` dimension). 2787 metrics_collections: An optional list of collections that `recall` should be 2788 added to. 2789 updates_collections: An optional list of collections that `update_op` should 2790 be added to. 2791 name: An optional variable_scope name. 2792 2793 Returns: 2794 recall: A float `Tensor` of shape `[len(thresholds)]`. 2795 update_op: An operation that increments the `true_positives`, 2796 `true_negatives`, `false_positives` and `false_negatives` variables that 2797 are used in the computation of `recall`. 2798 2799 Raises: 2800 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2801 `weights` is not `None` and its shape doesn't match `predictions`, or if 2802 either `metrics_collections` or `updates_collections` are not a list or 2803 tuple. 2804 RuntimeError: If eager execution is enabled. 2805 """ 2806 if context.executing_eagerly(): 2807 raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 2808 'supported when eager execution is enabled.') 2809 2810 with variable_scope.variable_scope(name, 'recall_at_thresholds', 2811 (predictions, labels, weights)): 2812 values, update_ops = _confusion_matrix_at_thresholds( 2813 labels, predictions, thresholds, weights, includes=('tp', 'fn')) 2814 2815 # Avoid division by zero. 2816 epsilon = 1e-7 2817 2818 def compute_recall(tp, fn, name): 2819 return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name) 2820 2821 def recall_across_replicas(_, values): 2822 return compute_recall(values['tp'], values['fn'], 'value') 2823 2824 rec = _aggregate_across_replicas( 2825 metrics_collections, recall_across_replicas, values) 2826 2827 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') 2828 if updates_collections: 2829 ops.add_to_collections(updates_collections, update_op) 2830 2831 return rec, update_op 2832 2833 2834@tf_export(v1=['metrics.root_mean_squared_error']) 2835def root_mean_squared_error(labels, 2836 predictions, 2837 weights=None, 2838 metrics_collections=None, 2839 updates_collections=None, 2840 name=None): 2841 """Computes the root mean squared error between the labels and predictions. 2842 2843 The `root_mean_squared_error` function creates two local variables, 2844 `total` and `count` that are used to compute the root mean squared error. 2845 This average is weighted by `weights`, and it is ultimately returned as 2846 `root_mean_squared_error`: an idempotent operation that takes the square root 2847 of the division of `total` by `count`. 2848 2849 For estimation of the metric over a stream of data, the function creates an 2850 `update_op` operation that updates these variables and returns the 2851 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2852 the element-wise square of the difference between `predictions` and `labels`. 2853 Then `update_op` increments `total` with the reduced sum of the product of 2854 `weights` and `squared_error`, and it increments `count` with the reduced sum 2855 of `weights`. 2856 2857 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2858 2859 Args: 2860 labels: A `Tensor` of the same shape as `predictions`. 2861 predictions: A `Tensor` of arbitrary shape. 2862 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2863 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2864 be either `1`, or the same as the corresponding `labels` dimension). 2865 metrics_collections: An optional list of collections that 2866 `root_mean_squared_error` should be added to. 2867 updates_collections: An optional list of collections that `update_op` should 2868 be added to. 2869 name: An optional variable_scope name. 2870 2871 Returns: 2872 root_mean_squared_error: A `Tensor` representing the current mean, the value 2873 of `total` divided by `count`. 2874 update_op: An operation that increments the `total` and `count` variables 2875 appropriately and whose value matches `root_mean_squared_error`. 2876 2877 Raises: 2878 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2879 `weights` is not `None` and its shape doesn't match `predictions`, or if 2880 either `metrics_collections` or `updates_collections` are not a list or 2881 tuple. 2882 RuntimeError: If eager execution is enabled. 2883 """ 2884 if context.executing_eagerly(): 2885 raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 2886 'supported when eager execution is enabled.') 2887 2888 predictions, labels, weights = _remove_squeezable_dimensions( 2889 predictions=predictions, labels=labels, weights=weights) 2890 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, 2891 None, name or 2892 'root_mean_squared_error') 2893 2894 once_across_replicas = lambda _, mse: math_ops.sqrt(mse) 2895 rmse = _aggregate_across_replicas( 2896 metrics_collections, once_across_replicas, mse) 2897 2898 update_rmse_op = math_ops.sqrt(update_mse_op) 2899 if updates_collections: 2900 ops.add_to_collections(updates_collections, update_rmse_op) 2901 2902 return rmse, update_rmse_op 2903 2904 2905@tf_export(v1=['metrics.sensitivity_at_specificity']) 2906def sensitivity_at_specificity(labels, 2907 predictions, 2908 specificity, 2909 weights=None, 2910 num_thresholds=200, 2911 metrics_collections=None, 2912 updates_collections=None, 2913 name=None): 2914 """Computes the specificity at a given sensitivity. 2915 2916 The `sensitivity_at_specificity` function creates four local 2917 variables, `true_positives`, `true_negatives`, `false_positives` and 2918 `false_negatives` that are used to compute the sensitivity at the given 2919 specificity value. The threshold for the given specificity value is computed 2920 and used to evaluate the corresponding sensitivity. 2921 2922 For estimation of the metric over a stream of data, the function creates an 2923 `update_op` operation that updates these variables and returns the 2924 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 2925 `false_positives` and `false_negatives` counts with the weight of each case 2926 found in the `predictions` and `labels`. 2927 2928 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2929 2930 For additional information about specificity and sensitivity, see the 2931 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 2932 2933 Args: 2934 labels: The ground truth values, a `Tensor` whose dimensions must match 2935 `predictions`. Will be cast to `bool`. 2936 predictions: A floating point `Tensor` of arbitrary shape and whose values 2937 are in the range `[0, 1]`. 2938 specificity: A scalar value in range `[0, 1]`. 2939 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2940 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2941 be either `1`, or the same as the corresponding `labels` dimension). 2942 num_thresholds: The number of thresholds to use for matching the given 2943 specificity. 2944 metrics_collections: An optional list of collections that `sensitivity` 2945 should be added to. 2946 updates_collections: An optional list of collections that `update_op` should 2947 be added to. 2948 name: An optional variable_scope name. 2949 2950 Returns: 2951 sensitivity: A scalar `Tensor` representing the sensitivity at the given 2952 `specificity` value. 2953 update_op: An operation that increments the `true_positives`, 2954 `true_negatives`, `false_positives` and `false_negatives` variables 2955 appropriately and whose value matches `sensitivity`. 2956 2957 Raises: 2958 ValueError: If `predictions` and `labels` have mismatched shapes, if 2959 `weights` is not `None` and its shape doesn't match `predictions`, or if 2960 `specificity` is not between 0 and 1, or if either `metrics_collections` 2961 or `updates_collections` are not a list or tuple. 2962 RuntimeError: If eager execution is enabled. 2963 """ 2964 if context.executing_eagerly(): 2965 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 2966 'supported when eager execution is enabled.') 2967 2968 if specificity < 0 or specificity > 1: 2969 raise ValueError('`specificity` must be in the range [0, 1]. Currently, ' 2970 f'`specificity` got {specificity}.') 2971 2972 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 2973 (predictions, labels, weights)): 2974 kepsilon = 1e-7 # to account for floating point imprecisions 2975 thresholds = [ 2976 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 2977 ] 2978 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 2979 2980 values, update_ops = _confusion_matrix_at_thresholds( 2981 labels, predictions, thresholds, weights) 2982 2983 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name): 2984 specificities = math_ops.divide(tn, tn + fp + kepsilon) 2985 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 2986 tf_index = math_ops.cast(tf_index, dtypes.int32) 2987 2988 # Now, we have the implicit threshold, so compute the sensitivity: 2989 return math_ops.divide(tp[tf_index], 2990 tp[tf_index] + fn[tf_index] + kepsilon, name) 2991 2992 def sensitivity_across_replicas(_, values): 2993 return compute_sensitivity_at_specificity( 2994 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 2995 2996 sensitivity = _aggregate_across_replicas( 2997 metrics_collections, sensitivity_across_replicas, values) 2998 2999 update_op = compute_sensitivity_at_specificity( 3000 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3001 'update_op') 3002 if updates_collections: 3003 ops.add_to_collections(updates_collections, update_op) 3004 3005 return sensitivity, update_op 3006 3007 3008def _expand_and_tile(tensor, multiple, dim=0, name=None): 3009 """Slice `tensor` shape in 2, then tile along the sliced dimension. 3010 3011 A new dimension is inserted in shape of `tensor` before `dim`, then values are 3012 tiled `multiple` times along the new dimension. 3013 3014 Args: 3015 tensor: Input `Tensor` or `SparseTensor`. 3016 multiple: Integer, number of times to tile. 3017 dim: Integer, dimension along which to tile. 3018 name: Name of operation. 3019 3020 Returns: 3021 `Tensor` result of expanding and tiling `tensor`. 3022 3023 Raises: 3024 ValueError: if `multiple` is less than 1, or `dim` is not in 3025 `[-rank(tensor), rank(tensor)]`. 3026 """ 3027 if multiple < 1: 3028 raise ValueError(f'Invalid argument multiple={multiple} for ' 3029 'expand_and_tile call. `multiple` must be an integer > 0') 3030 with ops.name_scope(name, 'expand_and_tile', 3031 (tensor, multiple, dim)) as scope: 3032 # Sparse. 3033 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 3034 if isinstance(tensor, sparse_tensor.SparseTensor): 3035 if dim < 0: 3036 expand_dims = array_ops.reshape( 3037 array_ops.size(tensor.dense_shape) + dim, [1]) 3038 else: 3039 expand_dims = [dim] 3040 expanded_shape = array_ops.concat( 3041 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1], 3042 array_ops.slice(tensor.dense_shape, expand_dims, [-1])), 3043 0, 3044 name='expanded_shape') 3045 expanded = sparse_ops.sparse_reshape( 3046 tensor, shape=expanded_shape, name='expand') 3047 if multiple == 1: 3048 return expanded 3049 return sparse_ops.sparse_concat( 3050 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 3051 3052 # Dense. 3053 expanded = array_ops.expand_dims( 3054 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 3055 if multiple == 1: 3056 return expanded 3057 ones = array_ops.ones_like(array_ops.shape(tensor)) 3058 tile_multiples = array_ops.concat( 3059 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 3060 return array_ops.tile(expanded, tile_multiples, name=scope) 3061 3062 3063def _num_relevant(labels, k): 3064 """Computes number of relevant values for each row in labels. 3065 3066 For labels with shape [D1, ... DN, num_labels], this is the minimum of 3067 `num_labels` and `k`. 3068 3069 Args: 3070 labels: `int64` `Tensor` or `SparseTensor` with shape 3071 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3072 target classes for the associated prediction. Commonly, N=1 and `labels` 3073 has shape [batch_size, num_labels]. 3074 k: Integer, k for @k metric. 3075 3076 Returns: 3077 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 3078 relevant values for that row. 3079 3080 Raises: 3081 ValueError: if inputs have invalid dtypes or values. 3082 """ 3083 if k < 1: 3084 raise ValueError(f'Invalid k={k}') 3085 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 3086 # For SparseTensor, calculate separate count for each row. 3087 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 3088 if isinstance(labels, sparse_tensor.SparseTensor): 3089 return math_ops.minimum(sets.set_size(labels), k, name=scope) 3090 3091 # The relevant values for each (d1, ... dN) is the minimum of k and the 3092 # number of labels along the last dimension that are non-negative. 3093 num_labels = math_ops.reduce_sum( 3094 array_ops.where_v2(math_ops.greater_equal(labels, 0), 3095 array_ops.ones_like(labels), 3096 array_ops.zeros_like(labels)), 3097 axis=-1) 3098 return math_ops.minimum(num_labels, k, name=scope) 3099 3100 3101def _sparse_average_precision_at_top_k(labels, predictions_idx): 3102 """Computes average precision@k of predictions with respect to sparse labels. 3103 3104 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 3105 for each row is: 3106 3107 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 3108 3109 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`, 3110 `labels`, and the result `Tensors`. In the common case, this is [batch_size]. 3111 Each row of the results contains the average precision for that row. 3112 3113 Args: 3114 labels: `int64` `Tensor` or `SparseTensor` with shape 3115 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3116 num_labels=1. N >= 1 and num_labels is the number of target classes for 3117 the associated prediction. Commonly, N=1 and `labels` has shape 3118 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3119 Values should be non-negative. Negative values are ignored. 3120 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3121 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3122 dimension must be set and contains the top `k` predicted class indices. 3123 [D1, ... DN] must match `labels`. Values should be in range 3124 [0, num_classes). 3125 3126 Returns: 3127 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 3128 precision for that row. 3129 3130 Raises: 3131 ValueError: if the last dimension of predictions_idx is not set. 3132 """ 3133 with ops.name_scope(None, 'average_precision', 3134 (predictions_idx, labels)) as scope: 3135 predictions_idx = math_ops.cast( 3136 predictions_idx, dtypes.int64, name='predictions_idx') 3137 if predictions_idx.get_shape().ndims == 0: 3138 raise ValueError('The rank of `predictions_idx` must be at least 1.') 3139 k = predictions_idx.get_shape().as_list()[-1] 3140 if k is None: 3141 raise ValueError('The last dimension of predictions_idx must be set. ' 3142 'Currently, it is None.') 3143 labels = _maybe_expand_labels(labels, predictions_idx) 3144 3145 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 3146 # prediction for each k, so we can calculate separate true positive values 3147 # for each k. 3148 predictions_idx_per_k = array_ops.expand_dims( 3149 predictions_idx, -1, name='predictions_idx_per_k') 3150 3151 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 3152 labels_per_k = _expand_and_tile( 3153 labels, multiple=k, dim=-1, name='labels_per_k') 3154 3155 # The following tensors are all of shape [D1, ... DN, k], containing values 3156 # per row, per k value. 3157 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 3158 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 3159 # the formula above. 3160 # `tp_per_k` (int32) - True positive counts. 3161 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 3162 # the precision denominator. 3163 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 3164 # term from the formula above. 3165 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 3166 # precisions at all k for which relevance indicator is true. 3167 relevant_per_k = _sparse_true_positive_at_k( 3168 labels_per_k, predictions_idx_per_k, name='relevant_per_k') 3169 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 3170 retrieved_per_k = math_ops.cumsum( 3171 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 3172 precision_per_k = math_ops.divide( 3173 math_ops.cast(tp_per_k, dtypes.float64), 3174 math_ops.cast(retrieved_per_k, dtypes.float64), 3175 name='precision_per_k') 3176 relevant_precision_per_k = math_ops.multiply( 3177 precision_per_k, 3178 math_ops.cast(relevant_per_k, dtypes.float64), 3179 name='relevant_precision_per_k') 3180 3181 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 3182 precision_sum = math_ops.reduce_sum( 3183 relevant_precision_per_k, axis=(-1,), name='precision_sum') 3184 3185 # Divide by number of relevant items to get average precision. These are 3186 # the "num_relevant_items" and "AveP" terms from the formula above. 3187 num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64) 3188 return math_ops.divide(precision_sum, num_relevant_items, name=scope) 3189 3190 3191def _streaming_sparse_average_precision_at_top_k(labels, 3192 predictions_idx, 3193 weights=None, 3194 metrics_collections=None, 3195 updates_collections=None, 3196 name=None): 3197 """Computes average precision@k of predictions with respect to sparse labels. 3198 3199 `sparse_average_precision_at_top_k` creates two local variables, 3200 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3201 are used to compute the frequency. This frequency is ultimately returned as 3202 `average_precision_at_<k>`: an idempotent operation that simply divides 3203 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3204 3205 For estimation of the metric over a stream of data, the function creates an 3206 `update_op` operation that updates these variables and returns the 3207 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 3208 the true positives and false positives weighted by `weights`. Then `update_op` 3209 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 3210 values. 3211 3212 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3213 3214 Args: 3215 labels: `int64` `Tensor` or `SparseTensor` with shape 3216 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3217 num_labels=1. N >= 1 and num_labels is the number of target classes for 3218 the associated prediction. Commonly, N=1 and `labels` has shape 3219 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3220 Values should be non-negative. Negative values are ignored. 3221 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3222 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3223 dimension contains the top `k` predicted class indices. [D1, ... DN] must 3224 match `labels`. Values should be in range [0, num_classes). 3225 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3226 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3227 dimensions must be either `1`, or the same as the corresponding `labels` 3228 dimension). 3229 metrics_collections: An optional list of collections that values should 3230 be added to. 3231 updates_collections: An optional list of collections that updates should 3232 be added to. 3233 name: Name of new update operation, and namespace for other dependent ops. 3234 3235 Returns: 3236 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3237 precision values. 3238 update: `Operation` that increments variables appropriately, and whose 3239 value matches `metric`. 3240 """ 3241 with ops.name_scope(name, 'average_precision_at_top_k', 3242 (predictions_idx, labels, weights)) as scope: 3243 # Calculate per-example average precision, and apply weights. 3244 average_precision = _sparse_average_precision_at_top_k( 3245 predictions_idx=predictions_idx, labels=labels) 3246 if weights is not None: 3247 weights = weights_broadcast_ops.broadcast_weights( 3248 math_ops.cast(weights, dtypes.float64), average_precision) 3249 average_precision = math_ops.multiply(average_precision, weights) 3250 3251 # Create accumulation variables and update ops for max average precision and 3252 # total average precision. 3253 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 3254 # `max` is the max possible precision. Since max for any row is 1.0: 3255 # - For the unweighted case, this is just the number of rows. 3256 # - For the weighted case, it's the sum of the weights broadcast across 3257 # `average_precision` rows. 3258 max_var = metric_variable([], dtypes.float64, name=max_scope) 3259 if weights is None: 3260 batch_max = math_ops.cast( 3261 array_ops.size(average_precision, name='batch_max'), dtypes.float64) 3262 else: 3263 batch_max = math_ops.reduce_sum(weights, name='batch_max') 3264 max_update = state_ops.assign_add(max_var, batch_max, name='update') 3265 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 3266 total_var = metric_variable([], dtypes.float64, name=total_scope) 3267 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 3268 total_update = state_ops.assign_add(total_var, batch_total, name='update') 3269 3270 # Divide total by max to get mean, for both vars and the update ops. 3271 def precision_across_replicas(_, total_var, max_var): 3272 return _safe_scalar_div(total_var, max_var, name='mean') 3273 3274 mean_average_precision = _aggregate_across_replicas( 3275 metrics_collections, precision_across_replicas, total_var, max_var) 3276 3277 update = _safe_scalar_div(total_update, max_update, name=scope) 3278 if updates_collections: 3279 ops.add_to_collections(updates_collections, update) 3280 3281 return mean_average_precision, update 3282 3283 3284def _clean_out_of_range_indices(labels, num_classes): 3285 """Replaces large out-of-range labels by small out-of-range labels. 3286 3287 Replaces any value in `labels` that is greater or equal to `num_classes` by 3288 -1. Do this conditionally for efficiency in case there are no such values. 3289 3290 Args: 3291 labels: `int64` `Tensor` or `SparseTensor`. 3292 num_classes: `int64` scalar `Tensor`. 3293 Returns: 3294 An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater 3295 or equal to num_classes replaced by -1. 3296 """ 3297 3298 def _labels_is_sparse(): 3299 """Returns true is `labels` is a sparse tensor.""" 3300 return isinstance(labels, (sparse_tensor.SparseTensor, 3301 sparse_tensor.SparseTensorValue)) 3302 3303 def _clean_out_of_range(values): 3304 """Replaces by -1 any large out-of-range `values`.""" 3305 return array_ops.where_v2(math_ops.greater_equal(values, num_classes), 3306 -1 * array_ops.ones_like(values), values) 3307 3308 def _clean_labels_out_of_range(): 3309 """Replaces by -1 ane large out-of-range values in `labels`.""" 3310 if _labels_is_sparse(): 3311 return type(labels)(indices=labels.indices, 3312 values=_clean_out_of_range(labels.values), 3313 dense_shape=labels.dense_shape) 3314 else: 3315 return _clean_out_of_range(labels) 3316 3317 max_labels = math_ops.reduce_max( 3318 labels.values if _labels_is_sparse() else labels) 3319 return control_flow_ops.cond( 3320 math_ops.greater_equal(max_labels, num_classes), 3321 _clean_labels_out_of_range, 3322 lambda: labels) 3323 3324 3325@tf_export(v1=['metrics.sparse_average_precision_at_k']) 3326@deprecated(None, 'Use average_precision_at_k instead') 3327def sparse_average_precision_at_k(labels, 3328 predictions, 3329 k, 3330 weights=None, 3331 metrics_collections=None, 3332 updates_collections=None, 3333 name=None): 3334 """Renamed to `average_precision_at_k`, please use that method instead.""" 3335 return average_precision_at_k( 3336 labels=labels, 3337 predictions=predictions, 3338 k=k, 3339 weights=weights, 3340 metrics_collections=metrics_collections, 3341 updates_collections=updates_collections, 3342 name=name) 3343 3344 3345@tf_export(v1=['metrics.average_precision_at_k']) 3346def average_precision_at_k(labels, 3347 predictions, 3348 k, 3349 weights=None, 3350 metrics_collections=None, 3351 updates_collections=None, 3352 name=None): 3353 """Computes average precision@k of predictions with respect to sparse labels. 3354 3355 `average_precision_at_k` creates two local variables, 3356 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3357 are used to compute the frequency. This frequency is ultimately returned as 3358 `average_precision_at_<k>`: an idempotent operation that simply divides 3359 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3360 3361 For estimation of the metric over a stream of data, the function creates an 3362 `update_op` operation that updates these variables and returns the 3363 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3364 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3365 `labels` calculate the true positives and false positives weighted by 3366 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3367 `false_positive_at_<k>` using these values. 3368 3369 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3370 3371 Args: 3372 labels: `int64` `Tensor` or `SparseTensor` with shape 3373 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3374 num_labels=1. N >= 1 and num_labels is the number of target classes for 3375 the associated prediction. Commonly, N=1 and `labels` has shape 3376 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3377 should be in range [0, num_classes), where num_classes is the last 3378 dimension of `predictions`. Values outside this range are ignored. 3379 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3380 N >= 1. Commonly, N=1 and `predictions` has shape 3381 [batch size, num_classes]. The final dimension contains the logit values 3382 for each class. [D1, ... DN] must match `labels`. 3383 k: Integer, k for @k metric. This will calculate an average precision for 3384 range `[1,k]`, as documented above. 3385 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3386 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3387 dimensions must be either `1`, or the same as the corresponding `labels` 3388 dimension). 3389 metrics_collections: An optional list of collections that values should 3390 be added to. 3391 updates_collections: An optional list of collections that updates should 3392 be added to. 3393 name: Name of new update operation, and namespace for other dependent ops. 3394 3395 Returns: 3396 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3397 precision values. 3398 update: `Operation` that increments variables appropriately, and whose 3399 value matches `metric`. 3400 3401 Raises: 3402 ValueError: if k is invalid. 3403 RuntimeError: If eager execution is enabled. 3404 """ 3405 if context.executing_eagerly(): 3406 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 3407 'supported when eager execution is enabled.') 3408 3409 if k < 1: 3410 raise ValueError(f'Invalid k={k}. `k` should be >= 1.') 3411 with ops.name_scope(name, _at_k_name('average_precision', k), 3412 (predictions, labels, weights)) as scope: 3413 # Calculate top k indices to produce [D1, ... DN, k] tensor. 3414 _, predictions_idx = nn.top_k(predictions, k) 3415 # The documentation states that labels should be in [0, ..., num_classes), 3416 # but num_classes is lost when predictions_idx replaces predictions. 3417 # For conformity with the documentation, any label >= num_classes, which is 3418 # ignored, is replaced by -1. 3419 labels = _clean_out_of_range_indices( 3420 labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64)) 3421 return _streaming_sparse_average_precision_at_top_k( 3422 labels=labels, 3423 predictions_idx=predictions_idx, 3424 weights=weights, 3425 metrics_collections=metrics_collections, 3426 updates_collections=updates_collections, 3427 name=scope) 3428 3429 3430def _sparse_false_positive_at_k(labels, 3431 predictions_idx, 3432 class_id=None, 3433 weights=None): 3434 """Calculates false positives for precision@k. 3435 3436 If `class_id` is specified, calculate binary true positives for `class_id` 3437 only. 3438 If `class_id` is not specified, calculate metrics for `k` predicted vs 3439 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 3440 3441 Args: 3442 labels: `int64` `Tensor` or `SparseTensor` with shape 3443 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3444 target classes for the associated prediction. Commonly, N=1 and `labels` 3445 has shape [batch_size, num_labels]. [D1, ... DN] must match 3446 `predictions_idx`. 3447 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3448 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3449 match `labels`. 3450 class_id: Class for which we want binary metrics. 3451 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3452 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3453 dimensions must be either `1`, or the same as the corresponding `labels` 3454 dimension). 3455 3456 Returns: 3457 A [D1, ... DN] `Tensor` of false positive counts. 3458 """ 3459 with ops.name_scope(None, 'false_positives', 3460 (predictions_idx, labels, weights)): 3461 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 3462 class_id) 3463 fp = sets.set_size( 3464 sets.set_difference(predictions_idx, labels, aminusb=True)) 3465 fp = math_ops.cast(fp, dtypes.float64) 3466 if weights is not None: 3467 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 3468 weights, fp),)): 3469 weights = math_ops.cast(weights, dtypes.float64) 3470 fp = math_ops.multiply(fp, weights) 3471 return fp 3472 3473 3474def _streaming_sparse_false_positive_at_k(labels, 3475 predictions_idx, 3476 k=None, 3477 class_id=None, 3478 weights=None, 3479 name=None): 3480 """Calculates weighted per step false positives for precision@k. 3481 3482 If `class_id` is specified, calculate binary true positives for `class_id` 3483 only. 3484 If `class_id` is not specified, calculate metrics for `k` predicted vs 3485 `n` label classes, where `n` is the 2nd dimension of `labels`. 3486 3487 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3488 3489 Args: 3490 labels: `int64` `Tensor` or `SparseTensor` with shape 3491 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3492 target classes for the associated prediction. Commonly, N=1 and `labels` 3493 has shape [batch_size, num_labels]. [D1, ... DN] must match 3494 `predictions_idx`. 3495 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3496 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3497 match `labels`. 3498 k: Integer, k for @k metric. This is only used for default op name. 3499 class_id: Class for which we want binary metrics. 3500 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3501 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3502 dimensions must be either `1`, or the same as the corresponding `labels` 3503 dimension). 3504 name: Name of new variable, and namespace for other dependent ops. 3505 3506 Returns: 3507 A tuple of `Variable` and update `Operation`. 3508 3509 Raises: 3510 ValueError: If `weights` is not `None` and has an incompatible shape. 3511 """ 3512 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), 3513 (predictions_idx, labels, weights)) as scope: 3514 fp = _sparse_false_positive_at_k( 3515 predictions_idx=predictions_idx, 3516 labels=labels, 3517 class_id=class_id, 3518 weights=weights) 3519 batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64) 3520 3521 var = metric_variable([], dtypes.float64, name=scope) 3522 return var, state_ops.assign_add(var, batch_total_fp, name='update') 3523 3524 3525@tf_export(v1=['metrics.precision_at_top_k']) 3526def precision_at_top_k(labels, 3527 predictions_idx, 3528 k=None, 3529 class_id=None, 3530 weights=None, 3531 metrics_collections=None, 3532 updates_collections=None, 3533 name=None): 3534 """Computes precision@k of the predictions with respect to sparse labels. 3535 3536 Differs from `sparse_precision_at_k` in that predictions must be in the form 3537 of top `k` class indices, whereas `sparse_precision_at_k` expects logits. 3538 Refer to `sparse_precision_at_k` for more details. 3539 3540 Args: 3541 labels: `int64` `Tensor` or `SparseTensor` with shape 3542 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3543 num_labels=1. N >= 1 and num_labels is the number of target classes for 3544 the associated prediction. Commonly, N=1 and `labels` has shape 3545 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3546 should be in range [0, num_classes), where num_classes is the last 3547 dimension of `predictions`. Values outside this range are ignored. 3548 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where 3549 N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 3550 The final dimension contains the top `k` predicted class indices. 3551 [D1, ... DN] must match `labels`. 3552 k: Integer, k for @k metric. Only used for the default op name. 3553 class_id: Integer class ID for which we want binary metrics. This should be 3554 in range [0, num_classes], where num_classes is the last dimension of 3555 `predictions`. If `class_id` is outside this range, the method returns 3556 NAN. 3557 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3558 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3559 dimensions must be either `1`, or the same as the corresponding `labels` 3560 dimension). 3561 metrics_collections: An optional list of collections that values should 3562 be added to. 3563 updates_collections: An optional list of collections that updates should 3564 be added to. 3565 name: Name of new update operation, and namespace for other dependent ops. 3566 3567 Returns: 3568 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3569 divided by the sum of `true_positives` and `false_positives`. 3570 update_op: `Operation` that increments `true_positives` and 3571 `false_positives` variables appropriately, and whose value matches 3572 `precision`. 3573 3574 Raises: 3575 ValueError: If `weights` is not `None` and its shape doesn't match 3576 `predictions`, or if either `metrics_collections` or `updates_collections` 3577 are not a list or tuple. 3578 RuntimeError: If eager execution is enabled. 3579 """ 3580 if context.executing_eagerly(): 3581 raise RuntimeError('tf.metrics.precision_at_top_k is not ' 3582 'supported when eager execution is enabled.') 3583 3584 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3585 (predictions_idx, labels, weights)) as scope: 3586 labels = _maybe_expand_labels(labels, predictions_idx) 3587 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 3588 tp, tp_update = _streaming_sparse_true_positive_at_k( 3589 predictions_idx=top_k_idx, 3590 labels=labels, 3591 k=k, 3592 class_id=class_id, 3593 weights=weights) 3594 fp, fp_update = _streaming_sparse_false_positive_at_k( 3595 predictions_idx=top_k_idx, 3596 labels=labels, 3597 k=k, 3598 class_id=class_id, 3599 weights=weights) 3600 3601 def precision_across_replicas(_, tp, fp): 3602 return math_ops.divide(tp, math_ops.add(tp, fp), name=scope) 3603 3604 metric = _aggregate_across_replicas( 3605 metrics_collections, precision_across_replicas, tp, fp) 3606 3607 update = math_ops.divide( 3608 tp_update, math_ops.add(tp_update, fp_update), name='update') 3609 if updates_collections: 3610 ops.add_to_collections(updates_collections, update) 3611 return metric, update 3612 3613 3614@tf_export(v1=['metrics.sparse_precision_at_k']) 3615@deprecated(None, 'Use precision_at_k instead') 3616def sparse_precision_at_k(labels, 3617 predictions, 3618 k, 3619 class_id=None, 3620 weights=None, 3621 metrics_collections=None, 3622 updates_collections=None, 3623 name=None): 3624 """Renamed to `precision_at_k`, please use that method instead.""" 3625 return precision_at_k( 3626 labels=labels, 3627 predictions=predictions, 3628 k=k, 3629 class_id=class_id, 3630 weights=weights, 3631 metrics_collections=metrics_collections, 3632 updates_collections=updates_collections, 3633 name=name) 3634 3635 3636@tf_export(v1=['metrics.precision_at_k']) 3637def precision_at_k(labels, 3638 predictions, 3639 k, 3640 class_id=None, 3641 weights=None, 3642 metrics_collections=None, 3643 updates_collections=None, 3644 name=None): 3645 """Computes precision@k of the predictions with respect to sparse labels. 3646 3647 If `class_id` is specified, we calculate precision by considering only the 3648 entries in the batch for which `class_id` is in the top-k highest 3649 `predictions`, and computing the fraction of them for which `class_id` is 3650 indeed a correct label. 3651 If `class_id` is not specified, we'll calculate precision as how often on 3652 average a class among the top-k classes with the highest predicted values 3653 of a batch entry is correct and can be found in the label for that entry. 3654 3655 `precision_at_k` creates two local variables, 3656 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 3657 the precision@k frequency. This frequency is ultimately returned as 3658 `precision_at_<k>`: an idempotent operation that simply divides 3659 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 3660 `false_positive_at_<k>`). 3661 3662 For estimation of the metric over a stream of data, the function creates an 3663 `update_op` operation that updates these variables and returns the 3664 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3665 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3666 `labels` calculate the true positives and false positives weighted by 3667 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3668 `false_positive_at_<k>` using these values. 3669 3670 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3671 3672 Args: 3673 labels: `int64` `Tensor` or `SparseTensor` with shape 3674 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3675 num_labels=1. N >= 1 and num_labels is the number of target classes for 3676 the associated prediction. Commonly, N=1 and `labels` has shape 3677 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3678 should be in range [0, num_classes), where num_classes is the last 3679 dimension of `predictions`. Values outside this range are ignored. 3680 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3681 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 3682 The final dimension contains the logit values for each class. [D1, ... DN] 3683 must match `labels`. 3684 k: Integer, k for @k metric. 3685 class_id: Integer class ID for which we want binary metrics. This should be 3686 in range [0, num_classes], where num_classes is the last dimension of 3687 `predictions`. If `class_id` is outside this range, the method returns 3688 NAN. 3689 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3690 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3691 dimensions must be either `1`, or the same as the corresponding `labels` 3692 dimension). 3693 metrics_collections: An optional list of collections that values should 3694 be added to. 3695 updates_collections: An optional list of collections that updates should 3696 be added to. 3697 name: Name of new update operation, and namespace for other dependent ops. 3698 3699 Returns: 3700 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3701 divided by the sum of `true_positives` and `false_positives`. 3702 update_op: `Operation` that increments `true_positives` and 3703 `false_positives` variables appropriately, and whose value matches 3704 `precision`. 3705 3706 Raises: 3707 ValueError: If `weights` is not `None` and its shape doesn't match 3708 `predictions`, or if either `metrics_collections` or `updates_collections` 3709 are not a list or tuple. 3710 RuntimeError: If eager execution is enabled. 3711 """ 3712 if context.executing_eagerly(): 3713 raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 3714 'supported when eager execution is enabled.') 3715 3716 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3717 (predictions, labels, weights)) as scope: 3718 _, top_k_idx = nn.top_k(predictions, k) 3719 return precision_at_top_k( 3720 labels=labels, 3721 predictions_idx=top_k_idx, 3722 k=k, 3723 class_id=class_id, 3724 weights=weights, 3725 metrics_collections=metrics_collections, 3726 updates_collections=updates_collections, 3727 name=scope) 3728 3729 3730@tf_export(v1=['metrics.specificity_at_sensitivity']) 3731def specificity_at_sensitivity(labels, 3732 predictions, 3733 sensitivity, 3734 weights=None, 3735 num_thresholds=200, 3736 metrics_collections=None, 3737 updates_collections=None, 3738 name=None): 3739 """Computes the specificity at a given sensitivity. 3740 3741 The `specificity_at_sensitivity` function creates four local 3742 variables, `true_positives`, `true_negatives`, `false_positives` and 3743 `false_negatives` that are used to compute the specificity at the given 3744 sensitivity value. The threshold for the given sensitivity value is computed 3745 and used to evaluate the corresponding specificity. 3746 3747 For estimation of the metric over a stream of data, the function creates an 3748 `update_op` operation that updates these variables and returns the 3749 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 3750 `false_positives` and `false_negatives` counts with the weight of each case 3751 found in the `predictions` and `labels`. 3752 3753 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3754 3755 For additional information about specificity and sensitivity, see the 3756 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 3757 3758 Args: 3759 labels: The ground truth values, a `Tensor` whose dimensions must match 3760 `predictions`. Will be cast to `bool`. 3761 predictions: A floating point `Tensor` of arbitrary shape and whose values 3762 are in the range `[0, 1]`. 3763 sensitivity: A scalar value in range `[0, 1]`. 3764 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3765 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 3766 be either `1`, or the same as the corresponding `labels` dimension). 3767 num_thresholds: The number of thresholds to use for matching the given 3768 sensitivity. 3769 metrics_collections: An optional list of collections that `specificity` 3770 should be added to. 3771 updates_collections: An optional list of collections that `update_op` should 3772 be added to. 3773 name: An optional variable_scope name. 3774 3775 Returns: 3776 specificity: A scalar `Tensor` representing the specificity at the given 3777 `sensitivity` value. 3778 update_op: An operation that increments the `true_positives`, 3779 `true_negatives`, `false_positives` and `false_negatives` variables 3780 appropriately and whose value matches `specificity`. 3781 3782 Raises: 3783 ValueError: If `predictions` and `labels` have mismatched shapes, if 3784 `weights` is not `None` and its shape doesn't match `predictions`, or if 3785 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 3786 or `updates_collections` are not a list or tuple. 3787 RuntimeError: If eager execution is enabled. 3788 """ 3789 if context.executing_eagerly(): 3790 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 3791 'supported when eager execution is enabled.') 3792 3793 if sensitivity < 0 or sensitivity > 1: 3794 raise ValueError('`sensitivity` must be in the range [0, 1]. Currently, ' 3795 f'`sensitivity` is {sensitivity}.') 3796 3797 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 3798 (predictions, labels, weights)): 3799 kepsilon = 1e-7 # to account for floating point imprecisions 3800 thresholds = [ 3801 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 3802 ] 3803 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 3804 3805 values, update_ops = _confusion_matrix_at_thresholds( 3806 labels, predictions, thresholds, weights) 3807 3808 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name): 3809 """Computes the specificity at the given sensitivity. 3810 3811 Args: 3812 tp: True positives. 3813 tn: True negatives. 3814 fp: False positives. 3815 fn: False negatives. 3816 name: The name of the operation. 3817 3818 Returns: 3819 The specificity using the aggregated values. 3820 """ 3821 sensitivities = math_ops.divide(tp, tp + fn + kepsilon) 3822 3823 # We'll need to use this trick until tf.argmax allows us to specify 3824 # whether we should use the first or last index in case of ties. 3825 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 3826 indices_at_minval = math_ops.equal( 3827 math_ops.abs(sensitivities - sensitivity), min_val) 3828 indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64) 3829 indices_at_minval = math_ops.cumsum(indices_at_minval) 3830 tf_index = math_ops.argmax(indices_at_minval, 0) 3831 tf_index = math_ops.cast(tf_index, dtypes.int32) 3832 3833 # Now, we have the implicit threshold, so compute the specificity: 3834 return math_ops.divide(tn[tf_index], 3835 tn[tf_index] + fp[tf_index] + kepsilon, name) 3836 3837 def specificity_across_replicas(_, values): 3838 return compute_specificity_at_sensitivity( 3839 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 3840 3841 specificity = _aggregate_across_replicas( 3842 metrics_collections, specificity_across_replicas, values) 3843 3844 update_op = compute_specificity_at_sensitivity( 3845 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3846 'update_op') 3847 if updates_collections: 3848 ops.add_to_collections(updates_collections, update_op) 3849 3850 return specificity, update_op 3851