1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of tf.metrics module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import confusion_matrix 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn 32from tensorflow.python.ops import sets 33from tensorflow.python.ops import sparse_ops 34from tensorflow.python.ops import state_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.ops import weights_broadcast_ops 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util.deprecation import deprecated 39from tensorflow.python.util.tf_export import tf_export 40 41 42def metric_variable(shape, dtype, validate_shape=True, name=None): 43 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. 44 45 If running in a `DistributionStrategy` context, the variable will be 46 "sync on read". This means: 47 48 * The returned object will be a container with separate variables 49 per replica of the model. 50 51 * When writing to the variable, e.g. using `assign_add` in a metric 52 update, the update will be applied to the variable local to the 53 replica. 54 55 * To get a metric's result value, we need to sum the variable values 56 across the replicas before computing the final answer. Furthermore, 57 the final answer should be computed once instead of in every 58 replica. Both of these are accomplished by running the computation 59 of the final result value inside 60 `distribution_strategy_context.get_replica_context().merge_call(fn)`. 61 Inside the `merge_call()`, ops are only added to the graph once 62 and access to a sync on read variable in a computation returns 63 the sum across all replicas. 64 65 Args: 66 shape: Shape of the created variable. 67 dtype: Type of the created variable. 68 validate_shape: (Optional) Whether shape validation is enabled for 69 the created variable. 70 name: (Optional) String name of the created variable. 71 72 Returns: 73 A (non-trainable) variable initialized to zero, or if inside a 74 `DistributionStrategy` scope a sync on read variable container. 75 """ 76 # Note that synchronization "ON_READ" implies trainable=False. 77 return variable_scope.variable( 78 lambda: array_ops.zeros(shape, dtype), 79 collections=[ 80 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 81 ], 82 validate_shape=validate_shape, 83 synchronization=variable_scope.VariableSynchronization.ON_READ, 84 aggregation=variable_scope.VariableAggregation.SUM, 85 name=name) 86 87 88def _remove_squeezable_dimensions(predictions, labels, weights): 89 """Squeeze or expand last dim if needed. 90 91 Squeezes last dim of `predictions` or `labels` if their rank differs by 1 92 (using confusion_matrix.remove_squeezable_dimensions). 93 Squeezes or expands last dim of `weights` if its rank differs by 1 from the 94 new rank of `predictions`. 95 96 If `weights` is scalar, it is kept scalar. 97 98 This will use static shape if available. Otherwise, it will add graph 99 operations, which could result in a performance hit. 100 101 Args: 102 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 103 labels: Optional label `Tensor` whose dimensions match `predictions`. 104 weights: Optional weight scalar or `Tensor` whose dimensions match 105 `predictions`. 106 107 Returns: 108 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has 109 the last dimension squeezed, `weights` could be extended by one dimension. 110 """ 111 predictions = ops.convert_to_tensor(predictions) 112 if labels is not None: 113 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 114 labels, predictions) 115 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 116 117 if weights is None: 118 return predictions, labels, None 119 120 weights = ops.convert_to_tensor(weights) 121 weights_shape = weights.get_shape() 122 weights_rank = weights_shape.ndims 123 if weights_rank == 0: 124 return predictions, labels, weights 125 126 predictions_shape = predictions.get_shape() 127 predictions_rank = predictions_shape.ndims 128 if (predictions_rank is not None) and (weights_rank is not None): 129 # Use static rank. 130 if weights_rank - predictions_rank == 1: 131 weights = array_ops.squeeze(weights, [-1]) 132 elif predictions_rank - weights_rank == 1: 133 weights = array_ops.expand_dims(weights, [-1]) 134 else: 135 # Use dynamic rank. 136 weights_rank_tensor = array_ops.rank(weights) 137 rank_diff = weights_rank_tensor - array_ops.rank(predictions) 138 139 def _maybe_expand_weights(): 140 return control_flow_ops.cond( 141 math_ops.equal(rank_diff, -1), 142 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) 143 144 # Don't attempt squeeze if it will fail based on static check. 145 if ((weights_rank is not None) and 146 (not weights_shape.dims[-1].is_compatible_with(1))): 147 maybe_squeeze_weights = lambda: weights 148 else: 149 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) 150 151 def _maybe_adjust_weights(): 152 return control_flow_ops.cond( 153 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 154 _maybe_expand_weights) 155 156 # If weights are scalar, do nothing. Otherwise, try to add or remove a 157 # dimension to match predictions. 158 weights = control_flow_ops.cond( 159 math_ops.equal(weights_rank_tensor, 0), lambda: weights, 160 _maybe_adjust_weights) 161 return predictions, labels, weights 162 163 164def _maybe_expand_labels(labels, predictions): 165 """If necessary, expand `labels` along last dimension to match `predictions`. 166 167 Args: 168 labels: `Tensor` or `SparseTensor` with shape 169 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies 170 num_labels=1, in which case the result is an expanded `labels` with shape 171 [D1, ... DN, 1]. 172 predictions: `Tensor` with shape [D1, ... DN, num_classes]. 173 174 Returns: 175 `labels` with the same rank as `predictions`. 176 177 Raises: 178 ValueError: if `labels` has invalid shape. 179 """ 180 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope: 181 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 182 183 # If sparse, expand sparse shape. 184 if isinstance(labels, sparse_tensor.SparseTensor): 185 return control_flow_ops.cond( 186 math_ops.equal( 187 array_ops.rank(predictions), 188 array_ops.size(labels.dense_shape) + 1), 189 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda 190 labels, 191 shape=array_ops.concat((labels.dense_shape, (1,)), 0), 192 name=scope), 193 lambda: labels) 194 195 # Otherwise, try to use static shape. 196 labels_rank = labels.get_shape().ndims 197 if labels_rank is not None: 198 predictions_rank = predictions.get_shape().ndims 199 if predictions_rank is not None: 200 if predictions_rank == labels_rank: 201 return labels 202 if predictions_rank == labels_rank + 1: 203 return array_ops.expand_dims(labels, -1, name=scope) 204 raise ValueError( 205 'Unexpected labels shape %s for predictions shape %s.' % 206 (labels.get_shape(), predictions.get_shape())) 207 208 # Otherwise, use dynamic shape. 209 return control_flow_ops.cond( 210 math_ops.equal(array_ops.rank(predictions), 211 array_ops.rank(labels) + 1), 212 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) 213 214 215def _safe_scalar_div(numerator, denominator, name): 216 """Divides two values, returning 0 if the denominator is 0. 217 218 Args: 219 numerator: A scalar `float64` `Tensor`. 220 denominator: A scalar `float64` `Tensor`. 221 name: Name for the returned op. 222 223 Returns: 224 0 if `denominator` == 0, else `numerator` / `denominator` 225 """ 226 numerator.get_shape().with_rank_at_most(1) 227 denominator.get_shape().with_rank_at_most(1) 228 return math_ops.div_no_nan(numerator, denominator, name=name) 229 230 231def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 232 """Calculate a streaming confusion matrix. 233 234 Calculates a confusion matrix. For estimation over a stream of data, 235 the function creates an `update_op` operation. 236 237 Args: 238 labels: A `Tensor` of ground truth labels with shape [batch size] and of 239 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 240 predictions: A `Tensor` of prediction results for semantic labels, whose 241 shape is [batch size] and type `int32` or `int64`. The tensor will be 242 flattened if its rank > 1. 243 num_classes: The possible number of labels the prediction task can 244 have. This value must be provided, since a confusion matrix of 245 dimension = [num_classes, num_classes] will be allocated. 246 weights: Optional `Tensor` whose rank is either 0, or the same rank as 247 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 248 be either `1`, or the same as the corresponding `labels` dimension). 249 250 Returns: 251 total_cm: A `Tensor` representing the confusion matrix. 252 update_op: An operation that increments the confusion matrix. 253 """ 254 # Local variable to accumulate the predictions in the confusion matrix. 255 total_cm = metric_variable( 256 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 257 258 # Cast the type to int64 required by confusion_matrix_ops. 259 predictions = math_ops.cast(predictions, dtypes.int64) 260 labels = math_ops.cast(labels, dtypes.int64) 261 num_classes = math_ops.cast(num_classes, dtypes.int64) 262 263 # Flatten the input if its rank > 1. 264 if predictions.get_shape().ndims > 1: 265 predictions = array_ops.reshape(predictions, [-1]) 266 267 if labels.get_shape().ndims > 1: 268 labels = array_ops.reshape(labels, [-1]) 269 270 if (weights is not None) and (weights.get_shape().ndims > 1): 271 weights = array_ops.reshape(weights, [-1]) 272 273 # Accumulate the prediction to current confusion matrix. 274 current_cm = confusion_matrix.confusion_matrix( 275 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 276 update_op = state_ops.assign_add(total_cm, current_cm) 277 return total_cm, update_op 278 279 280def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args): 281 """Aggregate metric value across replicas.""" 282 def fn(distribution, *a): 283 """Call `metric_value_fn` in the correct control flow context.""" 284 if hasattr(distribution.extended, '_outer_control_flow_context'): 285 # If there was an outer context captured before this method was called, 286 # then we enter that context to create the metric value op. If the 287 # caputred context is `None`, ops.control_dependencies(None) gives the 288 # desired behavior. Else we use `Enter` and `Exit` to enter and exit the 289 # captured context. 290 # This special handling is needed because sometimes the metric is created 291 # inside a while_loop (and perhaps a TPU rewrite context). But we don't 292 # want the value op to be evaluated every step or on the TPU. So we 293 # create it outside so that it can be evaluated at the end on the host, 294 # once the update ops have been evaluted. 295 296 # pylint: disable=protected-access 297 if distribution.extended._outer_control_flow_context is None: 298 with ops.control_dependencies(None): 299 metric_value = metric_value_fn(distribution, *a) 300 else: 301 distribution.extended._outer_control_flow_context.Enter() 302 metric_value = metric_value_fn(distribution, *a) 303 distribution.extended._outer_control_flow_context.Exit() 304 # pylint: enable=protected-access 305 else: 306 metric_value = metric_value_fn(distribution, *a) 307 if metrics_collections: 308 ops.add_to_collections(metrics_collections, metric_value) 309 return metric_value 310 311 return distribution_strategy_context.get_replica_context().merge_call( 312 fn, args=args) 313 314 315@tf_export(v1=['metrics.mean']) 316def mean(values, 317 weights=None, 318 metrics_collections=None, 319 updates_collections=None, 320 name=None): 321 """Computes the (weighted) mean of the given values. 322 323 The `mean` function creates two local variables, `total` and `count` 324 that are used to compute the average of `values`. This average is ultimately 325 returned as `mean` which is an idempotent operation that simply divides 326 `total` by `count`. 327 328 For estimation of the metric over a stream of data, the function creates an 329 `update_op` operation that updates these variables and returns the `mean`. 330 `update_op` increments `total` with the reduced sum of the product of `values` 331 and `weights`, and it increments `count` with the reduced sum of `weights`. 332 333 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 334 335 Args: 336 values: A `Tensor` of arbitrary dimensions. 337 weights: Optional `Tensor` whose rank is either 0, or the same rank as 338 `values`, and must be broadcastable to `values` (i.e., all dimensions must 339 be either `1`, or the same as the corresponding `values` dimension). 340 metrics_collections: An optional list of collections that `mean` 341 should be added to. 342 updates_collections: An optional list of collections that `update_op` 343 should be added to. 344 name: An optional variable_scope name. 345 346 Returns: 347 mean: A `Tensor` representing the current mean, the value of `total` divided 348 by `count`. 349 update_op: An operation that increments the `total` and `count` variables 350 appropriately and whose value matches `mean_value`. 351 352 Raises: 353 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 354 or if either `metrics_collections` or `updates_collections` are not a list 355 or tuple. 356 RuntimeError: If eager execution is enabled. 357 """ 358 if context.executing_eagerly(): 359 raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 360 'is enabled.') 361 362 with variable_scope.variable_scope(name, 'mean', (values, weights)): 363 values = math_ops.cast(values, dtypes.float32) 364 365 total = metric_variable([], dtypes.float32, name='total') 366 count = metric_variable([], dtypes.float32, name='count') 367 368 if weights is None: 369 num_values = math_ops.cast(array_ops.size(values), dtypes.float32) 370 else: 371 values, _, weights = _remove_squeezable_dimensions( 372 predictions=values, labels=None, weights=weights) 373 weights = weights_broadcast_ops.broadcast_weights( 374 math_ops.cast(weights, dtypes.float32), values) 375 values = math_ops.multiply(values, weights) 376 num_values = math_ops.reduce_sum(weights) 377 378 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 379 with ops.control_dependencies([values]): 380 update_count_op = state_ops.assign_add(count, num_values) 381 382 def compute_mean(_, t, c): 383 return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value') 384 385 mean_t = _aggregate_across_replicas( 386 metrics_collections, compute_mean, total, count) 387 update_op = math_ops.div_no_nan( 388 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 389 390 if updates_collections: 391 ops.add_to_collections(updates_collections, update_op) 392 393 return mean_t, update_op 394 395 396@tf_export(v1=['metrics.accuracy']) 397def accuracy(labels, 398 predictions, 399 weights=None, 400 metrics_collections=None, 401 updates_collections=None, 402 name=None): 403 """Calculates how often `predictions` matches `labels`. 404 405 The `accuracy` function creates two local variables, `total` and 406 `count` that are used to compute the frequency with which `predictions` 407 matches `labels`. This frequency is ultimately returned as `accuracy`: an 408 idempotent operation that simply divides `total` by `count`. 409 410 For estimation of the metric over a stream of data, the function creates an 411 `update_op` operation that updates these variables and returns the `accuracy`. 412 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 413 where the corresponding elements of `predictions` and `labels` match and 0.0 414 otherwise. Then `update_op` increments `total` with the reduced sum of the 415 product of `weights` and `is_correct`, and it increments `count` with the 416 reduced sum of `weights`. 417 418 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 419 420 Args: 421 labels: The ground truth values, a `Tensor` whose shape matches 422 `predictions`. 423 predictions: The predicted values, a `Tensor` of any shape. 424 weights: Optional `Tensor` whose rank is either 0, or the same rank as 425 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 426 be either `1`, or the same as the corresponding `labels` dimension). 427 metrics_collections: An optional list of collections that `accuracy` should 428 be added to. 429 updates_collections: An optional list of collections that `update_op` should 430 be added to. 431 name: An optional variable_scope name. 432 433 Returns: 434 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 435 by `count`. 436 update_op: An operation that increments the `total` and `count` variables 437 appropriately and whose value matches `accuracy`. 438 439 Raises: 440 ValueError: If `predictions` and `labels` have mismatched shapes, or if 441 `weights` is not `None` and its shape doesn't match `predictions`, or if 442 either `metrics_collections` or `updates_collections` are not a list or 443 tuple. 444 RuntimeError: If eager execution is enabled. 445 """ 446 if context.executing_eagerly(): 447 raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 448 'execution is enabled.') 449 450 predictions, labels, weights = _remove_squeezable_dimensions( 451 predictions=predictions, labels=labels, weights=weights) 452 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 453 if labels.dtype != predictions.dtype: 454 predictions = math_ops.cast(predictions, labels.dtype) 455 is_correct = math_ops.cast( 456 math_ops.equal(predictions, labels), dtypes.float32) 457 return mean(is_correct, weights, metrics_collections, updates_collections, 458 name or 'accuracy') 459 460 461def _confusion_matrix_at_thresholds(labels, 462 predictions, 463 thresholds, 464 weights=None, 465 includes=None): 466 """Computes true_positives, false_negatives, true_negatives, false_positives. 467 468 This function creates up to four local variables, `true_positives`, 469 `true_negatives`, `false_positives` and `false_negatives`. 470 `true_positive[i]` is defined as the total weight of values in `predictions` 471 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 472 `false_negatives[i]` is defined as the total weight of values in `predictions` 473 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 474 `true_negatives[i]` is defined as the total weight of values in `predictions` 475 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 476 `false_positives[i]` is defined as the total weight of values in `predictions` 477 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 478 479 For estimation of these metrics over a stream of data, for each metric the 480 function respectively creates an `update_op` operation that updates the 481 variable and returns its value. 482 483 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 484 485 Args: 486 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 487 `bool`. 488 predictions: A floating point `Tensor` of arbitrary shape and whose values 489 are in the range `[0, 1]`. 490 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 491 weights: Optional `Tensor` whose rank is either 0, or the same rank as 492 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 493 be either `1`, or the same as the corresponding `labels` dimension). 494 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 495 default to all four. 496 497 Returns: 498 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 499 `includes`. 500 update_ops: Dict of operations that increments the `values`. Keys are from 501 `includes`. 502 503 Raises: 504 ValueError: If `predictions` and `labels` have mismatched shapes, or if 505 `weights` is not `None` and its shape doesn't match `predictions`, or if 506 `includes` contains invalid keys. 507 """ 508 all_includes = ('tp', 'fn', 'tn', 'fp') 509 if includes is None: 510 includes = all_includes 511 else: 512 for include in includes: 513 if include not in all_includes: 514 raise ValueError('Invalid key: %s.' % include) 515 516 with ops.control_dependencies([ 517 check_ops.assert_greater_equal( 518 predictions, 519 math_ops.cast(0.0, dtype=predictions.dtype), 520 message='predictions must be in [0, 1]'), 521 check_ops.assert_less_equal( 522 predictions, 523 math_ops.cast(1.0, dtype=predictions.dtype), 524 message='predictions must be in [0, 1]') 525 ]): 526 predictions, labels, weights = _remove_squeezable_dimensions( 527 predictions=math_ops.cast(predictions, dtypes.float32), 528 labels=math_ops.cast(labels, dtype=dtypes.bool), 529 weights=weights) 530 531 num_thresholds = len(thresholds) 532 533 # Reshape predictions and labels. 534 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 535 labels_2d = array_ops.reshape( 536 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 537 538 # Use static shape if known. 539 num_predictions = predictions_2d.get_shape().as_list()[0] 540 541 # Otherwise use dynamic shape. 542 if num_predictions is None: 543 num_predictions = array_ops.shape(predictions_2d)[0] 544 thresh_tiled = array_ops.tile( 545 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 546 array_ops.stack([1, num_predictions])) 547 548 # Tile the predictions after thresholding them across different thresholds. 549 pred_is_pos = math_ops.greater( 550 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 551 thresh_tiled) 552 if ('fn' in includes) or ('tn' in includes): 553 pred_is_neg = math_ops.logical_not(pred_is_pos) 554 555 # Tile labels by number of thresholds 556 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 557 if ('fp' in includes) or ('tn' in includes): 558 label_is_neg = math_ops.logical_not(label_is_pos) 559 560 if weights is not None: 561 weights = weights_broadcast_ops.broadcast_weights( 562 math_ops.cast(weights, dtypes.float32), predictions) 563 weights_tiled = array_ops.tile( 564 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 565 thresh_tiled.get_shape().assert_is_compatible_with( 566 weights_tiled.get_shape()) 567 else: 568 weights_tiled = None 569 570 values = {} 571 update_ops = {} 572 573 if 'tp' in includes: 574 true_p = metric_variable( 575 [num_thresholds], dtypes.float32, name='true_positives') 576 is_true_positive = math_ops.cast( 577 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) 578 if weights_tiled is not None: 579 is_true_positive *= weights_tiled 580 update_ops['tp'] = state_ops.assign_add(true_p, 581 math_ops.reduce_sum( 582 is_true_positive, 1)) 583 values['tp'] = true_p 584 585 if 'fn' in includes: 586 false_n = metric_variable( 587 [num_thresholds], dtypes.float32, name='false_negatives') 588 is_false_negative = math_ops.cast( 589 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) 590 if weights_tiled is not None: 591 is_false_negative *= weights_tiled 592 update_ops['fn'] = state_ops.assign_add(false_n, 593 math_ops.reduce_sum( 594 is_false_negative, 1)) 595 values['fn'] = false_n 596 597 if 'tn' in includes: 598 true_n = metric_variable( 599 [num_thresholds], dtypes.float32, name='true_negatives') 600 is_true_negative = math_ops.cast( 601 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) 602 if weights_tiled is not None: 603 is_true_negative *= weights_tiled 604 update_ops['tn'] = state_ops.assign_add(true_n, 605 math_ops.reduce_sum( 606 is_true_negative, 1)) 607 values['tn'] = true_n 608 609 if 'fp' in includes: 610 false_p = metric_variable( 611 [num_thresholds], dtypes.float32, name='false_positives') 612 is_false_positive = math_ops.cast( 613 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) 614 if weights_tiled is not None: 615 is_false_positive *= weights_tiled 616 update_ops['fp'] = state_ops.assign_add(false_p, 617 math_ops.reduce_sum( 618 is_false_positive, 1)) 619 values['fp'] = false_p 620 621 return values, update_ops 622 623 624def _aggregate_variable(v, collections): 625 f = lambda distribution, value: distribution.extended.read_var(value) 626 return _aggregate_across_replicas(collections, f, v) 627 628 629@tf_export(v1=['metrics.auc']) 630def auc(labels, 631 predictions, 632 weights=None, 633 num_thresholds=200, 634 metrics_collections=None, 635 updates_collections=None, 636 curve='ROC', 637 name=None, 638 summation_method='trapezoidal'): 639 """Computes the approximate AUC via a Riemann sum. 640 641 The `auc` function creates four local variables, `true_positives`, 642 `true_negatives`, `false_positives` and `false_negatives` that are used to 643 compute the AUC. To discretize the AUC curve, a linearly spaced set of 644 thresholds is used to compute pairs of recall and precision values. The area 645 under the ROC-curve is therefore computed using the height of the recall 646 values by the false positive rate, while the area under the PR-curve is the 647 computed using the height of the precision values by the recall. 648 649 This value is ultimately returned as `auc`, an idempotent operation that 650 computes the area under a discretized curve of precision versus recall values 651 (computed using the aforementioned variables). The `num_thresholds` variable 652 controls the degree of discretization with larger numbers of thresholds more 653 closely approximating the true AUC. The quality of the approximation may vary 654 dramatically depending on `num_thresholds`. 655 656 For best results, `predictions` should be distributed approximately uniformly 657 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 658 approximation may be poor if this is not the case. Setting `summation_method` 659 to 'minoring' or 'majoring' can help quantify the error in the approximation 660 by providing lower or upper bound estimate of the AUC. 661 662 For estimation of the metric over a stream of data, the function creates an 663 `update_op` operation that updates these variables and returns the `auc`. 664 665 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 666 667 Args: 668 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 669 `bool`. 670 predictions: A floating point `Tensor` of arbitrary shape and whose values 671 are in the range `[0, 1]`. 672 weights: Optional `Tensor` whose rank is either 0, or the same rank as 673 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 674 be either `1`, or the same as the corresponding `labels` dimension). 675 num_thresholds: The number of thresholds to use when discretizing the roc 676 curve. 677 metrics_collections: An optional list of collections that `auc` should be 678 added to. 679 updates_collections: An optional list of collections that `update_op` should 680 be added to. 681 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 682 'PR' for the Precision-Recall-curve. 683 name: An optional variable_scope name. 684 summation_method: Specifies the Riemann summation method used 685 (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that 686 applies the trapezoidal rule; 'careful_interpolation', a variant of it 687 differing only by a more correct interpolation scheme for PR-AUC - 688 interpolating (true/false) positives but not the ratio that is precision; 689 'minoring' that applies left summation for increasing intervals and right 690 summation for decreasing intervals; 'majoring' that does the opposite. 691 Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' 692 (to be deprecated soon) as it applies the same method for ROC, and a 693 better one (see Davis & Goadrich 2006 for details) for the PR curve. 694 695 Returns: 696 auc: A scalar `Tensor` representing the current area-under-curve. 697 update_op: An operation that increments the `true_positives`, 698 `true_negatives`, `false_positives` and `false_negatives` variables 699 appropriately and whose value matches `auc`. 700 701 Raises: 702 ValueError: If `predictions` and `labels` have mismatched shapes, or if 703 `weights` is not `None` and its shape doesn't match `predictions`, or if 704 either `metrics_collections` or `updates_collections` are not a list or 705 tuple. 706 RuntimeError: If eager execution is enabled. 707 """ 708 if context.executing_eagerly(): 709 raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 710 'is enabled.') 711 712 with variable_scope.variable_scope(name, 'auc', 713 (labels, predictions, weights)): 714 if curve != 'ROC' and curve != 'PR': 715 raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) 716 kepsilon = 1e-7 # to account for floating point imprecisions 717 thresholds = [ 718 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 719 ] 720 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 721 722 values, update_ops = _confusion_matrix_at_thresholds( 723 labels, predictions, thresholds, weights) 724 725 # Add epsilons to avoid dividing by 0. 726 epsilon = 1.0e-6 727 728 def interpolate_pr_auc(tp, fp, fn): 729 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. 730 731 Note here we derive & use a closed formula not present in the paper 732 - as follows: 733 Modeling all of TP (true positive weight), 734 FP (false positive weight) and their sum P = TP + FP (positive weight) 735 as varying linearly within each interval [A, B] between successive 736 thresholds, we get 737 Precision = (TP_A + slope * (P - P_A)) / P 738 with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). 739 The area within the interval is thus (slope / total_pos_weight) times 740 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 741 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 742 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 743 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 744 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 745 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 746 where dTP == TP_B - TP_A. 747 Note that when P_A == 0 the above calculation simplifies into 748 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 749 which is really equivalent to imputing constant precision throughout the 750 first bucket having >0 true positives. 751 752 Args: 753 tp: true positive counts 754 fp: false positive counts 755 fn: false negative counts 756 Returns: 757 pr_auc: an approximation of the area under the P-R curve. 758 """ 759 dtp = tp[:num_thresholds - 1] - tp[1:] 760 p = tp + fp 761 prec_slope = math_ops.div_no_nan( 762 dtp, 763 math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0), 764 name='prec_slope') 765 intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) 766 safe_p_ratio = array_ops.where( 767 math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), 768 math_ops.div_no_nan( 769 p[:num_thresholds - 1], 770 math_ops.maximum(p[1:], 0), 771 name='recall_relative_ratio'), array_ops.ones_like(p[1:])) 772 return math_ops.reduce_sum( 773 math_ops.div_no_nan( 774 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 775 math_ops.maximum(tp[1:] + fn[1:], 0), 776 name='pr_auc_increment'), 777 name='interpolate_pr_auc') 778 779 def compute_auc(tp, fn, tn, fp, name): 780 """Computes the roc-auc or pr-auc based on confusion counts.""" 781 if curve == 'PR': 782 if summation_method == 'trapezoidal': 783 logging.warning( 784 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' 785 'please switch to "careful_interpolation" instead.') 786 elif summation_method == 'careful_interpolation': 787 # This one is a bit tricky and is handled separately. 788 return interpolate_pr_auc(tp, fp, fn) 789 rec = math_ops.div(tp + epsilon, tp + fn + epsilon) 790 if curve == 'ROC': 791 fp_rate = math_ops.div(fp, fp + tn + epsilon) 792 x = fp_rate 793 y = rec 794 else: # curve == 'PR'. 795 prec = math_ops.div(tp + epsilon, tp + fp + epsilon) 796 x = rec 797 y = prec 798 if summation_method in ('trapezoidal', 'careful_interpolation'): 799 # Note that the case ('PR', 'careful_interpolation') has been handled 800 # above. 801 return math_ops.reduce_sum( 802 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 803 (y[:num_thresholds - 1] + y[1:]) / 2.), 804 name=name) 805 elif summation_method == 'minoring': 806 return math_ops.reduce_sum( 807 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 808 math_ops.minimum(y[:num_thresholds - 1], y[1:])), 809 name=name) 810 elif summation_method == 'majoring': 811 return math_ops.reduce_sum( 812 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 813 math_ops.maximum(y[:num_thresholds - 1], y[1:])), 814 name=name) 815 else: 816 raise ValueError('Invalid summation_method: %s' % summation_method) 817 818 # sum up the areas of all the trapeziums 819 def compute_auc_value(_, values): 820 return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'], 821 'value') 822 823 auc_value = _aggregate_across_replicas( 824 metrics_collections, compute_auc_value, values) 825 update_op = compute_auc(update_ops['tp'], update_ops['fn'], 826 update_ops['tn'], update_ops['fp'], 'update_op') 827 828 if updates_collections: 829 ops.add_to_collections(updates_collections, update_op) 830 831 return auc_value, update_op 832 833 834@tf_export(v1=['metrics.mean_absolute_error']) 835def mean_absolute_error(labels, 836 predictions, 837 weights=None, 838 metrics_collections=None, 839 updates_collections=None, 840 name=None): 841 """Computes the mean absolute error between the labels and predictions. 842 843 The `mean_absolute_error` function creates two local variables, 844 `total` and `count` that are used to compute the mean absolute error. This 845 average is weighted by `weights`, and it is ultimately returned as 846 `mean_absolute_error`: an idempotent operation that simply divides `total` by 847 `count`. 848 849 For estimation of the metric over a stream of data, the function creates an 850 `update_op` operation that updates these variables and returns the 851 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 852 absolute value of the differences between `predictions` and `labels`. Then 853 `update_op` increments `total` with the reduced sum of the product of 854 `weights` and `absolute_errors`, and it increments `count` with the reduced 855 sum of `weights` 856 857 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 858 859 Args: 860 labels: A `Tensor` of the same shape as `predictions`. 861 predictions: A `Tensor` of arbitrary shape. 862 weights: Optional `Tensor` whose rank is either 0, or the same rank as 863 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 864 be either `1`, or the same as the corresponding `labels` dimension). 865 metrics_collections: An optional list of collections that 866 `mean_absolute_error` should be added to. 867 updates_collections: An optional list of collections that `update_op` should 868 be added to. 869 name: An optional variable_scope name. 870 871 Returns: 872 mean_absolute_error: A `Tensor` representing the current mean, the value of 873 `total` divided by `count`. 874 update_op: An operation that increments the `total` and `count` variables 875 appropriately and whose value matches `mean_absolute_error`. 876 877 Raises: 878 ValueError: If `predictions` and `labels` have mismatched shapes, or if 879 `weights` is not `None` and its shape doesn't match `predictions`, or if 880 either `metrics_collections` or `updates_collections` are not a list or 881 tuple. 882 RuntimeError: If eager execution is enabled. 883 """ 884 if context.executing_eagerly(): 885 raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 886 'when eager execution is enabled.') 887 888 predictions, labels, weights = _remove_squeezable_dimensions( 889 predictions=predictions, labels=labels, weights=weights) 890 absolute_errors = math_ops.abs(predictions - labels) 891 return mean(absolute_errors, weights, metrics_collections, 892 updates_collections, name or 'mean_absolute_error') 893 894 895@tf_export(v1=['metrics.mean_cosine_distance']) 896def mean_cosine_distance(labels, 897 predictions, 898 dim, 899 weights=None, 900 metrics_collections=None, 901 updates_collections=None, 902 name=None): 903 """Computes the cosine distance between the labels and predictions. 904 905 The `mean_cosine_distance` function creates two local variables, 906 `total` and `count` that are used to compute the average cosine distance 907 between `predictions` and `labels`. This average is weighted by `weights`, 908 and it is ultimately returned as `mean_distance`, which is an idempotent 909 operation that simply divides `total` by `count`. 910 911 For estimation of the metric over a stream of data, the function creates an 912 `update_op` operation that updates these variables and returns the 913 `mean_distance`. 914 915 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 916 917 Args: 918 labels: A `Tensor` of arbitrary shape. 919 predictions: A `Tensor` of the same shape as `labels`. 920 dim: The dimension along which the cosine distance is computed. 921 weights: Optional `Tensor` whose rank is either 0, or the same rank as 922 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 923 be either `1`, or the same as the corresponding `labels` dimension). Also, 924 dimension `dim` must be `1`. 925 metrics_collections: An optional list of collections that the metric 926 value variable should be added to. 927 updates_collections: An optional list of collections that the metric update 928 ops should be added to. 929 name: An optional variable_scope name. 930 931 Returns: 932 mean_distance: A `Tensor` representing the current mean, the value of 933 `total` divided by `count`. 934 update_op: An operation that increments the `total` and `count` variables 935 appropriately. 936 937 Raises: 938 ValueError: If `predictions` and `labels` have mismatched shapes, or if 939 `weights` is not `None` and its shape doesn't match `predictions`, or if 940 either `metrics_collections` or `updates_collections` are not a list or 941 tuple. 942 RuntimeError: If eager execution is enabled. 943 """ 944 if context.executing_eagerly(): 945 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 946 'eager execution is enabled.') 947 948 predictions, labels, weights = _remove_squeezable_dimensions( 949 predictions=predictions, labels=labels, weights=weights) 950 radial_diffs = math_ops.multiply(predictions, labels) 951 radial_diffs = math_ops.reduce_sum( 952 radial_diffs, axis=[ 953 dim, 954 ], keepdims=True) 955 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or 956 'mean_cosine_distance') 957 mean_distance = math_ops.subtract(1.0, mean_distance) 958 update_op = math_ops.subtract(1.0, update_op) 959 960 if metrics_collections: 961 ops.add_to_collections(metrics_collections, mean_distance) 962 963 if updates_collections: 964 ops.add_to_collections(updates_collections, update_op) 965 966 return mean_distance, update_op 967 968 969@tf_export(v1=['metrics.mean_per_class_accuracy']) 970def mean_per_class_accuracy(labels, 971 predictions, 972 num_classes, 973 weights=None, 974 metrics_collections=None, 975 updates_collections=None, 976 name=None): 977 """Calculates the mean of the per-class accuracies. 978 979 Calculates the accuracy for each class, then takes the mean of that. 980 981 For estimation of the metric over a stream of data, the function creates an 982 `update_op` operation that updates the accuracy of each class and returns 983 them. 984 985 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 986 987 Args: 988 labels: A `Tensor` of ground truth labels with shape [batch size] and of 989 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 990 predictions: A `Tensor` of prediction results for semantic labels, whose 991 shape is [batch size] and type `int32` or `int64`. The tensor will be 992 flattened if its rank > 1. 993 num_classes: The possible number of labels the prediction task can 994 have. This value must be provided, since two variables with shape = 995 [num_classes] will be allocated. 996 weights: Optional `Tensor` whose rank is either 0, or the same rank as 997 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 998 be either `1`, or the same as the corresponding `labels` dimension). 999 metrics_collections: An optional list of collections that 1000 `mean_per_class_accuracy' 1001 should be added to. 1002 updates_collections: An optional list of collections `update_op` should be 1003 added to. 1004 name: An optional variable_scope name. 1005 1006 Returns: 1007 mean_accuracy: A `Tensor` representing the mean per class accuracy. 1008 update_op: An operation that updates the accuracy tensor. 1009 1010 Raises: 1011 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1012 `weights` is not `None` and its shape doesn't match `predictions`, or if 1013 either `metrics_collections` or `updates_collections` are not a list or 1014 tuple. 1015 RuntimeError: If eager execution is enabled. 1016 """ 1017 if context.executing_eagerly(): 1018 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 1019 'when eager execution is enabled.') 1020 1021 with variable_scope.variable_scope(name, 'mean_accuracy', 1022 (predictions, labels, weights)): 1023 labels = math_ops.cast(labels, dtypes.int64) 1024 1025 # Flatten the input if its rank > 1. 1026 if labels.get_shape().ndims > 1: 1027 labels = array_ops.reshape(labels, [-1]) 1028 1029 if predictions.get_shape().ndims > 1: 1030 predictions = array_ops.reshape(predictions, [-1]) 1031 1032 # Check if shape is compatible. 1033 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1034 1035 total = metric_variable([num_classes], dtypes.float32, name='total') 1036 count = metric_variable([num_classes], dtypes.float32, name='count') 1037 1038 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) 1039 1040 if labels.dtype != predictions.dtype: 1041 predictions = math_ops.cast(predictions, labels.dtype) 1042 is_correct = math_ops.cast( 1043 math_ops.equal(predictions, labels), dtypes.float32) 1044 1045 if weights is not None: 1046 if weights.get_shape().ndims > 1: 1047 weights = array_ops.reshape(weights, [-1]) 1048 weights = math_ops.cast(weights, dtypes.float32) 1049 1050 is_correct *= weights 1051 ones *= weights 1052 1053 update_total_op = state_ops.scatter_add(total, labels, ones) 1054 update_count_op = state_ops.scatter_add(count, labels, is_correct) 1055 1056 def compute_mean_accuracy(_, count, total): 1057 per_class_accuracy = math_ops.div_no_nan( 1058 count, math_ops.maximum(total, 0), name=None) 1059 mean_accuracy_v = math_ops.reduce_mean( 1060 per_class_accuracy, name='mean_accuracy') 1061 return mean_accuracy_v 1062 1063 mean_accuracy_v = _aggregate_across_replicas( 1064 metrics_collections, compute_mean_accuracy, count, total) 1065 1066 update_op = math_ops.div_no_nan( 1067 update_count_op, math_ops.maximum(update_total_op, 0), name='update_op') 1068 if updates_collections: 1069 ops.add_to_collections(updates_collections, update_op) 1070 1071 return mean_accuracy_v, update_op 1072 1073 1074@tf_export(v1=['metrics.mean_iou']) 1075def mean_iou(labels, 1076 predictions, 1077 num_classes, 1078 weights=None, 1079 metrics_collections=None, 1080 updates_collections=None, 1081 name=None): 1082 """Calculate per-step mean Intersection-Over-Union (mIOU). 1083 1084 Mean Intersection-Over-Union is a common evaluation metric for 1085 semantic image segmentation, which first computes the IOU for each 1086 semantic class and then computes the average over classes. 1087 IOU is defined as follows: 1088 IOU = true_positive / (true_positive + false_positive + false_negative). 1089 The predictions are accumulated in a confusion matrix, weighted by `weights`, 1090 and mIOU is then calculated from it. 1091 1092 For estimation of the metric over a stream of data, the function creates an 1093 `update_op` operation that updates these variables and returns the `mean_iou`. 1094 1095 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1096 1097 Args: 1098 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1099 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1100 predictions: A `Tensor` of prediction results for semantic labels, whose 1101 shape is [batch size] and type `int32` or `int64`. The tensor will be 1102 flattened if its rank > 1. 1103 num_classes: The possible number of labels the prediction task can 1104 have. This value must be provided, since a confusion matrix of 1105 dimension = [num_classes, num_classes] will be allocated. 1106 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1107 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1108 be either `1`, or the same as the corresponding `labels` dimension). 1109 metrics_collections: An optional list of collections that `mean_iou` 1110 should be added to. 1111 updates_collections: An optional list of collections `update_op` should be 1112 added to. 1113 name: An optional variable_scope name. 1114 1115 Returns: 1116 mean_iou: A `Tensor` representing the mean intersection-over-union. 1117 update_op: An operation that increments the confusion matrix. 1118 1119 Raises: 1120 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1121 `weights` is not `None` and its shape doesn't match `predictions`, or if 1122 either `metrics_collections` or `updates_collections` are not a list or 1123 tuple. 1124 RuntimeError: If eager execution is enabled. 1125 """ 1126 if context.executing_eagerly(): 1127 raise RuntimeError('tf.metrics.mean_iou is not supported when ' 1128 'eager execution is enabled.') 1129 1130 with variable_scope.variable_scope(name, 'mean_iou', 1131 (predictions, labels, weights)): 1132 # Check if shape is compatible. 1133 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1134 1135 total_cm, update_op = _streaming_confusion_matrix(labels, predictions, 1136 num_classes, weights) 1137 1138 def compute_mean_iou(_, total_cm): 1139 """Compute the mean intersection-over-union via the confusion matrix.""" 1140 sum_over_row = math_ops.cast( 1141 math_ops.reduce_sum(total_cm, 0), dtypes.float32) 1142 sum_over_col = math_ops.cast( 1143 math_ops.reduce_sum(total_cm, 1), dtypes.float32) 1144 cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32) 1145 denominator = sum_over_row + sum_over_col - cm_diag 1146 1147 # The mean is only computed over classes that appear in the 1148 # label or prediction tensor. If the denominator is 0, we need to 1149 # ignore the class. 1150 num_valid_entries = math_ops.reduce_sum( 1151 math_ops.cast( 1152 math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) 1153 1154 # If the value of the denominator is 0, set it to 1 to avoid 1155 # zero division. 1156 denominator = array_ops.where( 1157 math_ops.greater(denominator, 0), denominator, 1158 array_ops.ones_like(denominator)) 1159 iou = math_ops.div(cm_diag, denominator) 1160 1161 # If the number of valid entries is 0 (no classes) we return 0. 1162 result = array_ops.where( 1163 math_ops.greater(num_valid_entries, 0), 1164 math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0) 1165 return result 1166 1167 # TODO(priyag): Use outside_compilation if in TPU context. 1168 mean_iou_v = _aggregate_across_replicas( 1169 metrics_collections, compute_mean_iou, total_cm) 1170 1171 if updates_collections: 1172 ops.add_to_collections(updates_collections, update_op) 1173 1174 return mean_iou_v, update_op 1175 1176 1177@tf_export(v1=['metrics.mean_relative_error']) 1178def mean_relative_error(labels, 1179 predictions, 1180 normalizer, 1181 weights=None, 1182 metrics_collections=None, 1183 updates_collections=None, 1184 name=None): 1185 """Computes the mean relative error by normalizing with the given values. 1186 1187 The `mean_relative_error` function creates two local variables, 1188 `total` and `count` that are used to compute the mean relative absolute error. 1189 This average is weighted by `weights`, and it is ultimately returned as 1190 `mean_relative_error`: an idempotent operation that simply divides `total` by 1191 `count`. 1192 1193 For estimation of the metric over a stream of data, the function creates an 1194 `update_op` operation that updates these variables and returns the 1195 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1196 absolute value of the differences between `predictions` and `labels` by the 1197 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1198 product of `weights` and `relative_errors`, and it increments `count` with the 1199 reduced sum of `weights`. 1200 1201 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1202 1203 Args: 1204 labels: A `Tensor` of the same shape as `predictions`. 1205 predictions: A `Tensor` of arbitrary shape. 1206 normalizer: A `Tensor` of the same shape as `predictions`. 1207 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1208 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1209 be either `1`, or the same as the corresponding `labels` dimension). 1210 metrics_collections: An optional list of collections that 1211 `mean_relative_error` should be added to. 1212 updates_collections: An optional list of collections that `update_op` should 1213 be added to. 1214 name: An optional variable_scope name. 1215 1216 Returns: 1217 mean_relative_error: A `Tensor` representing the current mean, the value of 1218 `total` divided by `count`. 1219 update_op: An operation that increments the `total` and `count` variables 1220 appropriately and whose value matches `mean_relative_error`. 1221 1222 Raises: 1223 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1224 `weights` is not `None` and its shape doesn't match `predictions`, or if 1225 either `metrics_collections` or `updates_collections` are not a list or 1226 tuple. 1227 RuntimeError: If eager execution is enabled. 1228 """ 1229 if context.executing_eagerly(): 1230 raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 1231 'eager execution is enabled.') 1232 1233 predictions, labels, weights = _remove_squeezable_dimensions( 1234 predictions=predictions, labels=labels, weights=weights) 1235 1236 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions( 1237 predictions, normalizer) 1238 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1239 relative_errors = array_ops.where( 1240 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), 1241 math_ops.div(math_ops.abs(labels - predictions), normalizer)) 1242 return mean(relative_errors, weights, metrics_collections, 1243 updates_collections, name or 'mean_relative_error') 1244 1245 1246@tf_export(v1=['metrics.mean_squared_error']) 1247def mean_squared_error(labels, 1248 predictions, 1249 weights=None, 1250 metrics_collections=None, 1251 updates_collections=None, 1252 name=None): 1253 """Computes the mean squared error between the labels and predictions. 1254 1255 The `mean_squared_error` function creates two local variables, 1256 `total` and `count` that are used to compute the mean squared error. 1257 This average is weighted by `weights`, and it is ultimately returned as 1258 `mean_squared_error`: an idempotent operation that simply divides `total` by 1259 `count`. 1260 1261 For estimation of the metric over a stream of data, the function creates an 1262 `update_op` operation that updates these variables and returns the 1263 `mean_squared_error`. Internally, a `squared_error` operation computes the 1264 element-wise square of the difference between `predictions` and `labels`. Then 1265 `update_op` increments `total` with the reduced sum of the product of 1266 `weights` and `squared_error`, and it increments `count` with the reduced sum 1267 of `weights`. 1268 1269 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1270 1271 Args: 1272 labels: A `Tensor` of the same shape as `predictions`. 1273 predictions: A `Tensor` of arbitrary shape. 1274 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1275 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1276 be either `1`, or the same as the corresponding `labels` dimension). 1277 metrics_collections: An optional list of collections that 1278 `mean_squared_error` should be added to. 1279 updates_collections: An optional list of collections that `update_op` should 1280 be added to. 1281 name: An optional variable_scope name. 1282 1283 Returns: 1284 mean_squared_error: A `Tensor` representing the current mean, the value of 1285 `total` divided by `count`. 1286 update_op: An operation that increments the `total` and `count` variables 1287 appropriately and whose value matches `mean_squared_error`. 1288 1289 Raises: 1290 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1291 `weights` is not `None` and its shape doesn't match `predictions`, or if 1292 either `metrics_collections` or `updates_collections` are not a list or 1293 tuple. 1294 RuntimeError: If eager execution is enabled. 1295 """ 1296 if context.executing_eagerly(): 1297 raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 1298 'eager execution is enabled.') 1299 1300 predictions, labels, weights = _remove_squeezable_dimensions( 1301 predictions=predictions, labels=labels, weights=weights) 1302 squared_error = math_ops.squared_difference(labels, predictions) 1303 return mean(squared_error, weights, metrics_collections, updates_collections, 1304 name or 'mean_squared_error') 1305 1306 1307@tf_export(v1=['metrics.mean_tensor']) 1308def mean_tensor(values, 1309 weights=None, 1310 metrics_collections=None, 1311 updates_collections=None, 1312 name=None): 1313 """Computes the element-wise (weighted) mean of the given tensors. 1314 1315 In contrast to the `mean` function which returns a scalar with the 1316 mean, this function returns an average tensor with the same shape as the 1317 input tensors. 1318 1319 The `mean_tensor` function creates two local variables, 1320 `total_tensor` and `count_tensor` that are used to compute the average of 1321 `values`. This average is ultimately returned as `mean` which is an idempotent 1322 operation that simply divides `total` by `count`. 1323 1324 For estimation of the metric over a stream of data, the function creates an 1325 `update_op` operation that updates these variables and returns the `mean`. 1326 `update_op` increments `total` with the reduced sum of the product of `values` 1327 and `weights`, and it increments `count` with the reduced sum of `weights`. 1328 1329 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1330 1331 Args: 1332 values: A `Tensor` of arbitrary dimensions. 1333 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1334 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1335 be either `1`, or the same as the corresponding `values` dimension). 1336 metrics_collections: An optional list of collections that `mean` 1337 should be added to. 1338 updates_collections: An optional list of collections that `update_op` 1339 should be added to. 1340 name: An optional variable_scope name. 1341 1342 Returns: 1343 mean: A float `Tensor` representing the current mean, the value of `total` 1344 divided by `count`. 1345 update_op: An operation that increments the `total` and `count` variables 1346 appropriately and whose value matches `mean_value`. 1347 1348 Raises: 1349 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1350 or if either `metrics_collections` or `updates_collections` are not a list 1351 or tuple. 1352 RuntimeError: If eager execution is enabled. 1353 """ 1354 if context.executing_eagerly(): 1355 raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 1356 'eager execution is enabled.') 1357 1358 with variable_scope.variable_scope(name, 'mean', (values, weights)): 1359 values = math_ops.cast(values, dtypes.float32) 1360 total = metric_variable( 1361 values.get_shape(), dtypes.float32, name='total_tensor') 1362 count = metric_variable( 1363 values.get_shape(), dtypes.float32, name='count_tensor') 1364 1365 num_values = array_ops.ones_like(values) 1366 if weights is not None: 1367 values, _, weights = _remove_squeezable_dimensions( 1368 predictions=values, labels=None, weights=weights) 1369 weights = weights_broadcast_ops.broadcast_weights( 1370 math_ops.cast(weights, dtypes.float32), values) 1371 values = math_ops.multiply(values, weights) 1372 num_values = math_ops.multiply(num_values, weights) 1373 1374 update_total_op = state_ops.assign_add(total, values) 1375 with ops.control_dependencies([values]): 1376 update_count_op = state_ops.assign_add(count, num_values) 1377 1378 compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda 1379 t, math_ops.maximum(c, 0), name='value') 1380 1381 mean_t = _aggregate_across_replicas( 1382 metrics_collections, compute_mean, total, count) 1383 1384 update_op = math_ops.div_no_nan( 1385 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 1386 if updates_collections: 1387 ops.add_to_collections(updates_collections, update_op) 1388 1389 return mean_t, update_op 1390 1391 1392@tf_export(v1=['metrics.percentage_below']) 1393def percentage_below(values, 1394 threshold, 1395 weights=None, 1396 metrics_collections=None, 1397 updates_collections=None, 1398 name=None): 1399 """Computes the percentage of values less than the given threshold. 1400 1401 The `percentage_below` function creates two local variables, 1402 `total` and `count` that are used to compute the percentage of `values` that 1403 fall below `threshold`. This rate is weighted by `weights`, and it is 1404 ultimately returned as `percentage` which is an idempotent operation that 1405 simply divides `total` by `count`. 1406 1407 For estimation of the metric over a stream of data, the function creates an 1408 `update_op` operation that updates these variables and returns the 1409 `percentage`. 1410 1411 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1412 1413 Args: 1414 values: A numeric `Tensor` of arbitrary size. 1415 threshold: A scalar threshold. 1416 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1417 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1418 be either `1`, or the same as the corresponding `values` dimension). 1419 metrics_collections: An optional list of collections that the metric 1420 value variable should be added to. 1421 updates_collections: An optional list of collections that the metric update 1422 ops should be added to. 1423 name: An optional variable_scope name. 1424 1425 Returns: 1426 percentage: A `Tensor` representing the current mean, the value of `total` 1427 divided by `count`. 1428 update_op: An operation that increments the `total` and `count` variables 1429 appropriately. 1430 1431 Raises: 1432 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1433 or if either `metrics_collections` or `updates_collections` are not a list 1434 or tuple. 1435 RuntimeError: If eager execution is enabled. 1436 """ 1437 if context.executing_eagerly(): 1438 raise RuntimeError('tf.metrics.percentage_below is not supported when ' 1439 'eager execution is enabled.') 1440 1441 is_below_threshold = math_ops.cast( 1442 math_ops.less(values, threshold), dtypes.float32) 1443 return mean(is_below_threshold, weights, metrics_collections, 1444 updates_collections, name or 'percentage_below_threshold') 1445 1446 1447def _count_condition(values, 1448 weights=None, 1449 metrics_collections=None, 1450 updates_collections=None): 1451 """Sums the weights of cases where the given values are True. 1452 1453 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1454 1455 Args: 1456 values: A `bool` `Tensor` of arbitrary size. 1457 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1458 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1459 be either `1`, or the same as the corresponding `values` dimension). 1460 metrics_collections: An optional list of collections that the metric 1461 value variable should be added to. 1462 updates_collections: An optional list of collections that the metric update 1463 ops should be added to. 1464 1465 Returns: 1466 value_tensor: A `Tensor` representing the current value of the metric. 1467 update_op: An operation that accumulates the error from a batch of data. 1468 1469 Raises: 1470 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1471 or if either `metrics_collections` or `updates_collections` are not a list 1472 or tuple. 1473 """ 1474 check_ops.assert_type(values, dtypes.bool) 1475 count = metric_variable([], dtypes.float32, name='count') 1476 1477 values = math_ops.cast(values, dtypes.float32) 1478 if weights is not None: 1479 with ops.control_dependencies((check_ops.assert_rank_in( 1480 weights, (0, array_ops.rank(values))),)): 1481 weights = math_ops.cast(weights, dtypes.float32) 1482 values = math_ops.multiply(values, weights) 1483 1484 value_tensor = _aggregate_variable(count, metrics_collections) 1485 1486 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 1487 if updates_collections: 1488 ops.add_to_collections(updates_collections, update_op) 1489 1490 return value_tensor, update_op 1491 1492 1493@tf_export(v1=['metrics.false_negatives']) 1494def false_negatives(labels, 1495 predictions, 1496 weights=None, 1497 metrics_collections=None, 1498 updates_collections=None, 1499 name=None): 1500 """Computes the total number of false negatives. 1501 1502 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1503 1504 Args: 1505 labels: The ground truth values, a `Tensor` whose dimensions must match 1506 `predictions`. Will be cast to `bool`. 1507 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1508 be cast to `bool`. 1509 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1510 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1511 be either `1`, or the same as the corresponding `labels` dimension). 1512 metrics_collections: An optional list of collections that the metric 1513 value variable should be added to. 1514 updates_collections: An optional list of collections that the metric update 1515 ops should be added to. 1516 name: An optional variable_scope name. 1517 1518 Returns: 1519 value_tensor: A `Tensor` representing the current value of the metric. 1520 update_op: An operation that accumulates the error from a batch of data. 1521 1522 Raises: 1523 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1524 or if either `metrics_collections` or `updates_collections` are not a list 1525 or tuple. 1526 RuntimeError: If eager execution is enabled. 1527 """ 1528 if context.executing_eagerly(): 1529 raise RuntimeError('tf.metrics.false_negatives is not supported when ' 1530 'eager execution is enabled.') 1531 1532 with variable_scope.variable_scope(name, 'false_negatives', 1533 (predictions, labels, weights)): 1534 1535 predictions, labels, weights = _remove_squeezable_dimensions( 1536 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1537 labels=math_ops.cast(labels, dtype=dtypes.bool), 1538 weights=weights) 1539 is_false_negative = math_ops.logical_and( 1540 math_ops.equal(labels, True), math_ops.equal(predictions, False)) 1541 return _count_condition(is_false_negative, weights, metrics_collections, 1542 updates_collections) 1543 1544 1545@tf_export(v1=['metrics.false_negatives_at_thresholds']) 1546def false_negatives_at_thresholds(labels, 1547 predictions, 1548 thresholds, 1549 weights=None, 1550 metrics_collections=None, 1551 updates_collections=None, 1552 name=None): 1553 """Computes false negatives at provided threshold values. 1554 1555 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1556 1557 Args: 1558 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1559 `bool`. 1560 predictions: A floating point `Tensor` of arbitrary shape and whose values 1561 are in the range `[0, 1]`. 1562 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1563 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1564 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1565 be either `1`, or the same as the corresponding `labels` dimension). 1566 metrics_collections: An optional list of collections that `false_negatives` 1567 should be added to. 1568 updates_collections: An optional list of collections that `update_op` should 1569 be added to. 1570 name: An optional variable_scope name. 1571 1572 Returns: 1573 false_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1574 update_op: An operation that updates the `false_negatives` variable and 1575 returns its current value. 1576 1577 Raises: 1578 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1579 `weights` is not `None` and its shape doesn't match `predictions`, or if 1580 either `metrics_collections` or `updates_collections` are not a list or 1581 tuple. 1582 RuntimeError: If eager execution is enabled. 1583 """ 1584 if context.executing_eagerly(): 1585 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 1586 'supported when eager execution is enabled.') 1587 1588 with variable_scope.variable_scope(name, 'false_negatives', 1589 (predictions, labels, weights)): 1590 values, update_ops = _confusion_matrix_at_thresholds( 1591 labels, predictions, thresholds, weights=weights, includes=('fn',)) 1592 1593 fn_value = _aggregate_variable(values['fn'], metrics_collections) 1594 1595 if updates_collections: 1596 ops.add_to_collections(updates_collections, update_ops['fn']) 1597 1598 return fn_value, update_ops['fn'] 1599 1600 1601@tf_export(v1=['metrics.false_positives']) 1602def false_positives(labels, 1603 predictions, 1604 weights=None, 1605 metrics_collections=None, 1606 updates_collections=None, 1607 name=None): 1608 """Sum the weights of false positives. 1609 1610 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1611 1612 Args: 1613 labels: The ground truth values, a `Tensor` whose dimensions must match 1614 `predictions`. Will be cast to `bool`. 1615 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1616 be cast to `bool`. 1617 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1618 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1619 be either `1`, or the same as the corresponding `labels` dimension). 1620 metrics_collections: An optional list of collections that the metric 1621 value variable should be added to. 1622 updates_collections: An optional list of collections that the metric update 1623 ops should be added to. 1624 name: An optional variable_scope name. 1625 1626 Returns: 1627 value_tensor: A `Tensor` representing the current value of the metric. 1628 update_op: An operation that accumulates the error from a batch of data. 1629 1630 Raises: 1631 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1632 `weights` is not `None` and its shape doesn't match `predictions`, or if 1633 either `metrics_collections` or `updates_collections` are not a list or 1634 tuple. 1635 RuntimeError: If eager execution is enabled. 1636 """ 1637 if context.executing_eagerly(): 1638 raise RuntimeError('tf.metrics.false_positives is not supported when ' 1639 'eager execution is enabled.') 1640 1641 with variable_scope.variable_scope(name, 'false_positives', 1642 (predictions, labels, weights)): 1643 1644 predictions, labels, weights = _remove_squeezable_dimensions( 1645 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1646 labels=math_ops.cast(labels, dtype=dtypes.bool), 1647 weights=weights) 1648 is_false_positive = math_ops.logical_and( 1649 math_ops.equal(labels, False), math_ops.equal(predictions, True)) 1650 return _count_condition(is_false_positive, weights, metrics_collections, 1651 updates_collections) 1652 1653 1654@tf_export(v1=['metrics.false_positives_at_thresholds']) 1655def false_positives_at_thresholds(labels, 1656 predictions, 1657 thresholds, 1658 weights=None, 1659 metrics_collections=None, 1660 updates_collections=None, 1661 name=None): 1662 """Computes false positives at provided threshold values. 1663 1664 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1665 1666 Args: 1667 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1668 `bool`. 1669 predictions: A floating point `Tensor` of arbitrary shape and whose values 1670 are in the range `[0, 1]`. 1671 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1672 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1673 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1674 be either `1`, or the same as the corresponding `labels` dimension). 1675 metrics_collections: An optional list of collections that `false_positives` 1676 should be added to. 1677 updates_collections: An optional list of collections that `update_op` should 1678 be added to. 1679 name: An optional variable_scope name. 1680 1681 Returns: 1682 false_positives: A float `Tensor` of shape `[len(thresholds)]`. 1683 update_op: An operation that updates the `false_positives` variable and 1684 returns its current value. 1685 1686 Raises: 1687 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1688 `weights` is not `None` and its shape doesn't match `predictions`, or if 1689 either `metrics_collections` or `updates_collections` are not a list or 1690 tuple. 1691 RuntimeError: If eager execution is enabled. 1692 """ 1693 if context.executing_eagerly(): 1694 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 1695 'supported when eager execution is enabled.') 1696 1697 with variable_scope.variable_scope(name, 'false_positives', 1698 (predictions, labels, weights)): 1699 values, update_ops = _confusion_matrix_at_thresholds( 1700 labels, predictions, thresholds, weights=weights, includes=('fp',)) 1701 1702 fp_value = _aggregate_variable(values['fp'], metrics_collections) 1703 1704 if updates_collections: 1705 ops.add_to_collections(updates_collections, update_ops['fp']) 1706 1707 return fp_value, update_ops['fp'] 1708 1709 1710@tf_export(v1=['metrics.true_negatives']) 1711def true_negatives(labels, 1712 predictions, 1713 weights=None, 1714 metrics_collections=None, 1715 updates_collections=None, 1716 name=None): 1717 """Sum the weights of true_negatives. 1718 1719 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1720 1721 Args: 1722 labels: The ground truth values, a `Tensor` whose dimensions must match 1723 `predictions`. Will be cast to `bool`. 1724 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1725 be cast to `bool`. 1726 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1727 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1728 be either `1`, or the same as the corresponding `labels` dimension). 1729 metrics_collections: An optional list of collections that the metric 1730 value variable should be added to. 1731 updates_collections: An optional list of collections that the metric update 1732 ops should be added to. 1733 name: An optional variable_scope name. 1734 1735 Returns: 1736 value_tensor: A `Tensor` representing the current value of the metric. 1737 update_op: An operation that accumulates the error from a batch of data. 1738 1739 Raises: 1740 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1741 `weights` is not `None` and its shape doesn't match `predictions`, or if 1742 either `metrics_collections` or `updates_collections` are not a list or 1743 tuple. 1744 RuntimeError: If eager execution is enabled. 1745 """ 1746 if context.executing_eagerly(): 1747 raise RuntimeError('tf.metrics.true_negatives is not ' 1748 'supported when eager execution is enabled.') 1749 1750 with variable_scope.variable_scope(name, 'true_negatives', 1751 (predictions, labels, weights)): 1752 1753 predictions, labels, weights = _remove_squeezable_dimensions( 1754 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1755 labels=math_ops.cast(labels, dtype=dtypes.bool), 1756 weights=weights) 1757 is_true_negative = math_ops.logical_and( 1758 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 1759 return _count_condition(is_true_negative, weights, metrics_collections, 1760 updates_collections) 1761 1762 1763@tf_export(v1=['metrics.true_negatives_at_thresholds']) 1764def true_negatives_at_thresholds(labels, 1765 predictions, 1766 thresholds, 1767 weights=None, 1768 metrics_collections=None, 1769 updates_collections=None, 1770 name=None): 1771 """Computes true negatives at provided threshold values. 1772 1773 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1774 1775 Args: 1776 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1777 `bool`. 1778 predictions: A floating point `Tensor` of arbitrary shape and whose values 1779 are in the range `[0, 1]`. 1780 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1781 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1782 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1783 be either `1`, or the same as the corresponding `labels` dimension). 1784 metrics_collections: An optional list of collections that `true_negatives` 1785 should be added to. 1786 updates_collections: An optional list of collections that `update_op` should 1787 be added to. 1788 name: An optional variable_scope name. 1789 1790 Returns: 1791 true_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1792 update_op: An operation that updates the `true_negatives` variable and 1793 returns its current value. 1794 1795 Raises: 1796 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1797 `weights` is not `None` and its shape doesn't match `predictions`, or if 1798 either `metrics_collections` or `updates_collections` are not a list or 1799 tuple. 1800 RuntimeError: If eager execution is enabled. 1801 """ 1802 if context.executing_eagerly(): 1803 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 1804 'supported when eager execution is enabled.') 1805 1806 with variable_scope.variable_scope(name, 'true_negatives', 1807 (predictions, labels, weights)): 1808 values, update_ops = _confusion_matrix_at_thresholds( 1809 labels, predictions, thresholds, weights=weights, includes=('tn',)) 1810 1811 tn_value = _aggregate_variable(values['tn'], metrics_collections) 1812 1813 if updates_collections: 1814 ops.add_to_collections(updates_collections, update_ops['tn']) 1815 1816 return tn_value, update_ops['tn'] 1817 1818 1819@tf_export(v1=['metrics.true_positives']) 1820def true_positives(labels, 1821 predictions, 1822 weights=None, 1823 metrics_collections=None, 1824 updates_collections=None, 1825 name=None): 1826 """Sum the weights of true_positives. 1827 1828 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1829 1830 Args: 1831 labels: The ground truth values, a `Tensor` whose dimensions must match 1832 `predictions`. Will be cast to `bool`. 1833 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1834 be cast to `bool`. 1835 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1836 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1837 be either `1`, or the same as the corresponding `labels` dimension). 1838 metrics_collections: An optional list of collections that the metric 1839 value variable should be added to. 1840 updates_collections: An optional list of collections that the metric update 1841 ops should be added to. 1842 name: An optional variable_scope name. 1843 1844 Returns: 1845 value_tensor: A `Tensor` representing the current value of the metric. 1846 update_op: An operation that accumulates the error from a batch of data. 1847 1848 Raises: 1849 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1850 `weights` is not `None` and its shape doesn't match `predictions`, or if 1851 either `metrics_collections` or `updates_collections` are not a list or 1852 tuple. 1853 RuntimeError: If eager execution is enabled. 1854 """ 1855 if context.executing_eagerly(): 1856 raise RuntimeError('tf.metrics.true_positives is not ' 1857 'supported when eager execution is enabled.') 1858 1859 with variable_scope.variable_scope(name, 'true_positives', 1860 (predictions, labels, weights)): 1861 1862 predictions, labels, weights = _remove_squeezable_dimensions( 1863 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1864 labels=math_ops.cast(labels, dtype=dtypes.bool), 1865 weights=weights) 1866 is_true_positive = math_ops.logical_and( 1867 math_ops.equal(labels, True), math_ops.equal(predictions, True)) 1868 return _count_condition(is_true_positive, weights, metrics_collections, 1869 updates_collections) 1870 1871 1872@tf_export(v1=['metrics.true_positives_at_thresholds']) 1873def true_positives_at_thresholds(labels, 1874 predictions, 1875 thresholds, 1876 weights=None, 1877 metrics_collections=None, 1878 updates_collections=None, 1879 name=None): 1880 """Computes true positives at provided threshold values. 1881 1882 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1883 1884 Args: 1885 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1886 `bool`. 1887 predictions: A floating point `Tensor` of arbitrary shape and whose values 1888 are in the range `[0, 1]`. 1889 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1890 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1891 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1892 be either `1`, or the same as the corresponding `labels` dimension). 1893 metrics_collections: An optional list of collections that `true_positives` 1894 should be added to. 1895 updates_collections: An optional list of collections that `update_op` should 1896 be added to. 1897 name: An optional variable_scope name. 1898 1899 Returns: 1900 true_positives: A float `Tensor` of shape `[len(thresholds)]`. 1901 update_op: An operation that updates the `true_positives` variable and 1902 returns its current value. 1903 1904 Raises: 1905 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1906 `weights` is not `None` and its shape doesn't match `predictions`, or if 1907 either `metrics_collections` or `updates_collections` are not a list or 1908 tuple. 1909 RuntimeError: If eager execution is enabled. 1910 """ 1911 if context.executing_eagerly(): 1912 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 1913 'supported when eager execution is enabled.') 1914 1915 with variable_scope.variable_scope(name, 'true_positives', 1916 (predictions, labels, weights)): 1917 values, update_ops = _confusion_matrix_at_thresholds( 1918 labels, predictions, thresholds, weights=weights, includes=('tp',)) 1919 1920 tp_value = _aggregate_variable(values['tp'], metrics_collections) 1921 1922 if updates_collections: 1923 ops.add_to_collections(updates_collections, update_ops['tp']) 1924 1925 return tp_value, update_ops['tp'] 1926 1927 1928@tf_export(v1=['metrics.precision']) 1929def precision(labels, 1930 predictions, 1931 weights=None, 1932 metrics_collections=None, 1933 updates_collections=None, 1934 name=None): 1935 """Computes the precision of the predictions with respect to the labels. 1936 1937 The `precision` function creates two local variables, 1938 `true_positives` and `false_positives`, that are used to compute the 1939 precision. This value is ultimately returned as `precision`, an idempotent 1940 operation that simply divides `true_positives` by the sum of `true_positives` 1941 and `false_positives`. 1942 1943 For estimation of the metric over a stream of data, the function creates an 1944 `update_op` operation that updates these variables and returns the 1945 `precision`. `update_op` weights each prediction by the corresponding value in 1946 `weights`. 1947 1948 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1949 1950 Args: 1951 labels: The ground truth values, a `Tensor` whose dimensions must match 1952 `predictions`. Will be cast to `bool`. 1953 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1954 be cast to `bool`. 1955 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1956 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1957 be either `1`, or the same as the corresponding `labels` dimension). 1958 metrics_collections: An optional list of collections that `precision` should 1959 be added to. 1960 updates_collections: An optional list of collections that `update_op` should 1961 be added to. 1962 name: An optional variable_scope name. 1963 1964 Returns: 1965 precision: Scalar float `Tensor` with the value of `true_positives` 1966 divided by the sum of `true_positives` and `false_positives`. 1967 update_op: `Operation` that increments `true_positives` and 1968 `false_positives` variables appropriately and whose value matches 1969 `precision`. 1970 1971 Raises: 1972 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1973 `weights` is not `None` and its shape doesn't match `predictions`, or if 1974 either `metrics_collections` or `updates_collections` are not a list or 1975 tuple. 1976 RuntimeError: If eager execution is enabled. 1977 """ 1978 if context.executing_eagerly(): 1979 raise RuntimeError('tf.metrics.precision is not ' 1980 'supported when eager execution is enabled.') 1981 1982 with variable_scope.variable_scope(name, 'precision', 1983 (predictions, labels, weights)): 1984 1985 predictions, labels, weights = _remove_squeezable_dimensions( 1986 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1987 labels=math_ops.cast(labels, dtype=dtypes.bool), 1988 weights=weights) 1989 1990 true_p, true_positives_update_op = true_positives( 1991 labels, 1992 predictions, 1993 weights, 1994 metrics_collections=None, 1995 updates_collections=None, 1996 name=None) 1997 false_p, false_positives_update_op = false_positives( 1998 labels, 1999 predictions, 2000 weights, 2001 metrics_collections=None, 2002 updates_collections=None, 2003 name=None) 2004 2005 def compute_precision(tp, fp, name): 2006 return array_ops.where( 2007 math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) 2008 2009 def once_across_replicas(_, true_p, false_p): 2010 return compute_precision(true_p, false_p, 'value') 2011 2012 p = _aggregate_across_replicas(metrics_collections, once_across_replicas, 2013 true_p, false_p) 2014 2015 update_op = compute_precision(true_positives_update_op, 2016 false_positives_update_op, 'update_op') 2017 if updates_collections: 2018 ops.add_to_collections(updates_collections, update_op) 2019 2020 return p, update_op 2021 2022 2023@tf_export(v1=['metrics.precision_at_thresholds']) 2024def precision_at_thresholds(labels, 2025 predictions, 2026 thresholds, 2027 weights=None, 2028 metrics_collections=None, 2029 updates_collections=None, 2030 name=None): 2031 """Computes precision values for different `thresholds` on `predictions`. 2032 2033 The `precision_at_thresholds` function creates four local variables, 2034 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2035 for various values of thresholds. `precision[i]` is defined as the total 2036 weight of values in `predictions` above `thresholds[i]` whose corresponding 2037 entry in `labels` is `True`, divided by the total weight of values in 2038 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 2039 false_positives[i])`). 2040 2041 For estimation of the metric over a stream of data, the function creates an 2042 `update_op` operation that updates these variables and returns the 2043 `precision`. 2044 2045 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2046 2047 Args: 2048 labels: The ground truth values, a `Tensor` whose dimensions must match 2049 `predictions`. Will be cast to `bool`. 2050 predictions: A floating point `Tensor` of arbitrary shape and whose values 2051 are in the range `[0, 1]`. 2052 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2053 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2054 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2055 be either `1`, or the same as the corresponding `labels` dimension). 2056 metrics_collections: An optional list of collections that `auc` should be 2057 added to. 2058 updates_collections: An optional list of collections that `update_op` should 2059 be added to. 2060 name: An optional variable_scope name. 2061 2062 Returns: 2063 precision: A float `Tensor` of shape `[len(thresholds)]`. 2064 update_op: An operation that increments the `true_positives`, 2065 `true_negatives`, `false_positives` and `false_negatives` variables that 2066 are used in the computation of `precision`. 2067 2068 Raises: 2069 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2070 `weights` is not `None` and its shape doesn't match `predictions`, or if 2071 either `metrics_collections` or `updates_collections` are not a list or 2072 tuple. 2073 RuntimeError: If eager execution is enabled. 2074 """ 2075 if context.executing_eagerly(): 2076 raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 2077 'supported when eager execution is enabled.') 2078 2079 with variable_scope.variable_scope(name, 'precision_at_thresholds', 2080 (predictions, labels, weights)): 2081 values, update_ops = _confusion_matrix_at_thresholds( 2082 labels, predictions, thresholds, weights, includes=('tp', 'fp')) 2083 2084 # Avoid division by zero. 2085 epsilon = 1e-7 2086 2087 def compute_precision(tp, fp, name): 2088 return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) 2089 2090 def precision_across_replicas(_, values): 2091 return compute_precision(values['tp'], values['fp'], 'value') 2092 2093 prec = _aggregate_across_replicas( 2094 metrics_collections, precision_across_replicas, values) 2095 2096 update_op = compute_precision(update_ops['tp'], update_ops['fp'], 2097 'update_op') 2098 if updates_collections: 2099 ops.add_to_collections(updates_collections, update_op) 2100 2101 return prec, update_op 2102 2103 2104@tf_export(v1=['metrics.recall']) 2105def recall(labels, 2106 predictions, 2107 weights=None, 2108 metrics_collections=None, 2109 updates_collections=None, 2110 name=None): 2111 """Computes the recall of the predictions with respect to the labels. 2112 2113 The `recall` function creates two local variables, `true_positives` 2114 and `false_negatives`, that are used to compute the recall. This value is 2115 ultimately returned as `recall`, an idempotent operation that simply divides 2116 `true_positives` by the sum of `true_positives` and `false_negatives`. 2117 2118 For estimation of the metric over a stream of data, the function creates an 2119 `update_op` that updates these variables and returns the `recall`. `update_op` 2120 weights each prediction by the corresponding value in `weights`. 2121 2122 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2123 2124 Args: 2125 labels: The ground truth values, a `Tensor` whose dimensions must match 2126 `predictions`. Will be cast to `bool`. 2127 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2128 be cast to `bool`. 2129 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2130 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2131 be either `1`, or the same as the corresponding `labels` dimension). 2132 metrics_collections: An optional list of collections that `recall` should 2133 be added to. 2134 updates_collections: An optional list of collections that `update_op` should 2135 be added to. 2136 name: An optional variable_scope name. 2137 2138 Returns: 2139 recall: Scalar float `Tensor` with the value of `true_positives` divided 2140 by the sum of `true_positives` and `false_negatives`. 2141 update_op: `Operation` that increments `true_positives` and 2142 `false_negatives` variables appropriately and whose value matches 2143 `recall`. 2144 2145 Raises: 2146 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2147 `weights` is not `None` and its shape doesn't match `predictions`, or if 2148 either `metrics_collections` or `updates_collections` are not a list or 2149 tuple. 2150 RuntimeError: If eager execution is enabled. 2151 """ 2152 if context.executing_eagerly(): 2153 raise RuntimeError('tf.metrics.recall is not supported is not ' 2154 'supported when eager execution is enabled.') 2155 2156 with variable_scope.variable_scope(name, 'recall', 2157 (predictions, labels, weights)): 2158 predictions, labels, weights = _remove_squeezable_dimensions( 2159 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2160 labels=math_ops.cast(labels, dtype=dtypes.bool), 2161 weights=weights) 2162 2163 true_p, true_positives_update_op = true_positives( 2164 labels, 2165 predictions, 2166 weights, 2167 metrics_collections=None, 2168 updates_collections=None, 2169 name=None) 2170 false_n, false_negatives_update_op = false_negatives( 2171 labels, 2172 predictions, 2173 weights, 2174 metrics_collections=None, 2175 updates_collections=None, 2176 name=None) 2177 2178 def compute_recall(true_p, false_n, name): 2179 return array_ops.where( 2180 math_ops.greater(true_p + false_n, 0), 2181 math_ops.div(true_p, true_p + false_n), 0, name) 2182 2183 def once_across_replicas(_, true_p, false_n): 2184 return compute_recall(true_p, false_n, 'value') 2185 2186 rec = _aggregate_across_replicas( 2187 metrics_collections, once_across_replicas, true_p, false_n) 2188 2189 update_op = compute_recall(true_positives_update_op, 2190 false_negatives_update_op, 'update_op') 2191 if updates_collections: 2192 ops.add_to_collections(updates_collections, update_op) 2193 2194 return rec, update_op 2195 2196 2197def _at_k_name(name, k=None, class_id=None): 2198 if k is not None: 2199 name = '%s_at_%d' % (name, k) 2200 else: 2201 name = '%s_at_k' % (name) 2202 if class_id is not None: 2203 name = '%s_class%d' % (name, class_id) 2204 return name 2205 2206 2207def _select_class_id(ids, selected_id): 2208 """Filter all but `selected_id` out of `ids`. 2209 2210 Args: 2211 ids: `int64` `Tensor` or `SparseTensor` of IDs. 2212 selected_id: Int id to select. 2213 2214 Returns: 2215 `SparseTensor` of same dimensions as `ids`. This contains only the entries 2216 equal to `selected_id`. 2217 """ 2218 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) 2219 if isinstance(ids, sparse_tensor.SparseTensor): 2220 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, 2221 selected_id)) 2222 2223 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 2224 # tf.equal and tf.reduce_any? 2225 2226 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 2227 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 2228 ids_last_dim = array_ops.size(ids_shape) - 1 2229 filled_selected_id_shape = math_ops.reduced_shape(ids_shape, 2230 array_ops.reshape( 2231 ids_last_dim, [1])) 2232 2233 # Intersect `ids` with the selected ID. 2234 filled_selected_id = array_ops.fill(filled_selected_id_shape, 2235 math_ops.cast(selected_id, dtypes.int64)) 2236 result = sets.set_intersection(filled_selected_id, ids) 2237 return sparse_tensor.SparseTensor( 2238 indices=result.indices, values=result.values, dense_shape=ids_shape) 2239 2240 2241def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 2242 """If class ID is specified, filter all other classes. 2243 2244 Args: 2245 labels: `int64` `Tensor` or `SparseTensor` with shape 2246 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2247 target classes for the associated prediction. Commonly, N=1 and `labels` 2248 has shape [batch_size, num_labels]. [D1, ... DN] must match 2249 `predictions_idx`. 2250 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 2251 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 2252 [batch size, k]. 2253 selected_id: Int id to select. 2254 2255 Returns: 2256 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 2257 """ 2258 if selected_id is None: 2259 return labels, predictions_idx 2260 return (_select_class_id(labels, selected_id), 2261 _select_class_id(predictions_idx, selected_id)) 2262 2263 2264def _sparse_true_positive_at_k(labels, 2265 predictions_idx, 2266 class_id=None, 2267 weights=None, 2268 name=None): 2269 """Calculates true positives for recall@k and precision@k. 2270 2271 If `class_id` is specified, calculate binary true positives for `class_id` 2272 only. 2273 If `class_id` is not specified, calculate metrics for `k` predicted vs 2274 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2275 2276 Args: 2277 labels: `int64` `Tensor` or `SparseTensor` with shape 2278 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2279 target classes for the associated prediction. Commonly, N=1 and `labels` 2280 has shape [batch_size, num_labels]. [D1, ... DN] must match 2281 `predictions_idx`. 2282 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2283 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2284 match `labels`. 2285 class_id: Class for which we want binary metrics. 2286 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2287 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2288 dimensions must be either `1`, or the same as the corresponding `labels` 2289 dimension). 2290 name: Name of operation. 2291 2292 Returns: 2293 A [D1, ... DN] `Tensor` of true positive counts. 2294 """ 2295 with ops.name_scope(name, 'true_positives', 2296 (predictions_idx, labels, weights)): 2297 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2298 class_id) 2299 tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) 2300 tp = math_ops.cast(tp, dtypes.float64) 2301 if weights is not None: 2302 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2303 weights, tp),)): 2304 weights = math_ops.cast(weights, dtypes.float64) 2305 tp = math_ops.multiply(tp, weights) 2306 return tp 2307 2308 2309def _streaming_sparse_true_positive_at_k(labels, 2310 predictions_idx, 2311 k=None, 2312 class_id=None, 2313 weights=None, 2314 name=None): 2315 """Calculates weighted per step true positives for recall@k and precision@k. 2316 2317 If `class_id` is specified, calculate binary true positives for `class_id` 2318 only. 2319 If `class_id` is not specified, calculate metrics for `k` predicted vs 2320 `n` label classes, where `n` is the 2nd dimension of `labels`. 2321 2322 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2323 2324 Args: 2325 labels: `int64` `Tensor` or `SparseTensor` with shape 2326 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2327 target classes for the associated prediction. Commonly, N=1 and `labels` 2328 has shape [batch_size, num_labels]. [D1, ... DN] must match 2329 `predictions_idx`. 2330 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2331 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2332 match `labels`. 2333 k: Integer, k for @k metric. This is only used for default op name. 2334 class_id: Class for which we want binary metrics. 2335 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2336 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2337 dimensions must be either `1`, or the same as the corresponding `labels` 2338 dimension). 2339 name: Name of new variable, and namespace for other dependent ops. 2340 2341 Returns: 2342 A tuple of `Variable` and update `Operation`. 2343 2344 Raises: 2345 ValueError: If `weights` is not `None` and has an incompatible shape. 2346 """ 2347 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), 2348 (predictions_idx, labels, weights)) as scope: 2349 tp = _sparse_true_positive_at_k( 2350 predictions_idx=predictions_idx, 2351 labels=labels, 2352 class_id=class_id, 2353 weights=weights) 2354 batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64) 2355 2356 var = metric_variable([], dtypes.float64, name=scope) 2357 return var, state_ops.assign_add(var, batch_total_tp, name='update') 2358 2359 2360def _sparse_false_negative_at_k(labels, 2361 predictions_idx, 2362 class_id=None, 2363 weights=None): 2364 """Calculates false negatives for recall@k. 2365 2366 If `class_id` is specified, calculate binary true positives for `class_id` 2367 only. 2368 If `class_id` is not specified, calculate metrics for `k` predicted vs 2369 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2370 2371 Args: 2372 labels: `int64` `Tensor` or `SparseTensor` with shape 2373 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2374 target classes for the associated prediction. Commonly, N=1 and `labels` 2375 has shape [batch_size, num_labels]. [D1, ... DN] must match 2376 `predictions_idx`. 2377 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2378 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2379 match `labels`. 2380 class_id: Class for which we want binary metrics. 2381 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2382 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2383 dimensions must be either `1`, or the same as the corresponding `labels` 2384 dimension). 2385 2386 Returns: 2387 A [D1, ... DN] `Tensor` of false negative counts. 2388 """ 2389 with ops.name_scope(None, 'false_negatives', 2390 (predictions_idx, labels, weights)): 2391 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2392 class_id) 2393 fn = sets.set_size( 2394 sets.set_difference(predictions_idx, labels, aminusb=False)) 2395 fn = math_ops.cast(fn, dtypes.float64) 2396 if weights is not None: 2397 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2398 weights, fn),)): 2399 weights = math_ops.cast(weights, dtypes.float64) 2400 fn = math_ops.multiply(fn, weights) 2401 return fn 2402 2403 2404def _streaming_sparse_false_negative_at_k(labels, 2405 predictions_idx, 2406 k, 2407 class_id=None, 2408 weights=None, 2409 name=None): 2410 """Calculates weighted per step false negatives for recall@k. 2411 2412 If `class_id` is specified, calculate binary true positives for `class_id` 2413 only. 2414 If `class_id` is not specified, calculate metrics for `k` predicted vs 2415 `n` label classes, where `n` is the 2nd dimension of `labels`. 2416 2417 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2418 2419 Args: 2420 labels: `int64` `Tensor` or `SparseTensor` with shape 2421 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2422 target classes for the associated prediction. Commonly, N=1 and `labels` 2423 has shape [batch_size, num_labels]. [D1, ... DN] must match 2424 `predictions_idx`. 2425 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2426 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2427 match `labels`. 2428 k: Integer, k for @k metric. This is only used for default op name. 2429 class_id: Class for which we want binary metrics. 2430 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2431 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2432 dimensions must be either `1`, or the same as the corresponding `labels` 2433 dimension). 2434 name: Name of new variable, and namespace for other dependent ops. 2435 2436 Returns: 2437 A tuple of `Variable` and update `Operation`. 2438 2439 Raises: 2440 ValueError: If `weights` is not `None` and has an incompatible shape. 2441 """ 2442 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), 2443 (predictions_idx, labels, weights)) as scope: 2444 fn = _sparse_false_negative_at_k( 2445 predictions_idx=predictions_idx, 2446 labels=labels, 2447 class_id=class_id, 2448 weights=weights) 2449 batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64) 2450 2451 var = metric_variable([], dtypes.float64, name=scope) 2452 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2453 2454 2455@tf_export(v1=['metrics.recall_at_k']) 2456def recall_at_k(labels, 2457 predictions, 2458 k, 2459 class_id=None, 2460 weights=None, 2461 metrics_collections=None, 2462 updates_collections=None, 2463 name=None): 2464 """Computes recall@k of the predictions with respect to sparse labels. 2465 2466 If `class_id` is specified, we calculate recall by considering only the 2467 entries in the batch for which `class_id` is in the label, and computing 2468 the fraction of them for which `class_id` is in the top-k `predictions`. 2469 If `class_id` is not specified, we'll calculate recall as how often on 2470 average a class among the labels of a batch entry is in the top-k 2471 `predictions`. 2472 2473 `sparse_recall_at_k` creates two local variables, 2474 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2475 the recall_at_k frequency. This frequency is ultimately returned as 2476 `recall_at_<k>`: an idempotent operation that simply divides 2477 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2478 `false_negative_at_<k>`). 2479 2480 For estimation of the metric over a stream of data, the function creates an 2481 `update_op` operation that updates these variables and returns the 2482 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2483 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2484 `labels` calculate the true positives and false negatives weighted by 2485 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2486 `false_negative_at_<k>` using these values. 2487 2488 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2489 2490 Args: 2491 labels: `int64` `Tensor` or `SparseTensor` with shape 2492 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2493 num_labels=1. N >= 1 and num_labels is the number of target classes for 2494 the associated prediction. Commonly, N=1 and `labels` has shape 2495 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2496 should be in range [0, num_classes), where num_classes is the last 2497 dimension of `predictions`. Values outside this range always count 2498 towards `false_negative_at_<k>`. 2499 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2500 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2501 The final dimension contains the logit values for each class. [D1, ... DN] 2502 must match `labels`. 2503 k: Integer, k for @k metric. 2504 class_id: Integer class ID for which we want binary metrics. This should be 2505 in range [0, num_classes), where num_classes is the last dimension of 2506 `predictions`. If class_id is outside this range, the method returns NAN. 2507 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2508 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2509 dimensions must be either `1`, or the same as the corresponding `labels` 2510 dimension). 2511 metrics_collections: An optional list of collections that values should 2512 be added to. 2513 updates_collections: An optional list of collections that updates should 2514 be added to. 2515 name: Name of new update operation, and namespace for other dependent ops. 2516 2517 Returns: 2518 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2519 by the sum of `true_positives` and `false_negatives`. 2520 update_op: `Operation` that increments `true_positives` and 2521 `false_negatives` variables appropriately, and whose value matches 2522 `recall`. 2523 2524 Raises: 2525 ValueError: If `weights` is not `None` and its shape doesn't match 2526 `predictions`, or if either `metrics_collections` or `updates_collections` 2527 are not a list or tuple. 2528 RuntimeError: If eager execution is enabled. 2529 """ 2530 if context.executing_eagerly(): 2531 raise RuntimeError('tf.metrics.recall_at_k is not ' 2532 'supported when eager execution is enabled.') 2533 2534 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2535 (predictions, labels, weights)) as scope: 2536 _, top_k_idx = nn.top_k(predictions, k) 2537 return recall_at_top_k( 2538 labels=labels, 2539 predictions_idx=top_k_idx, 2540 k=k, 2541 class_id=class_id, 2542 weights=weights, 2543 metrics_collections=metrics_collections, 2544 updates_collections=updates_collections, 2545 name=scope) 2546 2547 2548@tf_export(v1=['metrics.recall_at_top_k']) 2549def recall_at_top_k(labels, 2550 predictions_idx, 2551 k=None, 2552 class_id=None, 2553 weights=None, 2554 metrics_collections=None, 2555 updates_collections=None, 2556 name=None): 2557 """Computes recall@k of top-k predictions with respect to sparse labels. 2558 2559 Differs from `recall_at_k` in that predictions must be in the form of top `k` 2560 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` 2561 for more details. 2562 2563 Args: 2564 labels: `int64` `Tensor` or `SparseTensor` with shape 2565 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2566 num_labels=1. N >= 1 and num_labels is the number of target classes for 2567 the associated prediction. Commonly, N=1 and `labels` has shape 2568 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2569 should be in range [0, num_classes), where num_classes is the last 2570 dimension of `predictions`. Values outside this range always count 2571 towards `false_negative_at_<k>`. 2572 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2573 Commonly, N=1 and predictions has shape [batch size, k]. The final 2574 dimension contains the top `k` predicted class indices. [D1, ... DN] must 2575 match `labels`. 2576 k: Integer, k for @k metric. Only used for the default op name. 2577 class_id: Integer class ID for which we want binary metrics. This should be 2578 in range [0, num_classes), where num_classes is the last dimension of 2579 `predictions`. If class_id is outside this range, the method returns NAN. 2580 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2581 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2582 dimensions must be either `1`, or the same as the corresponding `labels` 2583 dimension). 2584 metrics_collections: An optional list of collections that values should 2585 be added to. 2586 updates_collections: An optional list of collections that updates should 2587 be added to. 2588 name: Name of new update operation, and namespace for other dependent ops. 2589 2590 Returns: 2591 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2592 by the sum of `true_positives` and `false_negatives`. 2593 update_op: `Operation` that increments `true_positives` and 2594 `false_negatives` variables appropriately, and whose value matches 2595 `recall`. 2596 2597 Raises: 2598 ValueError: If `weights` is not `None` and its shape doesn't match 2599 `predictions`, or if either `metrics_collections` or `updates_collections` 2600 are not a list or tuple. 2601 """ 2602 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2603 (predictions_idx, labels, weights)) as scope: 2604 labels = _maybe_expand_labels(labels, predictions_idx) 2605 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 2606 tp, tp_update = _streaming_sparse_true_positive_at_k( 2607 predictions_idx=top_k_idx, 2608 labels=labels, 2609 k=k, 2610 class_id=class_id, 2611 weights=weights) 2612 fn, fn_update = _streaming_sparse_false_negative_at_k( 2613 predictions_idx=top_k_idx, 2614 labels=labels, 2615 k=k, 2616 class_id=class_id, 2617 weights=weights) 2618 2619 def compute_recall(_, tp, fn): 2620 return math_ops.div(tp, math_ops.add(tp, fn), name=scope) 2621 2622 metric = _aggregate_across_replicas( 2623 metrics_collections, compute_recall, tp, fn) 2624 2625 update = math_ops.div( 2626 tp_update, math_ops.add(tp_update, fn_update), name='update') 2627 if updates_collections: 2628 ops.add_to_collections(updates_collections, update) 2629 return metric, update 2630 2631 2632@tf_export(v1=['metrics.recall_at_thresholds']) 2633def recall_at_thresholds(labels, 2634 predictions, 2635 thresholds, 2636 weights=None, 2637 metrics_collections=None, 2638 updates_collections=None, 2639 name=None): 2640 """Computes various recall values for different `thresholds` on `predictions`. 2641 2642 The `recall_at_thresholds` function creates four local variables, 2643 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2644 for various values of thresholds. `recall[i]` is defined as the total weight 2645 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2646 `labels` is `True`, divided by the total weight of `True` values in `labels` 2647 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 2648 2649 For estimation of the metric over a stream of data, the function creates an 2650 `update_op` operation that updates these variables and returns the `recall`. 2651 2652 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2653 2654 Args: 2655 labels: The ground truth values, a `Tensor` whose dimensions must match 2656 `predictions`. Will be cast to `bool`. 2657 predictions: A floating point `Tensor` of arbitrary shape and whose values 2658 are in the range `[0, 1]`. 2659 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2660 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2661 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2662 be either `1`, or the same as the corresponding `labels` dimension). 2663 metrics_collections: An optional list of collections that `recall` should be 2664 added to. 2665 updates_collections: An optional list of collections that `update_op` should 2666 be added to. 2667 name: An optional variable_scope name. 2668 2669 Returns: 2670 recall: A float `Tensor` of shape `[len(thresholds)]`. 2671 update_op: An operation that increments the `true_positives`, 2672 `true_negatives`, `false_positives` and `false_negatives` variables that 2673 are used in the computation of `recall`. 2674 2675 Raises: 2676 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2677 `weights` is not `None` and its shape doesn't match `predictions`, or if 2678 either `metrics_collections` or `updates_collections` are not a list or 2679 tuple. 2680 RuntimeError: If eager execution is enabled. 2681 """ 2682 if context.executing_eagerly(): 2683 raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 2684 'supported when eager execution is enabled.') 2685 2686 with variable_scope.variable_scope(name, 'recall_at_thresholds', 2687 (predictions, labels, weights)): 2688 values, update_ops = _confusion_matrix_at_thresholds( 2689 labels, predictions, thresholds, weights, includes=('tp', 'fn')) 2690 2691 # Avoid division by zero. 2692 epsilon = 1e-7 2693 2694 def compute_recall(tp, fn, name): 2695 return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) 2696 2697 def recall_across_replicas(_, values): 2698 return compute_recall(values['tp'], values['fn'], 'value') 2699 2700 rec = _aggregate_across_replicas( 2701 metrics_collections, recall_across_replicas, values) 2702 2703 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') 2704 if updates_collections: 2705 ops.add_to_collections(updates_collections, update_op) 2706 2707 return rec, update_op 2708 2709 2710@tf_export(v1=['metrics.root_mean_squared_error']) 2711def root_mean_squared_error(labels, 2712 predictions, 2713 weights=None, 2714 metrics_collections=None, 2715 updates_collections=None, 2716 name=None): 2717 """Computes the root mean squared error between the labels and predictions. 2718 2719 The `root_mean_squared_error` function creates two local variables, 2720 `total` and `count` that are used to compute the root mean squared error. 2721 This average is weighted by `weights`, and it is ultimately returned as 2722 `root_mean_squared_error`: an idempotent operation that takes the square root 2723 of the division of `total` by `count`. 2724 2725 For estimation of the metric over a stream of data, the function creates an 2726 `update_op` operation that updates these variables and returns the 2727 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2728 the element-wise square of the difference between `predictions` and `labels`. 2729 Then `update_op` increments `total` with the reduced sum of the product of 2730 `weights` and `squared_error`, and it increments `count` with the reduced sum 2731 of `weights`. 2732 2733 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2734 2735 Args: 2736 labels: A `Tensor` of the same shape as `predictions`. 2737 predictions: A `Tensor` of arbitrary shape. 2738 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2739 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2740 be either `1`, or the same as the corresponding `labels` dimension). 2741 metrics_collections: An optional list of collections that 2742 `root_mean_squared_error` should be added to. 2743 updates_collections: An optional list of collections that `update_op` should 2744 be added to. 2745 name: An optional variable_scope name. 2746 2747 Returns: 2748 root_mean_squared_error: A `Tensor` representing the current mean, the value 2749 of `total` divided by `count`. 2750 update_op: An operation that increments the `total` and `count` variables 2751 appropriately and whose value matches `root_mean_squared_error`. 2752 2753 Raises: 2754 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2755 `weights` is not `None` and its shape doesn't match `predictions`, or if 2756 either `metrics_collections` or `updates_collections` are not a list or 2757 tuple. 2758 RuntimeError: If eager execution is enabled. 2759 """ 2760 if context.executing_eagerly(): 2761 raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 2762 'supported when eager execution is enabled.') 2763 2764 predictions, labels, weights = _remove_squeezable_dimensions( 2765 predictions=predictions, labels=labels, weights=weights) 2766 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, 2767 None, name or 2768 'root_mean_squared_error') 2769 2770 once_across_replicas = lambda _, mse: math_ops.sqrt(mse) 2771 rmse = _aggregate_across_replicas( 2772 metrics_collections, once_across_replicas, mse) 2773 2774 update_rmse_op = math_ops.sqrt(update_mse_op) 2775 if updates_collections: 2776 ops.add_to_collections(updates_collections, update_rmse_op) 2777 2778 return rmse, update_rmse_op 2779 2780 2781@tf_export(v1=['metrics.sensitivity_at_specificity']) 2782def sensitivity_at_specificity(labels, 2783 predictions, 2784 specificity, 2785 weights=None, 2786 num_thresholds=200, 2787 metrics_collections=None, 2788 updates_collections=None, 2789 name=None): 2790 """Computes the specificity at a given sensitivity. 2791 2792 The `sensitivity_at_specificity` function creates four local 2793 variables, `true_positives`, `true_negatives`, `false_positives` and 2794 `false_negatives` that are used to compute the sensitivity at the given 2795 specificity value. The threshold for the given specificity value is computed 2796 and used to evaluate the corresponding sensitivity. 2797 2798 For estimation of the metric over a stream of data, the function creates an 2799 `update_op` operation that updates these variables and returns the 2800 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 2801 `false_positives` and `false_negatives` counts with the weight of each case 2802 found in the `predictions` and `labels`. 2803 2804 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2805 2806 For additional information about specificity and sensitivity, see the 2807 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 2808 2809 Args: 2810 labels: The ground truth values, a `Tensor` whose dimensions must match 2811 `predictions`. Will be cast to `bool`. 2812 predictions: A floating point `Tensor` of arbitrary shape and whose values 2813 are in the range `[0, 1]`. 2814 specificity: A scalar value in range `[0, 1]`. 2815 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2816 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2817 be either `1`, or the same as the corresponding `labels` dimension). 2818 num_thresholds: The number of thresholds to use for matching the given 2819 specificity. 2820 metrics_collections: An optional list of collections that `sensitivity` 2821 should be added to. 2822 updates_collections: An optional list of collections that `update_op` should 2823 be added to. 2824 name: An optional variable_scope name. 2825 2826 Returns: 2827 sensitivity: A scalar `Tensor` representing the sensitivity at the given 2828 `specificity` value. 2829 update_op: An operation that increments the `true_positives`, 2830 `true_negatives`, `false_positives` and `false_negatives` variables 2831 appropriately and whose value matches `sensitivity`. 2832 2833 Raises: 2834 ValueError: If `predictions` and `labels` have mismatched shapes, if 2835 `weights` is not `None` and its shape doesn't match `predictions`, or if 2836 `specificity` is not between 0 and 1, or if either `metrics_collections` 2837 or `updates_collections` are not a list or tuple. 2838 RuntimeError: If eager execution is enabled. 2839 """ 2840 if context.executing_eagerly(): 2841 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 2842 'supported when eager execution is enabled.') 2843 2844 if specificity < 0 or specificity > 1: 2845 raise ValueError('`specificity` must be in the range [0, 1].') 2846 2847 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 2848 (predictions, labels, weights)): 2849 kepsilon = 1e-7 # to account for floating point imprecisions 2850 thresholds = [ 2851 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 2852 ] 2853 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 2854 2855 values, update_ops = _confusion_matrix_at_thresholds( 2856 labels, predictions, thresholds, weights) 2857 2858 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name): 2859 specificities = math_ops.div(tn, tn + fp + kepsilon) 2860 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 2861 tf_index = math_ops.cast(tf_index, dtypes.int32) 2862 2863 # Now, we have the implicit threshold, so compute the sensitivity: 2864 return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, 2865 name) 2866 2867 def sensitivity_across_replicas(_, values): 2868 return compute_sensitivity_at_specificity( 2869 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 2870 2871 sensitivity = _aggregate_across_replicas( 2872 metrics_collections, sensitivity_across_replicas, values) 2873 2874 update_op = compute_sensitivity_at_specificity( 2875 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 2876 'update_op') 2877 if updates_collections: 2878 ops.add_to_collections(updates_collections, update_op) 2879 2880 return sensitivity, update_op 2881 2882 2883def _expand_and_tile(tensor, multiple, dim=0, name=None): 2884 """Slice `tensor` shape in 2, then tile along the sliced dimension. 2885 2886 A new dimension is inserted in shape of `tensor` before `dim`, then values are 2887 tiled `multiple` times along the new dimension. 2888 2889 Args: 2890 tensor: Input `Tensor` or `SparseTensor`. 2891 multiple: Integer, number of times to tile. 2892 dim: Integer, dimension along which to tile. 2893 name: Name of operation. 2894 2895 Returns: 2896 `Tensor` result of expanding and tiling `tensor`. 2897 2898 Raises: 2899 ValueError: if `multiple` is less than 1, or `dim` is not in 2900 `[-rank(tensor), rank(tensor)]`. 2901 """ 2902 if multiple < 1: 2903 raise ValueError('Invalid multiple %s, must be > 0.' % multiple) 2904 with ops.name_scope(name, 'expand_and_tile', 2905 (tensor, multiple, dim)) as scope: 2906 # Sparse. 2907 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 2908 if isinstance(tensor, sparse_tensor.SparseTensor): 2909 if dim < 0: 2910 expand_dims = array_ops.reshape( 2911 array_ops.size(tensor.dense_shape) + dim, [1]) 2912 else: 2913 expand_dims = [dim] 2914 expanded_shape = array_ops.concat( 2915 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1], 2916 array_ops.slice(tensor.dense_shape, expand_dims, [-1])), 2917 0, 2918 name='expanded_shape') 2919 expanded = sparse_ops.sparse_reshape( 2920 tensor, shape=expanded_shape, name='expand') 2921 if multiple == 1: 2922 return expanded 2923 return sparse_ops.sparse_concat( 2924 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 2925 2926 # Dense. 2927 expanded = array_ops.expand_dims( 2928 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 2929 if multiple == 1: 2930 return expanded 2931 ones = array_ops.ones_like(array_ops.shape(tensor)) 2932 tile_multiples = array_ops.concat( 2933 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 2934 return array_ops.tile(expanded, tile_multiples, name=scope) 2935 2936 2937def _num_relevant(labels, k): 2938 """Computes number of relevant values for each row in labels. 2939 2940 For labels with shape [D1, ... DN, num_labels], this is the minimum of 2941 `num_labels` and `k`. 2942 2943 Args: 2944 labels: `int64` `Tensor` or `SparseTensor` with shape 2945 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2946 target classes for the associated prediction. Commonly, N=1 and `labels` 2947 has shape [batch_size, num_labels]. 2948 k: Integer, k for @k metric. 2949 2950 Returns: 2951 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 2952 relevant values for that row. 2953 2954 Raises: 2955 ValueError: if inputs have invalid dtypes or values. 2956 """ 2957 if k < 1: 2958 raise ValueError('Invalid k=%s.' % k) 2959 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 2960 # For SparseTensor, calculate separate count for each row. 2961 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 2962 if isinstance(labels, sparse_tensor.SparseTensor): 2963 return math_ops.minimum(sets.set_size(labels), k, name=scope) 2964 2965 # For dense Tensor, calculate scalar count based on last dimension, and 2966 # tile across labels shape. 2967 labels_shape = array_ops.shape(labels) 2968 labels_size = labels_shape[-1] 2969 num_relevant_scalar = math_ops.minimum(labels_size, k) 2970 return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope) 2971 2972 2973def _sparse_average_precision_at_top_k(labels, predictions_idx): 2974 """Computes average precision@k of predictions with respect to sparse labels. 2975 2976 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 2977 for each row is: 2978 2979 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 2980 2981 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`, 2982 `labels`, and the result `Tensors`. In the common case, this is [batch_size]. 2983 Each row of the results contains the average precision for that row. 2984 2985 Args: 2986 labels: `int64` `Tensor` or `SparseTensor` with shape 2987 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2988 num_labels=1. N >= 1 and num_labels is the number of target classes for 2989 the associated prediction. Commonly, N=1 and `labels` has shape 2990 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 2991 Values should be in range [0, num_classes). 2992 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2993 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 2994 dimension must be set and contains the top `k` predicted class indices. 2995 [D1, ... DN] must match `labels`. Values should be in range 2996 [0, num_classes). 2997 2998 Returns: 2999 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 3000 precision for that row. 3001 3002 Raises: 3003 ValueError: if the last dimension of predictions_idx is not set. 3004 """ 3005 with ops.name_scope(None, 'average_precision', 3006 (predictions_idx, labels)) as scope: 3007 predictions_idx = math_ops.cast( 3008 predictions_idx, dtypes.int64, name='predictions_idx') 3009 if predictions_idx.get_shape().ndims == 0: 3010 raise ValueError('The rank of predictions_idx must be at least 1.') 3011 k = predictions_idx.get_shape().as_list()[-1] 3012 if k is None: 3013 raise ValueError('The last dimension of predictions_idx must be set.') 3014 labels = _maybe_expand_labels(labels, predictions_idx) 3015 3016 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 3017 # prediction for each k, so we can calculate separate true positive values 3018 # for each k. 3019 predictions_idx_per_k = array_ops.expand_dims( 3020 predictions_idx, -1, name='predictions_idx_per_k') 3021 3022 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 3023 labels_per_k = _expand_and_tile( 3024 labels, multiple=k, dim=-1, name='labels_per_k') 3025 3026 # The following tensors are all of shape [D1, ... DN, k], containing values 3027 # per row, per k value. 3028 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 3029 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 3030 # the formula above. 3031 # `tp_per_k` (int32) - True positive counts. 3032 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 3033 # the precision denominator. 3034 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 3035 # term from the formula above. 3036 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 3037 # precisions at all k for which relevance indicator is true. 3038 relevant_per_k = _sparse_true_positive_at_k( 3039 labels_per_k, predictions_idx_per_k, name='relevant_per_k') 3040 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 3041 retrieved_per_k = math_ops.cumsum( 3042 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 3043 precision_per_k = math_ops.div( 3044 math_ops.cast(tp_per_k, dtypes.float64), 3045 math_ops.cast(retrieved_per_k, dtypes.float64), 3046 name='precision_per_k') 3047 relevant_precision_per_k = math_ops.multiply( 3048 precision_per_k, 3049 math_ops.cast(relevant_per_k, dtypes.float64), 3050 name='relevant_precision_per_k') 3051 3052 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 3053 precision_sum = math_ops.reduce_sum( 3054 relevant_precision_per_k, axis=(-1,), name='precision_sum') 3055 3056 # Divide by number of relevant items to get average precision. These are 3057 # the "num_relevant_items" and "AveP" terms from the formula above. 3058 num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64) 3059 return math_ops.div(precision_sum, num_relevant_items, name=scope) 3060 3061 3062def _streaming_sparse_average_precision_at_top_k(labels, 3063 predictions_idx, 3064 weights=None, 3065 metrics_collections=None, 3066 updates_collections=None, 3067 name=None): 3068 """Computes average precision@k of predictions with respect to sparse labels. 3069 3070 `sparse_average_precision_at_top_k` creates two local variables, 3071 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3072 are used to compute the frequency. This frequency is ultimately returned as 3073 `average_precision_at_<k>`: an idempotent operation that simply divides 3074 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3075 3076 For estimation of the metric over a stream of data, the function creates an 3077 `update_op` operation that updates these variables and returns the 3078 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 3079 the true positives and false positives weighted by `weights`. Then `update_op` 3080 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 3081 values. 3082 3083 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3084 3085 Args: 3086 labels: `int64` `Tensor` or `SparseTensor` with shape 3087 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3088 num_labels=1. N >= 1 and num_labels is the number of target classes for 3089 the associated prediction. Commonly, N=1 and `labels` has shape 3090 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3091 Values should be in range [0, num_classes). 3092 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3093 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3094 dimension contains the top `k` predicted class indices. [D1, ... DN] must 3095 match `labels`. Values should be in range [0, num_classes). 3096 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3097 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3098 dimensions must be either `1`, or the same as the corresponding `labels` 3099 dimension). 3100 metrics_collections: An optional list of collections that values should 3101 be added to. 3102 updates_collections: An optional list of collections that updates should 3103 be added to. 3104 name: Name of new update operation, and namespace for other dependent ops. 3105 3106 Returns: 3107 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3108 precision values. 3109 update: `Operation` that increments variables appropriately, and whose 3110 value matches `metric`. 3111 """ 3112 with ops.name_scope(name, 'average_precision_at_top_k', 3113 (predictions_idx, labels, weights)) as scope: 3114 # Calculate per-example average precision, and apply weights. 3115 average_precision = _sparse_average_precision_at_top_k( 3116 predictions_idx=predictions_idx, labels=labels) 3117 if weights is not None: 3118 weights = weights_broadcast_ops.broadcast_weights( 3119 math_ops.cast(weights, dtypes.float64), average_precision) 3120 average_precision = math_ops.multiply(average_precision, weights) 3121 3122 # Create accumulation variables and update ops for max average precision and 3123 # total average precision. 3124 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 3125 # `max` is the max possible precision. Since max for any row is 1.0: 3126 # - For the unweighted case, this is just the number of rows. 3127 # - For the weighted case, it's the sum of the weights broadcast across 3128 # `average_precision` rows. 3129 max_var = metric_variable([], dtypes.float64, name=max_scope) 3130 if weights is None: 3131 batch_max = math_ops.cast( 3132 array_ops.size(average_precision, name='batch_max'), dtypes.float64) 3133 else: 3134 batch_max = math_ops.reduce_sum(weights, name='batch_max') 3135 max_update = state_ops.assign_add(max_var, batch_max, name='update') 3136 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 3137 total_var = metric_variable([], dtypes.float64, name=total_scope) 3138 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 3139 total_update = state_ops.assign_add(total_var, batch_total, name='update') 3140 3141 # Divide total by max to get mean, for both vars and the update ops. 3142 def precision_across_replicas(_, total_var, max_var): 3143 return _safe_scalar_div(total_var, max_var, name='mean') 3144 3145 mean_average_precision = _aggregate_across_replicas( 3146 metrics_collections, precision_across_replicas, total_var, max_var) 3147 3148 update = _safe_scalar_div(total_update, max_update, name=scope) 3149 if updates_collections: 3150 ops.add_to_collections(updates_collections, update) 3151 3152 return mean_average_precision, update 3153 3154 3155@tf_export(v1=['metrics.sparse_average_precision_at_k']) 3156@deprecated(None, 'Use average_precision_at_k instead') 3157def sparse_average_precision_at_k(labels, 3158 predictions, 3159 k, 3160 weights=None, 3161 metrics_collections=None, 3162 updates_collections=None, 3163 name=None): 3164 """Renamed to `average_precision_at_k`, please use that method instead.""" 3165 return average_precision_at_k( 3166 labels=labels, 3167 predictions=predictions, 3168 k=k, 3169 weights=weights, 3170 metrics_collections=metrics_collections, 3171 updates_collections=updates_collections, 3172 name=name) 3173 3174 3175@tf_export(v1=['metrics.average_precision_at_k']) 3176def average_precision_at_k(labels, 3177 predictions, 3178 k, 3179 weights=None, 3180 metrics_collections=None, 3181 updates_collections=None, 3182 name=None): 3183 """Computes average precision@k of predictions with respect to sparse labels. 3184 3185 `average_precision_at_k` creates two local variables, 3186 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3187 are used to compute the frequency. This frequency is ultimately returned as 3188 `average_precision_at_<k>`: an idempotent operation that simply divides 3189 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3190 3191 For estimation of the metric over a stream of data, the function creates an 3192 `update_op` operation that updates these variables and returns the 3193 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3194 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3195 `labels` calculate the true positives and false positives weighted by 3196 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3197 `false_positive_at_<k>` using these values. 3198 3199 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3200 3201 Args: 3202 labels: `int64` `Tensor` or `SparseTensor` with shape 3203 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3204 num_labels=1. N >= 1 and num_labels is the number of target classes for 3205 the associated prediction. Commonly, N=1 and `labels` has shape 3206 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3207 should be in range [0, num_classes), where num_classes is the last 3208 dimension of `predictions`. Values outside this range are ignored. 3209 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3210 N >= 1. Commonly, N=1 and `predictions` has shape 3211 [batch size, num_classes]. The final dimension contains the logit values 3212 for each class. [D1, ... DN] must match `labels`. 3213 k: Integer, k for @k metric. This will calculate an average precision for 3214 range `[1,k]`, as documented above. 3215 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3216 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3217 dimensions must be either `1`, or the same as the corresponding `labels` 3218 dimension). 3219 metrics_collections: An optional list of collections that values should 3220 be added to. 3221 updates_collections: An optional list of collections that updates should 3222 be added to. 3223 name: Name of new update operation, and namespace for other dependent ops. 3224 3225 Returns: 3226 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3227 precision values. 3228 update: `Operation` that increments variables appropriately, and whose 3229 value matches `metric`. 3230 3231 Raises: 3232 ValueError: if k is invalid. 3233 RuntimeError: If eager execution is enabled. 3234 """ 3235 if context.executing_eagerly(): 3236 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 3237 'supported when eager execution is enabled.') 3238 3239 if k < 1: 3240 raise ValueError('Invalid k=%s.' % k) 3241 with ops.name_scope(name, _at_k_name('average_precision', k), 3242 (predictions, labels, weights)) as scope: 3243 # Calculate top k indices to produce [D1, ... DN, k] tensor. 3244 _, predictions_idx = nn.top_k(predictions, k) 3245 return _streaming_sparse_average_precision_at_top_k( 3246 labels=labels, 3247 predictions_idx=predictions_idx, 3248 weights=weights, 3249 metrics_collections=metrics_collections, 3250 updates_collections=updates_collections, 3251 name=scope) 3252 3253 3254def _sparse_false_positive_at_k(labels, 3255 predictions_idx, 3256 class_id=None, 3257 weights=None): 3258 """Calculates false positives for precision@k. 3259 3260 If `class_id` is specified, calculate binary true positives for `class_id` 3261 only. 3262 If `class_id` is not specified, calculate metrics for `k` predicted vs 3263 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 3264 3265 Args: 3266 labels: `int64` `Tensor` or `SparseTensor` with shape 3267 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3268 target classes for the associated prediction. Commonly, N=1 and `labels` 3269 has shape [batch_size, num_labels]. [D1, ... DN] must match 3270 `predictions_idx`. 3271 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3272 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3273 match `labels`. 3274 class_id: Class for which we want binary metrics. 3275 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3276 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3277 dimensions must be either `1`, or the same as the corresponding `labels` 3278 dimension). 3279 3280 Returns: 3281 A [D1, ... DN] `Tensor` of false positive counts. 3282 """ 3283 with ops.name_scope(None, 'false_positives', 3284 (predictions_idx, labels, weights)): 3285 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 3286 class_id) 3287 fp = sets.set_size( 3288 sets.set_difference(predictions_idx, labels, aminusb=True)) 3289 fp = math_ops.cast(fp, dtypes.float64) 3290 if weights is not None: 3291 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 3292 weights, fp),)): 3293 weights = math_ops.cast(weights, dtypes.float64) 3294 fp = math_ops.multiply(fp, weights) 3295 return fp 3296 3297 3298def _streaming_sparse_false_positive_at_k(labels, 3299 predictions_idx, 3300 k=None, 3301 class_id=None, 3302 weights=None, 3303 name=None): 3304 """Calculates weighted per step false positives for precision@k. 3305 3306 If `class_id` is specified, calculate binary true positives for `class_id` 3307 only. 3308 If `class_id` is not specified, calculate metrics for `k` predicted vs 3309 `n` label classes, where `n` is the 2nd dimension of `labels`. 3310 3311 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3312 3313 Args: 3314 labels: `int64` `Tensor` or `SparseTensor` with shape 3315 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3316 target classes for the associated prediction. Commonly, N=1 and `labels` 3317 has shape [batch_size, num_labels]. [D1, ... DN] must match 3318 `predictions_idx`. 3319 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3320 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3321 match `labels`. 3322 k: Integer, k for @k metric. This is only used for default op name. 3323 class_id: Class for which we want binary metrics. 3324 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3325 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3326 dimensions must be either `1`, or the same as the corresponding `labels` 3327 dimension). 3328 name: Name of new variable, and namespace for other dependent ops. 3329 3330 Returns: 3331 A tuple of `Variable` and update `Operation`. 3332 3333 Raises: 3334 ValueError: If `weights` is not `None` and has an incompatible shape. 3335 """ 3336 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), 3337 (predictions_idx, labels, weights)) as scope: 3338 fp = _sparse_false_positive_at_k( 3339 predictions_idx=predictions_idx, 3340 labels=labels, 3341 class_id=class_id, 3342 weights=weights) 3343 batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64) 3344 3345 var = metric_variable([], dtypes.float64, name=scope) 3346 return var, state_ops.assign_add(var, batch_total_fp, name='update') 3347 3348 3349@tf_export(v1=['metrics.precision_at_top_k']) 3350def precision_at_top_k(labels, 3351 predictions_idx, 3352 k=None, 3353 class_id=None, 3354 weights=None, 3355 metrics_collections=None, 3356 updates_collections=None, 3357 name=None): 3358 """Computes precision@k of the predictions with respect to sparse labels. 3359 3360 Differs from `sparse_precision_at_k` in that predictions must be in the form 3361 of top `k` class indices, whereas `sparse_precision_at_k` expects logits. 3362 Refer to `sparse_precision_at_k` for more details. 3363 3364 Args: 3365 labels: `int64` `Tensor` or `SparseTensor` with shape 3366 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3367 num_labels=1. N >= 1 and num_labels is the number of target classes for 3368 the associated prediction. Commonly, N=1 and `labels` has shape 3369 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3370 should be in range [0, num_classes), where num_classes is the last 3371 dimension of `predictions`. Values outside this range are ignored. 3372 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where 3373 N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 3374 The final dimension contains the top `k` predicted class indices. 3375 [D1, ... DN] must match `labels`. 3376 k: Integer, k for @k metric. Only used for the default op name. 3377 class_id: Integer class ID for which we want binary metrics. This should be 3378 in range [0, num_classes], where num_classes is the last dimension of 3379 `predictions`. If `class_id` is outside this range, the method returns 3380 NAN. 3381 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3382 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3383 dimensions must be either `1`, or the same as the corresponding `labels` 3384 dimension). 3385 metrics_collections: An optional list of collections that values should 3386 be added to. 3387 updates_collections: An optional list of collections that updates should 3388 be added to. 3389 name: Name of new update operation, and namespace for other dependent ops. 3390 3391 Returns: 3392 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3393 divided by the sum of `true_positives` and `false_positives`. 3394 update_op: `Operation` that increments `true_positives` and 3395 `false_positives` variables appropriately, and whose value matches 3396 `precision`. 3397 3398 Raises: 3399 ValueError: If `weights` is not `None` and its shape doesn't match 3400 `predictions`, or if either `metrics_collections` or `updates_collections` 3401 are not a list or tuple. 3402 RuntimeError: If eager execution is enabled. 3403 """ 3404 if context.executing_eagerly(): 3405 raise RuntimeError('tf.metrics.precision_at_top_k is not ' 3406 'supported when eager execution is enabled.') 3407 3408 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3409 (predictions_idx, labels, weights)) as scope: 3410 labels = _maybe_expand_labels(labels, predictions_idx) 3411 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 3412 tp, tp_update = _streaming_sparse_true_positive_at_k( 3413 predictions_idx=top_k_idx, 3414 labels=labels, 3415 k=k, 3416 class_id=class_id, 3417 weights=weights) 3418 fp, fp_update = _streaming_sparse_false_positive_at_k( 3419 predictions_idx=top_k_idx, 3420 labels=labels, 3421 k=k, 3422 class_id=class_id, 3423 weights=weights) 3424 3425 def precision_across_replicas(_, tp, fp): 3426 return math_ops.div(tp, math_ops.add(tp, fp), name=scope) 3427 3428 metric = _aggregate_across_replicas( 3429 metrics_collections, precision_across_replicas, tp, fp) 3430 3431 update = math_ops.div( 3432 tp_update, math_ops.add(tp_update, fp_update), name='update') 3433 if updates_collections: 3434 ops.add_to_collections(updates_collections, update) 3435 return metric, update 3436 3437 3438@tf_export(v1=['metrics.sparse_precision_at_k']) 3439@deprecated(None, 'Use precision_at_k instead') 3440def sparse_precision_at_k(labels, 3441 predictions, 3442 k, 3443 class_id=None, 3444 weights=None, 3445 metrics_collections=None, 3446 updates_collections=None, 3447 name=None): 3448 """Renamed to `precision_at_k`, please use that method instead.""" 3449 return precision_at_k( 3450 labels=labels, 3451 predictions=predictions, 3452 k=k, 3453 class_id=class_id, 3454 weights=weights, 3455 metrics_collections=metrics_collections, 3456 updates_collections=updates_collections, 3457 name=name) 3458 3459 3460@tf_export(v1=['metrics.precision_at_k']) 3461def precision_at_k(labels, 3462 predictions, 3463 k, 3464 class_id=None, 3465 weights=None, 3466 metrics_collections=None, 3467 updates_collections=None, 3468 name=None): 3469 """Computes precision@k of the predictions with respect to sparse labels. 3470 3471 If `class_id` is specified, we calculate precision by considering only the 3472 entries in the batch for which `class_id` is in the top-k highest 3473 `predictions`, and computing the fraction of them for which `class_id` is 3474 indeed a correct label. 3475 If `class_id` is not specified, we'll calculate precision as how often on 3476 average a class among the top-k classes with the highest predicted values 3477 of a batch entry is correct and can be found in the label for that entry. 3478 3479 `precision_at_k` creates two local variables, 3480 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 3481 the precision@k frequency. This frequency is ultimately returned as 3482 `precision_at_<k>`: an idempotent operation that simply divides 3483 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 3484 `false_positive_at_<k>`). 3485 3486 For estimation of the metric over a stream of data, the function creates an 3487 `update_op` operation that updates these variables and returns the 3488 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3489 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3490 `labels` calculate the true positives and false positives weighted by 3491 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3492 `false_positive_at_<k>` using these values. 3493 3494 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3495 3496 Args: 3497 labels: `int64` `Tensor` or `SparseTensor` with shape 3498 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3499 num_labels=1. N >= 1 and num_labels is the number of target classes for 3500 the associated prediction. Commonly, N=1 and `labels` has shape 3501 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3502 should be in range [0, num_classes), where num_classes is the last 3503 dimension of `predictions`. Values outside this range are ignored. 3504 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3505 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 3506 The final dimension contains the logit values for each class. [D1, ... DN] 3507 must match `labels`. 3508 k: Integer, k for @k metric. 3509 class_id: Integer class ID for which we want binary metrics. This should be 3510 in range [0, num_classes], where num_classes is the last dimension of 3511 `predictions`. If `class_id` is outside this range, the method returns 3512 NAN. 3513 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3514 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3515 dimensions must be either `1`, or the same as the corresponding `labels` 3516 dimension). 3517 metrics_collections: An optional list of collections that values should 3518 be added to. 3519 updates_collections: An optional list of collections that updates should 3520 be added to. 3521 name: Name of new update operation, and namespace for other dependent ops. 3522 3523 Returns: 3524 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3525 divided by the sum of `true_positives` and `false_positives`. 3526 update_op: `Operation` that increments `true_positives` and 3527 `false_positives` variables appropriately, and whose value matches 3528 `precision`. 3529 3530 Raises: 3531 ValueError: If `weights` is not `None` and its shape doesn't match 3532 `predictions`, or if either `metrics_collections` or `updates_collections` 3533 are not a list or tuple. 3534 RuntimeError: If eager execution is enabled. 3535 """ 3536 if context.executing_eagerly(): 3537 raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 3538 'supported when eager execution is enabled.') 3539 3540 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3541 (predictions, labels, weights)) as scope: 3542 _, top_k_idx = nn.top_k(predictions, k) 3543 return precision_at_top_k( 3544 labels=labels, 3545 predictions_idx=top_k_idx, 3546 k=k, 3547 class_id=class_id, 3548 weights=weights, 3549 metrics_collections=metrics_collections, 3550 updates_collections=updates_collections, 3551 name=scope) 3552 3553 3554@tf_export(v1=['metrics.specificity_at_sensitivity']) 3555def specificity_at_sensitivity(labels, 3556 predictions, 3557 sensitivity, 3558 weights=None, 3559 num_thresholds=200, 3560 metrics_collections=None, 3561 updates_collections=None, 3562 name=None): 3563 """Computes the specificity at a given sensitivity. 3564 3565 The `specificity_at_sensitivity` function creates four local 3566 variables, `true_positives`, `true_negatives`, `false_positives` and 3567 `false_negatives` that are used to compute the specificity at the given 3568 sensitivity value. The threshold for the given sensitivity value is computed 3569 and used to evaluate the corresponding specificity. 3570 3571 For estimation of the metric over a stream of data, the function creates an 3572 `update_op` operation that updates these variables and returns the 3573 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 3574 `false_positives` and `false_negatives` counts with the weight of each case 3575 found in the `predictions` and `labels`. 3576 3577 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3578 3579 For additional information about specificity and sensitivity, see the 3580 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 3581 3582 Args: 3583 labels: The ground truth values, a `Tensor` whose dimensions must match 3584 `predictions`. Will be cast to `bool`. 3585 predictions: A floating point `Tensor` of arbitrary shape and whose values 3586 are in the range `[0, 1]`. 3587 sensitivity: A scalar value in range `[0, 1]`. 3588 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3589 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 3590 be either `1`, or the same as the corresponding `labels` dimension). 3591 num_thresholds: The number of thresholds to use for matching the given 3592 sensitivity. 3593 metrics_collections: An optional list of collections that `specificity` 3594 should be added to. 3595 updates_collections: An optional list of collections that `update_op` should 3596 be added to. 3597 name: An optional variable_scope name. 3598 3599 Returns: 3600 specificity: A scalar `Tensor` representing the specificity at the given 3601 `specificity` value. 3602 update_op: An operation that increments the `true_positives`, 3603 `true_negatives`, `false_positives` and `false_negatives` variables 3604 appropriately and whose value matches `specificity`. 3605 3606 Raises: 3607 ValueError: If `predictions` and `labels` have mismatched shapes, if 3608 `weights` is not `None` and its shape doesn't match `predictions`, or if 3609 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 3610 or `updates_collections` are not a list or tuple. 3611 RuntimeError: If eager execution is enabled. 3612 """ 3613 if context.executing_eagerly(): 3614 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 3615 'supported when eager execution is enabled.') 3616 3617 if sensitivity < 0 or sensitivity > 1: 3618 raise ValueError('`sensitivity` must be in the range [0, 1].') 3619 3620 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 3621 (predictions, labels, weights)): 3622 kepsilon = 1e-7 # to account for floating point imprecisions 3623 thresholds = [ 3624 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 3625 ] 3626 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 3627 3628 values, update_ops = _confusion_matrix_at_thresholds( 3629 labels, predictions, thresholds, weights) 3630 3631 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name): 3632 """Computes the specificity at the given sensitivity. 3633 3634 Args: 3635 tp: True positives. 3636 tn: True negatives. 3637 fp: False positives. 3638 fn: False negatives. 3639 name: The name of the operation. 3640 3641 Returns: 3642 The specificity using the aggregated values. 3643 """ 3644 sensitivities = math_ops.div(tp, tp + fn + kepsilon) 3645 3646 # We'll need to use this trick until tf.argmax allows us to specify 3647 # whether we should use the first or last index in case of ties. 3648 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 3649 indices_at_minval = math_ops.equal( 3650 math_ops.abs(sensitivities - sensitivity), min_val) 3651 indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64) 3652 indices_at_minval = math_ops.cumsum(indices_at_minval) 3653 tf_index = math_ops.argmax(indices_at_minval, 0) 3654 tf_index = math_ops.cast(tf_index, dtypes.int32) 3655 3656 # Now, we have the implicit threshold, so compute the specificity: 3657 return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, 3658 name) 3659 3660 def specificity_across_replicas(_, values): 3661 return compute_specificity_at_sensitivity( 3662 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 3663 3664 specificity = _aggregate_across_replicas( 3665 metrics_collections, specificity_across_replicas, values) 3666 3667 update_op = compute_specificity_at_sensitivity( 3668 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3669 'update_op') 3670 if updates_collections: 3671 ops.add_to_collections(updates_collections, update_op) 3672 3673 return specificity, update_op 3674