1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# Licensed under the Apache License, Version 2.0 (the "License"); 3# you may not use this file except in compliance with the License. 4# You may obtain a copy of the License at 5# 6# http://www.apache.org/licenses/LICENSE-2.0 7# 8# Unless required by applicable law or agreed to in writing, software 9# distributed under the License is distributed on an "AS IS" BASIS, 10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11# See the License for the specific language governing permissions and 12# limitations under the License. 13# ============================================================================== 14"""Implementation of tf.metrics module.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import distribution_strategy_context 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import confusion_matrix 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn 31from tensorflow.python.ops import sets 32from tensorflow.python.ops import sparse_ops 33from tensorflow.python.ops import state_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.ops import weights_broadcast_ops 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.util.deprecation import deprecated 38from tensorflow.python.util.tf_export import tf_export 39 40 41def metric_variable(shape, dtype, validate_shape=True, name=None): 42 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. 43 44 If running in a `DistributionStrategy` context, the variable will be 45 "sync on read". This means: 46 47 * The returned object will be a container with separate variables 48 per replica of the model. 49 50 * When writing to the variable, e.g. using `assign_add` in a metric 51 update, the update will be applied to the variable local to the 52 replica. 53 54 * To get a metric's result value, we need to sum the variable values 55 across the replicas before computing the final answer. Furthermore, 56 the final answer should be computed once instead of in every 57 replica. Both of these are accomplished by running the computation 58 of the final result value inside 59 `distribution_strategy_context.get_replica_context().merge_call(fn)`. 60 Inside the `merge_call()`, ops are only added to the graph once 61 and access to a sync on read variable in a computation returns 62 the sum across all replicas. 63 64 Args: 65 shape: Shape of the created variable. 66 dtype: Type of the created variable. 67 validate_shape: (Optional) Whether shape validation is enabled for 68 the created variable. 69 name: (Optional) String name of the created variable. 70 71 Returns: 72 A (non-trainable) variable initialized to zero, or if inside a 73 `DistributionStrategy` scope a sync on read variable container. 74 """ 75 # Note that synchronization "ON_READ" implies trainable=False. 76 return variable_scope.variable( 77 lambda: array_ops.zeros(shape, dtype), 78 trainable=False, 79 collections=[ 80 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 81 ], 82 validate_shape=validate_shape, 83 synchronization=variable_scope.VariableSynchronization.ON_READ, 84 aggregation=variable_scope.VariableAggregation.SUM, 85 name=name) 86 87 88def _remove_squeezable_dimensions(predictions, labels, weights): 89 """Squeeze or expand last dim if needed. 90 91 Squeezes last dim of `predictions` or `labels` if their rank differs by 1 92 (using confusion_matrix.remove_squeezable_dimensions). 93 Squeezes or expands last dim of `weights` if its rank differs by 1 from the 94 new rank of `predictions`. 95 96 If `weights` is scalar, it is kept scalar. 97 98 This will use static shape if available. Otherwise, it will add graph 99 operations, which could result in a performance hit. 100 101 Args: 102 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 103 labels: Optional label `Tensor` whose dimensions match `predictions`. 104 weights: Optional weight scalar or `Tensor` whose dimensions match 105 `predictions`. 106 107 Returns: 108 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has 109 the last dimension squeezed, `weights` could be extended by one dimension. 110 """ 111 predictions = ops.convert_to_tensor(predictions) 112 if labels is not None: 113 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 114 labels, predictions) 115 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 116 117 if weights is None: 118 return predictions, labels, None 119 120 weights = ops.convert_to_tensor(weights) 121 weights_shape = weights.get_shape() 122 weights_rank = weights_shape.ndims 123 if weights_rank == 0: 124 return predictions, labels, weights 125 126 predictions_shape = predictions.get_shape() 127 predictions_rank = predictions_shape.ndims 128 if (predictions_rank is not None) and (weights_rank is not None): 129 # Use static rank. 130 if weights_rank - predictions_rank == 1: 131 weights = array_ops.squeeze(weights, [-1]) 132 elif predictions_rank - weights_rank == 1: 133 weights = array_ops.expand_dims(weights, [-1]) 134 else: 135 # Use dynamic rank. 136 weights_rank_tensor = array_ops.rank(weights) 137 rank_diff = weights_rank_tensor - array_ops.rank(predictions) 138 139 def _maybe_expand_weights(): 140 return control_flow_ops.cond( 141 math_ops.equal(rank_diff, -1), 142 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) 143 144 # Don't attempt squeeze if it will fail based on static check. 145 if ((weights_rank is not None) and 146 (not weights_shape.dims[-1].is_compatible_with(1))): 147 maybe_squeeze_weights = lambda: weights 148 else: 149 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) 150 151 def _maybe_adjust_weights(): 152 return control_flow_ops.cond( 153 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 154 _maybe_expand_weights) 155 156 # If weights are scalar, do nothing. Otherwise, try to add or remove a 157 # dimension to match predictions. 158 weights = control_flow_ops.cond( 159 math_ops.equal(weights_rank_tensor, 0), lambda: weights, 160 _maybe_adjust_weights) 161 return predictions, labels, weights 162 163 164def _maybe_expand_labels(labels, predictions): 165 """If necessary, expand `labels` along last dimension to match `predictions`. 166 167 Args: 168 labels: `Tensor` or `SparseTensor` with shape 169 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies 170 num_labels=1, in which case the result is an expanded `labels` with shape 171 [D1, ... DN, 1]. 172 predictions: `Tensor` with shape [D1, ... DN, num_classes]. 173 174 Returns: 175 `labels` with the same rank as `predictions`. 176 177 Raises: 178 ValueError: if `labels` has invalid shape. 179 """ 180 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope: 181 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 182 183 # If sparse, expand sparse shape. 184 if isinstance(labels, sparse_tensor.SparseTensor): 185 return control_flow_ops.cond( 186 math_ops.equal( 187 array_ops.rank(predictions), 188 array_ops.size(labels.dense_shape) + 1), 189 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda 190 labels, 191 shape=array_ops.concat((labels.dense_shape, (1,)), 0), 192 name=scope), 193 lambda: labels) 194 195 # Otherwise, try to use static shape. 196 labels_rank = labels.get_shape().ndims 197 if labels_rank is not None: 198 predictions_rank = predictions.get_shape().ndims 199 if predictions_rank is not None: 200 if predictions_rank == labels_rank: 201 return labels 202 if predictions_rank == labels_rank + 1: 203 return array_ops.expand_dims(labels, -1, name=scope) 204 raise ValueError( 205 'Unexpected labels shape %s for predictions shape %s.' % 206 (labels.get_shape(), predictions.get_shape())) 207 208 # Otherwise, use dynamic shape. 209 return control_flow_ops.cond( 210 math_ops.equal(array_ops.rank(predictions), 211 array_ops.rank(labels) + 1), 212 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) 213 214 215def _safe_scalar_div(numerator, denominator, name): 216 """Divides two values, returning 0 if the denominator is 0. 217 218 Args: 219 numerator: A scalar `float64` `Tensor`. 220 denominator: A scalar `float64` `Tensor`. 221 name: Name for the returned op. 222 223 Returns: 224 0 if `denominator` == 0, else `numerator` / `denominator` 225 """ 226 numerator.get_shape().with_rank_at_most(1) 227 denominator.get_shape().with_rank_at_most(1) 228 return math_ops.div_no_nan(numerator, denominator, name=name) 229 230 231def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 232 """Calculate a streaming confusion matrix. 233 234 Calculates a confusion matrix. For estimation over a stream of data, 235 the function creates an `update_op` operation. 236 237 Args: 238 labels: A `Tensor` of ground truth labels with shape [batch size] and of 239 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 240 predictions: A `Tensor` of prediction results for semantic labels, whose 241 shape is [batch size] and type `int32` or `int64`. The tensor will be 242 flattened if its rank > 1. 243 num_classes: The possible number of labels the prediction task can 244 have. This value must be provided, since a confusion matrix of 245 dimension = [num_classes, num_classes] will be allocated. 246 weights: Optional `Tensor` whose rank is either 0, or the same rank as 247 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 248 be either `1`, or the same as the corresponding `labels` dimension). 249 250 Returns: 251 total_cm: A `Tensor` representing the confusion matrix. 252 update_op: An operation that increments the confusion matrix. 253 """ 254 # Local variable to accumulate the predictions in the confusion matrix. 255 total_cm = metric_variable( 256 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 257 258 # Cast the type to int64 required by confusion_matrix_ops. 259 predictions = math_ops.cast(predictions, dtypes.int64) 260 labels = math_ops.cast(labels, dtypes.int64) 261 num_classes = math_ops.cast(num_classes, dtypes.int64) 262 263 # Flatten the input if its rank > 1. 264 if predictions.get_shape().ndims > 1: 265 predictions = array_ops.reshape(predictions, [-1]) 266 267 if labels.get_shape().ndims > 1: 268 labels = array_ops.reshape(labels, [-1]) 269 270 if (weights is not None) and (weights.get_shape().ndims > 1): 271 weights = array_ops.reshape(weights, [-1]) 272 273 # Accumulate the prediction to current confusion matrix. 274 current_cm = confusion_matrix.confusion_matrix( 275 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 276 update_op = state_ops.assign_add(total_cm, current_cm) 277 return total_cm, update_op 278 279 280def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args): 281 """Aggregate metric value across replicas.""" 282 def fn(distribution, *a): 283 """Call `metric_value_fn` in the correct control flow context.""" 284 if hasattr(distribution.extended, '_outer_control_flow_context'): 285 # If there was an outer context captured before this method was called, 286 # then we enter that context to create the metric value op. If the 287 # captured context is `None`, ops.control_dependencies(None) gives the 288 # desired behavior. Else we use `Enter` and `Exit` to enter and exit the 289 # captured context. 290 # This special handling is needed because sometimes the metric is created 291 # inside a while_loop (and perhaps a TPU rewrite context). But we don't 292 # want the value op to be evaluated every step or on the TPU. So we 293 # create it outside so that it can be evaluated at the end on the host, 294 # once the update ops have been evaluated. 295 296 # pylint: disable=protected-access 297 if distribution.extended._outer_control_flow_context is None: 298 with ops.control_dependencies(None): 299 metric_value = metric_value_fn(distribution, *a) 300 else: 301 distribution.extended._outer_control_flow_context.Enter() 302 metric_value = metric_value_fn(distribution, *a) 303 distribution.extended._outer_control_flow_context.Exit() 304 # pylint: enable=protected-access 305 else: 306 metric_value = metric_value_fn(distribution, *a) 307 if metrics_collections: 308 ops.add_to_collections(metrics_collections, metric_value) 309 return metric_value 310 311 return distribution_strategy_context.get_replica_context().merge_call( 312 fn, args=args) 313 314 315@tf_export(v1=['metrics.mean']) 316def mean(values, 317 weights=None, 318 metrics_collections=None, 319 updates_collections=None, 320 name=None): 321 """Computes the (weighted) mean of the given values. 322 323 The `mean` function creates two local variables, `total` and `count` 324 that are used to compute the average of `values`. This average is ultimately 325 returned as `mean` which is an idempotent operation that simply divides 326 `total` by `count`. 327 328 For estimation of the metric over a stream of data, the function creates an 329 `update_op` operation that updates these variables and returns the `mean`. 330 `update_op` increments `total` with the reduced sum of the product of `values` 331 and `weights`, and it increments `count` with the reduced sum of `weights`. 332 333 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 334 335 Args: 336 values: A `Tensor` of arbitrary dimensions. 337 weights: Optional `Tensor` whose rank is either 0, or the same rank as 338 `values`, and must be broadcastable to `values` (i.e., all dimensions must 339 be either `1`, or the same as the corresponding `values` dimension). 340 metrics_collections: An optional list of collections that `mean` 341 should be added to. 342 updates_collections: An optional list of collections that `update_op` 343 should be added to. 344 name: An optional variable_scope name. 345 346 Returns: 347 mean: A `Tensor` representing the current mean, the value of `total` divided 348 by `count`. 349 update_op: An operation that increments the `total` and `count` variables 350 appropriately and whose value matches `mean_value`. 351 352 Raises: 353 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 354 or if either `metrics_collections` or `updates_collections` are not a list 355 or tuple. 356 RuntimeError: If eager execution is enabled. 357 """ 358 if context.executing_eagerly(): 359 raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 360 'is enabled.') 361 362 with variable_scope.variable_scope(name, 'mean', (values, weights)): 363 values = math_ops.cast(values, dtypes.float32) 364 365 total = metric_variable([], dtypes.float32, name='total') 366 count = metric_variable([], dtypes.float32, name='count') 367 368 if weights is None: 369 num_values = math_ops.cast(array_ops.size(values), dtypes.float32) 370 else: 371 values, _, weights = _remove_squeezable_dimensions( 372 predictions=values, labels=None, weights=weights) 373 weights = weights_broadcast_ops.broadcast_weights( 374 math_ops.cast(weights, dtypes.float32), values) 375 values = math_ops.multiply(values, weights) 376 num_values = math_ops.reduce_sum(weights) 377 378 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 379 with ops.control_dependencies([values]): 380 update_count_op = state_ops.assign_add(count, num_values) 381 382 def compute_mean(_, t, c): 383 return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value') 384 385 mean_t = _aggregate_across_replicas( 386 metrics_collections, compute_mean, total, count) 387 update_op = math_ops.div_no_nan( 388 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 389 390 if updates_collections: 391 ops.add_to_collections(updates_collections, update_op) 392 393 return mean_t, update_op 394 395 396@tf_export(v1=['metrics.accuracy']) 397def accuracy(labels, 398 predictions, 399 weights=None, 400 metrics_collections=None, 401 updates_collections=None, 402 name=None): 403 """Calculates how often `predictions` matches `labels`. 404 405 The `accuracy` function creates two local variables, `total` and 406 `count` that are used to compute the frequency with which `predictions` 407 matches `labels`. This frequency is ultimately returned as `accuracy`: an 408 idempotent operation that simply divides `total` by `count`. 409 410 For estimation of the metric over a stream of data, the function creates an 411 `update_op` operation that updates these variables and returns the `accuracy`. 412 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 413 where the corresponding elements of `predictions` and `labels` match and 0.0 414 otherwise. Then `update_op` increments `total` with the reduced sum of the 415 product of `weights` and `is_correct`, and it increments `count` with the 416 reduced sum of `weights`. 417 418 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 419 420 Args: 421 labels: The ground truth values, a `Tensor` whose shape matches 422 `predictions`. 423 predictions: The predicted values, a `Tensor` of any shape. 424 weights: Optional `Tensor` whose rank is either 0, or the same rank as 425 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 426 be either `1`, or the same as the corresponding `labels` dimension). 427 metrics_collections: An optional list of collections that `accuracy` should 428 be added to. 429 updates_collections: An optional list of collections that `update_op` should 430 be added to. 431 name: An optional variable_scope name. 432 433 Returns: 434 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 435 by `count`. 436 update_op: An operation that increments the `total` and `count` variables 437 appropriately and whose value matches `accuracy`. 438 439 Raises: 440 ValueError: If `predictions` and `labels` have mismatched shapes, or if 441 `weights` is not `None` and its shape doesn't match `predictions`, or if 442 either `metrics_collections` or `updates_collections` are not a list or 443 tuple. 444 RuntimeError: If eager execution is enabled. 445 """ 446 if context.executing_eagerly(): 447 raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 448 'execution is enabled.') 449 450 predictions, labels, weights = _remove_squeezable_dimensions( 451 predictions=predictions, labels=labels, weights=weights) 452 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 453 if labels.dtype != predictions.dtype: 454 predictions = math_ops.cast(predictions, labels.dtype) 455 is_correct = math_ops.cast( 456 math_ops.equal(predictions, labels), dtypes.float32) 457 return mean(is_correct, weights, metrics_collections, updates_collections, 458 name or 'accuracy') 459 460 461def _confusion_matrix_at_thresholds(labels, 462 predictions, 463 thresholds, 464 weights=None, 465 includes=None): 466 """Computes true_positives, false_negatives, true_negatives, false_positives. 467 468 This function creates up to four local variables, `true_positives`, 469 `true_negatives`, `false_positives` and `false_negatives`. 470 `true_positive[i]` is defined as the total weight of values in `predictions` 471 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 472 `false_negatives[i]` is defined as the total weight of values in `predictions` 473 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 474 `true_negatives[i]` is defined as the total weight of values in `predictions` 475 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 476 `false_positives[i]` is defined as the total weight of values in `predictions` 477 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 478 479 For estimation of these metrics over a stream of data, for each metric the 480 function respectively creates an `update_op` operation that updates the 481 variable and returns its value. 482 483 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 484 485 Args: 486 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 487 `bool`. 488 predictions: A floating point `Tensor` of arbitrary shape and whose values 489 are in the range `[0, 1]`. 490 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 491 weights: Optional `Tensor` whose rank is either 0, or the same rank as 492 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 493 be either `1`, or the same as the corresponding `labels` dimension). 494 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 495 default to all four. 496 497 Returns: 498 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 499 `includes`. 500 update_ops: Dict of operations that increments the `values`. Keys are from 501 `includes`. 502 503 Raises: 504 ValueError: If `predictions` and `labels` have mismatched shapes, or if 505 `weights` is not `None` and its shape doesn't match `predictions`, or if 506 `includes` contains invalid keys. 507 """ 508 all_includes = ('tp', 'fn', 'tn', 'fp') 509 if includes is None: 510 includes = all_includes 511 else: 512 for include in includes: 513 if include not in all_includes: 514 raise ValueError('Invalid key: %s.' % include) 515 516 with ops.control_dependencies([ 517 check_ops.assert_greater_equal( 518 predictions, 519 math_ops.cast(0.0, dtype=predictions.dtype), 520 message='predictions must be in [0, 1]'), 521 check_ops.assert_less_equal( 522 predictions, 523 math_ops.cast(1.0, dtype=predictions.dtype), 524 message='predictions must be in [0, 1]') 525 ]): 526 predictions, labels, weights = _remove_squeezable_dimensions( 527 predictions=math_ops.cast(predictions, dtypes.float32), 528 labels=math_ops.cast(labels, dtype=dtypes.bool), 529 weights=weights) 530 531 num_thresholds = len(thresholds) 532 533 # Reshape predictions and labels. 534 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 535 labels_2d = array_ops.reshape( 536 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 537 538 # Use static shape if known. 539 num_predictions = predictions_2d.get_shape().as_list()[0] 540 541 # Otherwise use dynamic shape. 542 if num_predictions is None: 543 num_predictions = array_ops.shape(predictions_2d)[0] 544 thresh_tiled = array_ops.tile( 545 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 546 array_ops.stack([1, num_predictions])) 547 548 # Tile the predictions after thresholding them across different thresholds. 549 pred_is_pos = math_ops.greater( 550 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 551 thresh_tiled) 552 if ('fn' in includes) or ('tn' in includes): 553 pred_is_neg = math_ops.logical_not(pred_is_pos) 554 555 # Tile labels by number of thresholds 556 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 557 if ('fp' in includes) or ('tn' in includes): 558 label_is_neg = math_ops.logical_not(label_is_pos) 559 560 if weights is not None: 561 weights = weights_broadcast_ops.broadcast_weights( 562 math_ops.cast(weights, dtypes.float32), predictions) 563 weights_tiled = array_ops.tile( 564 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 565 thresh_tiled.get_shape().assert_is_compatible_with( 566 weights_tiled.get_shape()) 567 else: 568 weights_tiled = None 569 570 values = {} 571 update_ops = {} 572 573 if 'tp' in includes: 574 true_p = metric_variable( 575 [num_thresholds], dtypes.float32, name='true_positives') 576 is_true_positive = math_ops.cast( 577 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) 578 if weights_tiled is not None: 579 is_true_positive *= weights_tiled 580 update_ops['tp'] = state_ops.assign_add(true_p, 581 math_ops.reduce_sum( 582 is_true_positive, 1)) 583 values['tp'] = true_p 584 585 if 'fn' in includes: 586 false_n = metric_variable( 587 [num_thresholds], dtypes.float32, name='false_negatives') 588 is_false_negative = math_ops.cast( 589 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) 590 if weights_tiled is not None: 591 is_false_negative *= weights_tiled 592 update_ops['fn'] = state_ops.assign_add(false_n, 593 math_ops.reduce_sum( 594 is_false_negative, 1)) 595 values['fn'] = false_n 596 597 if 'tn' in includes: 598 true_n = metric_variable( 599 [num_thresholds], dtypes.float32, name='true_negatives') 600 is_true_negative = math_ops.cast( 601 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) 602 if weights_tiled is not None: 603 is_true_negative *= weights_tiled 604 update_ops['tn'] = state_ops.assign_add(true_n, 605 math_ops.reduce_sum( 606 is_true_negative, 1)) 607 values['tn'] = true_n 608 609 if 'fp' in includes: 610 false_p = metric_variable( 611 [num_thresholds], dtypes.float32, name='false_positives') 612 is_false_positive = math_ops.cast( 613 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) 614 if weights_tiled is not None: 615 is_false_positive *= weights_tiled 616 update_ops['fp'] = state_ops.assign_add(false_p, 617 math_ops.reduce_sum( 618 is_false_positive, 1)) 619 values['fp'] = false_p 620 621 return values, update_ops 622 623 624def _aggregate_variable(v, collections): 625 f = lambda distribution, value: distribution.extended.read_var(value) 626 return _aggregate_across_replicas(collections, f, v) 627 628 629@tf_export(v1=['metrics.auc']) 630@deprecated(None, 631 'The value of AUC returned by this may race with the update so ' 632 'this is deprecated. Please use tf.keras.metrics.AUC instead.') 633def auc(labels, 634 predictions, 635 weights=None, 636 num_thresholds=200, 637 metrics_collections=None, 638 updates_collections=None, 639 curve='ROC', 640 name=None, 641 summation_method='trapezoidal', 642 thresholds=None): 643 """Computes the approximate AUC via a Riemann sum. 644 645 The `auc` function creates four local variables, `true_positives`, 646 `true_negatives`, `false_positives` and `false_negatives` that are used to 647 compute the AUC. To discretize the AUC curve, a linearly spaced set of 648 thresholds is used to compute pairs of recall and precision values. The area 649 under the ROC-curve is therefore computed using the height of the recall 650 values by the false positive rate, while the area under the PR-curve is the 651 computed using the height of the precision values by the recall. 652 653 This value is ultimately returned as `auc`, an idempotent operation that 654 computes the area under a discretized curve of precision versus recall values 655 (computed using the aforementioned variables). The `num_thresholds` variable 656 controls the degree of discretization with larger numbers of thresholds more 657 closely approximating the true AUC. The quality of the approximation may vary 658 dramatically depending on `num_thresholds`. 659 660 For best results, `predictions` should be distributed approximately uniformly 661 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 662 approximation may be poor if this is not the case. Setting `summation_method` 663 to 'minoring' or 'majoring' can help quantify the error in the approximation 664 by providing lower or upper bound estimate of the AUC. The `thresholds` 665 parameter can be used to manually specify thresholds which split the 666 predictions more evenly. 667 668 For estimation of the metric over a stream of data, the function creates an 669 `update_op` operation that updates these variables and returns the `auc`. 670 671 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 672 673 Args: 674 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 675 `bool`. 676 predictions: A floating point `Tensor` of arbitrary shape and whose values 677 are in the range `[0, 1]`. 678 weights: Optional `Tensor` whose rank is either 0, or the same rank as 679 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 680 be either `1`, or the same as the corresponding `labels` dimension). 681 num_thresholds: The number of thresholds to use when discretizing the roc 682 curve. 683 metrics_collections: An optional list of collections that `auc` should be 684 added to. 685 updates_collections: An optional list of collections that `update_op` should 686 be added to. 687 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 688 'PR' for the Precision-Recall-curve. 689 name: An optional variable_scope name. 690 summation_method: Specifies the Riemann summation method used 691 (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that 692 applies the trapezoidal rule; 'careful_interpolation', a variant of it 693 differing only by a more correct interpolation scheme for PR-AUC - 694 interpolating (true/false) positives but not the ratio that is precision; 695 'minoring' that applies left summation for increasing intervals and right 696 summation for decreasing intervals; 'majoring' that does the opposite. 697 Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' 698 (to be deprecated soon) as it applies the same method for ROC, and a 699 better one (see Davis & Goadrich 2006 for details) for the PR curve. 700 thresholds: An optional list of floating point values to use as the 701 thresholds for discretizing the curve. If set, the `num_thresholds` 702 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 703 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be 704 automatically included with these to correctly handle predictions equal to 705 exactly 0 or 1. 706 707 Returns: 708 auc: A scalar `Tensor` representing the current area-under-curve. 709 update_op: An operation that increments the `true_positives`, 710 `true_negatives`, `false_positives` and `false_negatives` variables 711 appropriately and whose value matches `auc`. 712 713 Raises: 714 ValueError: If `predictions` and `labels` have mismatched shapes, or if 715 `weights` is not `None` and its shape doesn't match `predictions`, or if 716 either `metrics_collections` or `updates_collections` are not a list or 717 tuple. 718 RuntimeError: If eager execution is enabled. 719 """ 720 if context.executing_eagerly(): 721 raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 722 'is enabled.') 723 724 with variable_scope.variable_scope(name, 'auc', 725 (labels, predictions, weights)): 726 if curve != 'ROC' and curve != 'PR': 727 raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) 728 729 kepsilon = 1e-7 # To account for floating point imprecisions. 730 if thresholds is not None: 731 # If specified, use the supplied thresholds. 732 thresholds = sorted(thresholds) 733 num_thresholds = len(thresholds) + 2 734 else: 735 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 736 # (0, 1). 737 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 738 for i in range(num_thresholds - 2)] 739 740 # Add an endpoint "threshold" below zero and above one for either threshold 741 # method. 742 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 743 744 values, update_ops = _confusion_matrix_at_thresholds( 745 labels, predictions, thresholds, weights) 746 747 # Add epsilons to avoid dividing by 0. 748 epsilon = 1.0e-6 749 750 def interpolate_pr_auc(tp, fp, fn): 751 """Interpolation formula inspired by section 4 of (Davis et al., 2006). 752 753 Note here we derive & use a closed formula not present in the paper 754 - as follows: 755 Modeling all of TP (true positive weight), 756 FP (false positive weight) and their sum P = TP + FP (positive weight) 757 as varying linearly within each interval [A, B] between successive 758 thresholds, we get 759 Precision = (TP_A + slope * (P - P_A)) / P 760 with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). 761 The area within the interval is thus (slope / total_pos_weight) times 762 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 763 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 764 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 765 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 766 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 767 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 768 where dTP == TP_B - TP_A. 769 Note that when P_A == 0 the above calculation simplifies into 770 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 771 which is really equivalent to imputing constant precision throughout the 772 first bucket having >0 true positives. 773 774 Args: 775 tp: true positive counts 776 fp: false positive counts 777 fn: false negative counts 778 779 Returns: 780 pr_auc: an approximation of the area under the P-R curve. 781 782 References: 783 The Relationship Between Precision-Recall and ROC Curves: 784 [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874) 785 ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf)) 786 """ 787 dtp = tp[:num_thresholds - 1] - tp[1:] 788 p = tp + fp 789 prec_slope = math_ops.div_no_nan( 790 dtp, 791 math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0), 792 name='prec_slope') 793 intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) 794 safe_p_ratio = array_ops.where( 795 math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), 796 math_ops.div_no_nan( 797 p[:num_thresholds - 1], 798 math_ops.maximum(p[1:], 0), 799 name='recall_relative_ratio'), array_ops.ones_like(p[1:])) 800 return math_ops.reduce_sum( 801 math_ops.div_no_nan( 802 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 803 math_ops.maximum(tp[1:] + fn[1:], 0), 804 name='pr_auc_increment'), 805 name='interpolate_pr_auc') 806 807 def compute_auc(tp, fn, tn, fp, name): 808 """Computes the roc-auc or pr-auc based on confusion counts.""" 809 if curve == 'PR': 810 if summation_method == 'trapezoidal': 811 logging.warning( 812 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' 813 'please switch to "careful_interpolation" instead.') 814 elif summation_method == 'careful_interpolation': 815 # This one is a bit tricky and is handled separately. 816 return interpolate_pr_auc(tp, fp, fn) 817 rec = math_ops.divide(tp + epsilon, tp + fn + epsilon) 818 if curve == 'ROC': 819 fp_rate = math_ops.divide(fp, fp + tn + epsilon) 820 x = fp_rate 821 y = rec 822 else: # curve == 'PR'. 823 prec = math_ops.divide(tp + epsilon, tp + fp + epsilon) 824 x = rec 825 y = prec 826 if summation_method in ('trapezoidal', 'careful_interpolation'): 827 # Note that the case ('PR', 'careful_interpolation') has been handled 828 # above. 829 return math_ops.reduce_sum( 830 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 831 (y[:num_thresholds - 1] + y[1:]) / 2.), 832 name=name) 833 elif summation_method == 'minoring': 834 return math_ops.reduce_sum( 835 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 836 math_ops.minimum(y[:num_thresholds - 1], y[1:])), 837 name=name) 838 elif summation_method == 'majoring': 839 return math_ops.reduce_sum( 840 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 841 math_ops.maximum(y[:num_thresholds - 1], y[1:])), 842 name=name) 843 else: 844 raise ValueError('Invalid summation_method: %s' % summation_method) 845 846 # sum up the areas of all the trapeziums 847 def compute_auc_value(_, values): 848 return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'], 849 'value') 850 851 auc_value = _aggregate_across_replicas( 852 metrics_collections, compute_auc_value, values) 853 update_op = compute_auc(update_ops['tp'], update_ops['fn'], 854 update_ops['tn'], update_ops['fp'], 'update_op') 855 856 if updates_collections: 857 ops.add_to_collections(updates_collections, update_op) 858 859 return auc_value, update_op 860 861 862@tf_export(v1=['metrics.mean_absolute_error']) 863def mean_absolute_error(labels, 864 predictions, 865 weights=None, 866 metrics_collections=None, 867 updates_collections=None, 868 name=None): 869 """Computes the mean absolute error between the labels and predictions. 870 871 The `mean_absolute_error` function creates two local variables, 872 `total` and `count` that are used to compute the mean absolute error. This 873 average is weighted by `weights`, and it is ultimately returned as 874 `mean_absolute_error`: an idempotent operation that simply divides `total` by 875 `count`. 876 877 For estimation of the metric over a stream of data, the function creates an 878 `update_op` operation that updates these variables and returns the 879 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 880 absolute value of the differences between `predictions` and `labels`. Then 881 `update_op` increments `total` with the reduced sum of the product of 882 `weights` and `absolute_errors`, and it increments `count` with the reduced 883 sum of `weights` 884 885 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 886 887 Args: 888 labels: A `Tensor` of the same shape as `predictions`. 889 predictions: A `Tensor` of arbitrary shape. 890 weights: Optional `Tensor` whose rank is either 0, or the same rank as 891 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 892 be either `1`, or the same as the corresponding `labels` dimension). 893 metrics_collections: An optional list of collections that 894 `mean_absolute_error` should be added to. 895 updates_collections: An optional list of collections that `update_op` should 896 be added to. 897 name: An optional variable_scope name. 898 899 Returns: 900 mean_absolute_error: A `Tensor` representing the current mean, the value of 901 `total` divided by `count`. 902 update_op: An operation that increments the `total` and `count` variables 903 appropriately and whose value matches `mean_absolute_error`. 904 905 Raises: 906 ValueError: If `predictions` and `labels` have mismatched shapes, or if 907 `weights` is not `None` and its shape doesn't match `predictions`, or if 908 either `metrics_collections` or `updates_collections` are not a list or 909 tuple. 910 RuntimeError: If eager execution is enabled. 911 """ 912 if context.executing_eagerly(): 913 raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 914 'when eager execution is enabled.') 915 916 predictions, labels, weights = _remove_squeezable_dimensions( 917 predictions=predictions, labels=labels, weights=weights) 918 absolute_errors = math_ops.abs(predictions - labels) 919 return mean(absolute_errors, weights, metrics_collections, 920 updates_collections, name or 'mean_absolute_error') 921 922 923@tf_export(v1=['metrics.mean_cosine_distance']) 924def mean_cosine_distance(labels, 925 predictions, 926 dim, 927 weights=None, 928 metrics_collections=None, 929 updates_collections=None, 930 name=None): 931 """Computes the cosine distance between the labels and predictions. 932 933 The `mean_cosine_distance` function creates two local variables, 934 `total` and `count` that are used to compute the average cosine distance 935 between `predictions` and `labels`. This average is weighted by `weights`, 936 and it is ultimately returned as `mean_distance`, which is an idempotent 937 operation that simply divides `total` by `count`. 938 939 For estimation of the metric over a stream of data, the function creates an 940 `update_op` operation that updates these variables and returns the 941 `mean_distance`. 942 943 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 944 945 Args: 946 labels: A `Tensor` of arbitrary shape. 947 predictions: A `Tensor` of the same shape as `labels`. 948 dim: The dimension along which the cosine distance is computed. 949 weights: Optional `Tensor` whose rank is either 0, or the same rank as 950 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 951 be either `1`, or the same as the corresponding `labels` dimension). Also, 952 dimension `dim` must be `1`. 953 metrics_collections: An optional list of collections that the metric 954 value variable should be added to. 955 updates_collections: An optional list of collections that the metric update 956 ops should be added to. 957 name: An optional variable_scope name. 958 959 Returns: 960 mean_distance: A `Tensor` representing the current mean, the value of 961 `total` divided by `count`. 962 update_op: An operation that increments the `total` and `count` variables 963 appropriately. 964 965 Raises: 966 ValueError: If `predictions` and `labels` have mismatched shapes, or if 967 `weights` is not `None` and its shape doesn't match `predictions`, or if 968 either `metrics_collections` or `updates_collections` are not a list or 969 tuple. 970 RuntimeError: If eager execution is enabled. 971 """ 972 if context.executing_eagerly(): 973 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 974 'eager execution is enabled.') 975 976 predictions, labels, weights = _remove_squeezable_dimensions( 977 predictions=predictions, labels=labels, weights=weights) 978 radial_diffs = math_ops.multiply(predictions, labels) 979 radial_diffs = math_ops.reduce_sum( 980 radial_diffs, axis=[ 981 dim, 982 ], keepdims=True) 983 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or 984 'mean_cosine_distance') 985 mean_distance = math_ops.subtract(1.0, mean_distance) 986 update_op = math_ops.subtract(1.0, update_op) 987 988 if metrics_collections: 989 ops.add_to_collections(metrics_collections, mean_distance) 990 991 if updates_collections: 992 ops.add_to_collections(updates_collections, update_op) 993 994 return mean_distance, update_op 995 996 997@tf_export(v1=['metrics.mean_per_class_accuracy']) 998def mean_per_class_accuracy(labels, 999 predictions, 1000 num_classes, 1001 weights=None, 1002 metrics_collections=None, 1003 updates_collections=None, 1004 name=None): 1005 """Calculates the mean of the per-class accuracies. 1006 1007 Calculates the accuracy for each class, then takes the mean of that. 1008 1009 For estimation of the metric over a stream of data, the function creates an 1010 `update_op` operation that updates the accuracy of each class and returns 1011 them. 1012 1013 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1014 1015 Args: 1016 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1017 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1018 predictions: A `Tensor` of prediction results for semantic labels, whose 1019 shape is [batch size] and type `int32` or `int64`. The tensor will be 1020 flattened if its rank > 1. 1021 num_classes: The possible number of labels the prediction task can 1022 have. This value must be provided, since two variables with shape = 1023 [num_classes] will be allocated. 1024 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1025 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1026 be either `1`, or the same as the corresponding `labels` dimension). 1027 metrics_collections: An optional list of collections that 1028 `mean_per_class_accuracy' 1029 should be added to. 1030 updates_collections: An optional list of collections `update_op` should be 1031 added to. 1032 name: An optional variable_scope name. 1033 1034 Returns: 1035 mean_accuracy: A `Tensor` representing the mean per class accuracy. 1036 update_op: An operation that updates the accuracy tensor. 1037 1038 Raises: 1039 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1040 `weights` is not `None` and its shape doesn't match `predictions`, or if 1041 either `metrics_collections` or `updates_collections` are not a list or 1042 tuple. 1043 RuntimeError: If eager execution is enabled. 1044 """ 1045 if context.executing_eagerly(): 1046 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 1047 'when eager execution is enabled.') 1048 1049 with variable_scope.variable_scope(name, 'mean_accuracy', 1050 (predictions, labels, weights)): 1051 labels = math_ops.cast(labels, dtypes.int64) 1052 1053 # Flatten the input if its rank > 1. 1054 if labels.get_shape().ndims > 1: 1055 labels = array_ops.reshape(labels, [-1]) 1056 1057 if predictions.get_shape().ndims > 1: 1058 predictions = array_ops.reshape(predictions, [-1]) 1059 1060 # Check if shape is compatible. 1061 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1062 1063 total = metric_variable([num_classes], dtypes.float32, name='total') 1064 count = metric_variable([num_classes], dtypes.float32, name='count') 1065 1066 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) 1067 1068 if labels.dtype != predictions.dtype: 1069 predictions = math_ops.cast(predictions, labels.dtype) 1070 is_correct = math_ops.cast( 1071 math_ops.equal(predictions, labels), dtypes.float32) 1072 1073 if weights is not None: 1074 if weights.get_shape().ndims > 1: 1075 weights = array_ops.reshape(weights, [-1]) 1076 weights = math_ops.cast(weights, dtypes.float32) 1077 1078 is_correct *= weights 1079 ones *= weights 1080 1081 update_total_op = state_ops.scatter_add(total, labels, ones) 1082 update_count_op = state_ops.scatter_add(count, labels, is_correct) 1083 1084 def compute_mean_accuracy(_, count, total): 1085 per_class_accuracy = math_ops.div_no_nan( 1086 count, math_ops.maximum(total, 0), name=None) 1087 mean_accuracy_v = math_ops.reduce_mean( 1088 per_class_accuracy, name='mean_accuracy') 1089 return mean_accuracy_v 1090 1091 mean_accuracy_v = _aggregate_across_replicas( 1092 metrics_collections, compute_mean_accuracy, count, total) 1093 1094 update_op = math_ops.div_no_nan( 1095 update_count_op, math_ops.maximum(update_total_op, 0), name='update_op') 1096 if updates_collections: 1097 ops.add_to_collections(updates_collections, update_op) 1098 1099 return mean_accuracy_v, update_op 1100 1101 1102@tf_export(v1=['metrics.mean_iou']) 1103def mean_iou(labels, 1104 predictions, 1105 num_classes, 1106 weights=None, 1107 metrics_collections=None, 1108 updates_collections=None, 1109 name=None): 1110 """Calculate per-step mean Intersection-Over-Union (mIOU). 1111 1112 Mean Intersection-Over-Union is a common evaluation metric for 1113 semantic image segmentation, which first computes the IOU for each 1114 semantic class and then computes the average over classes. 1115 IOU is defined as follows: 1116 IOU = true_positive / (true_positive + false_positive + false_negative). 1117 The predictions are accumulated in a confusion matrix, weighted by `weights`, 1118 and mIOU is then calculated from it. 1119 1120 For estimation of the metric over a stream of data, the function creates an 1121 `update_op` operation that updates these variables and returns the `mean_iou`. 1122 1123 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1124 1125 Args: 1126 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1127 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1128 predictions: A `Tensor` of prediction results for semantic labels, whose 1129 shape is [batch size] and type `int32` or `int64`. The tensor will be 1130 flattened if its rank > 1. 1131 num_classes: The possible number of labels the prediction task can 1132 have. This value must be provided, since a confusion matrix of 1133 dimension = [num_classes, num_classes] will be allocated. 1134 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1135 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1136 be either `1`, or the same as the corresponding `labels` dimension). 1137 metrics_collections: An optional list of collections that `mean_iou` 1138 should be added to. 1139 updates_collections: An optional list of collections `update_op` should be 1140 added to. 1141 name: An optional variable_scope name. 1142 1143 Returns: 1144 mean_iou: A `Tensor` representing the mean intersection-over-union. 1145 update_op: An operation that increments the confusion matrix. 1146 1147 Raises: 1148 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1149 `weights` is not `None` and its shape doesn't match `predictions`, or if 1150 either `metrics_collections` or `updates_collections` are not a list or 1151 tuple. 1152 RuntimeError: If eager execution is enabled. 1153 """ 1154 if context.executing_eagerly(): 1155 raise RuntimeError('tf.metrics.mean_iou is not supported when ' 1156 'eager execution is enabled.') 1157 1158 with variable_scope.variable_scope(name, 'mean_iou', 1159 (predictions, labels, weights)): 1160 # Check if shape is compatible. 1161 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1162 1163 total_cm, update_op = _streaming_confusion_matrix(labels, predictions, 1164 num_classes, weights) 1165 1166 def compute_mean_iou(_, total_cm): 1167 """Compute the mean intersection-over-union via the confusion matrix.""" 1168 sum_over_row = math_ops.cast( 1169 math_ops.reduce_sum(total_cm, 0), dtypes.float32) 1170 sum_over_col = math_ops.cast( 1171 math_ops.reduce_sum(total_cm, 1), dtypes.float32) 1172 cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32) 1173 denominator = sum_over_row + sum_over_col - cm_diag 1174 1175 # The mean is only computed over classes that appear in the 1176 # label or prediction tensor. If the denominator is 0, we need to 1177 # ignore the class. 1178 num_valid_entries = math_ops.reduce_sum( 1179 math_ops.cast( 1180 math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) 1181 1182 # If the value of the denominator is 0, set it to 1 to avoid 1183 # zero division. 1184 denominator = array_ops.where( 1185 math_ops.greater(denominator, 0), denominator, 1186 array_ops.ones_like(denominator)) 1187 iou = math_ops.divide(cm_diag, denominator) 1188 1189 # If the number of valid entries is 0 (no classes) we return 0. 1190 result = array_ops.where( 1191 math_ops.greater(num_valid_entries, 0), 1192 math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0) 1193 return result 1194 1195 # TODO(priyag): Use outside_compilation if in TPU context. 1196 mean_iou_v = _aggregate_across_replicas( 1197 metrics_collections, compute_mean_iou, total_cm) 1198 1199 if updates_collections: 1200 ops.add_to_collections(updates_collections, update_op) 1201 1202 return mean_iou_v, update_op 1203 1204 1205@tf_export(v1=['metrics.mean_relative_error']) 1206def mean_relative_error(labels, 1207 predictions, 1208 normalizer, 1209 weights=None, 1210 metrics_collections=None, 1211 updates_collections=None, 1212 name=None): 1213 """Computes the mean relative error by normalizing with the given values. 1214 1215 The `mean_relative_error` function creates two local variables, 1216 `total` and `count` that are used to compute the mean relative absolute error. 1217 This average is weighted by `weights`, and it is ultimately returned as 1218 `mean_relative_error`: an idempotent operation that simply divides `total` by 1219 `count`. 1220 1221 For estimation of the metric over a stream of data, the function creates an 1222 `update_op` operation that updates these variables and returns the 1223 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1224 absolute value of the differences between `predictions` and `labels` by the 1225 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1226 product of `weights` and `relative_errors`, and it increments `count` with the 1227 reduced sum of `weights`. 1228 1229 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1230 1231 Args: 1232 labels: A `Tensor` of the same shape as `predictions`. 1233 predictions: A `Tensor` of arbitrary shape. 1234 normalizer: A `Tensor` of the same shape as `predictions`. 1235 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1236 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1237 be either `1`, or the same as the corresponding `labels` dimension). 1238 metrics_collections: An optional list of collections that 1239 `mean_relative_error` should be added to. 1240 updates_collections: An optional list of collections that `update_op` should 1241 be added to. 1242 name: An optional variable_scope name. 1243 1244 Returns: 1245 mean_relative_error: A `Tensor` representing the current mean, the value of 1246 `total` divided by `count`. 1247 update_op: An operation that increments the `total` and `count` variables 1248 appropriately and whose value matches `mean_relative_error`. 1249 1250 Raises: 1251 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1252 `weights` is not `None` and its shape doesn't match `predictions`, or if 1253 either `metrics_collections` or `updates_collections` are not a list or 1254 tuple. 1255 RuntimeError: If eager execution is enabled. 1256 """ 1257 if context.executing_eagerly(): 1258 raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 1259 'eager execution is enabled.') 1260 1261 predictions, labels, weights = _remove_squeezable_dimensions( 1262 predictions=predictions, labels=labels, weights=weights) 1263 1264 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions( 1265 predictions, normalizer) 1266 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1267 relative_errors = array_ops.where( 1268 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), 1269 math_ops.divide(math_ops.abs(labels - predictions), normalizer)) 1270 return mean(relative_errors, weights, metrics_collections, 1271 updates_collections, name or 'mean_relative_error') 1272 1273 1274@tf_export(v1=['metrics.mean_squared_error']) 1275def mean_squared_error(labels, 1276 predictions, 1277 weights=None, 1278 metrics_collections=None, 1279 updates_collections=None, 1280 name=None): 1281 """Computes the mean squared error between the labels and predictions. 1282 1283 The `mean_squared_error` function creates two local variables, 1284 `total` and `count` that are used to compute the mean squared error. 1285 This average is weighted by `weights`, and it is ultimately returned as 1286 `mean_squared_error`: an idempotent operation that simply divides `total` by 1287 `count`. 1288 1289 For estimation of the metric over a stream of data, the function creates an 1290 `update_op` operation that updates these variables and returns the 1291 `mean_squared_error`. Internally, a `squared_error` operation computes the 1292 element-wise square of the difference between `predictions` and `labels`. Then 1293 `update_op` increments `total` with the reduced sum of the product of 1294 `weights` and `squared_error`, and it increments `count` with the reduced sum 1295 of `weights`. 1296 1297 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1298 1299 Args: 1300 labels: A `Tensor` of the same shape as `predictions`. 1301 predictions: A `Tensor` of arbitrary shape. 1302 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1303 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1304 be either `1`, or the same as the corresponding `labels` dimension). 1305 metrics_collections: An optional list of collections that 1306 `mean_squared_error` should be added to. 1307 updates_collections: An optional list of collections that `update_op` should 1308 be added to. 1309 name: An optional variable_scope name. 1310 1311 Returns: 1312 mean_squared_error: A `Tensor` representing the current mean, the value of 1313 `total` divided by `count`. 1314 update_op: An operation that increments the `total` and `count` variables 1315 appropriately and whose value matches `mean_squared_error`. 1316 1317 Raises: 1318 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1319 `weights` is not `None` and its shape doesn't match `predictions`, or if 1320 either `metrics_collections` or `updates_collections` are not a list or 1321 tuple. 1322 RuntimeError: If eager execution is enabled. 1323 """ 1324 if context.executing_eagerly(): 1325 raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 1326 'eager execution is enabled.') 1327 1328 predictions, labels, weights = _remove_squeezable_dimensions( 1329 predictions=predictions, labels=labels, weights=weights) 1330 squared_error = math_ops.squared_difference(labels, predictions) 1331 return mean(squared_error, weights, metrics_collections, updates_collections, 1332 name or 'mean_squared_error') 1333 1334 1335@tf_export(v1=['metrics.mean_tensor']) 1336def mean_tensor(values, 1337 weights=None, 1338 metrics_collections=None, 1339 updates_collections=None, 1340 name=None): 1341 """Computes the element-wise (weighted) mean of the given tensors. 1342 1343 In contrast to the `mean` function which returns a scalar with the 1344 mean, this function returns an average tensor with the same shape as the 1345 input tensors. 1346 1347 The `mean_tensor` function creates two local variables, 1348 `total_tensor` and `count_tensor` that are used to compute the average of 1349 `values`. This average is ultimately returned as `mean` which is an idempotent 1350 operation that simply divides `total` by `count`. 1351 1352 For estimation of the metric over a stream of data, the function creates an 1353 `update_op` operation that updates these variables and returns the `mean`. 1354 `update_op` increments `total` with the reduced sum of the product of `values` 1355 and `weights`, and it increments `count` with the reduced sum of `weights`. 1356 1357 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1358 1359 Args: 1360 values: A `Tensor` of arbitrary dimensions. 1361 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1362 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1363 be either `1`, or the same as the corresponding `values` dimension). 1364 metrics_collections: An optional list of collections that `mean` 1365 should be added to. 1366 updates_collections: An optional list of collections that `update_op` 1367 should be added to. 1368 name: An optional variable_scope name. 1369 1370 Returns: 1371 mean: A float `Tensor` representing the current mean, the value of `total` 1372 divided by `count`. 1373 update_op: An operation that increments the `total` and `count` variables 1374 appropriately and whose value matches `mean_value`. 1375 1376 Raises: 1377 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1378 or if either `metrics_collections` or `updates_collections` are not a list 1379 or tuple. 1380 RuntimeError: If eager execution is enabled. 1381 """ 1382 if context.executing_eagerly(): 1383 raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 1384 'eager execution is enabled.') 1385 1386 with variable_scope.variable_scope(name, 'mean', (values, weights)): 1387 values = math_ops.cast(values, dtypes.float32) 1388 total = metric_variable( 1389 values.get_shape(), dtypes.float32, name='total_tensor') 1390 count = metric_variable( 1391 values.get_shape(), dtypes.float32, name='count_tensor') 1392 1393 num_values = array_ops.ones_like(values) 1394 if weights is not None: 1395 values, _, weights = _remove_squeezable_dimensions( 1396 predictions=values, labels=None, weights=weights) 1397 weights = weights_broadcast_ops.broadcast_weights( 1398 math_ops.cast(weights, dtypes.float32), values) 1399 values = math_ops.multiply(values, weights) 1400 num_values = math_ops.multiply(num_values, weights) 1401 1402 update_total_op = state_ops.assign_add(total, values) 1403 with ops.control_dependencies([values]): 1404 update_count_op = state_ops.assign_add(count, num_values) 1405 1406 compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda 1407 t, math_ops.maximum(c, 0), name='value') 1408 1409 mean_t = _aggregate_across_replicas( 1410 metrics_collections, compute_mean, total, count) 1411 1412 update_op = math_ops.div_no_nan( 1413 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 1414 if updates_collections: 1415 ops.add_to_collections(updates_collections, update_op) 1416 1417 return mean_t, update_op 1418 1419 1420@tf_export(v1=['metrics.percentage_below']) 1421def percentage_below(values, 1422 threshold, 1423 weights=None, 1424 metrics_collections=None, 1425 updates_collections=None, 1426 name=None): 1427 """Computes the percentage of values less than the given threshold. 1428 1429 The `percentage_below` function creates two local variables, 1430 `total` and `count` that are used to compute the percentage of `values` that 1431 fall below `threshold`. This rate is weighted by `weights`, and it is 1432 ultimately returned as `percentage` which is an idempotent operation that 1433 simply divides `total` by `count`. 1434 1435 For estimation of the metric over a stream of data, the function creates an 1436 `update_op` operation that updates these variables and returns the 1437 `percentage`. 1438 1439 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1440 1441 Args: 1442 values: A numeric `Tensor` of arbitrary size. 1443 threshold: A scalar threshold. 1444 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1445 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1446 be either `1`, or the same as the corresponding `values` dimension). 1447 metrics_collections: An optional list of collections that the metric 1448 value variable should be added to. 1449 updates_collections: An optional list of collections that the metric update 1450 ops should be added to. 1451 name: An optional variable_scope name. 1452 1453 Returns: 1454 percentage: A `Tensor` representing the current mean, the value of `total` 1455 divided by `count`. 1456 update_op: An operation that increments the `total` and `count` variables 1457 appropriately. 1458 1459 Raises: 1460 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1461 or if either `metrics_collections` or `updates_collections` are not a list 1462 or tuple. 1463 RuntimeError: If eager execution is enabled. 1464 """ 1465 if context.executing_eagerly(): 1466 raise RuntimeError('tf.metrics.percentage_below is not supported when ' 1467 'eager execution is enabled.') 1468 1469 is_below_threshold = math_ops.cast( 1470 math_ops.less(values, threshold), dtypes.float32) 1471 return mean(is_below_threshold, weights, metrics_collections, 1472 updates_collections, name or 'percentage_below_threshold') 1473 1474 1475def _count_condition(values, 1476 weights=None, 1477 metrics_collections=None, 1478 updates_collections=None): 1479 """Sums the weights of cases where the given values are True. 1480 1481 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1482 1483 Args: 1484 values: A `bool` `Tensor` of arbitrary size. 1485 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1486 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1487 be either `1`, or the same as the corresponding `values` dimension). 1488 metrics_collections: An optional list of collections that the metric 1489 value variable should be added to. 1490 updates_collections: An optional list of collections that the metric update 1491 ops should be added to. 1492 1493 Returns: 1494 value_tensor: A `Tensor` representing the current value of the metric. 1495 update_op: An operation that accumulates the error from a batch of data. 1496 1497 Raises: 1498 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1499 or if either `metrics_collections` or `updates_collections` are not a list 1500 or tuple. 1501 """ 1502 check_ops.assert_type(values, dtypes.bool) 1503 count = metric_variable([], dtypes.float32, name='count') 1504 1505 values = math_ops.cast(values, dtypes.float32) 1506 if weights is not None: 1507 with ops.control_dependencies((check_ops.assert_rank_in( 1508 weights, (0, array_ops.rank(values))),)): 1509 weights = math_ops.cast(weights, dtypes.float32) 1510 values = math_ops.multiply(values, weights) 1511 1512 value_tensor = _aggregate_variable(count, metrics_collections) 1513 1514 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 1515 if updates_collections: 1516 ops.add_to_collections(updates_collections, update_op) 1517 1518 return value_tensor, update_op 1519 1520 1521@tf_export(v1=['metrics.false_negatives']) 1522def false_negatives(labels, 1523 predictions, 1524 weights=None, 1525 metrics_collections=None, 1526 updates_collections=None, 1527 name=None): 1528 """Computes the total number of false negatives. 1529 1530 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1531 1532 Args: 1533 labels: The ground truth values, a `Tensor` whose dimensions must match 1534 `predictions`. Will be cast to `bool`. 1535 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1536 be cast to `bool`. 1537 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1538 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1539 be either `1`, or the same as the corresponding `labels` dimension). 1540 metrics_collections: An optional list of collections that the metric 1541 value variable should be added to. 1542 updates_collections: An optional list of collections that the metric update 1543 ops should be added to. 1544 name: An optional variable_scope name. 1545 1546 Returns: 1547 value_tensor: A `Tensor` representing the current value of the metric. 1548 update_op: An operation that accumulates the error from a batch of data. 1549 1550 Raises: 1551 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1552 or if either `metrics_collections` or `updates_collections` are not a list 1553 or tuple. 1554 RuntimeError: If eager execution is enabled. 1555 """ 1556 if context.executing_eagerly(): 1557 raise RuntimeError('tf.metrics.false_negatives is not supported when ' 1558 'eager execution is enabled.') 1559 1560 with variable_scope.variable_scope(name, 'false_negatives', 1561 (predictions, labels, weights)): 1562 1563 predictions, labels, weights = _remove_squeezable_dimensions( 1564 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1565 labels=math_ops.cast(labels, dtype=dtypes.bool), 1566 weights=weights) 1567 is_false_negative = math_ops.logical_and( 1568 math_ops.equal(labels, True), math_ops.equal(predictions, False)) 1569 return _count_condition(is_false_negative, weights, metrics_collections, 1570 updates_collections) 1571 1572 1573@tf_export(v1=['metrics.false_negatives_at_thresholds']) 1574def false_negatives_at_thresholds(labels, 1575 predictions, 1576 thresholds, 1577 weights=None, 1578 metrics_collections=None, 1579 updates_collections=None, 1580 name=None): 1581 """Computes false negatives at provided threshold values. 1582 1583 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1584 1585 Args: 1586 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1587 `bool`. 1588 predictions: A floating point `Tensor` of arbitrary shape and whose values 1589 are in the range `[0, 1]`. 1590 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1591 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1592 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1593 be either `1`, or the same as the corresponding `labels` dimension). 1594 metrics_collections: An optional list of collections that `false_negatives` 1595 should be added to. 1596 updates_collections: An optional list of collections that `update_op` should 1597 be added to. 1598 name: An optional variable_scope name. 1599 1600 Returns: 1601 false_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1602 update_op: An operation that updates the `false_negatives` variable and 1603 returns its current value. 1604 1605 Raises: 1606 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1607 `weights` is not `None` and its shape doesn't match `predictions`, or if 1608 either `metrics_collections` or `updates_collections` are not a list or 1609 tuple. 1610 RuntimeError: If eager execution is enabled. 1611 """ 1612 if context.executing_eagerly(): 1613 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 1614 'supported when eager execution is enabled.') 1615 1616 with variable_scope.variable_scope(name, 'false_negatives', 1617 (predictions, labels, weights)): 1618 values, update_ops = _confusion_matrix_at_thresholds( 1619 labels, predictions, thresholds, weights=weights, includes=('fn',)) 1620 1621 fn_value = _aggregate_variable(values['fn'], metrics_collections) 1622 1623 if updates_collections: 1624 ops.add_to_collections(updates_collections, update_ops['fn']) 1625 1626 return fn_value, update_ops['fn'] 1627 1628 1629@tf_export(v1=['metrics.false_positives']) 1630def false_positives(labels, 1631 predictions, 1632 weights=None, 1633 metrics_collections=None, 1634 updates_collections=None, 1635 name=None): 1636 """Sum the weights of false positives. 1637 1638 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1639 1640 Args: 1641 labels: The ground truth values, a `Tensor` whose dimensions must match 1642 `predictions`. Will be cast to `bool`. 1643 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1644 be cast to `bool`. 1645 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1646 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1647 be either `1`, or the same as the corresponding `labels` dimension). 1648 metrics_collections: An optional list of collections that the metric 1649 value variable should be added to. 1650 updates_collections: An optional list of collections that the metric update 1651 ops should be added to. 1652 name: An optional variable_scope name. 1653 1654 Returns: 1655 value_tensor: A `Tensor` representing the current value of the metric. 1656 update_op: An operation that accumulates the error from a batch of data. 1657 1658 Raises: 1659 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1660 `weights` is not `None` and its shape doesn't match `predictions`, or if 1661 either `metrics_collections` or `updates_collections` are not a list or 1662 tuple. 1663 RuntimeError: If eager execution is enabled. 1664 """ 1665 if context.executing_eagerly(): 1666 raise RuntimeError('tf.metrics.false_positives is not supported when ' 1667 'eager execution is enabled.') 1668 1669 with variable_scope.variable_scope(name, 'false_positives', 1670 (predictions, labels, weights)): 1671 1672 predictions, labels, weights = _remove_squeezable_dimensions( 1673 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1674 labels=math_ops.cast(labels, dtype=dtypes.bool), 1675 weights=weights) 1676 is_false_positive = math_ops.logical_and( 1677 math_ops.equal(labels, False), math_ops.equal(predictions, True)) 1678 return _count_condition(is_false_positive, weights, metrics_collections, 1679 updates_collections) 1680 1681 1682@tf_export(v1=['metrics.false_positives_at_thresholds']) 1683def false_positives_at_thresholds(labels, 1684 predictions, 1685 thresholds, 1686 weights=None, 1687 metrics_collections=None, 1688 updates_collections=None, 1689 name=None): 1690 """Computes false positives at provided threshold values. 1691 1692 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1693 1694 Args: 1695 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1696 `bool`. 1697 predictions: A floating point `Tensor` of arbitrary shape and whose values 1698 are in the range `[0, 1]`. 1699 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1700 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1701 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1702 be either `1`, or the same as the corresponding `labels` dimension). 1703 metrics_collections: An optional list of collections that `false_positives` 1704 should be added to. 1705 updates_collections: An optional list of collections that `update_op` should 1706 be added to. 1707 name: An optional variable_scope name. 1708 1709 Returns: 1710 false_positives: A float `Tensor` of shape `[len(thresholds)]`. 1711 update_op: An operation that updates the `false_positives` variable and 1712 returns its current value. 1713 1714 Raises: 1715 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1716 `weights` is not `None` and its shape doesn't match `predictions`, or if 1717 either `metrics_collections` or `updates_collections` are not a list or 1718 tuple. 1719 RuntimeError: If eager execution is enabled. 1720 """ 1721 if context.executing_eagerly(): 1722 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 1723 'supported when eager execution is enabled.') 1724 1725 with variable_scope.variable_scope(name, 'false_positives', 1726 (predictions, labels, weights)): 1727 values, update_ops = _confusion_matrix_at_thresholds( 1728 labels, predictions, thresholds, weights=weights, includes=('fp',)) 1729 1730 fp_value = _aggregate_variable(values['fp'], metrics_collections) 1731 1732 if updates_collections: 1733 ops.add_to_collections(updates_collections, update_ops['fp']) 1734 1735 return fp_value, update_ops['fp'] 1736 1737 1738@tf_export(v1=['metrics.true_negatives']) 1739def true_negatives(labels, 1740 predictions, 1741 weights=None, 1742 metrics_collections=None, 1743 updates_collections=None, 1744 name=None): 1745 """Sum the weights of true_negatives. 1746 1747 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1748 1749 Args: 1750 labels: The ground truth values, a `Tensor` whose dimensions must match 1751 `predictions`. Will be cast to `bool`. 1752 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1753 be cast to `bool`. 1754 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1755 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1756 be either `1`, or the same as the corresponding `labels` dimension). 1757 metrics_collections: An optional list of collections that the metric 1758 value variable should be added to. 1759 updates_collections: An optional list of collections that the metric update 1760 ops should be added to. 1761 name: An optional variable_scope name. 1762 1763 Returns: 1764 value_tensor: A `Tensor` representing the current value of the metric. 1765 update_op: An operation that accumulates the error from a batch of data. 1766 1767 Raises: 1768 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1769 `weights` is not `None` and its shape doesn't match `predictions`, or if 1770 either `metrics_collections` or `updates_collections` are not a list or 1771 tuple. 1772 RuntimeError: If eager execution is enabled. 1773 """ 1774 if context.executing_eagerly(): 1775 raise RuntimeError('tf.metrics.true_negatives is not ' 1776 'supported when eager execution is enabled.') 1777 1778 with variable_scope.variable_scope(name, 'true_negatives', 1779 (predictions, labels, weights)): 1780 1781 predictions, labels, weights = _remove_squeezable_dimensions( 1782 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1783 labels=math_ops.cast(labels, dtype=dtypes.bool), 1784 weights=weights) 1785 is_true_negative = math_ops.logical_and( 1786 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 1787 return _count_condition(is_true_negative, weights, metrics_collections, 1788 updates_collections) 1789 1790 1791@tf_export(v1=['metrics.true_negatives_at_thresholds']) 1792def true_negatives_at_thresholds(labels, 1793 predictions, 1794 thresholds, 1795 weights=None, 1796 metrics_collections=None, 1797 updates_collections=None, 1798 name=None): 1799 """Computes true negatives at provided threshold values. 1800 1801 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1802 1803 Args: 1804 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1805 `bool`. 1806 predictions: A floating point `Tensor` of arbitrary shape and whose values 1807 are in the range `[0, 1]`. 1808 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1809 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1810 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1811 be either `1`, or the same as the corresponding `labels` dimension). 1812 metrics_collections: An optional list of collections that `true_negatives` 1813 should be added to. 1814 updates_collections: An optional list of collections that `update_op` should 1815 be added to. 1816 name: An optional variable_scope name. 1817 1818 Returns: 1819 true_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1820 update_op: An operation that updates the `true_negatives` variable and 1821 returns its current value. 1822 1823 Raises: 1824 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1825 `weights` is not `None` and its shape doesn't match `predictions`, or if 1826 either `metrics_collections` or `updates_collections` are not a list or 1827 tuple. 1828 RuntimeError: If eager execution is enabled. 1829 """ 1830 if context.executing_eagerly(): 1831 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 1832 'supported when eager execution is enabled.') 1833 1834 with variable_scope.variable_scope(name, 'true_negatives', 1835 (predictions, labels, weights)): 1836 values, update_ops = _confusion_matrix_at_thresholds( 1837 labels, predictions, thresholds, weights=weights, includes=('tn',)) 1838 1839 tn_value = _aggregate_variable(values['tn'], metrics_collections) 1840 1841 if updates_collections: 1842 ops.add_to_collections(updates_collections, update_ops['tn']) 1843 1844 return tn_value, update_ops['tn'] 1845 1846 1847@tf_export(v1=['metrics.true_positives']) 1848def true_positives(labels, 1849 predictions, 1850 weights=None, 1851 metrics_collections=None, 1852 updates_collections=None, 1853 name=None): 1854 """Sum the weights of true_positives. 1855 1856 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1857 1858 Args: 1859 labels: The ground truth values, a `Tensor` whose dimensions must match 1860 `predictions`. Will be cast to `bool`. 1861 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1862 be cast to `bool`. 1863 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1864 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1865 be either `1`, or the same as the corresponding `labels` dimension). 1866 metrics_collections: An optional list of collections that the metric 1867 value variable should be added to. 1868 updates_collections: An optional list of collections that the metric update 1869 ops should be added to. 1870 name: An optional variable_scope name. 1871 1872 Returns: 1873 value_tensor: A `Tensor` representing the current value of the metric. 1874 update_op: An operation that accumulates the error from a batch of data. 1875 1876 Raises: 1877 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1878 `weights` is not `None` and its shape doesn't match `predictions`, or if 1879 either `metrics_collections` or `updates_collections` are not a list or 1880 tuple. 1881 RuntimeError: If eager execution is enabled. 1882 """ 1883 if context.executing_eagerly(): 1884 raise RuntimeError('tf.metrics.true_positives is not ' 1885 'supported when eager execution is enabled.') 1886 1887 with variable_scope.variable_scope(name, 'true_positives', 1888 (predictions, labels, weights)): 1889 1890 predictions, labels, weights = _remove_squeezable_dimensions( 1891 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1892 labels=math_ops.cast(labels, dtype=dtypes.bool), 1893 weights=weights) 1894 is_true_positive = math_ops.logical_and( 1895 math_ops.equal(labels, True), math_ops.equal(predictions, True)) 1896 return _count_condition(is_true_positive, weights, metrics_collections, 1897 updates_collections) 1898 1899 1900@tf_export(v1=['metrics.true_positives_at_thresholds']) 1901def true_positives_at_thresholds(labels, 1902 predictions, 1903 thresholds, 1904 weights=None, 1905 metrics_collections=None, 1906 updates_collections=None, 1907 name=None): 1908 """Computes true positives at provided threshold values. 1909 1910 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1911 1912 Args: 1913 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1914 `bool`. 1915 predictions: A floating point `Tensor` of arbitrary shape and whose values 1916 are in the range `[0, 1]`. 1917 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1918 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1919 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1920 be either `1`, or the same as the corresponding `labels` dimension). 1921 metrics_collections: An optional list of collections that `true_positives` 1922 should be added to. 1923 updates_collections: An optional list of collections that `update_op` should 1924 be added to. 1925 name: An optional variable_scope name. 1926 1927 Returns: 1928 true_positives: A float `Tensor` of shape `[len(thresholds)]`. 1929 update_op: An operation that updates the `true_positives` variable and 1930 returns its current value. 1931 1932 Raises: 1933 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1934 `weights` is not `None` and its shape doesn't match `predictions`, or if 1935 either `metrics_collections` or `updates_collections` are not a list or 1936 tuple. 1937 RuntimeError: If eager execution is enabled. 1938 """ 1939 if context.executing_eagerly(): 1940 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 1941 'supported when eager execution is enabled.') 1942 1943 with variable_scope.variable_scope(name, 'true_positives', 1944 (predictions, labels, weights)): 1945 values, update_ops = _confusion_matrix_at_thresholds( 1946 labels, predictions, thresholds, weights=weights, includes=('tp',)) 1947 1948 tp_value = _aggregate_variable(values['tp'], metrics_collections) 1949 1950 if updates_collections: 1951 ops.add_to_collections(updates_collections, update_ops['tp']) 1952 1953 return tp_value, update_ops['tp'] 1954 1955 1956@tf_export(v1=['metrics.precision']) 1957def precision(labels, 1958 predictions, 1959 weights=None, 1960 metrics_collections=None, 1961 updates_collections=None, 1962 name=None): 1963 """Computes the precision of the predictions with respect to the labels. 1964 1965 The `precision` function creates two local variables, 1966 `true_positives` and `false_positives`, that are used to compute the 1967 precision. This value is ultimately returned as `precision`, an idempotent 1968 operation that simply divides `true_positives` by the sum of `true_positives` 1969 and `false_positives`. 1970 1971 For estimation of the metric over a stream of data, the function creates an 1972 `update_op` operation that updates these variables and returns the 1973 `precision`. `update_op` weights each prediction by the corresponding value in 1974 `weights`. 1975 1976 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1977 1978 Args: 1979 labels: The ground truth values, a `Tensor` whose dimensions must match 1980 `predictions`. Will be cast to `bool`. 1981 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1982 be cast to `bool`. 1983 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1984 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1985 be either `1`, or the same as the corresponding `labels` dimension). 1986 metrics_collections: An optional list of collections that `precision` should 1987 be added to. 1988 updates_collections: An optional list of collections that `update_op` should 1989 be added to. 1990 name: An optional variable_scope name. 1991 1992 Returns: 1993 precision: Scalar float `Tensor` with the value of `true_positives` 1994 divided by the sum of `true_positives` and `false_positives`. 1995 update_op: `Operation` that increments `true_positives` and 1996 `false_positives` variables appropriately and whose value matches 1997 `precision`. 1998 1999 Raises: 2000 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2001 `weights` is not `None` and its shape doesn't match `predictions`, or if 2002 either `metrics_collections` or `updates_collections` are not a list or 2003 tuple. 2004 RuntimeError: If eager execution is enabled. 2005 """ 2006 if context.executing_eagerly(): 2007 raise RuntimeError('tf.metrics.precision is not ' 2008 'supported when eager execution is enabled.') 2009 2010 with variable_scope.variable_scope(name, 'precision', 2011 (predictions, labels, weights)): 2012 2013 predictions, labels, weights = _remove_squeezable_dimensions( 2014 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2015 labels=math_ops.cast(labels, dtype=dtypes.bool), 2016 weights=weights) 2017 2018 true_p, true_positives_update_op = true_positives( 2019 labels, 2020 predictions, 2021 weights, 2022 metrics_collections=None, 2023 updates_collections=None, 2024 name=None) 2025 false_p, false_positives_update_op = false_positives( 2026 labels, 2027 predictions, 2028 weights, 2029 metrics_collections=None, 2030 updates_collections=None, 2031 name=None) 2032 2033 def compute_precision(tp, fp, name): 2034 return array_ops.where( 2035 math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name) 2036 2037 def once_across_replicas(_, true_p, false_p): 2038 return compute_precision(true_p, false_p, 'value') 2039 2040 p = _aggregate_across_replicas(metrics_collections, once_across_replicas, 2041 true_p, false_p) 2042 2043 update_op = compute_precision(true_positives_update_op, 2044 false_positives_update_op, 'update_op') 2045 if updates_collections: 2046 ops.add_to_collections(updates_collections, update_op) 2047 2048 return p, update_op 2049 2050 2051@tf_export(v1=['metrics.precision_at_thresholds']) 2052def precision_at_thresholds(labels, 2053 predictions, 2054 thresholds, 2055 weights=None, 2056 metrics_collections=None, 2057 updates_collections=None, 2058 name=None): 2059 """Computes precision values for different `thresholds` on `predictions`. 2060 2061 The `precision_at_thresholds` function creates four local variables, 2062 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2063 for various values of thresholds. `precision[i]` is defined as the total 2064 weight of values in `predictions` above `thresholds[i]` whose corresponding 2065 entry in `labels` is `True`, divided by the total weight of values in 2066 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 2067 false_positives[i])`). 2068 2069 For estimation of the metric over a stream of data, the function creates an 2070 `update_op` operation that updates these variables and returns the 2071 `precision`. 2072 2073 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2074 2075 Args: 2076 labels: The ground truth values, a `Tensor` whose dimensions must match 2077 `predictions`. Will be cast to `bool`. 2078 predictions: A floating point `Tensor` of arbitrary shape and whose values 2079 are in the range `[0, 1]`. 2080 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2081 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2082 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2083 be either `1`, or the same as the corresponding `labels` dimension). 2084 metrics_collections: An optional list of collections that `auc` should be 2085 added to. 2086 updates_collections: An optional list of collections that `update_op` should 2087 be added to. 2088 name: An optional variable_scope name. 2089 2090 Returns: 2091 precision: A float `Tensor` of shape `[len(thresholds)]`. 2092 update_op: An operation that increments the `true_positives`, 2093 `true_negatives`, `false_positives` and `false_negatives` variables that 2094 are used in the computation of `precision`. 2095 2096 Raises: 2097 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2098 `weights` is not `None` and its shape doesn't match `predictions`, or if 2099 either `metrics_collections` or `updates_collections` are not a list or 2100 tuple. 2101 RuntimeError: If eager execution is enabled. 2102 """ 2103 if context.executing_eagerly(): 2104 raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 2105 'supported when eager execution is enabled.') 2106 2107 with variable_scope.variable_scope(name, 'precision_at_thresholds', 2108 (predictions, labels, weights)): 2109 values, update_ops = _confusion_matrix_at_thresholds( 2110 labels, predictions, thresholds, weights, includes=('tp', 'fp')) 2111 2112 # Avoid division by zero. 2113 epsilon = 1e-7 2114 2115 def compute_precision(tp, fp, name): 2116 return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name) 2117 2118 def precision_across_replicas(_, values): 2119 return compute_precision(values['tp'], values['fp'], 'value') 2120 2121 prec = _aggregate_across_replicas( 2122 metrics_collections, precision_across_replicas, values) 2123 2124 update_op = compute_precision(update_ops['tp'], update_ops['fp'], 2125 'update_op') 2126 if updates_collections: 2127 ops.add_to_collections(updates_collections, update_op) 2128 2129 return prec, update_op 2130 2131 2132@tf_export(v1=['metrics.recall']) 2133def recall(labels, 2134 predictions, 2135 weights=None, 2136 metrics_collections=None, 2137 updates_collections=None, 2138 name=None): 2139 """Computes the recall of the predictions with respect to the labels. 2140 2141 The `recall` function creates two local variables, `true_positives` 2142 and `false_negatives`, that are used to compute the recall. This value is 2143 ultimately returned as `recall`, an idempotent operation that simply divides 2144 `true_positives` by the sum of `true_positives` and `false_negatives`. 2145 2146 For estimation of the metric over a stream of data, the function creates an 2147 `update_op` that updates these variables and returns the `recall`. `update_op` 2148 weights each prediction by the corresponding value in `weights`. 2149 2150 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2151 2152 Args: 2153 labels: The ground truth values, a `Tensor` whose dimensions must match 2154 `predictions`. Will be cast to `bool`. 2155 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2156 be cast to `bool`. 2157 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2158 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2159 be either `1`, or the same as the corresponding `labels` dimension). 2160 metrics_collections: An optional list of collections that `recall` should 2161 be added to. 2162 updates_collections: An optional list of collections that `update_op` should 2163 be added to. 2164 name: An optional variable_scope name. 2165 2166 Returns: 2167 recall: Scalar float `Tensor` with the value of `true_positives` divided 2168 by the sum of `true_positives` and `false_negatives`. 2169 update_op: `Operation` that increments `true_positives` and 2170 `false_negatives` variables appropriately and whose value matches 2171 `recall`. 2172 2173 Raises: 2174 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2175 `weights` is not `None` and its shape doesn't match `predictions`, or if 2176 either `metrics_collections` or `updates_collections` are not a list or 2177 tuple. 2178 RuntimeError: If eager execution is enabled. 2179 """ 2180 if context.executing_eagerly(): 2181 raise RuntimeError('tf.metrics.recall is not supported is not ' 2182 'supported when eager execution is enabled.') 2183 2184 with variable_scope.variable_scope(name, 'recall', 2185 (predictions, labels, weights)): 2186 predictions, labels, weights = _remove_squeezable_dimensions( 2187 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2188 labels=math_ops.cast(labels, dtype=dtypes.bool), 2189 weights=weights) 2190 2191 true_p, true_positives_update_op = true_positives( 2192 labels, 2193 predictions, 2194 weights, 2195 metrics_collections=None, 2196 updates_collections=None, 2197 name=None) 2198 false_n, false_negatives_update_op = false_negatives( 2199 labels, 2200 predictions, 2201 weights, 2202 metrics_collections=None, 2203 updates_collections=None, 2204 name=None) 2205 2206 def compute_recall(true_p, false_n, name): 2207 return array_ops.where( 2208 math_ops.greater(true_p + false_n, 0), 2209 math_ops.divide(true_p, true_p + false_n), 0, name) 2210 2211 def once_across_replicas(_, true_p, false_n): 2212 return compute_recall(true_p, false_n, 'value') 2213 2214 rec = _aggregate_across_replicas( 2215 metrics_collections, once_across_replicas, true_p, false_n) 2216 2217 update_op = compute_recall(true_positives_update_op, 2218 false_negatives_update_op, 'update_op') 2219 if updates_collections: 2220 ops.add_to_collections(updates_collections, update_op) 2221 2222 return rec, update_op 2223 2224 2225def _at_k_name(name, k=None, class_id=None): 2226 if k is not None: 2227 name = '%s_at_%d' % (name, k) 2228 else: 2229 name = '%s_at_k' % (name) 2230 if class_id is not None: 2231 name = '%s_class%d' % (name, class_id) 2232 return name 2233 2234 2235def _select_class_id(ids, selected_id): 2236 """Filter all but `selected_id` out of `ids`. 2237 2238 Args: 2239 ids: `int64` `Tensor` or `SparseTensor` of IDs. 2240 selected_id: Int id to select. 2241 2242 Returns: 2243 `SparseTensor` of same dimensions as `ids`. This contains only the entries 2244 equal to `selected_id`. 2245 """ 2246 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) 2247 if isinstance(ids, sparse_tensor.SparseTensor): 2248 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, 2249 selected_id)) 2250 2251 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 2252 # tf.equal and tf.reduce_any? 2253 2254 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 2255 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 2256 ids_last_dim = array_ops.size(ids_shape) - 1 2257 filled_selected_id_shape = math_ops.reduced_shape(ids_shape, 2258 array_ops.reshape( 2259 ids_last_dim, [1])) 2260 2261 # Intersect `ids` with the selected ID. 2262 filled_selected_id = array_ops.fill(filled_selected_id_shape, 2263 math_ops.cast(selected_id, dtypes.int64)) 2264 result = sets.set_intersection(filled_selected_id, ids) 2265 return sparse_tensor.SparseTensor( 2266 indices=result.indices, values=result.values, dense_shape=ids_shape) 2267 2268 2269def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 2270 """If class ID is specified, filter all other classes. 2271 2272 Args: 2273 labels: `int64` `Tensor` or `SparseTensor` with shape 2274 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2275 target classes for the associated prediction. Commonly, N=1 and `labels` 2276 has shape [batch_size, num_labels]. [D1, ... DN] must match 2277 `predictions_idx`. 2278 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 2279 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 2280 [batch size, k]. 2281 selected_id: Int id to select. 2282 2283 Returns: 2284 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 2285 """ 2286 if selected_id is None: 2287 return labels, predictions_idx 2288 return (_select_class_id(labels, selected_id), 2289 _select_class_id(predictions_idx, selected_id)) 2290 2291 2292def _sparse_true_positive_at_k(labels, 2293 predictions_idx, 2294 class_id=None, 2295 weights=None, 2296 name=None): 2297 """Calculates true positives for recall@k and precision@k. 2298 2299 If `class_id` is specified, calculate binary true positives for `class_id` 2300 only. 2301 If `class_id` is not specified, calculate metrics for `k` predicted vs 2302 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2303 2304 Args: 2305 labels: `int64` `Tensor` or `SparseTensor` with shape 2306 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2307 target classes for the associated prediction. Commonly, N=1 and `labels` 2308 has shape [batch_size, num_labels]. [D1, ... DN] must match 2309 `predictions_idx`. 2310 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2311 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2312 match `labels`. 2313 class_id: Class for which we want binary metrics. 2314 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2315 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2316 dimensions must be either `1`, or the same as the corresponding `labels` 2317 dimension). 2318 name: Name of operation. 2319 2320 Returns: 2321 A [D1, ... DN] `Tensor` of true positive counts. 2322 """ 2323 with ops.name_scope(name, 'true_positives', 2324 (predictions_idx, labels, weights)): 2325 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2326 class_id) 2327 tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) 2328 tp = math_ops.cast(tp, dtypes.float64) 2329 if weights is not None: 2330 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2331 weights, tp),)): 2332 weights = math_ops.cast(weights, dtypes.float64) 2333 tp = math_ops.multiply(tp, weights) 2334 return tp 2335 2336 2337def _streaming_sparse_true_positive_at_k(labels, 2338 predictions_idx, 2339 k=None, 2340 class_id=None, 2341 weights=None, 2342 name=None): 2343 """Calculates weighted per step true positives for recall@k and precision@k. 2344 2345 If `class_id` is specified, calculate binary true positives for `class_id` 2346 only. 2347 If `class_id` is not specified, calculate metrics for `k` predicted vs 2348 `n` label classes, where `n` is the 2nd dimension of `labels`. 2349 2350 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2351 2352 Args: 2353 labels: `int64` `Tensor` or `SparseTensor` with shape 2354 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2355 target classes for the associated prediction. Commonly, N=1 and `labels` 2356 has shape [batch_size, num_labels]. [D1, ... DN] must match 2357 `predictions_idx`. 2358 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2359 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2360 match `labels`. 2361 k: Integer, k for @k metric. This is only used for default op name. 2362 class_id: Class for which we want binary metrics. 2363 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2364 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2365 dimensions must be either `1`, or the same as the corresponding `labels` 2366 dimension). 2367 name: Name of new variable, and namespace for other dependent ops. 2368 2369 Returns: 2370 A tuple of `Variable` and update `Operation`. 2371 2372 Raises: 2373 ValueError: If `weights` is not `None` and has an incompatible shape. 2374 """ 2375 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), 2376 (predictions_idx, labels, weights)) as scope: 2377 tp = _sparse_true_positive_at_k( 2378 predictions_idx=predictions_idx, 2379 labels=labels, 2380 class_id=class_id, 2381 weights=weights) 2382 batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64) 2383 2384 var = metric_variable([], dtypes.float64, name=scope) 2385 return var, state_ops.assign_add(var, batch_total_tp, name='update') 2386 2387 2388def _sparse_false_negative_at_k(labels, 2389 predictions_idx, 2390 class_id=None, 2391 weights=None): 2392 """Calculates false negatives for recall@k. 2393 2394 If `class_id` is specified, calculate binary true positives for `class_id` 2395 only. 2396 If `class_id` is not specified, calculate metrics for `k` predicted vs 2397 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2398 2399 Args: 2400 labels: `int64` `Tensor` or `SparseTensor` with shape 2401 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2402 target classes for the associated prediction. Commonly, N=1 and `labels` 2403 has shape [batch_size, num_labels]. [D1, ... DN] must match 2404 `predictions_idx`. 2405 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2406 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2407 match `labels`. 2408 class_id: Class for which we want binary metrics. 2409 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2410 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2411 dimensions must be either `1`, or the same as the corresponding `labels` 2412 dimension). 2413 2414 Returns: 2415 A [D1, ... DN] `Tensor` of false negative counts. 2416 """ 2417 with ops.name_scope(None, 'false_negatives', 2418 (predictions_idx, labels, weights)): 2419 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2420 class_id) 2421 fn = sets.set_size( 2422 sets.set_difference(predictions_idx, labels, aminusb=False)) 2423 fn = math_ops.cast(fn, dtypes.float64) 2424 if weights is not None: 2425 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2426 weights, fn),)): 2427 weights = math_ops.cast(weights, dtypes.float64) 2428 fn = math_ops.multiply(fn, weights) 2429 return fn 2430 2431 2432def _streaming_sparse_false_negative_at_k(labels, 2433 predictions_idx, 2434 k, 2435 class_id=None, 2436 weights=None, 2437 name=None): 2438 """Calculates weighted per step false negatives for recall@k. 2439 2440 If `class_id` is specified, calculate binary true positives for `class_id` 2441 only. 2442 If `class_id` is not specified, calculate metrics for `k` predicted vs 2443 `n` label classes, where `n` is the 2nd dimension of `labels`. 2444 2445 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2446 2447 Args: 2448 labels: `int64` `Tensor` or `SparseTensor` with shape 2449 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2450 target classes for the associated prediction. Commonly, N=1 and `labels` 2451 has shape [batch_size, num_labels]. [D1, ... DN] must match 2452 `predictions_idx`. 2453 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2454 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2455 match `labels`. 2456 k: Integer, k for @k metric. This is only used for default op name. 2457 class_id: Class for which we want binary metrics. 2458 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2459 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2460 dimensions must be either `1`, or the same as the corresponding `labels` 2461 dimension). 2462 name: Name of new variable, and namespace for other dependent ops. 2463 2464 Returns: 2465 A tuple of `Variable` and update `Operation`. 2466 2467 Raises: 2468 ValueError: If `weights` is not `None` and has an incompatible shape. 2469 """ 2470 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), 2471 (predictions_idx, labels, weights)) as scope: 2472 fn = _sparse_false_negative_at_k( 2473 predictions_idx=predictions_idx, 2474 labels=labels, 2475 class_id=class_id, 2476 weights=weights) 2477 batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64) 2478 2479 var = metric_variable([], dtypes.float64, name=scope) 2480 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2481 2482 2483@tf_export(v1=['metrics.recall_at_k']) 2484def recall_at_k(labels, 2485 predictions, 2486 k, 2487 class_id=None, 2488 weights=None, 2489 metrics_collections=None, 2490 updates_collections=None, 2491 name=None): 2492 """Computes recall@k of the predictions with respect to sparse labels. 2493 2494 If `class_id` is specified, we calculate recall by considering only the 2495 entries in the batch for which `class_id` is in the label, and computing 2496 the fraction of them for which `class_id` is in the top-k `predictions`. 2497 If `class_id` is not specified, we'll calculate recall as how often on 2498 average a class among the labels of a batch entry is in the top-k 2499 `predictions`. 2500 2501 `sparse_recall_at_k` creates two local variables, 2502 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2503 the recall_at_k frequency. This frequency is ultimately returned as 2504 `recall_at_<k>`: an idempotent operation that simply divides 2505 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2506 `false_negative_at_<k>`). 2507 2508 For estimation of the metric over a stream of data, the function creates an 2509 `update_op` operation that updates these variables and returns the 2510 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2511 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2512 `labels` calculate the true positives and false negatives weighted by 2513 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2514 `false_negative_at_<k>` using these values. 2515 2516 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2517 2518 Args: 2519 labels: `int64` `Tensor` or `SparseTensor` with shape 2520 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2521 num_labels=1. N >= 1 and num_labels is the number of target classes for 2522 the associated prediction. Commonly, N=1 and `labels` has shape 2523 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2524 should be in range [0, num_classes), where num_classes is the last 2525 dimension of `predictions`. Values outside this range always count 2526 towards `false_negative_at_<k>`. 2527 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2528 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2529 The final dimension contains the logit values for each class. [D1, ... DN] 2530 must match `labels`. 2531 k: Integer, k for @k metric. 2532 class_id: Integer class ID for which we want binary metrics. This should be 2533 in range [0, num_classes), where num_classes is the last dimension of 2534 `predictions`. If class_id is outside this range, the method returns NAN. 2535 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2536 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2537 dimensions must be either `1`, or the same as the corresponding `labels` 2538 dimension). 2539 metrics_collections: An optional list of collections that values should 2540 be added to. 2541 updates_collections: An optional list of collections that updates should 2542 be added to. 2543 name: Name of new update operation, and namespace for other dependent ops. 2544 2545 Returns: 2546 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2547 by the sum of `true_positives` and `false_negatives`. 2548 update_op: `Operation` that increments `true_positives` and 2549 `false_negatives` variables appropriately, and whose value matches 2550 `recall`. 2551 2552 Raises: 2553 ValueError: If `weights` is not `None` and its shape doesn't match 2554 `predictions`, or if either `metrics_collections` or `updates_collections` 2555 are not a list or tuple. 2556 RuntimeError: If eager execution is enabled. 2557 """ 2558 if context.executing_eagerly(): 2559 raise RuntimeError('tf.metrics.recall_at_k is not ' 2560 'supported when eager execution is enabled.') 2561 2562 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2563 (predictions, labels, weights)) as scope: 2564 _, top_k_idx = nn.top_k(predictions, k) 2565 return recall_at_top_k( 2566 labels=labels, 2567 predictions_idx=top_k_idx, 2568 k=k, 2569 class_id=class_id, 2570 weights=weights, 2571 metrics_collections=metrics_collections, 2572 updates_collections=updates_collections, 2573 name=scope) 2574 2575 2576@tf_export(v1=['metrics.recall_at_top_k']) 2577def recall_at_top_k(labels, 2578 predictions_idx, 2579 k=None, 2580 class_id=None, 2581 weights=None, 2582 metrics_collections=None, 2583 updates_collections=None, 2584 name=None): 2585 """Computes recall@k of top-k predictions with respect to sparse labels. 2586 2587 Differs from `recall_at_k` in that predictions must be in the form of top `k` 2588 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` 2589 for more details. 2590 2591 Args: 2592 labels: `int64` `Tensor` or `SparseTensor` with shape 2593 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2594 num_labels=1. N >= 1 and num_labels is the number of target classes for 2595 the associated prediction. Commonly, N=1 and `labels` has shape 2596 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2597 should be in range [0, num_classes), where num_classes is the last 2598 dimension of `predictions`. Values outside this range always count 2599 towards `false_negative_at_<k>`. 2600 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2601 Commonly, N=1 and predictions has shape [batch size, k]. The final 2602 dimension contains the top `k` predicted class indices. [D1, ... DN] must 2603 match `labels`. 2604 k: Integer, k for @k metric. Only used for the default op name. 2605 class_id: Integer class ID for which we want binary metrics. This should be 2606 in range [0, num_classes), where num_classes is the last dimension of 2607 `predictions`. If class_id is outside this range, the method returns NAN. 2608 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2609 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2610 dimensions must be either `1`, or the same as the corresponding `labels` 2611 dimension). 2612 metrics_collections: An optional list of collections that values should 2613 be added to. 2614 updates_collections: An optional list of collections that updates should 2615 be added to. 2616 name: Name of new update operation, and namespace for other dependent ops. 2617 2618 Returns: 2619 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2620 by the sum of `true_positives` and `false_negatives`. 2621 update_op: `Operation` that increments `true_positives` and 2622 `false_negatives` variables appropriately, and whose value matches 2623 `recall`. 2624 2625 Raises: 2626 ValueError: If `weights` is not `None` and its shape doesn't match 2627 `predictions`, or if either `metrics_collections` or `updates_collections` 2628 are not a list or tuple. 2629 """ 2630 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2631 (predictions_idx, labels, weights)) as scope: 2632 labels = _maybe_expand_labels(labels, predictions_idx) 2633 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 2634 tp, tp_update = _streaming_sparse_true_positive_at_k( 2635 predictions_idx=top_k_idx, 2636 labels=labels, 2637 k=k, 2638 class_id=class_id, 2639 weights=weights) 2640 fn, fn_update = _streaming_sparse_false_negative_at_k( 2641 predictions_idx=top_k_idx, 2642 labels=labels, 2643 k=k, 2644 class_id=class_id, 2645 weights=weights) 2646 2647 def compute_recall(_, tp, fn): 2648 return math_ops.divide(tp, math_ops.add(tp, fn), name=scope) 2649 2650 metric = _aggregate_across_replicas( 2651 metrics_collections, compute_recall, tp, fn) 2652 2653 update = math_ops.divide( 2654 tp_update, math_ops.add(tp_update, fn_update), name='update') 2655 if updates_collections: 2656 ops.add_to_collections(updates_collections, update) 2657 return metric, update 2658 2659 2660@tf_export(v1=['metrics.recall_at_thresholds']) 2661def recall_at_thresholds(labels, 2662 predictions, 2663 thresholds, 2664 weights=None, 2665 metrics_collections=None, 2666 updates_collections=None, 2667 name=None): 2668 """Computes various recall values for different `thresholds` on `predictions`. 2669 2670 The `recall_at_thresholds` function creates four local variables, 2671 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2672 for various values of thresholds. `recall[i]` is defined as the total weight 2673 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2674 `labels` is `True`, divided by the total weight of `True` values in `labels` 2675 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 2676 2677 For estimation of the metric over a stream of data, the function creates an 2678 `update_op` operation that updates these variables and returns the `recall`. 2679 2680 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2681 2682 Args: 2683 labels: The ground truth values, a `Tensor` whose dimensions must match 2684 `predictions`. Will be cast to `bool`. 2685 predictions: A floating point `Tensor` of arbitrary shape and whose values 2686 are in the range `[0, 1]`. 2687 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2688 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2689 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2690 be either `1`, or the same as the corresponding `labels` dimension). 2691 metrics_collections: An optional list of collections that `recall` should be 2692 added to. 2693 updates_collections: An optional list of collections that `update_op` should 2694 be added to. 2695 name: An optional variable_scope name. 2696 2697 Returns: 2698 recall: A float `Tensor` of shape `[len(thresholds)]`. 2699 update_op: An operation that increments the `true_positives`, 2700 `true_negatives`, `false_positives` and `false_negatives` variables that 2701 are used in the computation of `recall`. 2702 2703 Raises: 2704 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2705 `weights` is not `None` and its shape doesn't match `predictions`, or if 2706 either `metrics_collections` or `updates_collections` are not a list or 2707 tuple. 2708 RuntimeError: If eager execution is enabled. 2709 """ 2710 if context.executing_eagerly(): 2711 raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 2712 'supported when eager execution is enabled.') 2713 2714 with variable_scope.variable_scope(name, 'recall_at_thresholds', 2715 (predictions, labels, weights)): 2716 values, update_ops = _confusion_matrix_at_thresholds( 2717 labels, predictions, thresholds, weights, includes=('tp', 'fn')) 2718 2719 # Avoid division by zero. 2720 epsilon = 1e-7 2721 2722 def compute_recall(tp, fn, name): 2723 return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name) 2724 2725 def recall_across_replicas(_, values): 2726 return compute_recall(values['tp'], values['fn'], 'value') 2727 2728 rec = _aggregate_across_replicas( 2729 metrics_collections, recall_across_replicas, values) 2730 2731 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') 2732 if updates_collections: 2733 ops.add_to_collections(updates_collections, update_op) 2734 2735 return rec, update_op 2736 2737 2738@tf_export(v1=['metrics.root_mean_squared_error']) 2739def root_mean_squared_error(labels, 2740 predictions, 2741 weights=None, 2742 metrics_collections=None, 2743 updates_collections=None, 2744 name=None): 2745 """Computes the root mean squared error between the labels and predictions. 2746 2747 The `root_mean_squared_error` function creates two local variables, 2748 `total` and `count` that are used to compute the root mean squared error. 2749 This average is weighted by `weights`, and it is ultimately returned as 2750 `root_mean_squared_error`: an idempotent operation that takes the square root 2751 of the division of `total` by `count`. 2752 2753 For estimation of the metric over a stream of data, the function creates an 2754 `update_op` operation that updates these variables and returns the 2755 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2756 the element-wise square of the difference between `predictions` and `labels`. 2757 Then `update_op` increments `total` with the reduced sum of the product of 2758 `weights` and `squared_error`, and it increments `count` with the reduced sum 2759 of `weights`. 2760 2761 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2762 2763 Args: 2764 labels: A `Tensor` of the same shape as `predictions`. 2765 predictions: A `Tensor` of arbitrary shape. 2766 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2767 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2768 be either `1`, or the same as the corresponding `labels` dimension). 2769 metrics_collections: An optional list of collections that 2770 `root_mean_squared_error` should be added to. 2771 updates_collections: An optional list of collections that `update_op` should 2772 be added to. 2773 name: An optional variable_scope name. 2774 2775 Returns: 2776 root_mean_squared_error: A `Tensor` representing the current mean, the value 2777 of `total` divided by `count`. 2778 update_op: An operation that increments the `total` and `count` variables 2779 appropriately and whose value matches `root_mean_squared_error`. 2780 2781 Raises: 2782 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2783 `weights` is not `None` and its shape doesn't match `predictions`, or if 2784 either `metrics_collections` or `updates_collections` are not a list or 2785 tuple. 2786 RuntimeError: If eager execution is enabled. 2787 """ 2788 if context.executing_eagerly(): 2789 raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 2790 'supported when eager execution is enabled.') 2791 2792 predictions, labels, weights = _remove_squeezable_dimensions( 2793 predictions=predictions, labels=labels, weights=weights) 2794 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, 2795 None, name or 2796 'root_mean_squared_error') 2797 2798 once_across_replicas = lambda _, mse: math_ops.sqrt(mse) 2799 rmse = _aggregate_across_replicas( 2800 metrics_collections, once_across_replicas, mse) 2801 2802 update_rmse_op = math_ops.sqrt(update_mse_op) 2803 if updates_collections: 2804 ops.add_to_collections(updates_collections, update_rmse_op) 2805 2806 return rmse, update_rmse_op 2807 2808 2809@tf_export(v1=['metrics.sensitivity_at_specificity']) 2810def sensitivity_at_specificity(labels, 2811 predictions, 2812 specificity, 2813 weights=None, 2814 num_thresholds=200, 2815 metrics_collections=None, 2816 updates_collections=None, 2817 name=None): 2818 """Computes the specificity at a given sensitivity. 2819 2820 The `sensitivity_at_specificity` function creates four local 2821 variables, `true_positives`, `true_negatives`, `false_positives` and 2822 `false_negatives` that are used to compute the sensitivity at the given 2823 specificity value. The threshold for the given specificity value is computed 2824 and used to evaluate the corresponding sensitivity. 2825 2826 For estimation of the metric over a stream of data, the function creates an 2827 `update_op` operation that updates these variables and returns the 2828 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 2829 `false_positives` and `false_negatives` counts with the weight of each case 2830 found in the `predictions` and `labels`. 2831 2832 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2833 2834 For additional information about specificity and sensitivity, see the 2835 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 2836 2837 Args: 2838 labels: The ground truth values, a `Tensor` whose dimensions must match 2839 `predictions`. Will be cast to `bool`. 2840 predictions: A floating point `Tensor` of arbitrary shape and whose values 2841 are in the range `[0, 1]`. 2842 specificity: A scalar value in range `[0, 1]`. 2843 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2844 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2845 be either `1`, or the same as the corresponding `labels` dimension). 2846 num_thresholds: The number of thresholds to use for matching the given 2847 specificity. 2848 metrics_collections: An optional list of collections that `sensitivity` 2849 should be added to. 2850 updates_collections: An optional list of collections that `update_op` should 2851 be added to. 2852 name: An optional variable_scope name. 2853 2854 Returns: 2855 sensitivity: A scalar `Tensor` representing the sensitivity at the given 2856 `specificity` value. 2857 update_op: An operation that increments the `true_positives`, 2858 `true_negatives`, `false_positives` and `false_negatives` variables 2859 appropriately and whose value matches `sensitivity`. 2860 2861 Raises: 2862 ValueError: If `predictions` and `labels` have mismatched shapes, if 2863 `weights` is not `None` and its shape doesn't match `predictions`, or if 2864 `specificity` is not between 0 and 1, or if either `metrics_collections` 2865 or `updates_collections` are not a list or tuple. 2866 RuntimeError: If eager execution is enabled. 2867 """ 2868 if context.executing_eagerly(): 2869 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 2870 'supported when eager execution is enabled.') 2871 2872 if specificity < 0 or specificity > 1: 2873 raise ValueError('`specificity` must be in the range [0, 1].') 2874 2875 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 2876 (predictions, labels, weights)): 2877 kepsilon = 1e-7 # to account for floating point imprecisions 2878 thresholds = [ 2879 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 2880 ] 2881 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 2882 2883 values, update_ops = _confusion_matrix_at_thresholds( 2884 labels, predictions, thresholds, weights) 2885 2886 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name): 2887 specificities = math_ops.divide(tn, tn + fp + kepsilon) 2888 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 2889 tf_index = math_ops.cast(tf_index, dtypes.int32) 2890 2891 # Now, we have the implicit threshold, so compute the sensitivity: 2892 return math_ops.divide(tp[tf_index], 2893 tp[tf_index] + fn[tf_index] + kepsilon, name) 2894 2895 def sensitivity_across_replicas(_, values): 2896 return compute_sensitivity_at_specificity( 2897 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 2898 2899 sensitivity = _aggregate_across_replicas( 2900 metrics_collections, sensitivity_across_replicas, values) 2901 2902 update_op = compute_sensitivity_at_specificity( 2903 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 2904 'update_op') 2905 if updates_collections: 2906 ops.add_to_collections(updates_collections, update_op) 2907 2908 return sensitivity, update_op 2909 2910 2911def _expand_and_tile(tensor, multiple, dim=0, name=None): 2912 """Slice `tensor` shape in 2, then tile along the sliced dimension. 2913 2914 A new dimension is inserted in shape of `tensor` before `dim`, then values are 2915 tiled `multiple` times along the new dimension. 2916 2917 Args: 2918 tensor: Input `Tensor` or `SparseTensor`. 2919 multiple: Integer, number of times to tile. 2920 dim: Integer, dimension along which to tile. 2921 name: Name of operation. 2922 2923 Returns: 2924 `Tensor` result of expanding and tiling `tensor`. 2925 2926 Raises: 2927 ValueError: if `multiple` is less than 1, or `dim` is not in 2928 `[-rank(tensor), rank(tensor)]`. 2929 """ 2930 if multiple < 1: 2931 raise ValueError('Invalid multiple %s, must be > 0.' % multiple) 2932 with ops.name_scope(name, 'expand_and_tile', 2933 (tensor, multiple, dim)) as scope: 2934 # Sparse. 2935 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 2936 if isinstance(tensor, sparse_tensor.SparseTensor): 2937 if dim < 0: 2938 expand_dims = array_ops.reshape( 2939 array_ops.size(tensor.dense_shape) + dim, [1]) 2940 else: 2941 expand_dims = [dim] 2942 expanded_shape = array_ops.concat( 2943 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1], 2944 array_ops.slice(tensor.dense_shape, expand_dims, [-1])), 2945 0, 2946 name='expanded_shape') 2947 expanded = sparse_ops.sparse_reshape( 2948 tensor, shape=expanded_shape, name='expand') 2949 if multiple == 1: 2950 return expanded 2951 return sparse_ops.sparse_concat( 2952 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 2953 2954 # Dense. 2955 expanded = array_ops.expand_dims( 2956 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 2957 if multiple == 1: 2958 return expanded 2959 ones = array_ops.ones_like(array_ops.shape(tensor)) 2960 tile_multiples = array_ops.concat( 2961 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 2962 return array_ops.tile(expanded, tile_multiples, name=scope) 2963 2964 2965def _num_relevant(labels, k): 2966 """Computes number of relevant values for each row in labels. 2967 2968 For labels with shape [D1, ... DN, num_labels], this is the minimum of 2969 `num_labels` and `k`. 2970 2971 Args: 2972 labels: `int64` `Tensor` or `SparseTensor` with shape 2973 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2974 target classes for the associated prediction. Commonly, N=1 and `labels` 2975 has shape [batch_size, num_labels]. 2976 k: Integer, k for @k metric. 2977 2978 Returns: 2979 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 2980 relevant values for that row. 2981 2982 Raises: 2983 ValueError: if inputs have invalid dtypes or values. 2984 """ 2985 if k < 1: 2986 raise ValueError('Invalid k=%s.' % k) 2987 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 2988 # For SparseTensor, calculate separate count for each row. 2989 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 2990 if isinstance(labels, sparse_tensor.SparseTensor): 2991 return math_ops.minimum(sets.set_size(labels), k, name=scope) 2992 2993 # The relevant values for each (d1, ... dN) is the minimum of k and the 2994 # number of labels along the last dimension that are non-negative. 2995 num_labels = math_ops.reduce_sum( 2996 array_ops.where_v2(math_ops.greater_equal(labels, 0), 2997 array_ops.ones_like(labels), 2998 array_ops.zeros_like(labels)), 2999 axis=-1) 3000 return math_ops.minimum(num_labels, k, name=scope) 3001 3002 3003def _sparse_average_precision_at_top_k(labels, predictions_idx): 3004 """Computes average precision@k of predictions with respect to sparse labels. 3005 3006 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 3007 for each row is: 3008 3009 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 3010 3011 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`, 3012 `labels`, and the result `Tensors`. In the common case, this is [batch_size]. 3013 Each row of the results contains the average precision for that row. 3014 3015 Args: 3016 labels: `int64` `Tensor` or `SparseTensor` with shape 3017 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3018 num_labels=1. N >= 1 and num_labels is the number of target classes for 3019 the associated prediction. Commonly, N=1 and `labels` has shape 3020 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3021 Values should be non-negative. Negative values are ignored. 3022 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3023 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3024 dimension must be set and contains the top `k` predicted class indices. 3025 [D1, ... DN] must match `labels`. Values should be in range 3026 [0, num_classes). 3027 3028 Returns: 3029 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 3030 precision for that row. 3031 3032 Raises: 3033 ValueError: if the last dimension of predictions_idx is not set. 3034 """ 3035 with ops.name_scope(None, 'average_precision', 3036 (predictions_idx, labels)) as scope: 3037 predictions_idx = math_ops.cast( 3038 predictions_idx, dtypes.int64, name='predictions_idx') 3039 if predictions_idx.get_shape().ndims == 0: 3040 raise ValueError('The rank of predictions_idx must be at least 1.') 3041 k = predictions_idx.get_shape().as_list()[-1] 3042 if k is None: 3043 raise ValueError('The last dimension of predictions_idx must be set.') 3044 labels = _maybe_expand_labels(labels, predictions_idx) 3045 3046 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 3047 # prediction for each k, so we can calculate separate true positive values 3048 # for each k. 3049 predictions_idx_per_k = array_ops.expand_dims( 3050 predictions_idx, -1, name='predictions_idx_per_k') 3051 3052 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 3053 labels_per_k = _expand_and_tile( 3054 labels, multiple=k, dim=-1, name='labels_per_k') 3055 3056 # The following tensors are all of shape [D1, ... DN, k], containing values 3057 # per row, per k value. 3058 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 3059 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 3060 # the formula above. 3061 # `tp_per_k` (int32) - True positive counts. 3062 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 3063 # the precision denominator. 3064 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 3065 # term from the formula above. 3066 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 3067 # precisions at all k for which relevance indicator is true. 3068 relevant_per_k = _sparse_true_positive_at_k( 3069 labels_per_k, predictions_idx_per_k, name='relevant_per_k') 3070 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 3071 retrieved_per_k = math_ops.cumsum( 3072 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 3073 precision_per_k = math_ops.divide( 3074 math_ops.cast(tp_per_k, dtypes.float64), 3075 math_ops.cast(retrieved_per_k, dtypes.float64), 3076 name='precision_per_k') 3077 relevant_precision_per_k = math_ops.multiply( 3078 precision_per_k, 3079 math_ops.cast(relevant_per_k, dtypes.float64), 3080 name='relevant_precision_per_k') 3081 3082 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 3083 precision_sum = math_ops.reduce_sum( 3084 relevant_precision_per_k, axis=(-1,), name='precision_sum') 3085 3086 # Divide by number of relevant items to get average precision. These are 3087 # the "num_relevant_items" and "AveP" terms from the formula above. 3088 num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64) 3089 return math_ops.divide(precision_sum, num_relevant_items, name=scope) 3090 3091 3092def _streaming_sparse_average_precision_at_top_k(labels, 3093 predictions_idx, 3094 weights=None, 3095 metrics_collections=None, 3096 updates_collections=None, 3097 name=None): 3098 """Computes average precision@k of predictions with respect to sparse labels. 3099 3100 `sparse_average_precision_at_top_k` creates two local variables, 3101 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3102 are used to compute the frequency. This frequency is ultimately returned as 3103 `average_precision_at_<k>`: an idempotent operation that simply divides 3104 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3105 3106 For estimation of the metric over a stream of data, the function creates an 3107 `update_op` operation that updates these variables and returns the 3108 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 3109 the true positives and false positives weighted by `weights`. Then `update_op` 3110 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 3111 values. 3112 3113 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3114 3115 Args: 3116 labels: `int64` `Tensor` or `SparseTensor` with shape 3117 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3118 num_labels=1. N >= 1 and num_labels is the number of target classes for 3119 the associated prediction. Commonly, N=1 and `labels` has shape 3120 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3121 Values should be non-negative. Negative values are ignored. 3122 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3123 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3124 dimension contains the top `k` predicted class indices. [D1, ... DN] must 3125 match `labels`. Values should be in range [0, num_classes). 3126 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3127 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3128 dimensions must be either `1`, or the same as the corresponding `labels` 3129 dimension). 3130 metrics_collections: An optional list of collections that values should 3131 be added to. 3132 updates_collections: An optional list of collections that updates should 3133 be added to. 3134 name: Name of new update operation, and namespace for other dependent ops. 3135 3136 Returns: 3137 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3138 precision values. 3139 update: `Operation` that increments variables appropriately, and whose 3140 value matches `metric`. 3141 """ 3142 with ops.name_scope(name, 'average_precision_at_top_k', 3143 (predictions_idx, labels, weights)) as scope: 3144 # Calculate per-example average precision, and apply weights. 3145 average_precision = _sparse_average_precision_at_top_k( 3146 predictions_idx=predictions_idx, labels=labels) 3147 if weights is not None: 3148 weights = weights_broadcast_ops.broadcast_weights( 3149 math_ops.cast(weights, dtypes.float64), average_precision) 3150 average_precision = math_ops.multiply(average_precision, weights) 3151 3152 # Create accumulation variables and update ops for max average precision and 3153 # total average precision. 3154 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 3155 # `max` is the max possible precision. Since max for any row is 1.0: 3156 # - For the unweighted case, this is just the number of rows. 3157 # - For the weighted case, it's the sum of the weights broadcast across 3158 # `average_precision` rows. 3159 max_var = metric_variable([], dtypes.float64, name=max_scope) 3160 if weights is None: 3161 batch_max = math_ops.cast( 3162 array_ops.size(average_precision, name='batch_max'), dtypes.float64) 3163 else: 3164 batch_max = math_ops.reduce_sum(weights, name='batch_max') 3165 max_update = state_ops.assign_add(max_var, batch_max, name='update') 3166 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 3167 total_var = metric_variable([], dtypes.float64, name=total_scope) 3168 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 3169 total_update = state_ops.assign_add(total_var, batch_total, name='update') 3170 3171 # Divide total by max to get mean, for both vars and the update ops. 3172 def precision_across_replicas(_, total_var, max_var): 3173 return _safe_scalar_div(total_var, max_var, name='mean') 3174 3175 mean_average_precision = _aggregate_across_replicas( 3176 metrics_collections, precision_across_replicas, total_var, max_var) 3177 3178 update = _safe_scalar_div(total_update, max_update, name=scope) 3179 if updates_collections: 3180 ops.add_to_collections(updates_collections, update) 3181 3182 return mean_average_precision, update 3183 3184 3185def _clean_out_of_range_indices(labels, num_classes): 3186 """Replaces large out-of-range labels by small out-of-range labels. 3187 3188 Replaces any value in `labels` that is greater or equal to `num_classes` by 3189 -1. Do this conditionally for efficiency in case there are no such values. 3190 3191 Args: 3192 labels: `int64` `Tensor` or `SparseTensor`. 3193 num_classes: `int64` scalar `Tensor`. 3194 Returns: 3195 An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater 3196 or equal to num_classes replaced by -1. 3197 """ 3198 3199 def _labels_is_sparse(): 3200 """Returns true is `labels` is a sparse tensor.""" 3201 return isinstance(labels, (sparse_tensor.SparseTensor, 3202 sparse_tensor.SparseTensorValue)) 3203 3204 def _clean_out_of_range(values): 3205 """Replaces by -1 any large out-of-range `values`.""" 3206 return array_ops.where_v2(math_ops.greater_equal(values, num_classes), 3207 -1 * array_ops.ones_like(values), values) 3208 3209 def _clean_labels_out_of_range(): 3210 """Replaces by -1 ane large out-of-range values in `labels`.""" 3211 if _labels_is_sparse(): 3212 return type(labels)(indices=labels.indices, 3213 values=_clean_out_of_range(labels.values), 3214 dense_shape=labels.dense_shape) 3215 else: 3216 return _clean_out_of_range(labels) 3217 3218 max_labels = math_ops.reduce_max( 3219 labels.values if _labels_is_sparse() else labels) 3220 return control_flow_ops.cond( 3221 math_ops.greater_equal(max_labels, num_classes), 3222 _clean_labels_out_of_range, 3223 lambda: labels) 3224 3225 3226@tf_export(v1=['metrics.sparse_average_precision_at_k']) 3227@deprecated(None, 'Use average_precision_at_k instead') 3228def sparse_average_precision_at_k(labels, 3229 predictions, 3230 k, 3231 weights=None, 3232 metrics_collections=None, 3233 updates_collections=None, 3234 name=None): 3235 """Renamed to `average_precision_at_k`, please use that method instead.""" 3236 return average_precision_at_k( 3237 labels=labels, 3238 predictions=predictions, 3239 k=k, 3240 weights=weights, 3241 metrics_collections=metrics_collections, 3242 updates_collections=updates_collections, 3243 name=name) 3244 3245 3246@tf_export(v1=['metrics.average_precision_at_k']) 3247def average_precision_at_k(labels, 3248 predictions, 3249 k, 3250 weights=None, 3251 metrics_collections=None, 3252 updates_collections=None, 3253 name=None): 3254 """Computes average precision@k of predictions with respect to sparse labels. 3255 3256 `average_precision_at_k` creates two local variables, 3257 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3258 are used to compute the frequency. This frequency is ultimately returned as 3259 `average_precision_at_<k>`: an idempotent operation that simply divides 3260 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3261 3262 For estimation of the metric over a stream of data, the function creates an 3263 `update_op` operation that updates these variables and returns the 3264 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3265 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3266 `labels` calculate the true positives and false positives weighted by 3267 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3268 `false_positive_at_<k>` using these values. 3269 3270 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3271 3272 Args: 3273 labels: `int64` `Tensor` or `SparseTensor` with shape 3274 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3275 num_labels=1. N >= 1 and num_labels is the number of target classes for 3276 the associated prediction. Commonly, N=1 and `labels` has shape 3277 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3278 should be in range [0, num_classes), where num_classes is the last 3279 dimension of `predictions`. Values outside this range are ignored. 3280 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3281 N >= 1. Commonly, N=1 and `predictions` has shape 3282 [batch size, num_classes]. The final dimension contains the logit values 3283 for each class. [D1, ... DN] must match `labels`. 3284 k: Integer, k for @k metric. This will calculate an average precision for 3285 range `[1,k]`, as documented above. 3286 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3287 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3288 dimensions must be either `1`, or the same as the corresponding `labels` 3289 dimension). 3290 metrics_collections: An optional list of collections that values should 3291 be added to. 3292 updates_collections: An optional list of collections that updates should 3293 be added to. 3294 name: Name of new update operation, and namespace for other dependent ops. 3295 3296 Returns: 3297 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3298 precision values. 3299 update: `Operation` that increments variables appropriately, and whose 3300 value matches `metric`. 3301 3302 Raises: 3303 ValueError: if k is invalid. 3304 RuntimeError: If eager execution is enabled. 3305 """ 3306 if context.executing_eagerly(): 3307 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 3308 'supported when eager execution is enabled.') 3309 3310 if k < 1: 3311 raise ValueError('Invalid k=%s.' % k) 3312 with ops.name_scope(name, _at_k_name('average_precision', k), 3313 (predictions, labels, weights)) as scope: 3314 # Calculate top k indices to produce [D1, ... DN, k] tensor. 3315 _, predictions_idx = nn.top_k(predictions, k) 3316 # The documentation states that labels should be in [0, ..., num_classes), 3317 # but num_classes is lost when predictions_idx replaces predictions. 3318 # For conformity with the documentation, any label >= num_classes, which is 3319 # ignored, is replaced by -1. 3320 labels = _clean_out_of_range_indices( 3321 labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64)) 3322 return _streaming_sparse_average_precision_at_top_k( 3323 labels=labels, 3324 predictions_idx=predictions_idx, 3325 weights=weights, 3326 metrics_collections=metrics_collections, 3327 updates_collections=updates_collections, 3328 name=scope) 3329 3330 3331def _sparse_false_positive_at_k(labels, 3332 predictions_idx, 3333 class_id=None, 3334 weights=None): 3335 """Calculates false positives for precision@k. 3336 3337 If `class_id` is specified, calculate binary true positives for `class_id` 3338 only. 3339 If `class_id` is not specified, calculate metrics for `k` predicted vs 3340 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 3341 3342 Args: 3343 labels: `int64` `Tensor` or `SparseTensor` with shape 3344 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3345 target classes for the associated prediction. Commonly, N=1 and `labels` 3346 has shape [batch_size, num_labels]. [D1, ... DN] must match 3347 `predictions_idx`. 3348 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3349 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3350 match `labels`. 3351 class_id: Class for which we want binary metrics. 3352 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3353 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3354 dimensions must be either `1`, or the same as the corresponding `labels` 3355 dimension). 3356 3357 Returns: 3358 A [D1, ... DN] `Tensor` of false positive counts. 3359 """ 3360 with ops.name_scope(None, 'false_positives', 3361 (predictions_idx, labels, weights)): 3362 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 3363 class_id) 3364 fp = sets.set_size( 3365 sets.set_difference(predictions_idx, labels, aminusb=True)) 3366 fp = math_ops.cast(fp, dtypes.float64) 3367 if weights is not None: 3368 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 3369 weights, fp),)): 3370 weights = math_ops.cast(weights, dtypes.float64) 3371 fp = math_ops.multiply(fp, weights) 3372 return fp 3373 3374 3375def _streaming_sparse_false_positive_at_k(labels, 3376 predictions_idx, 3377 k=None, 3378 class_id=None, 3379 weights=None, 3380 name=None): 3381 """Calculates weighted per step false positives for precision@k. 3382 3383 If `class_id` is specified, calculate binary true positives for `class_id` 3384 only. 3385 If `class_id` is not specified, calculate metrics for `k` predicted vs 3386 `n` label classes, where `n` is the 2nd dimension of `labels`. 3387 3388 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3389 3390 Args: 3391 labels: `int64` `Tensor` or `SparseTensor` with shape 3392 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3393 target classes for the associated prediction. Commonly, N=1 and `labels` 3394 has shape [batch_size, num_labels]. [D1, ... DN] must match 3395 `predictions_idx`. 3396 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3397 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3398 match `labels`. 3399 k: Integer, k for @k metric. This is only used for default op name. 3400 class_id: Class for which we want binary metrics. 3401 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3402 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3403 dimensions must be either `1`, or the same as the corresponding `labels` 3404 dimension). 3405 name: Name of new variable, and namespace for other dependent ops. 3406 3407 Returns: 3408 A tuple of `Variable` and update `Operation`. 3409 3410 Raises: 3411 ValueError: If `weights` is not `None` and has an incompatible shape. 3412 """ 3413 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), 3414 (predictions_idx, labels, weights)) as scope: 3415 fp = _sparse_false_positive_at_k( 3416 predictions_idx=predictions_idx, 3417 labels=labels, 3418 class_id=class_id, 3419 weights=weights) 3420 batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64) 3421 3422 var = metric_variable([], dtypes.float64, name=scope) 3423 return var, state_ops.assign_add(var, batch_total_fp, name='update') 3424 3425 3426@tf_export(v1=['metrics.precision_at_top_k']) 3427def precision_at_top_k(labels, 3428 predictions_idx, 3429 k=None, 3430 class_id=None, 3431 weights=None, 3432 metrics_collections=None, 3433 updates_collections=None, 3434 name=None): 3435 """Computes precision@k of the predictions with respect to sparse labels. 3436 3437 Differs from `sparse_precision_at_k` in that predictions must be in the form 3438 of top `k` class indices, whereas `sparse_precision_at_k` expects logits. 3439 Refer to `sparse_precision_at_k` for more details. 3440 3441 Args: 3442 labels: `int64` `Tensor` or `SparseTensor` with shape 3443 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3444 num_labels=1. N >= 1 and num_labels is the number of target classes for 3445 the associated prediction. Commonly, N=1 and `labels` has shape 3446 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3447 should be in range [0, num_classes), where num_classes is the last 3448 dimension of `predictions`. Values outside this range are ignored. 3449 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where 3450 N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 3451 The final dimension contains the top `k` predicted class indices. 3452 [D1, ... DN] must match `labels`. 3453 k: Integer, k for @k metric. Only used for the default op name. 3454 class_id: Integer class ID for which we want binary metrics. This should be 3455 in range [0, num_classes], where num_classes is the last dimension of 3456 `predictions`. If `class_id` is outside this range, the method returns 3457 NAN. 3458 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3459 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3460 dimensions must be either `1`, or the same as the corresponding `labels` 3461 dimension). 3462 metrics_collections: An optional list of collections that values should 3463 be added to. 3464 updates_collections: An optional list of collections that updates should 3465 be added to. 3466 name: Name of new update operation, and namespace for other dependent ops. 3467 3468 Returns: 3469 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3470 divided by the sum of `true_positives` and `false_positives`. 3471 update_op: `Operation` that increments `true_positives` and 3472 `false_positives` variables appropriately, and whose value matches 3473 `precision`. 3474 3475 Raises: 3476 ValueError: If `weights` is not `None` and its shape doesn't match 3477 `predictions`, or if either `metrics_collections` or `updates_collections` 3478 are not a list or tuple. 3479 RuntimeError: If eager execution is enabled. 3480 """ 3481 if context.executing_eagerly(): 3482 raise RuntimeError('tf.metrics.precision_at_top_k is not ' 3483 'supported when eager execution is enabled.') 3484 3485 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3486 (predictions_idx, labels, weights)) as scope: 3487 labels = _maybe_expand_labels(labels, predictions_idx) 3488 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 3489 tp, tp_update = _streaming_sparse_true_positive_at_k( 3490 predictions_idx=top_k_idx, 3491 labels=labels, 3492 k=k, 3493 class_id=class_id, 3494 weights=weights) 3495 fp, fp_update = _streaming_sparse_false_positive_at_k( 3496 predictions_idx=top_k_idx, 3497 labels=labels, 3498 k=k, 3499 class_id=class_id, 3500 weights=weights) 3501 3502 def precision_across_replicas(_, tp, fp): 3503 return math_ops.divide(tp, math_ops.add(tp, fp), name=scope) 3504 3505 metric = _aggregate_across_replicas( 3506 metrics_collections, precision_across_replicas, tp, fp) 3507 3508 update = math_ops.divide( 3509 tp_update, math_ops.add(tp_update, fp_update), name='update') 3510 if updates_collections: 3511 ops.add_to_collections(updates_collections, update) 3512 return metric, update 3513 3514 3515@tf_export(v1=['metrics.sparse_precision_at_k']) 3516@deprecated(None, 'Use precision_at_k instead') 3517def sparse_precision_at_k(labels, 3518 predictions, 3519 k, 3520 class_id=None, 3521 weights=None, 3522 metrics_collections=None, 3523 updates_collections=None, 3524 name=None): 3525 """Renamed to `precision_at_k`, please use that method instead.""" 3526 return precision_at_k( 3527 labels=labels, 3528 predictions=predictions, 3529 k=k, 3530 class_id=class_id, 3531 weights=weights, 3532 metrics_collections=metrics_collections, 3533 updates_collections=updates_collections, 3534 name=name) 3535 3536 3537@tf_export(v1=['metrics.precision_at_k']) 3538def precision_at_k(labels, 3539 predictions, 3540 k, 3541 class_id=None, 3542 weights=None, 3543 metrics_collections=None, 3544 updates_collections=None, 3545 name=None): 3546 """Computes precision@k of the predictions with respect to sparse labels. 3547 3548 If `class_id` is specified, we calculate precision by considering only the 3549 entries in the batch for which `class_id` is in the top-k highest 3550 `predictions`, and computing the fraction of them for which `class_id` is 3551 indeed a correct label. 3552 If `class_id` is not specified, we'll calculate precision as how often on 3553 average a class among the top-k classes with the highest predicted values 3554 of a batch entry is correct and can be found in the label for that entry. 3555 3556 `precision_at_k` creates two local variables, 3557 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 3558 the precision@k frequency. This frequency is ultimately returned as 3559 `precision_at_<k>`: an idempotent operation that simply divides 3560 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 3561 `false_positive_at_<k>`). 3562 3563 For estimation of the metric over a stream of data, the function creates an 3564 `update_op` operation that updates these variables and returns the 3565 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3566 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3567 `labels` calculate the true positives and false positives weighted by 3568 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3569 `false_positive_at_<k>` using these values. 3570 3571 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3572 3573 Args: 3574 labels: `int64` `Tensor` or `SparseTensor` with shape 3575 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3576 num_labels=1. N >= 1 and num_labels is the number of target classes for 3577 the associated prediction. Commonly, N=1 and `labels` has shape 3578 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3579 should be in range [0, num_classes), where num_classes is the last 3580 dimension of `predictions`. Values outside this range are ignored. 3581 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3582 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 3583 The final dimension contains the logit values for each class. [D1, ... DN] 3584 must match `labels`. 3585 k: Integer, k for @k metric. 3586 class_id: Integer class ID for which we want binary metrics. This should be 3587 in range [0, num_classes], where num_classes is the last dimension of 3588 `predictions`. If `class_id` is outside this range, the method returns 3589 NAN. 3590 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3591 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3592 dimensions must be either `1`, or the same as the corresponding `labels` 3593 dimension). 3594 metrics_collections: An optional list of collections that values should 3595 be added to. 3596 updates_collections: An optional list of collections that updates should 3597 be added to. 3598 name: Name of new update operation, and namespace for other dependent ops. 3599 3600 Returns: 3601 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3602 divided by the sum of `true_positives` and `false_positives`. 3603 update_op: `Operation` that increments `true_positives` and 3604 `false_positives` variables appropriately, and whose value matches 3605 `precision`. 3606 3607 Raises: 3608 ValueError: If `weights` is not `None` and its shape doesn't match 3609 `predictions`, or if either `metrics_collections` or `updates_collections` 3610 are not a list or tuple. 3611 RuntimeError: If eager execution is enabled. 3612 """ 3613 if context.executing_eagerly(): 3614 raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 3615 'supported when eager execution is enabled.') 3616 3617 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3618 (predictions, labels, weights)) as scope: 3619 _, top_k_idx = nn.top_k(predictions, k) 3620 return precision_at_top_k( 3621 labels=labels, 3622 predictions_idx=top_k_idx, 3623 k=k, 3624 class_id=class_id, 3625 weights=weights, 3626 metrics_collections=metrics_collections, 3627 updates_collections=updates_collections, 3628 name=scope) 3629 3630 3631@tf_export(v1=['metrics.specificity_at_sensitivity']) 3632def specificity_at_sensitivity(labels, 3633 predictions, 3634 sensitivity, 3635 weights=None, 3636 num_thresholds=200, 3637 metrics_collections=None, 3638 updates_collections=None, 3639 name=None): 3640 """Computes the specificity at a given sensitivity. 3641 3642 The `specificity_at_sensitivity` function creates four local 3643 variables, `true_positives`, `true_negatives`, `false_positives` and 3644 `false_negatives` that are used to compute the specificity at the given 3645 sensitivity value. The threshold for the given sensitivity value is computed 3646 and used to evaluate the corresponding specificity. 3647 3648 For estimation of the metric over a stream of data, the function creates an 3649 `update_op` operation that updates these variables and returns the 3650 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 3651 `false_positives` and `false_negatives` counts with the weight of each case 3652 found in the `predictions` and `labels`. 3653 3654 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3655 3656 For additional information about specificity and sensitivity, see the 3657 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 3658 3659 Args: 3660 labels: The ground truth values, a `Tensor` whose dimensions must match 3661 `predictions`. Will be cast to `bool`. 3662 predictions: A floating point `Tensor` of arbitrary shape and whose values 3663 are in the range `[0, 1]`. 3664 sensitivity: A scalar value in range `[0, 1]`. 3665 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3666 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 3667 be either `1`, or the same as the corresponding `labels` dimension). 3668 num_thresholds: The number of thresholds to use for matching the given 3669 sensitivity. 3670 metrics_collections: An optional list of collections that `specificity` 3671 should be added to. 3672 updates_collections: An optional list of collections that `update_op` should 3673 be added to. 3674 name: An optional variable_scope name. 3675 3676 Returns: 3677 specificity: A scalar `Tensor` representing the specificity at the given 3678 `sensitivity` value. 3679 update_op: An operation that increments the `true_positives`, 3680 `true_negatives`, `false_positives` and `false_negatives` variables 3681 appropriately and whose value matches `specificity`. 3682 3683 Raises: 3684 ValueError: If `predictions` and `labels` have mismatched shapes, if 3685 `weights` is not `None` and its shape doesn't match `predictions`, or if 3686 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 3687 or `updates_collections` are not a list or tuple. 3688 RuntimeError: If eager execution is enabled. 3689 """ 3690 if context.executing_eagerly(): 3691 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 3692 'supported when eager execution is enabled.') 3693 3694 if sensitivity < 0 or sensitivity > 1: 3695 raise ValueError('`sensitivity` must be in the range [0, 1].') 3696 3697 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 3698 (predictions, labels, weights)): 3699 kepsilon = 1e-7 # to account for floating point imprecisions 3700 thresholds = [ 3701 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 3702 ] 3703 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 3704 3705 values, update_ops = _confusion_matrix_at_thresholds( 3706 labels, predictions, thresholds, weights) 3707 3708 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name): 3709 """Computes the specificity at the given sensitivity. 3710 3711 Args: 3712 tp: True positives. 3713 tn: True negatives. 3714 fp: False positives. 3715 fn: False negatives. 3716 name: The name of the operation. 3717 3718 Returns: 3719 The specificity using the aggregated values. 3720 """ 3721 sensitivities = math_ops.divide(tp, tp + fn + kepsilon) 3722 3723 # We'll need to use this trick until tf.argmax allows us to specify 3724 # whether we should use the first or last index in case of ties. 3725 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 3726 indices_at_minval = math_ops.equal( 3727 math_ops.abs(sensitivities - sensitivity), min_val) 3728 indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64) 3729 indices_at_minval = math_ops.cumsum(indices_at_minval) 3730 tf_index = math_ops.argmax(indices_at_minval, 0) 3731 tf_index = math_ops.cast(tf_index, dtypes.int32) 3732 3733 # Now, we have the implicit threshold, so compute the specificity: 3734 return math_ops.divide(tn[tf_index], 3735 tn[tf_index] + fp[tf_index] + kepsilon, name) 3736 3737 def specificity_across_replicas(_, values): 3738 return compute_specificity_at_sensitivity( 3739 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 3740 3741 specificity = _aggregate_across_replicas( 3742 metrics_collections, specificity_across_replicas, values) 3743 3744 update_op = compute_specificity_at_sensitivity( 3745 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3746 'update_op') 3747 if updates_collections: 3748 ops.add_to_collections(updates_collections, update_op) 3749 3750 return specificity, update_op 3751