1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utils related to keras metrics. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import weakref 24 25from enum import Enum 26 27from tensorflow.python.distribute import distribution_strategy_context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.keras.utils.generic_utils import to_list 31from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import nn_ops 37from tensorflow.python.ops import weights_broadcast_ops 38from tensorflow.python.util import tf_decorator 39 40NEG_INF = -1e10 41 42 43class Reduction(Enum): 44 """Types of metrics reduction. 45 46 Contains the following values: 47 48 * `SUM`: Scalar sum of weighted values. 49 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by 50 number of elements. 51 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. 52 """ 53 SUM = 'sum' 54 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 55 WEIGHTED_MEAN = 'weighted_mean' 56 57 58def update_state_wrapper(update_state_fn): 59 """Decorator to wrap metric `update_state()` with `add_update()`. 60 61 Args: 62 update_state_fn: function that accumulates metric statistics. 63 64 Returns: 65 Decorated function that wraps `update_state_fn()` with `add_update()`. 66 """ 67 68 def decorated(metric_obj, *args, **kwargs): 69 """Decorated function with `add_update()`.""" 70 71 update_op = update_state_fn(*args, **kwargs) 72 if update_op is not None: # update_op will be None in eager execution. 73 metric_obj.add_update(update_op, inputs=True) 74 return update_op 75 76 return tf_decorator.make_decorator(update_state_fn, decorated) 77 78 79def result_wrapper(result_fn): 80 """Decorator to wrap metric `result()` function in `merge_call()`. 81 82 Result computation is an idempotent operation that simply calculates the 83 metric value using the state variables. 84 85 If metric state variables are distributed across replicas/devices and 86 `result()` is requested from the context of one device - This function wraps 87 `result()` in a distribution strategy `merge_call()`. With this, 88 the metric state variables will be aggregated across devices. 89 90 Args: 91 result_fn: function that computes the metric result. 92 93 Returns: 94 Decorated function that wraps `result_fn()` in distribution strategy 95 `merge_call()`. 96 """ 97 98 def decorated(_, *args): 99 """Decorated function with merge_call.""" 100 replica_context = distribution_strategy_context.get_replica_context() 101 if replica_context is None: # if in cross replica context already 102 result_t = array_ops.identity(result_fn(*args)) 103 else: 104 # TODO(psv): Test distribution of metrics using different distribution 105 # strategies. 106 107 # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn 108 # with distribution object as the first parameter. We create a wrapper 109 # here so that the result function need not have that parameter. 110 def merge_fn_wrapper(distribution, merge_fn, *args): 111 # We will get `PerDevice` merge function. Taking the first one as all 112 # are identical copies of the function that we had passed below. 113 merged_result_fn = distribution.unwrap(merge_fn)[0](*args) 114 115 # Wrapping result in identity so that control dependency between 116 # update_op from `update_state` and result works in case result returns 117 # a tensor. 118 return array_ops.identity(merged_result_fn) 119 120 # Wrapping result in merge_call. merge_call is used when we want to leave 121 # replica mode and compute a value in cross replica mode. 122 result_t = replica_context.merge_call( 123 merge_fn_wrapper, args=(result_fn,) + args) 124 return result_t 125 126 return tf_decorator.make_decorator(result_fn, decorated) 127 128 129def weakmethod(method): 130 """Creates a weak reference to the bound method.""" 131 132 cls = method.im_class 133 func = method.im_func 134 instance_ref = weakref.ref(method.im_self) 135 136 @functools.wraps(method) 137 def inner(*args, **kwargs): 138 return func.__get__(instance_ref(), cls)(*args, **kwargs) 139 140 del method 141 return inner 142 143 144def assert_thresholds_range(thresholds): 145 if thresholds is not None: 146 invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1] 147 if invalid_thresholds: 148 raise ValueError( 149 'Threshold values must be in [0, 1]. Invalid values: {}'.format( 150 invalid_thresholds)) 151 152 153def parse_init_thresholds(thresholds, default_threshold=0.5): 154 if thresholds is not None: 155 assert_thresholds_range(to_list(thresholds)) 156 thresholds = to_list(default_threshold if thresholds is None else thresholds) 157 return thresholds 158 159 160class ConfusionMatrix(Enum): 161 TRUE_POSITIVES = 'tp' 162 FALSE_POSITIVES = 'fp' 163 TRUE_NEGATIVES = 'tn' 164 FALSE_NEGATIVES = 'fn' 165 166 167class AUCCurve(Enum): 168 """Type of AUC Curve (ROC or PR).""" 169 ROC = 'ROC' 170 PR = 'PR' 171 172 @staticmethod 173 def from_str(key): 174 if key in ('pr', 'PR'): 175 return AUCCurve.PR 176 elif key in ('roc', 'ROC'): 177 return AUCCurve.ROC 178 else: 179 raise ValueError('Invalid AUC curve value "%s".' % key) 180 181 182class AUCSummationMethod(Enum): 183 """Type of AUC summation method. 184 185 https://en.wikipedia.org/wiki/Riemann_sum) 186 187 Contains the following values: 188 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For 189 `PR` curve, interpolates (true/false) positives but not the ratio that is 190 precision (see Davis & Goadrich 2006 for details). 191 * 'minoring': Applies left summation for increasing intervals and right 192 summation for decreasing intervals. 193 * 'majoring': Applies right summation for increasing intervals and left 194 summation for decreasing intervals. 195 """ 196 INTERPOLATION = 'interpolation' 197 MAJORING = 'majoring' 198 MINORING = 'minoring' 199 200 @staticmethod 201 def from_str(key): 202 if key in ('interpolation', 'Interpolation'): 203 return AUCSummationMethod.INTERPOLATION 204 elif key in ('majoring', 'Majoring'): 205 return AUCSummationMethod.MAJORING 206 elif key in ('minoring', 'Minoring'): 207 return AUCSummationMethod.MINORING 208 else: 209 raise ValueError('Invalid AUC summation method value "%s".' % key) 210 211 212def update_confusion_matrix_variables(variables_to_update, 213 y_true, 214 y_pred, 215 thresholds, 216 top_k=None, 217 class_id=None, 218 sample_weight=None): 219 """Returns op to update the given confusion matrix variables. 220 221 For every pair of values in y_true and y_pred: 222 223 true_positive: y_true == True and y_pred > thresholds 224 false_negatives: y_true == True and y_pred <= thresholds 225 true_negatives: y_true == False and y_pred <= thresholds 226 false_positive: y_true == False and y_pred > thresholds 227 228 The results will be weighted and added together. When multiple thresholds are 229 provided, we will repeat the same for every threshold. 230 231 For estimation of these metrics over a stream of data, the function creates an 232 `update_op` operation that updates the given variables. 233 234 If `sample_weight` is `None`, weights default to 1. 235 Use weights of 0 to mask values. 236 237 Args: 238 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 239 and corresponding variables to update as values. 240 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. 241 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in 242 the range `[0, 1]`. 243 thresholds: A float value or a python list or tuple of float thresholds in 244 `[0, 1]`, or NEG_INF (used when top_k is set). 245 top_k: Optional int, indicates that the positive labels should be limited to 246 the top k predictions. 247 class_id: Optional int, limits the prediction and labels to the class 248 specified by this argument. 249 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 250 `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must 251 be either `1`, or the same as the corresponding `y_true` dimension). 252 253 Returns: 254 Update op. 255 256 Raises: 257 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if 258 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if 259 `variables_to_update` contains invalid keys. 260 """ 261 if variables_to_update is None: 262 return 263 y_true = math_ops.cast(y_true, dtype=dtypes.float32) 264 y_pred = math_ops.cast(y_pred, dtype=dtypes.float32) 265 y_pred.shape.assert_is_compatible_with(y_true.shape) 266 267 if not any( 268 key for key in variables_to_update if key in list(ConfusionMatrix)): 269 raise ValueError( 270 'Please provide at least one valid confusion matrix ' 271 'variable to update. Valid variable key options are: "{}". ' 272 'Received: "{}"'.format( 273 list(ConfusionMatrix), variables_to_update.keys())) 274 275 invalid_keys = [ 276 key for key in variables_to_update if key not in list(ConfusionMatrix) 277 ] 278 if invalid_keys: 279 raise ValueError( 280 'Invalid keys: {}. Valid variable key options are: "{}"'.format( 281 invalid_keys, list(ConfusionMatrix))) 282 283 with ops.control_dependencies([ 284 check_ops.assert_greater_equal( 285 y_pred, 286 math_ops.cast(0.0, dtype=y_pred.dtype), 287 message='predictions must be >= 0'), 288 check_ops.assert_less_equal( 289 y_pred, 290 math_ops.cast(1.0, dtype=y_pred.dtype), 291 message='predictions must be <= 1') 292 ]): 293 y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( 294 y_pred, y_true, sample_weight) 295 296 if top_k is not None: 297 y_pred = _filter_top_k(y_pred, top_k) 298 if class_id is not None: 299 y_true = y_true[..., class_id] 300 y_pred = y_pred[..., class_id] 301 302 thresholds = to_list(thresholds) 303 num_thresholds = len(thresholds) 304 num_predictions = array_ops.size(y_pred) 305 306 # Reshape predictions and labels. 307 predictions_2d = array_ops.reshape(y_pred, [1, -1]) 308 labels_2d = array_ops.reshape( 309 math_ops.cast(y_true, dtype=dtypes.bool), [1, -1]) 310 311 # Tile the thresholds for every prediction. 312 thresh_tiled = array_ops.tile( 313 array_ops.expand_dims(array_ops.constant(thresholds), 1), 314 array_ops.stack([1, num_predictions])) 315 316 # Tile the predictions for every threshold. 317 preds_tiled = array_ops.tile(predictions_2d, [num_thresholds, 1]) 318 319 # Compare predictions and threshold. 320 pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled) 321 322 # Tile labels by number of thresholds 323 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 324 325 if sample_weight is not None: 326 weights = weights_broadcast_ops.broadcast_weights( 327 math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred) 328 weights_tiled = array_ops.tile( 329 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 330 else: 331 weights_tiled = None 332 333 update_ops = [] 334 335 def weighted_assign_add(label, pred, weights, var): 336 label_and_pred = math_ops.cast( 337 math_ops.logical_and(label, pred), dtype=dtypes.float32) 338 if weights is not None: 339 label_and_pred *= weights 340 return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) 341 342 loop_vars = { 343 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), 344 } 345 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update 346 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update 347 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update 348 349 if update_fn or update_tn: 350 pred_is_neg = math_ops.logical_not(pred_is_pos) 351 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) 352 353 if update_fp or update_tn: 354 label_is_neg = math_ops.logical_not(label_is_pos) 355 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) 356 if update_tn: 357 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) 358 359 for matrix_cond, (label, pred) in loop_vars.items(): 360 if matrix_cond in variables_to_update: 361 update_ops.append( 362 weighted_assign_add(label, pred, weights_tiled, 363 variables_to_update[matrix_cond])) 364 return control_flow_ops.group(update_ops) 365 366 367def _filter_top_k(x, k): 368 """Filters top-k values in the last dim of x and set the rest to NEG_INF. 369 370 Used for computing top-k prediction values in dense labels (which has the same 371 shape as predictions) for recall and precision top-k metrics. 372 373 Args: 374 x: tensor with any dimensions. 375 k: the number of values to keep. 376 377 Returns: 378 tensor with same shape and dtype as x. 379 """ 380 _, top_k_idx = nn_ops.top_k(x, k, sorted=False) 381 top_k_mask = math_ops.reduce_sum( 382 array_ops.one_hot(top_k_idx, x.shape[-1], axis=-1), axis=-2) 383 return x * top_k_mask + NEG_INF * (1 - top_k_mask) 384