1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utils related to keras metrics. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import weakref 24 25from enum import Enum 26 27from tensorflow.python.distribute import distribution_strategy_context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.keras import backend 31from tensorflow.python.keras.utils import losses_utils 32from tensorflow.python.keras.utils import tf_utils 33from tensorflow.python.keras.utils.generic_utils import to_list 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import check_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_math_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import nn_ops 40from tensorflow.python.ops import weights_broadcast_ops 41from tensorflow.python.ops.ragged import ragged_tensor 42from tensorflow.python.util import tf_decorator 43 44NEG_INF = -1e10 45 46 47class Reduction(Enum): 48 """Types of metrics reduction. 49 50 Contains the following values: 51 52 * `SUM`: Scalar sum of weighted values. 53 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by 54 number of elements. 55 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. 56 """ 57 SUM = 'sum' 58 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 59 WEIGHTED_MEAN = 'weighted_mean' 60 61 62def update_state_wrapper(update_state_fn): 63 """Decorator to wrap metric `update_state()` with `add_update()`. 64 65 Args: 66 update_state_fn: function that accumulates metric statistics. 67 68 Returns: 69 Decorated function that wraps `update_state_fn()` with `add_update()`. 70 """ 71 72 def decorated(metric_obj, *args, **kwargs): 73 """Decorated function with `add_update()`.""" 74 strategy = distribution_strategy_context.get_strategy() 75 # TODO(b/142574744): Remove this check if a better solution is found for 76 # declaring keras Metric outside of TPUStrategy and then updating it per 77 # replica. 78 79 for weight in metric_obj.weights: 80 if (backend.is_tpu_strategy(strategy) and 81 not strategy.extended.variable_created_in_scope(weight) 82 and not distribution_strategy_context.in_cross_replica_context()): 83 raise ValueError( 84 'Trying to run metric.update_state in replica context when ' 85 'the metric was not created in TPUStrategy scope. ' 86 'Make sure the keras Metric is created in TPUstrategy scope. ') 87 88 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): 89 update_op = update_state_fn(*args, **kwargs) 90 if update_op is not None: # update_op will be None in eager execution. 91 metric_obj.add_update(update_op) 92 return update_op 93 94 return tf_decorator.make_decorator(update_state_fn, decorated) 95 96 97def result_wrapper(result_fn): 98 """Decorator to wrap metric `result()` function in `merge_call()`. 99 100 Result computation is an idempotent operation that simply calculates the 101 metric value using the state variables. 102 103 If metric state variables are distributed across replicas/devices and 104 `result()` is requested from the context of one device - This function wraps 105 `result()` in a distribution strategy `merge_call()`. With this, 106 the metric state variables will be aggregated across devices. 107 108 Args: 109 result_fn: function that computes the metric result. 110 111 Returns: 112 Decorated function that wraps `result_fn()` in distribution strategy 113 `merge_call()`. 114 """ 115 116 def decorated(metric_obj, *args): 117 """Decorated function with merge_call.""" 118 has_strategy = distribution_strategy_context.has_strategy() 119 replica_context = distribution_strategy_context.get_replica_context() 120 if not has_strategy or replica_context is None: 121 result_t = array_ops.identity(result_fn(*args)) 122 else: 123 # TODO(psv): Test distribution of metrics using different distribution 124 # strategies. 125 126 # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn 127 # with distribution object as the first parameter. We create a wrapper 128 # here so that the result function need not have that parameter. 129 def merge_fn_wrapper(distribution, merge_fn, *args): 130 # We will get `PerReplica` merge function. Taking the first one as all 131 # are identical copies of the function that we had passed below. 132 result = distribution.experimental_local_results(merge_fn)[0](*args) 133 134 # Wrapping result in identity so that control dependency between 135 # update_op from `update_state` and result works in case result returns 136 # a tensor. 137 return array_ops.identity(result) 138 139 # Wrapping result in merge_call. merge_call is used when we want to leave 140 # replica mode and compute a value in cross replica mode. 141 result_t = replica_context.merge_call( 142 merge_fn_wrapper, args=(result_fn,) + args) 143 144 # We are saving the result op here to be used in train/test execution 145 # functions. This basically gives the result op that was generated with a 146 # control dep to the updates for these workflows. 147 metric_obj._call_result = result_t 148 return result_t 149 150 return tf_decorator.make_decorator(result_fn, decorated) 151 152 153def weakmethod(method): 154 """Creates a weak reference to the bound method.""" 155 156 cls = method.im_class 157 func = method.im_func 158 instance_ref = weakref.ref(method.im_self) 159 160 @functools.wraps(method) 161 def inner(*args, **kwargs): 162 return func.__get__(instance_ref(), cls)(*args, **kwargs) 163 164 del method 165 return inner 166 167 168def assert_thresholds_range(thresholds): 169 if thresholds is not None: 170 invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1] 171 if invalid_thresholds: 172 raise ValueError( 173 'Threshold values must be in [0, 1]. Invalid values: {}'.format( 174 invalid_thresholds)) 175 176 177def parse_init_thresholds(thresholds, default_threshold=0.5): 178 if thresholds is not None: 179 assert_thresholds_range(to_list(thresholds)) 180 thresholds = to_list(default_threshold if thresholds is None else thresholds) 181 return thresholds 182 183 184class ConfusionMatrix(Enum): 185 TRUE_POSITIVES = 'tp' 186 FALSE_POSITIVES = 'fp' 187 TRUE_NEGATIVES = 'tn' 188 FALSE_NEGATIVES = 'fn' 189 190 191class AUCCurve(Enum): 192 """Type of AUC Curve (ROC or PR).""" 193 ROC = 'ROC' 194 PR = 'PR' 195 196 @staticmethod 197 def from_str(key): 198 if key in ('pr', 'PR'): 199 return AUCCurve.PR 200 elif key in ('roc', 'ROC'): 201 return AUCCurve.ROC 202 else: 203 raise ValueError('Invalid AUC curve value "%s".' % key) 204 205 206class AUCSummationMethod(Enum): 207 """Type of AUC summation method. 208 209 https://en.wikipedia.org/wiki/Riemann_sum) 210 211 Contains the following values: 212 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For 213 `PR` curve, interpolates (true/false) positives but not the ratio that is 214 precision (see Davis & Goadrich 2006 for details). 215 * 'minoring': Applies left summation for increasing intervals and right 216 summation for decreasing intervals. 217 * 'majoring': Applies right summation for increasing intervals and left 218 summation for decreasing intervals. 219 """ 220 INTERPOLATION = 'interpolation' 221 MAJORING = 'majoring' 222 MINORING = 'minoring' 223 224 @staticmethod 225 def from_str(key): 226 if key in ('interpolation', 'Interpolation'): 227 return AUCSummationMethod.INTERPOLATION 228 elif key in ('majoring', 'Majoring'): 229 return AUCSummationMethod.MAJORING 230 elif key in ('minoring', 'Minoring'): 231 return AUCSummationMethod.MINORING 232 else: 233 raise ValueError('Invalid AUC summation method value "%s".' % key) 234 235 236def update_confusion_matrix_variables(variables_to_update, 237 y_true, 238 y_pred, 239 thresholds, 240 top_k=None, 241 class_id=None, 242 sample_weight=None, 243 multi_label=False, 244 label_weights=None): 245 """Returns op to update the given confusion matrix variables. 246 247 For every pair of values in y_true and y_pred: 248 249 true_positive: y_true == True and y_pred > thresholds 250 false_negatives: y_true == True and y_pred <= thresholds 251 true_negatives: y_true == False and y_pred <= thresholds 252 false_positive: y_true == False and y_pred > thresholds 253 254 The results will be weighted and added together. When multiple thresholds are 255 provided, we will repeat the same for every threshold. 256 257 For estimation of these metrics over a stream of data, the function creates an 258 `update_op` operation that updates the given variables. 259 260 If `sample_weight` is `None`, weights default to 1. 261 Use weights of 0 to mask values. 262 263 Args: 264 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 265 and corresponding variables to update as values. 266 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. 267 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in 268 the range `[0, 1]`. 269 thresholds: A float value, float tensor, python list, or tuple of float 270 thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). 271 top_k: Optional int, indicates that the positive labels should be limited to 272 the top k predictions. 273 class_id: Optional int, limits the prediction and labels to the class 274 specified by this argument. 275 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 276 `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must 277 be either `1`, or the same as the corresponding `y_true` dimension). 278 multi_label: Optional boolean indicating whether multidimensional 279 prediction/labels should be treated as multilabel responses, or flattened 280 into a single label. When True, the valus of `variables_to_update` must 281 have a second dimension equal to the number of labels in y_true and 282 y_pred, and those tensors must not be RaggedTensors. 283 label_weights: (optional) tensor of non-negative weights for multilabel 284 data. The weights are applied when calculating TP, FP, FN, and TN without 285 explicit multilabel handling (i.e. when the data is to be flattened). 286 287 Returns: 288 Update op. 289 290 Raises: 291 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if 292 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if 293 `variables_to_update` contains invalid keys. 294 """ 295 if multi_label and label_weights is not None: 296 raise ValueError('`label_weights` for multilabel data should be handled ' 297 'outside of `update_confusion_matrix_variables` when ' 298 '`multi_label` is True.') 299 if variables_to_update is None: 300 return 301 if not any( 302 key for key in variables_to_update if key in list(ConfusionMatrix)): 303 raise ValueError( 304 'Please provide at least one valid confusion matrix ' 305 'variable to update. Valid variable key options are: "{}". ' 306 'Received: "{}"'.format( 307 list(ConfusionMatrix), variables_to_update.keys())) 308 309 variable_dtype = list(variables_to_update.values())[0].dtype 310 311 y_true = math_ops.cast(y_true, dtype=variable_dtype) 312 y_pred = math_ops.cast(y_pred, dtype=variable_dtype) 313 thresholds = ops.convert_to_tensor_v2_with_dispatch( 314 thresholds, dtype=variable_dtype) 315 num_thresholds = thresholds.shape[0] 316 if multi_label: 317 one_thresh = math_ops.equal( 318 math_ops.cast(1, dtype=dtypes.int32), 319 array_ops.rank(thresholds), 320 name='one_set_of_thresholds_cond') 321 else: 322 [y_pred, 323 y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], 324 sample_weight) 325 one_thresh = math_ops.cast(True, dtype=dtypes.bool) 326 327 invalid_keys = [ 328 key for key in variables_to_update if key not in list(ConfusionMatrix) 329 ] 330 if invalid_keys: 331 raise ValueError( 332 'Invalid keys: {}. Valid variable key options are: "{}"'.format( 333 invalid_keys, list(ConfusionMatrix))) 334 335 with ops.control_dependencies([ 336 check_ops.assert_greater_equal( 337 y_pred, 338 math_ops.cast(0.0, dtype=y_pred.dtype), 339 message='predictions must be >= 0'), 340 check_ops.assert_less_equal( 341 y_pred, 342 math_ops.cast(1.0, dtype=y_pred.dtype), 343 message='predictions must be <= 1') 344 ]): 345 if sample_weight is None: 346 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 347 y_pred, y_true) 348 else: 349 sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype) 350 y_pred, y_true, sample_weight = ( 351 losses_utils.squeeze_or_expand_dimensions( 352 y_pred, y_true, sample_weight=sample_weight)) 353 y_pred.shape.assert_is_compatible_with(y_true.shape) 354 355 if top_k is not None: 356 y_pred = _filter_top_k(y_pred, top_k) 357 if class_id is not None: 358 y_true = y_true[..., class_id] 359 y_pred = y_pred[..., class_id] 360 361 pred_shape = array_ops.shape(y_pred) 362 num_predictions = pred_shape[0] 363 if y_pred.shape.ndims == 1: 364 num_labels = 1 365 else: 366 num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0) 367 thresh_label_tile = control_flow_ops.cond( 368 one_thresh, lambda: num_labels, 369 lambda: math_ops.cast(1, dtype=dtypes.int32)) 370 371 # Reshape predictions and labels, adding a dim for thresholding. 372 if multi_label: 373 predictions_extra_dim = array_ops.expand_dims(y_pred, 0) 374 labels_extra_dim = array_ops.expand_dims( 375 math_ops.cast(y_true, dtype=dtypes.bool), 0) 376 else: 377 # Flatten predictions and labels when not multilabel. 378 predictions_extra_dim = array_ops.reshape(y_pred, [1, -1]) 379 labels_extra_dim = array_ops.reshape( 380 math_ops.cast(y_true, dtype=dtypes.bool), [1, -1]) 381 382 # Tile the thresholds for every prediction. 383 if multi_label: 384 thresh_pretile_shape = [num_thresholds, 1, -1] 385 thresh_tiles = [1, num_predictions, thresh_label_tile] 386 data_tiles = [num_thresholds, 1, 1] 387 else: 388 thresh_pretile_shape = [num_thresholds, -1] 389 thresh_tiles = [1, num_predictions * num_labels] 390 data_tiles = [num_thresholds, 1] 391 392 thresh_tiled = array_ops.tile( 393 array_ops.reshape(thresholds, thresh_pretile_shape), 394 array_ops.stack(thresh_tiles)) 395 396 # Tile the predictions for every threshold. 397 preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles) 398 399 # Compare predictions and threshold. 400 pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled) 401 402 # Tile labels by number of thresholds 403 label_is_pos = array_ops.tile(labels_extra_dim, data_tiles) 404 405 if sample_weight is not None: 406 sample_weight = weights_broadcast_ops.broadcast_weights( 407 math_ops.cast(sample_weight, dtype=variable_dtype), y_pred) 408 weights_tiled = array_ops.tile( 409 array_ops.reshape(sample_weight, thresh_tiles), data_tiles) 410 else: 411 weights_tiled = None 412 413 if label_weights is not None and not multi_label: 414 label_weights = array_ops.expand_dims(label_weights, 0) 415 label_weights = weights_broadcast_ops.broadcast_weights(label_weights, 416 y_pred) 417 label_weights_tiled = array_ops.tile( 418 array_ops.reshape(label_weights, thresh_tiles), data_tiles) 419 if weights_tiled is None: 420 weights_tiled = label_weights_tiled 421 else: 422 weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled) 423 424 update_ops = [] 425 426 def weighted_assign_add(label, pred, weights, var): 427 label_and_pred = math_ops.cast( 428 math_ops.logical_and(label, pred), dtype=var.dtype) 429 if weights is not None: 430 label_and_pred *= math_ops.cast(weights, dtype=var.dtype) 431 return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) 432 433 loop_vars = { 434 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), 435 } 436 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update 437 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update 438 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update 439 440 if update_fn or update_tn: 441 pred_is_neg = math_ops.logical_not(pred_is_pos) 442 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) 443 444 if update_fp or update_tn: 445 label_is_neg = math_ops.logical_not(label_is_pos) 446 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) 447 if update_tn: 448 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) 449 450 for matrix_cond, (label, pred) in loop_vars.items(): 451 452 if matrix_cond in variables_to_update: 453 update_ops.append( 454 weighted_assign_add(label, pred, weights_tiled, 455 variables_to_update[matrix_cond])) 456 457 return control_flow_ops.group(update_ops) 458 459 460def _filter_top_k(x, k): 461 """Filters top-k values in the last dim of x and set the rest to NEG_INF. 462 463 Used for computing top-k prediction values in dense labels (which has the same 464 shape as predictions) for recall and precision top-k metrics. 465 466 Args: 467 x: tensor with any dimensions. 468 k: the number of values to keep. 469 470 Returns: 471 tensor with same shape and dtype as x. 472 """ 473 _, top_k_idx = nn_ops.top_k(x, k, sorted=False) 474 top_k_mask = math_ops.reduce_sum( 475 array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2) 476 return x * top_k_mask + NEG_INF * (1 - top_k_mask) 477 478 479def ragged_assert_compatible_and_get_flat_values(values, mask=None): 480 """If ragged, it checks the compatibility and then returns the flat_values. 481 482 Note: If two tensors are dense, it does not check their compatibility. 483 Note: Although two ragged tensors with different ragged ranks could have 484 identical overall rank and dimension sizes and hence be compatible, 485 we do not support those cases. 486 Args: 487 values: A list of potentially ragged tensor of the same ragged_rank. 488 mask: A potentially ragged tensor of the same ragged_rank as elements in 489 Values. 490 491 Returns: 492 A tuple in which the first element is the list of tensors and the second 493 is the mask tensor. ([Values], mask). Mask and the element in Values 494 are equal to the flat_values of the input arguments (if they were ragged). 495 """ 496 if isinstance(values, list): 497 is_all_ragged = \ 498 all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) 499 is_any_ragged = \ 500 any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) 501 else: 502 is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor) 503 is_any_ragged = is_all_ragged 504 if (is_all_ragged and 505 ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))): 506 to_be_stripped = False 507 if not isinstance(values, list): 508 values = [values] 509 to_be_stripped = True 510 511 # NOTE: we leave the flat_values compatibility to 512 # tf.TensorShape `assert_is_compatible_with` 513 # check if both dynamic dimensions are equal and then use the flat_values. 514 nested_row_split_list = [rt.nested_row_splits for rt in values] 515 assertion_list = _assert_splits_match(nested_row_split_list) 516 517 # if both are ragged sample_weights also should be ragged with same dims. 518 if isinstance(mask, ragged_tensor.RaggedTensor): 519 assertion_list_for_mask = _assert_splits_match( 520 [nested_row_split_list[0], mask.nested_row_splits]) 521 with ops.control_dependencies(assertion_list_for_mask): 522 mask = array_ops.expand_dims(mask.flat_values, -1) 523 524 # values has at least 1 element. 525 flat_values = [] 526 for value in values: 527 with ops.control_dependencies(assertion_list): 528 flat_values.append(array_ops.expand_dims(value.flat_values, -1)) 529 530 values = flat_values[0] if to_be_stripped else flat_values 531 532 elif is_any_ragged: 533 raise TypeError('One of the inputs does not have acceptable types.') 534 # values are empty or value are not ragged and mask is ragged. 535 elif isinstance(mask, ragged_tensor.RaggedTensor): 536 raise TypeError('Ragged mask is not allowed with non-ragged inputs.') 537 538 return values, mask 539 540 541def _assert_splits_match(nested_splits_lists): 542 """Checks that the given splits lists are identical. 543 544 Performs static tests to ensure that the given splits lists are identical, 545 and returns a list of control dependency op tensors that check that they are 546 fully identical. 547 548 Args: 549 nested_splits_lists: A list of nested_splits_lists, where each split_list is 550 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost 551 ragged dimension to innermost ragged dimension. 552 553 Returns: 554 A list of control dependency op tensors. 555 Raises: 556 ValueError: If the splits are not identical. 557 """ 558 error_msg = 'Inputs must have identical ragged splits' 559 for splits_list in nested_splits_lists: 560 if len(splits_list) != len(nested_splits_lists[0]): 561 raise ValueError(error_msg) 562 return [ 563 check_ops.assert_equal(s1, s2, message=error_msg) # pylint: disable=g-complex-comprehension 564 for splits_list in nested_splits_lists[1:] 565 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 566 ] 567