1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unused-import 16# pylint: disable=g-classes-have-attributes 17"""Built-in metrics. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24import math 25import types 26 27import numpy as np 28import six 29 30from tensorflow.python.autograph.core import ag_ctx 31from tensorflow.python.autograph.impl import api as autograph 32from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 33from tensorflow.python.eager import context 34from tensorflow.python.eager import def_function 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.keras import activations 41from tensorflow.python.keras import backend as K 42from tensorflow.python.keras.engine import base_layer 43from tensorflow.python.keras.engine import base_layer_utils 44from tensorflow.python.keras.engine import keras_tensor 45from tensorflow.python.keras.losses import binary_crossentropy 46from tensorflow.python.keras.losses import categorical_crossentropy 47from tensorflow.python.keras.losses import categorical_hinge 48from tensorflow.python.keras.losses import hinge 49from tensorflow.python.keras.losses import kullback_leibler_divergence 50from tensorflow.python.keras.losses import logcosh 51from tensorflow.python.keras.losses import mean_absolute_error 52from tensorflow.python.keras.losses import mean_absolute_percentage_error 53from tensorflow.python.keras.losses import mean_squared_error 54from tensorflow.python.keras.losses import mean_squared_logarithmic_error 55from tensorflow.python.keras.losses import poisson 56from tensorflow.python.keras.losses import sparse_categorical_crossentropy 57from tensorflow.python.keras.losses import squared_hinge 58from tensorflow.python.keras.saving.saved_model import metric_serialization 59from tensorflow.python.keras.utils import losses_utils 60from tensorflow.python.keras.utils import metrics_utils 61from tensorflow.python.keras.utils import tf_inspect 62from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 63from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 64from tensorflow.python.keras.utils.generic_utils import to_list 65from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable 66from tensorflow.python.ops import array_ops 67from tensorflow.python.ops import check_ops 68from tensorflow.python.ops import confusion_matrix 69from tensorflow.python.ops import control_flow_ops 70from tensorflow.python.ops import init_ops 71from tensorflow.python.ops import math_ops 72from tensorflow.python.ops import nn 73from tensorflow.python.ops import variables as tf_variables 74from tensorflow.python.ops import weights_broadcast_ops 75from tensorflow.python.training.tracking import base as trackable 76from tensorflow.python.util import dispatch 77from tensorflow.python.util import nest 78from tensorflow.python.util.tf_export import keras_export 79from tensorflow.tools.docs import doc_controls 80 81 82@keras_export('keras.metrics.Metric') 83@six.add_metaclass(abc.ABCMeta) 84class Metric(base_layer.Layer): 85 """Encapsulates metric logic and state. 86 87 Args: 88 name: (Optional) string name of the metric instance. 89 dtype: (Optional) data type of the metric result. 90 **kwargs: Additional layer keywords arguments. 91 92 Standalone usage: 93 94 ```python 95 m = SomeMetric(...) 96 for input in ...: 97 m.update_state(input) 98 print('Final result: ', m.result().numpy()) 99 ``` 100 101 Usage with `compile()` API: 102 103 ```python 104 model = tf.keras.Sequential() 105 model.add(tf.keras.layers.Dense(64, activation='relu')) 106 model.add(tf.keras.layers.Dense(64, activation='relu')) 107 model.add(tf.keras.layers.Dense(10, activation='softmax')) 108 109 model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01), 110 loss=tf.keras.losses.CategoricalCrossentropy(), 111 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 112 113 data = np.random.random((1000, 32)) 114 labels = np.random.random((1000, 10)) 115 116 dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 117 dataset = dataset.batch(32) 118 119 model.fit(dataset, epochs=10) 120 ``` 121 122 To be implemented by subclasses: 123 * `__init__()`: All state variables should be created in this method by 124 calling `self.add_weight()` like: `self.var = self.add_weight(...)` 125 * `update_state()`: Has all updates to the state variables like: 126 self.var.assign_add(...). 127 * `result()`: Computes and returns a value for the metric 128 from the state variables. 129 130 Example subclass implementation: 131 132 ```python 133 class BinaryTruePositives(tf.keras.metrics.Metric): 134 135 def __init__(self, name='binary_true_positives', **kwargs): 136 super(BinaryTruePositives, self).__init__(name=name, **kwargs) 137 self.true_positives = self.add_weight(name='tp', initializer='zeros') 138 139 def update_state(self, y_true, y_pred, sample_weight=None): 140 y_true = tf.cast(y_true, tf.bool) 141 y_pred = tf.cast(y_pred, tf.bool) 142 143 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) 144 values = tf.cast(values, self.dtype) 145 if sample_weight is not None: 146 sample_weight = tf.cast(sample_weight, self.dtype) 147 sample_weight = tf.broadcast_to(sample_weight, values.shape) 148 values = tf.multiply(values, sample_weight) 149 self.true_positives.assign_add(tf.reduce_sum(values)) 150 151 def result(self): 152 return self.true_positives 153 ``` 154 """ 155 156 def __init__(self, name=None, dtype=None, **kwargs): 157 super(Metric, self).__init__(name=name, dtype=dtype, **kwargs) 158 self.stateful = True # All metric layers are stateful. 159 self.built = True 160 if not base_layer_utils.v2_dtype_behavior_enabled(): 161 # We only do this when the V2 behavior is not enabled, as when it is 162 # enabled, the dtype already defaults to floatx. 163 self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name 164 165 def __new__(cls, *args, **kwargs): 166 obj = super(Metric, cls).__new__(cls) 167 168 # If `update_state` is not in eager/tf.function and it is not from a 169 # built-in metric, wrap it in `tf.function`. This is so that users writing 170 # custom metrics in v1 need not worry about control dependencies and 171 # return ops. 172 if (base_layer_utils.is_in_eager_or_tf_function() or 173 is_built_in(cls)): 174 obj_update_state = obj.update_state 175 176 def update_state_fn(*args, **kwargs): 177 control_status = ag_ctx.control_status_ctx() 178 ag_update_state = autograph.tf_convert(obj_update_state, control_status) 179 return ag_update_state(*args, **kwargs) 180 else: 181 if isinstance(obj.update_state, def_function.Function): 182 update_state_fn = obj.update_state 183 else: 184 update_state_fn = def_function.function(obj.update_state) 185 186 obj.update_state = types.MethodType( 187 metrics_utils.update_state_wrapper(update_state_fn), obj) 188 189 obj_result = obj.result 190 191 def result_fn(*args, **kwargs): 192 control_status = ag_ctx.control_status_ctx() 193 ag_result = autograph.tf_convert(obj_result, control_status) 194 return ag_result(*args, **kwargs) 195 196 obj.result = types.MethodType(metrics_utils.result_wrapper(result_fn), obj) 197 198 return obj 199 200 def __call__(self, *args, **kwargs): 201 """Accumulates statistics and then computes metric result value. 202 203 Args: 204 *args: 205 **kwargs: A mini-batch of inputs to the Metric, 206 passed on to `update_state()`. 207 208 Returns: 209 The metric value tensor. 210 """ 211 212 def replica_local_fn(*args, **kwargs): 213 """Updates the state of the metric in a replica-local context.""" 214 if any( 215 isinstance(arg, keras_tensor.KerasTensor) 216 for arg in nest.flatten((args, kwargs))): 217 update_op = None 218 else: 219 update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable 220 update_ops = [] 221 if update_op is not None: 222 update_ops.append(update_op) 223 with ops.control_dependencies(update_ops): 224 result_t = self.result() # pylint: disable=not-callable 225 226 # We are adding the metric object as metadata on the result tensor. 227 # This is required when we want to use a metric with `add_metric` API on 228 # a Model/Layer in graph mode. This metric instance will later be used 229 # to reset variable state after each epoch of training. 230 # Example: 231 # model = Model() 232 # mean = Mean() 233 # model.add_metric(mean(values), name='mean') 234 result_t._metric_obj = self # pylint: disable=protected-access 235 return result_t 236 237 from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top 238 return distributed_training_utils.call_replica_local_fn( 239 replica_local_fn, *args, **kwargs) 240 241 @property 242 def dtype(self): 243 return self._dtype 244 245 def get_config(self): 246 """Returns the serializable config of the metric.""" 247 return {'name': self.name, 'dtype': self.dtype} 248 249 def reset_states(self): 250 """Resets all of the metric state variables. 251 252 This function is called between epochs/steps, 253 when a metric is evaluated during training. 254 """ 255 K.batch_set_value([(v, 0) for v in self.variables]) 256 257 @abc.abstractmethod 258 def update_state(self, *args, **kwargs): 259 """Accumulates statistics for the metric. 260 261 Note: This function is executed as a graph function in graph mode. 262 This means: 263 a) Operations on the same resource are executed in textual order. 264 This should make it easier to do things like add the updated 265 value of a variable to another, for example. 266 b) You don't need to worry about collecting the update ops to execute. 267 All update ops added to the graph by this function will be executed. 268 As a result, code should generally work the same way with graph or 269 eager execution. 270 271 Args: 272 *args: 273 **kwargs: A mini-batch of inputs to the Metric. 274 """ 275 raise NotImplementedError('Must be implemented in subclasses.') 276 277 @abc.abstractmethod 278 def result(self): 279 """Computes and returns the metric value tensor. 280 281 Result computation is an idempotent operation that simply calculates the 282 metric value using the state variables. 283 """ 284 raise NotImplementedError('Must be implemented in subclasses.') 285 286 ### For use by subclasses ### 287 @doc_controls.for_subclass_implementers 288 def add_weight(self, 289 name, 290 shape=(), 291 aggregation=tf_variables.VariableAggregation.SUM, 292 synchronization=tf_variables.VariableSynchronization.ON_READ, 293 initializer=None, 294 dtype=None): 295 """Adds state variable. Only for use by subclasses.""" 296 if distribute_ctx.has_strategy(): 297 strategy = distribute_ctx.get_strategy() 298 else: 299 strategy = None 300 301 # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU. 302 if K.is_tpu_strategy(strategy): 303 synchronization = tf_variables.VariableSynchronization.ON_WRITE 304 305 with ops.init_scope(): 306 return super(Metric, self).add_weight( 307 name=name, 308 shape=shape, 309 dtype=self._dtype if dtype is None else dtype, 310 trainable=False, 311 initializer=initializer, 312 collections=[], 313 synchronization=synchronization, 314 aggregation=aggregation) 315 316 ### End: For use by subclasses ### 317 318 @property 319 def trainable_weights(self): 320 # Overridden from Layer class to track submetric weights. 321 if self.trainable: 322 trainable_weights = self._trainable_weights 323 for m in self._metrics: 324 trainable_weights += m.trainable_weights 325 return self._dedup_weights(trainable_weights) 326 else: 327 return [] 328 329 @property 330 def non_trainable_weights(self): 331 # Overridden from Layer class to track submetric weights. 332 if self.trainable: 333 non_trainable_weights = self._non_trainable_weights 334 for m in self._metrics: 335 non_trainable_weights += m.non_trainable_weights 336 else: 337 non_trainable_weights = ( 338 self._non_trainable_weights + self._trainable_weights) 339 for m in self._metrics: 340 non_trainable_weights += m.weights 341 return self._dedup_weights(non_trainable_weights) 342 343 @property 344 def _trackable_saved_model_saver(self): 345 return metric_serialization.MetricSavedModelSaver(self) 346 347 348class Reduce(Metric): 349 """Encapsulates metrics that perform a reduce operation on the values. 350 351 Args: 352 reduction: a `tf.keras.metrics.Reduction` enum value. 353 name: string name of the metric instance. 354 dtype: (Optional) data type of the metric result. 355 """ 356 357 def __init__(self, reduction, name, dtype=None): 358 super(Reduce, self).__init__(name=name, dtype=dtype) 359 self.reduction = reduction 360 self.total = self.add_weight( 361 'total', initializer=init_ops.zeros_initializer) 362 if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 363 metrics_utils.Reduction.WEIGHTED_MEAN]: 364 self.count = self.add_weight( 365 'count', initializer=init_ops.zeros_initializer) 366 367 def update_state(self, values, sample_weight=None): 368 """Accumulates statistics for computing the metric. 369 370 Args: 371 values: Per-example value. 372 sample_weight: Optional weighting of each example. Defaults to 1. 373 374 Returns: 375 Update op. 376 """ 377 [values], sample_weight = \ 378 metrics_utils.ragged_assert_compatible_and_get_flat_values( 379 [values], sample_weight) 380 values = math_ops.cast(values, self._dtype) 381 if sample_weight is not None: 382 sample_weight = math_ops.cast(sample_weight, self._dtype) 383 # Update dimensions of weights to match with values if possible. 384 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( 385 values, sample_weight=sample_weight) 386 try: 387 # Broadcast weights if possible. 388 sample_weight = weights_broadcast_ops.broadcast_weights( 389 sample_weight, values) 390 except ValueError: 391 # Reduce values to same ndim as weight array 392 ndim = K.ndim(values) 393 weight_ndim = K.ndim(sample_weight) 394 if self.reduction == metrics_utils.Reduction.SUM: 395 values = math_ops.reduce_sum( 396 values, axis=list(range(weight_ndim, ndim))) 397 else: 398 values = math_ops.reduce_mean( 399 values, axis=list(range(weight_ndim, ndim))) 400 values = math_ops.multiply(values, sample_weight) 401 402 value_sum = math_ops.reduce_sum(values) 403 with ops.control_dependencies([value_sum]): 404 update_total_op = self.total.assign_add(value_sum) 405 406 # Exit early if the reduction doesn't have a denominator. 407 if self.reduction == metrics_utils.Reduction.SUM: 408 return update_total_op 409 410 # Update `count` for reductions that require a denominator. 411 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: 412 num_values = math_ops.cast(array_ops.size(values), self._dtype) 413 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: 414 if sample_weight is None: 415 num_values = math_ops.cast(array_ops.size(values), self._dtype) 416 else: 417 num_values = math_ops.reduce_sum(sample_weight) 418 else: 419 raise NotImplementedError( 420 'reduction [%s] not implemented' % self.reduction) 421 422 with ops.control_dependencies([update_total_op]): 423 return self.count.assign_add(num_values) 424 425 def result(self): 426 if self.reduction == metrics_utils.Reduction.SUM: 427 return array_ops.identity(self.total) 428 elif self.reduction in [ 429 metrics_utils.Reduction.WEIGHTED_MEAN, 430 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE 431 ]: 432 return math_ops.div_no_nan(self.total, self.count) 433 else: 434 raise NotImplementedError( 435 'reduction [%s] not implemented' % self.reduction) 436 437 438@keras_export('keras.metrics.Sum') 439class Sum(Reduce): 440 """Computes the (weighted) sum of the given values. 441 442 For example, if values is [1, 3, 5, 7] then the sum is 16. 443 If the weights were specified as [1, 1, 0, 0] then the sum would be 4. 444 445 This metric creates one variable, `total`, that is used to compute the sum of 446 `values`. This is ultimately returned as `sum`. 447 448 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 449 to mask values. 450 451 Args: 452 name: (Optional) string name of the metric instance. 453 dtype: (Optional) data type of the metric result. 454 455 Standalone usage: 456 457 >>> m = tf.keras.metrics.Sum() 458 >>> m.update_state([1, 3, 5, 7]) 459 >>> m.result().numpy() 460 16.0 461 462 Usage with `compile()` API: 463 464 ```python 465 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs)) 466 model.compile(optimizer='sgd', loss='mse') 467 ``` 468 """ 469 470 def __init__(self, name='sum', dtype=None): 471 super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM, 472 name=name, dtype=dtype) 473 474 475@keras_export('keras.metrics.Mean') 476class Mean(Reduce): 477 """Computes the (weighted) mean of the given values. 478 479 For example, if values is [1, 3, 5, 7] then the mean is 4. 480 If the weights were specified as [1, 1, 0, 0] then the mean would be 2. 481 482 This metric creates two variables, `total` and `count` that are used to 483 compute the average of `values`. This average is ultimately returned as `mean` 484 which is an idempotent operation that simply divides `total` by `count`. 485 486 If `sample_weight` is `None`, weights default to 1. 487 Use `sample_weight` of 0 to mask values. 488 489 Args: 490 name: (Optional) string name of the metric instance. 491 dtype: (Optional) data type of the metric result. 492 493 Standalone usage: 494 495 >>> m = tf.keras.metrics.Mean() 496 >>> m.update_state([1, 3, 5, 7]) 497 >>> m.result().numpy() 498 4.0 499 >>> m.reset_states() 500 >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) 501 >>> m.result().numpy() 502 2.0 503 504 Usage with `compile()` API: 505 506 ```python 507 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs)) 508 model.compile(optimizer='sgd', loss='mse') 509 ``` 510 """ 511 512 def __init__(self, name='mean', dtype=None): 513 super(Mean, self).__init__( 514 reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype) 515 516 517@keras_export('keras.metrics.MeanRelativeError') 518class MeanRelativeError(Mean): 519 """Computes the mean relative error by normalizing with the given values. 520 521 This metric creates two local variables, `total` and `count` that are used to 522 compute the mean relative error. This is weighted by `sample_weight`, and 523 it is ultimately returned as `mean_relative_error`: 524 an idempotent operation that simply divides `total` by `count`. 525 526 If `sample_weight` is `None`, weights default to 1. 527 Use `sample_weight` of 0 to mask values. 528 529 Args: 530 normalizer: The normalizer values with same shape as predictions. 531 name: (Optional) string name of the metric instance. 532 dtype: (Optional) data type of the metric result. 533 534 Standalone usage: 535 536 >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3]) 537 >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8]) 538 539 >>> # metric = mean(|y_pred - y_true| / normalizer) 540 >>> # = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3]) 541 >>> # = 5/4 = 1.25 542 >>> m.result().numpy() 543 1.25 544 545 Usage with `compile()` API: 546 547 ```python 548 model.compile( 549 optimizer='sgd', 550 loss='mse', 551 metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])]) 552 ``` 553 """ 554 555 def __init__(self, normalizer, name=None, dtype=None): 556 super(MeanRelativeError, self).__init__(name=name, dtype=dtype) 557 normalizer = math_ops.cast(normalizer, self._dtype) 558 self.normalizer = normalizer 559 560 def update_state(self, y_true, y_pred, sample_weight=None): 561 """Accumulates metric statistics. 562 563 Args: 564 y_true: The ground truth values. 565 y_pred: The predicted values. 566 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 567 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 568 be broadcastable to `y_true`. 569 570 Returns: 571 Update op. 572 """ 573 y_true = math_ops.cast(y_true, self._dtype) 574 y_pred = math_ops.cast(y_pred, self._dtype) 575 [y_pred, y_true], sample_weight = \ 576 metrics_utils.ragged_assert_compatible_and_get_flat_values( 577 [y_pred, y_true], sample_weight) 578 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 579 y_pred, y_true) 580 581 y_pred, self.normalizer = losses_utils.remove_squeezable_dimensions( 582 y_pred, self.normalizer) 583 y_pred.shape.assert_is_compatible_with(y_true.shape) 584 relative_errors = math_ops.div_no_nan( 585 math_ops.abs(y_true - y_pred), self.normalizer) 586 587 return super(MeanRelativeError, self).update_state( 588 relative_errors, sample_weight=sample_weight) 589 590 def get_config(self): 591 n = self.normalizer 592 config = {'normalizer': K.eval(n) if is_tensor_or_variable(n) else n} 593 base_config = super(MeanRelativeError, self).get_config() 594 return dict(list(base_config.items()) + list(config.items())) 595 596 597class MeanMetricWrapper(Mean): 598 """Wraps a stateless metric function with the Mean metric. 599 600 Args: 601 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 602 **kwargs)`. 603 name: (Optional) string name of the metric instance. 604 dtype: (Optional) data type of the metric result. 605 **kwargs: The keyword arguments that are passed on to `fn`. 606 """ 607 608 def __init__(self, fn, name=None, dtype=None, **kwargs): 609 super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype) 610 self._fn = fn 611 self._fn_kwargs = kwargs 612 613 def update_state(self, y_true, y_pred, sample_weight=None): 614 """Accumulates metric statistics. 615 616 `y_true` and `y_pred` should have the same shape. 617 618 Args: 619 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 620 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 621 sample_weight: Optional `sample_weight` acts as a 622 coefficient for the metric. If a scalar is provided, then the metric is 623 simply scaled by the given value. If `sample_weight` is a tensor of size 624 `[batch_size]`, then the metric for each sample of the batch is rescaled 625 by the corresponding element in the `sample_weight` vector. If the shape 626 of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted 627 to this shape), then each metric element of `y_pred` is scaled by the 628 corresponding value of `sample_weight`. (Note on `dN-1`: all metric 629 functions reduce by 1 dimension, usually the last axis (-1)). 630 631 Returns: 632 Update op. 633 """ 634 y_true = math_ops.cast(y_true, self._dtype) 635 y_pred = math_ops.cast(y_pred, self._dtype) 636 [y_true, y_pred], sample_weight = \ 637 metrics_utils.ragged_assert_compatible_and_get_flat_values( 638 [y_true, y_pred], sample_weight) 639 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 640 y_pred, y_true) 641 642 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) 643 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 644 return super(MeanMetricWrapper, self).update_state( 645 matches, sample_weight=sample_weight) 646 647 def get_config(self): 648 config = {} 649 650 if type(self) is MeanMetricWrapper: # pylint: disable=unidiomatic-typecheck 651 # Only include function argument when the object is a MeanMetricWrapper 652 # and not a subclass. 653 config['fn'] = self._fn 654 655 for k, v in six.iteritems(self._fn_kwargs): 656 config[k] = K.eval(v) if is_tensor_or_variable(v) else v 657 base_config = super(MeanMetricWrapper, self).get_config() 658 return dict(list(base_config.items()) + list(config.items())) 659 660 @classmethod 661 def from_config(cls, config): 662 # Note that while MeanMetricWrapper itself isn't public, objects of this 663 # class may be created and added to the model by calling model.compile. 664 fn = config.pop('fn', None) 665 if cls is MeanMetricWrapper: 666 return cls(get(fn), **config) 667 return super(MeanMetricWrapper, cls).from_config(config) 668 669 670@keras_export('keras.metrics.Accuracy') 671class Accuracy(MeanMetricWrapper): 672 """Calculates how often predictions equal labels. 673 674 This metric creates two local variables, `total` and `count` that are used to 675 compute the frequency with which `y_pred` matches `y_true`. This frequency is 676 ultimately returned as `binary accuracy`: an idempotent operation that simply 677 divides `total` by `count`. 678 679 If `sample_weight` is `None`, weights default to 1. 680 Use `sample_weight` of 0 to mask values. 681 682 Args: 683 name: (Optional) string name of the metric instance. 684 dtype: (Optional) data type of the metric result. 685 686 Standalone usage: 687 688 >>> m = tf.keras.metrics.Accuracy() 689 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]]) 690 >>> m.result().numpy() 691 0.75 692 693 >>> m.reset_states() 694 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]], 695 ... sample_weight=[1, 1, 0, 0]) 696 >>> m.result().numpy() 697 0.5 698 699 Usage with `compile()` API: 700 701 ```python 702 model.compile(optimizer='sgd', 703 loss='mse', 704 metrics=[tf.keras.metrics.Accuracy()]) 705 ``` 706 """ 707 708 def __init__(self, name='accuracy', dtype=None): 709 super(Accuracy, self).__init__(accuracy, name, dtype=dtype) 710 711 712@keras_export('keras.metrics.BinaryAccuracy') 713class BinaryAccuracy(MeanMetricWrapper): 714 """Calculates how often predictions match binary labels. 715 716 This metric creates two local variables, `total` and `count` that are used to 717 compute the frequency with which `y_pred` matches `y_true`. This frequency is 718 ultimately returned as `binary accuracy`: an idempotent operation that simply 719 divides `total` by `count`. 720 721 If `sample_weight` is `None`, weights default to 1. 722 Use `sample_weight` of 0 to mask values. 723 724 Args: 725 name: (Optional) string name of the metric instance. 726 dtype: (Optional) data type of the metric result. 727 threshold: (Optional) Float representing the threshold for deciding 728 whether prediction values are 1 or 0. 729 730 Standalone usage: 731 732 >>> m = tf.keras.metrics.BinaryAccuracy() 733 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) 734 >>> m.result().numpy() 735 0.75 736 737 >>> m.reset_states() 738 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], 739 ... sample_weight=[1, 0, 0, 1]) 740 >>> m.result().numpy() 741 0.5 742 743 Usage with `compile()` API: 744 745 ```python 746 model.compile(optimizer='sgd', 747 loss='mse', 748 metrics=[tf.keras.metrics.BinaryAccuracy()]) 749 ``` 750 """ 751 752 def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): 753 super(BinaryAccuracy, self).__init__( 754 binary_accuracy, name, dtype=dtype, threshold=threshold) 755 756 757@keras_export('keras.metrics.CategoricalAccuracy') 758class CategoricalAccuracy(MeanMetricWrapper): 759 """Calculates how often predictions match one-hot labels. 760 761 You can provide logits of classes as `y_pred`, since argmax of 762 logits and probabilities are same. 763 764 This metric creates two local variables, `total` and `count` that are used to 765 compute the frequency with which `y_pred` matches `y_true`. This frequency is 766 ultimately returned as `categorical accuracy`: an idempotent operation that 767 simply divides `total` by `count`. 768 769 `y_pred` and `y_true` should be passed in as vectors of probabilities, rather 770 than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector. 771 772 If `sample_weight` is `None`, weights default to 1. 773 Use `sample_weight` of 0 to mask values. 774 775 Args: 776 name: (Optional) string name of the metric instance. 777 dtype: (Optional) data type of the metric result. 778 779 Standalone usage: 780 781 >>> m = tf.keras.metrics.CategoricalAccuracy() 782 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], 783 ... [0.05, 0.95, 0]]) 784 >>> m.result().numpy() 785 0.5 786 787 >>> m.reset_states() 788 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], 789 ... [0.05, 0.95, 0]], 790 ... sample_weight=[0.7, 0.3]) 791 >>> m.result().numpy() 792 0.3 793 794 Usage with `compile()` API: 795 796 ```python 797 model.compile( 798 optimizer='sgd', 799 loss='mse', 800 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 801 ``` 802 """ 803 804 def __init__(self, name='categorical_accuracy', dtype=None): 805 super(CategoricalAccuracy, self).__init__( 806 categorical_accuracy, name, dtype=dtype) 807 808 809@keras_export('keras.metrics.SparseCategoricalAccuracy') 810class SparseCategoricalAccuracy(MeanMetricWrapper): 811 """Calculates how often predictions match integer labels. 812 813 ```python 814 acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1)) 815 ``` 816 817 You can provide logits of classes as `y_pred`, since argmax of 818 logits and probabilities are same. 819 820 This metric creates two local variables, `total` and `count` that are used to 821 compute the frequency with which `y_pred` matches `y_true`. This frequency is 822 ultimately returned as `sparse categorical accuracy`: an idempotent operation 823 that simply divides `total` by `count`. 824 825 If `sample_weight` is `None`, weights default to 1. 826 Use `sample_weight` of 0 to mask values. 827 828 Args: 829 name: (Optional) string name of the metric instance. 830 dtype: (Optional) data type of the metric result. 831 832 Standalone usage: 833 834 >>> m = tf.keras.metrics.SparseCategoricalAccuracy() 835 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) 836 >>> m.result().numpy() 837 0.5 838 839 >>> m.reset_states() 840 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], 841 ... sample_weight=[0.7, 0.3]) 842 >>> m.result().numpy() 843 0.3 844 845 Usage with `compile()` API: 846 847 ```python 848 model.compile( 849 optimizer='sgd', 850 loss='mse', 851 metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) 852 ``` 853 """ 854 855 def __init__(self, name='sparse_categorical_accuracy', dtype=None): 856 super(SparseCategoricalAccuracy, self).__init__( 857 sparse_categorical_accuracy, name, dtype=dtype) 858 859 860@keras_export('keras.metrics.TopKCategoricalAccuracy') 861class TopKCategoricalAccuracy(MeanMetricWrapper): 862 """Computes how often targets are in the top `K` predictions. 863 864 Args: 865 k: (Optional) Number of top elements to look at for computing accuracy. 866 Defaults to 5. 867 name: (Optional) string name of the metric instance. 868 dtype: (Optional) data type of the metric result. 869 870 Standalone usage: 871 872 >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) 873 >>> m.update_state([[0, 0, 1], [0, 1, 0]], 874 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 875 >>> m.result().numpy() 876 0.5 877 878 >>> m.reset_states() 879 >>> m.update_state([[0, 0, 1], [0, 1, 0]], 880 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], 881 ... sample_weight=[0.7, 0.3]) 882 >>> m.result().numpy() 883 0.3 884 885 Usage with `compile()` API: 886 887 ```python 888 model.compile(optimizer='sgd', 889 loss='mse', 890 metrics=[tf.keras.metrics.TopKCategoricalAccuracy()]) 891 ``` 892 """ 893 894 def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): 895 super(TopKCategoricalAccuracy, self).__init__( 896 top_k_categorical_accuracy, name, dtype=dtype, k=k) 897 898 899@keras_export('keras.metrics.SparseTopKCategoricalAccuracy') 900class SparseTopKCategoricalAccuracy(MeanMetricWrapper): 901 """Computes how often integer targets are in the top `K` predictions. 902 903 Args: 904 k: (Optional) Number of top elements to look at for computing accuracy. 905 Defaults to 5. 906 name: (Optional) string name of the metric instance. 907 dtype: (Optional) data type of the metric result. 908 909 Standalone usage: 910 911 >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1) 912 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 913 >>> m.result().numpy() 914 0.5 915 916 >>> m.reset_states() 917 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], 918 ... sample_weight=[0.7, 0.3]) 919 >>> m.result().numpy() 920 0.3 921 922 Usage with `compile()` API: 923 924 ```python 925 model.compile( 926 optimizer='sgd', 927 loss='mse', 928 metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()]) 929 ``` 930 """ 931 932 def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None): 933 super(SparseTopKCategoricalAccuracy, self).__init__( 934 sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k) 935 936 937class _ConfusionMatrixConditionCount(Metric): 938 """Calculates the number of the given confusion matrix condition. 939 940 Args: 941 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. 942 thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple 943 of float threshold values in [0, 1]. A threshold is compared with 944 prediction values to determine the truth value of predictions (i.e., above 945 the threshold is `true`, below is `false`). One metric value is generated 946 for each threshold value. 947 name: (Optional) string name of the metric instance. 948 dtype: (Optional) data type of the metric result. 949 """ 950 951 def __init__(self, 952 confusion_matrix_cond, 953 thresholds=None, 954 name=None, 955 dtype=None): 956 super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype) 957 self._confusion_matrix_cond = confusion_matrix_cond 958 self.init_thresholds = thresholds 959 self.thresholds = metrics_utils.parse_init_thresholds( 960 thresholds, default_threshold=0.5) 961 self.accumulator = self.add_weight( 962 'accumulator', 963 shape=(len(self.thresholds),), 964 initializer=init_ops.zeros_initializer) 965 966 def update_state(self, y_true, y_pred, sample_weight=None): 967 """Accumulates the metric statistics. 968 969 Args: 970 y_true: The ground truth values. 971 y_pred: The predicted values. 972 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 973 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 974 be broadcastable to `y_true`. 975 976 Returns: 977 Update op. 978 """ 979 return metrics_utils.update_confusion_matrix_variables( 980 {self._confusion_matrix_cond: self.accumulator}, 981 y_true, 982 y_pred, 983 thresholds=self.thresholds, 984 sample_weight=sample_weight) 985 986 def result(self): 987 if len(self.thresholds) == 1: 988 result = self.accumulator[0] 989 else: 990 result = self.accumulator 991 return ops.convert_to_tensor_v2_with_dispatch(result) 992 993 def reset_states(self): 994 num_thresholds = len(to_list(self.thresholds)) 995 K.batch_set_value( 996 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 997 998 def get_config(self): 999 config = {'thresholds': self.init_thresholds} 1000 base_config = super(_ConfusionMatrixConditionCount, self).get_config() 1001 return dict(list(base_config.items()) + list(config.items())) 1002 1003 1004@keras_export('keras.metrics.FalsePositives') 1005class FalsePositives(_ConfusionMatrixConditionCount): 1006 """Calculates the number of false positives. 1007 1008 If `sample_weight` is given, calculates the sum of the weights of 1009 false positives. This metric creates one local variable, `accumulator` 1010 that is used to keep track of the number of false positives. 1011 1012 If `sample_weight` is `None`, weights default to 1. 1013 Use `sample_weight` of 0 to mask values. 1014 1015 Args: 1016 thresholds: (Optional) Defaults to 0.5. A float value or a python 1017 list/tuple of float threshold values in [0, 1]. A threshold is compared 1018 with prediction values to determine the truth value of predictions 1019 (i.e., above the threshold is `true`, below is `false`). One metric 1020 value is generated for each threshold value. 1021 name: (Optional) string name of the metric instance. 1022 dtype: (Optional) data type of the metric result. 1023 1024 Standalone usage: 1025 1026 >>> m = tf.keras.metrics.FalsePositives() 1027 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) 1028 >>> m.result().numpy() 1029 2.0 1030 1031 >>> m.reset_states() 1032 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1033 >>> m.result().numpy() 1034 1.0 1035 1036 Usage with `compile()` API: 1037 1038 ```python 1039 model.compile(optimizer='sgd', 1040 loss='mse', 1041 metrics=[tf.keras.metrics.FalsePositives()]) 1042 ``` 1043 """ 1044 1045 def __init__(self, thresholds=None, name=None, dtype=None): 1046 super(FalsePositives, self).__init__( 1047 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, 1048 thresholds=thresholds, 1049 name=name, 1050 dtype=dtype) 1051 1052 1053@keras_export('keras.metrics.FalseNegatives') 1054class FalseNegatives(_ConfusionMatrixConditionCount): 1055 """Calculates the number of false negatives. 1056 1057 If `sample_weight` is given, calculates the sum of the weights of 1058 false negatives. This metric creates one local variable, `accumulator` 1059 that is used to keep track of the number of false negatives. 1060 1061 If `sample_weight` is `None`, weights default to 1. 1062 Use `sample_weight` of 0 to mask values. 1063 1064 Args: 1065 thresholds: (Optional) Defaults to 0.5. A float value or a python 1066 list/tuple of float threshold values in [0, 1]. A threshold is compared 1067 with prediction values to determine the truth value of predictions 1068 (i.e., above the threshold is `true`, below is `false`). One metric 1069 value is generated for each threshold value. 1070 name: (Optional) string name of the metric instance. 1071 dtype: (Optional) data type of the metric result. 1072 1073 Standalone usage: 1074 1075 >>> m = tf.keras.metrics.FalseNegatives() 1076 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) 1077 >>> m.result().numpy() 1078 2.0 1079 1080 >>> m.reset_states() 1081 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 1082 >>> m.result().numpy() 1083 1.0 1084 1085 Usage with `compile()` API: 1086 1087 ```python 1088 model.compile(optimizer='sgd', 1089 loss='mse', 1090 metrics=[tf.keras.metrics.FalseNegatives()]) 1091 ``` 1092 """ 1093 1094 def __init__(self, thresholds=None, name=None, dtype=None): 1095 super(FalseNegatives, self).__init__( 1096 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, 1097 thresholds=thresholds, 1098 name=name, 1099 dtype=dtype) 1100 1101 1102@keras_export('keras.metrics.TrueNegatives') 1103class TrueNegatives(_ConfusionMatrixConditionCount): 1104 """Calculates the number of true negatives. 1105 1106 If `sample_weight` is given, calculates the sum of the weights of 1107 true negatives. This metric creates one local variable, `accumulator` 1108 that is used to keep track of the number of true negatives. 1109 1110 If `sample_weight` is `None`, weights default to 1. 1111 Use `sample_weight` of 0 to mask values. 1112 1113 Args: 1114 thresholds: (Optional) Defaults to 0.5. A float value or a python 1115 list/tuple of float threshold values in [0, 1]. A threshold is compared 1116 with prediction values to determine the truth value of predictions 1117 (i.e., above the threshold is `true`, below is `false`). One metric 1118 value is generated for each threshold value. 1119 name: (Optional) string name of the metric instance. 1120 dtype: (Optional) data type of the metric result. 1121 1122 Standalone usage: 1123 1124 >>> m = tf.keras.metrics.TrueNegatives() 1125 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) 1126 >>> m.result().numpy() 1127 2.0 1128 1129 >>> m.reset_states() 1130 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 1131 >>> m.result().numpy() 1132 1.0 1133 1134 Usage with `compile()` API: 1135 1136 ```python 1137 model.compile(optimizer='sgd', 1138 loss='mse', 1139 metrics=[tf.keras.metrics.TrueNegatives()]) 1140 ``` 1141 """ 1142 1143 def __init__(self, thresholds=None, name=None, dtype=None): 1144 super(TrueNegatives, self).__init__( 1145 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, 1146 thresholds=thresholds, 1147 name=name, 1148 dtype=dtype) 1149 1150 1151@keras_export('keras.metrics.TruePositives') 1152class TruePositives(_ConfusionMatrixConditionCount): 1153 """Calculates the number of true positives. 1154 1155 If `sample_weight` is given, calculates the sum of the weights of 1156 true positives. This metric creates one local variable, `true_positives` 1157 that is used to keep track of the number of true positives. 1158 1159 If `sample_weight` is `None`, weights default to 1. 1160 Use `sample_weight` of 0 to mask values. 1161 1162 Args: 1163 thresholds: (Optional) Defaults to 0.5. A float value or a python 1164 list/tuple of float threshold values in [0, 1]. A threshold is compared 1165 with prediction values to determine the truth value of predictions 1166 (i.e., above the threshold is `true`, below is `false`). One metric 1167 value is generated for each threshold value. 1168 name: (Optional) string name of the metric instance. 1169 dtype: (Optional) data type of the metric result. 1170 1171 Standalone usage: 1172 1173 >>> m = tf.keras.metrics.TruePositives() 1174 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1175 >>> m.result().numpy() 1176 2.0 1177 1178 >>> m.reset_states() 1179 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1180 >>> m.result().numpy() 1181 1.0 1182 1183 Usage with `compile()` API: 1184 1185 ```python 1186 model.compile(optimizer='sgd', 1187 loss='mse', 1188 metrics=[tf.keras.metrics.TruePositives()]) 1189 ``` 1190 """ 1191 1192 def __init__(self, thresholds=None, name=None, dtype=None): 1193 super(TruePositives, self).__init__( 1194 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, 1195 thresholds=thresholds, 1196 name=name, 1197 dtype=dtype) 1198 1199 1200@keras_export('keras.metrics.Precision') 1201class Precision(Metric): 1202 """Computes the precision of the predictions with respect to the labels. 1203 1204 The metric creates two local variables, `true_positives` and `false_positives` 1205 that are used to compute the precision. This value is ultimately returned as 1206 `precision`, an idempotent operation that simply divides `true_positives` 1207 by the sum of `true_positives` and `false_positives`. 1208 1209 If `sample_weight` is `None`, weights default to 1. 1210 Use `sample_weight` of 0 to mask values. 1211 1212 If `top_k` is set, we'll calculate precision as how often on average a class 1213 among the top-k classes with the highest predicted values of a batch entry is 1214 correct and can be found in the label for that entry. 1215 1216 If `class_id` is specified, we calculate precision by considering only the 1217 entries in the batch for which `class_id` is above the threshold and/or in the 1218 top-k highest predictions, and computing the fraction of them for which 1219 `class_id` is indeed a correct label. 1220 1221 Args: 1222 thresholds: (Optional) A float value or a python list/tuple of float 1223 threshold values in [0, 1]. A threshold is compared with prediction 1224 values to determine the truth value of predictions (i.e., above the 1225 threshold is `true`, below is `false`). One metric value is generated 1226 for each threshold value. If neither thresholds nor top_k are set, the 1227 default is to calculate precision with `thresholds=0.5`. 1228 top_k: (Optional) Unset by default. An int value specifying the top-k 1229 predictions to consider when calculating precision. 1230 class_id: (Optional) Integer class ID for which we want binary metrics. 1231 This must be in the half-open interval `[0, num_classes)`, where 1232 `num_classes` is the last dimension of predictions. 1233 name: (Optional) string name of the metric instance. 1234 dtype: (Optional) data type of the metric result. 1235 1236 Standalone usage: 1237 1238 >>> m = tf.keras.metrics.Precision() 1239 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1240 >>> m.result().numpy() 1241 0.6666667 1242 1243 >>> m.reset_states() 1244 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1245 >>> m.result().numpy() 1246 1.0 1247 1248 >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2] 1249 >>> m = tf.keras.metrics.Precision(top_k=2) 1250 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 1251 >>> m.result().numpy() 1252 0.0 1253 1254 >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4] 1255 >>> m = tf.keras.metrics.Precision(top_k=4) 1256 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 1257 >>> m.result().numpy() 1258 0.5 1259 1260 Usage with `compile()` API: 1261 1262 ```python 1263 model.compile(optimizer='sgd', 1264 loss='mse', 1265 metrics=[tf.keras.metrics.Precision()]) 1266 ``` 1267 """ 1268 1269 def __init__(self, 1270 thresholds=None, 1271 top_k=None, 1272 class_id=None, 1273 name=None, 1274 dtype=None): 1275 super(Precision, self).__init__(name=name, dtype=dtype) 1276 self.init_thresholds = thresholds 1277 self.top_k = top_k 1278 self.class_id = class_id 1279 1280 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 1281 self.thresholds = metrics_utils.parse_init_thresholds( 1282 thresholds, default_threshold=default_threshold) 1283 self.true_positives = self.add_weight( 1284 'true_positives', 1285 shape=(len(self.thresholds),), 1286 initializer=init_ops.zeros_initializer) 1287 self.false_positives = self.add_weight( 1288 'false_positives', 1289 shape=(len(self.thresholds),), 1290 initializer=init_ops.zeros_initializer) 1291 1292 def update_state(self, y_true, y_pred, sample_weight=None): 1293 """Accumulates true positive and false positive statistics. 1294 1295 Args: 1296 y_true: The ground truth values, with the same dimensions as `y_pred`. 1297 Will be cast to `bool`. 1298 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 1299 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1300 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1301 be broadcastable to `y_true`. 1302 1303 Returns: 1304 Update op. 1305 """ 1306 return metrics_utils.update_confusion_matrix_variables( 1307 { 1308 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1309 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives 1310 }, 1311 y_true, 1312 y_pred, 1313 thresholds=self.thresholds, 1314 top_k=self.top_k, 1315 class_id=self.class_id, 1316 sample_weight=sample_weight) 1317 1318 def result(self): 1319 result = math_ops.div_no_nan(self.true_positives, 1320 self.true_positives + self.false_positives) 1321 return result[0] if len(self.thresholds) == 1 else result 1322 1323 def reset_states(self): 1324 num_thresholds = len(to_list(self.thresholds)) 1325 K.batch_set_value( 1326 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1327 1328 def get_config(self): 1329 config = { 1330 'thresholds': self.init_thresholds, 1331 'top_k': self.top_k, 1332 'class_id': self.class_id 1333 } 1334 base_config = super(Precision, self).get_config() 1335 return dict(list(base_config.items()) + list(config.items())) 1336 1337 1338@keras_export('keras.metrics.Recall') 1339class Recall(Metric): 1340 """Computes the recall of the predictions with respect to the labels. 1341 1342 This metric creates two local variables, `true_positives` and 1343 `false_negatives`, that are used to compute the recall. This value is 1344 ultimately returned as `recall`, an idempotent operation that simply divides 1345 `true_positives` by the sum of `true_positives` and `false_negatives`. 1346 1347 If `sample_weight` is `None`, weights default to 1. 1348 Use `sample_weight` of 0 to mask values. 1349 1350 If `top_k` is set, recall will be computed as how often on average a class 1351 among the labels of a batch entry is in the top-k predictions. 1352 1353 If `class_id` is specified, we calculate recall by considering only the 1354 entries in the batch for which `class_id` is in the label, and computing the 1355 fraction of them for which `class_id` is above the threshold and/or in the 1356 top-k predictions. 1357 1358 Args: 1359 thresholds: (Optional) A float value or a python list/tuple of float 1360 threshold values in [0, 1]. A threshold is compared with prediction 1361 values to determine the truth value of predictions (i.e., above the 1362 threshold is `true`, below is `false`). One metric value is generated 1363 for each threshold value. If neither thresholds nor top_k are set, the 1364 default is to calculate recall with `thresholds=0.5`. 1365 top_k: (Optional) Unset by default. An int value specifying the top-k 1366 predictions to consider when calculating recall. 1367 class_id: (Optional) Integer class ID for which we want binary metrics. 1368 This must be in the half-open interval `[0, num_classes)`, where 1369 `num_classes` is the last dimension of predictions. 1370 name: (Optional) string name of the metric instance. 1371 dtype: (Optional) data type of the metric result. 1372 1373 Standalone usage: 1374 1375 >>> m = tf.keras.metrics.Recall() 1376 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1377 >>> m.result().numpy() 1378 0.6666667 1379 1380 >>> m.reset_states() 1381 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1382 >>> m.result().numpy() 1383 1.0 1384 1385 Usage with `compile()` API: 1386 1387 ```python 1388 model.compile(optimizer='sgd', 1389 loss='mse', 1390 metrics=[tf.keras.metrics.Recall()]) 1391 ``` 1392 """ 1393 1394 def __init__(self, 1395 thresholds=None, 1396 top_k=None, 1397 class_id=None, 1398 name=None, 1399 dtype=None): 1400 super(Recall, self).__init__(name=name, dtype=dtype) 1401 self.init_thresholds = thresholds 1402 self.top_k = top_k 1403 self.class_id = class_id 1404 1405 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 1406 self.thresholds = metrics_utils.parse_init_thresholds( 1407 thresholds, default_threshold=default_threshold) 1408 self.true_positives = self.add_weight( 1409 'true_positives', 1410 shape=(len(self.thresholds),), 1411 initializer=init_ops.zeros_initializer) 1412 self.false_negatives = self.add_weight( 1413 'false_negatives', 1414 shape=(len(self.thresholds),), 1415 initializer=init_ops.zeros_initializer) 1416 1417 def update_state(self, y_true, y_pred, sample_weight=None): 1418 """Accumulates true positive and false negative statistics. 1419 1420 Args: 1421 y_true: The ground truth values, with the same dimensions as `y_pred`. 1422 Will be cast to `bool`. 1423 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 1424 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1425 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1426 be broadcastable to `y_true`. 1427 1428 Returns: 1429 Update op. 1430 """ 1431 return metrics_utils.update_confusion_matrix_variables( 1432 { 1433 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1434 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives 1435 }, 1436 y_true, 1437 y_pred, 1438 thresholds=self.thresholds, 1439 top_k=self.top_k, 1440 class_id=self.class_id, 1441 sample_weight=sample_weight) 1442 1443 def result(self): 1444 result = math_ops.div_no_nan(self.true_positives, 1445 self.true_positives + self.false_negatives) 1446 return result[0] if len(self.thresholds) == 1 else result 1447 1448 def reset_states(self): 1449 num_thresholds = len(to_list(self.thresholds)) 1450 K.batch_set_value( 1451 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1452 1453 def get_config(self): 1454 config = { 1455 'thresholds': self.init_thresholds, 1456 'top_k': self.top_k, 1457 'class_id': self.class_id 1458 } 1459 base_config = super(Recall, self).get_config() 1460 return dict(list(base_config.items()) + list(config.items())) 1461 1462 1463@six.add_metaclass(abc.ABCMeta) 1464class SensitivitySpecificityBase(Metric): 1465 """Abstract base class for computing sensitivity and specificity. 1466 1467 For additional information about specificity and sensitivity, see 1468 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 1469 """ 1470 1471 def __init__(self, value, num_thresholds=200, name=None, dtype=None): 1472 super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) 1473 if num_thresholds <= 0: 1474 raise ValueError('`num_thresholds` must be > 0.') 1475 self.value = value 1476 self.true_positives = self.add_weight( 1477 'true_positives', 1478 shape=(num_thresholds,), 1479 initializer=init_ops.zeros_initializer) 1480 self.true_negatives = self.add_weight( 1481 'true_negatives', 1482 shape=(num_thresholds,), 1483 initializer=init_ops.zeros_initializer) 1484 self.false_positives = self.add_weight( 1485 'false_positives', 1486 shape=(num_thresholds,), 1487 initializer=init_ops.zeros_initializer) 1488 self.false_negatives = self.add_weight( 1489 'false_negatives', 1490 shape=(num_thresholds,), 1491 initializer=init_ops.zeros_initializer) 1492 1493 # Compute `num_thresholds` thresholds in [0, 1] 1494 if num_thresholds == 1: 1495 self.thresholds = [0.5] 1496 else: 1497 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1498 for i in range(num_thresholds - 2)] 1499 self.thresholds = [0.0] + thresholds + [1.0] 1500 1501 def update_state(self, y_true, y_pred, sample_weight=None): 1502 """Accumulates confusion matrix statistics. 1503 1504 Args: 1505 y_true: The ground truth values. 1506 y_pred: The predicted values. 1507 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1508 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1509 be broadcastable to `y_true`. 1510 1511 Returns: 1512 Update op. 1513 """ 1514 return metrics_utils.update_confusion_matrix_variables( 1515 { 1516 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1517 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, 1518 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, 1519 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, 1520 }, 1521 y_true, 1522 y_pred, 1523 thresholds=self.thresholds, 1524 sample_weight=sample_weight) 1525 1526 def reset_states(self): 1527 num_thresholds = len(self.thresholds) 1528 K.batch_set_value( 1529 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1530 1531 def _find_max_under_constraint(self, constrained, dependent, predicate): 1532 """Returns the maximum of dependent_statistic that satisfies the constraint. 1533 1534 Args: 1535 constrained: Over these values the constraint 1536 is specified. A rank-1 tensor. 1537 dependent: From these values the maximum that satiesfies the 1538 constraint is selected. Values in this tensor and in 1539 `constrained` are linked by having the same threshold at each 1540 position, hence this tensor must have the same shape. 1541 predicate: A binary boolean functor to be applied to arguments 1542 `constrained` and `self.value`, e.g. `tf.greater`. 1543 1544 Returns maximal dependent value, if no value satiesfies the constraint 0.0. 1545 """ 1546 feasible = array_ops.where(predicate(constrained, self.value)) 1547 feasible_exists = math_ops.greater(array_ops.size(feasible), 0) 1548 1549 def get_max(): 1550 return math_ops.reduce_max(array_ops.gather(dependent, feasible)) 1551 1552 return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0) 1553 1554 1555@keras_export('keras.metrics.SensitivityAtSpecificity') 1556class SensitivityAtSpecificity(SensitivitySpecificityBase): 1557 """Computes best sensitivity where specificity is >= specified value. 1558 1559 the sensitivity at a given specificity. 1560 1561 `Sensitivity` measures the proportion of actual positives that are correctly 1562 identified as such (tp / (tp + fn)). 1563 `Specificity` measures the proportion of actual negatives that are correctly 1564 identified as such (tn / (tn + fp)). 1565 1566 This metric creates four local variables, `true_positives`, `true_negatives`, 1567 `false_positives` and `false_negatives` that are used to compute the 1568 sensitivity at the given specificity. The threshold for the given specificity 1569 value is computed and used to evaluate the corresponding sensitivity. 1570 1571 If `sample_weight` is `None`, weights default to 1. 1572 Use `sample_weight` of 0 to mask values. 1573 1574 For additional information about specificity and sensitivity, see 1575 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 1576 1577 Args: 1578 specificity: A scalar value in range `[0, 1]`. 1579 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1580 use for matching the given specificity. 1581 name: (Optional) string name of the metric instance. 1582 dtype: (Optional) data type of the metric result. 1583 1584 Standalone usage: 1585 1586 >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) 1587 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 1588 >>> m.result().numpy() 1589 0.5 1590 1591 >>> m.reset_states() 1592 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 1593 ... sample_weight=[1, 1, 2, 2, 1]) 1594 >>> m.result().numpy() 1595 0.333333 1596 1597 Usage with `compile()` API: 1598 1599 ```python 1600 model.compile( 1601 optimizer='sgd', 1602 loss='mse', 1603 metrics=[tf.keras.metrics.SensitivityAtSpecificity()]) 1604 ``` 1605 """ 1606 1607 def __init__(self, specificity, num_thresholds=200, name=None, dtype=None): 1608 if specificity < 0 or specificity > 1: 1609 raise ValueError('`specificity` must be in the range [0, 1].') 1610 self.specificity = specificity 1611 self.num_thresholds = num_thresholds 1612 super(SensitivityAtSpecificity, self).__init__( 1613 specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) 1614 1615 def result(self): 1616 specificities = math_ops.div_no_nan( 1617 self.true_negatives, self.true_negatives + self.false_positives) 1618 sensitivities = math_ops.div_no_nan( 1619 self.true_positives, self.true_positives + self.false_negatives) 1620 return self._find_max_under_constraint( 1621 specificities, sensitivities, math_ops.greater_equal) 1622 1623 def get_config(self): 1624 config = { 1625 'num_thresholds': self.num_thresholds, 1626 'specificity': self.specificity 1627 } 1628 base_config = super(SensitivityAtSpecificity, self).get_config() 1629 return dict(list(base_config.items()) + list(config.items())) 1630 1631 1632@keras_export('keras.metrics.SpecificityAtSensitivity') 1633class SpecificityAtSensitivity(SensitivitySpecificityBase): 1634 """Computes best specificity where sensitivity is >= specified value. 1635 1636 `Sensitivity` measures the proportion of actual positives that are correctly 1637 identified as such (tp / (tp + fn)). 1638 `Specificity` measures the proportion of actual negatives that are correctly 1639 identified as such (tn / (tn + fp)). 1640 1641 This metric creates four local variables, `true_positives`, `true_negatives`, 1642 `false_positives` and `false_negatives` that are used to compute the 1643 specificity at the given sensitivity. The threshold for the given sensitivity 1644 value is computed and used to evaluate the corresponding specificity. 1645 1646 If `sample_weight` is `None`, weights default to 1. 1647 Use `sample_weight` of 0 to mask values. 1648 1649 For additional information about specificity and sensitivity, see 1650 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 1651 1652 Args: 1653 sensitivity: A scalar value in range `[0, 1]`. 1654 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1655 use for matching the given sensitivity. 1656 name: (Optional) string name of the metric instance. 1657 dtype: (Optional) data type of the metric result. 1658 1659 Standalone usage: 1660 1661 >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) 1662 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 1663 >>> m.result().numpy() 1664 0.66666667 1665 1666 >>> m.reset_states() 1667 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 1668 ... sample_weight=[1, 1, 2, 2, 2]) 1669 >>> m.result().numpy() 1670 0.5 1671 1672 Usage with `compile()` API: 1673 1674 ```python 1675 model.compile( 1676 optimizer='sgd', 1677 loss='mse', 1678 metrics=[tf.keras.metrics.SpecificityAtSensitivity()]) 1679 ``` 1680 """ 1681 1682 def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None): 1683 if sensitivity < 0 or sensitivity > 1: 1684 raise ValueError('`sensitivity` must be in the range [0, 1].') 1685 self.sensitivity = sensitivity 1686 self.num_thresholds = num_thresholds 1687 super(SpecificityAtSensitivity, self).__init__( 1688 sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype) 1689 1690 def result(self): 1691 sensitivities = math_ops.div_no_nan( 1692 self.true_positives, self.true_positives + self.false_negatives) 1693 specificities = math_ops.div_no_nan( 1694 self.true_negatives, self.true_negatives + self.false_positives) 1695 return self._find_max_under_constraint( 1696 sensitivities, specificities, math_ops.greater_equal) 1697 1698 def get_config(self): 1699 config = { 1700 'num_thresholds': self.num_thresholds, 1701 'sensitivity': self.sensitivity 1702 } 1703 base_config = super(SpecificityAtSensitivity, self).get_config() 1704 return dict(list(base_config.items()) + list(config.items())) 1705 1706 1707@keras_export('keras.metrics.PrecisionAtRecall') 1708class PrecisionAtRecall(SensitivitySpecificityBase): 1709 """Computes best precision where recall is >= specified value. 1710 1711 This metric creates four local variables, `true_positives`, `true_negatives`, 1712 `false_positives` and `false_negatives` that are used to compute the 1713 precision at the given recall. The threshold for the given recall 1714 value is computed and used to evaluate the corresponding precision. 1715 1716 If `sample_weight` is `None`, weights default to 1. 1717 Use `sample_weight` of 0 to mask values. 1718 1719 Args: 1720 recall: A scalar value in range `[0, 1]`. 1721 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1722 use for matching the given recall. 1723 name: (Optional) string name of the metric instance. 1724 dtype: (Optional) data type of the metric result. 1725 1726 Standalone usage: 1727 1728 >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) 1729 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 1730 >>> m.result().numpy() 1731 0.5 1732 1733 >>> m.reset_states() 1734 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 1735 ... sample_weight=[2, 2, 2, 1, 1]) 1736 >>> m.result().numpy() 1737 0.33333333 1738 1739 Usage with `compile()` API: 1740 1741 ```python 1742 model.compile( 1743 optimizer='sgd', 1744 loss='mse', 1745 metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)]) 1746 ``` 1747 """ 1748 1749 def __init__(self, recall, num_thresholds=200, name=None, dtype=None): 1750 if recall < 0 or recall > 1: 1751 raise ValueError('`recall` must be in the range [0, 1].') 1752 self.recall = recall 1753 self.num_thresholds = num_thresholds 1754 super(PrecisionAtRecall, self).__init__( 1755 value=recall, 1756 num_thresholds=num_thresholds, 1757 name=name, 1758 dtype=dtype) 1759 1760 def result(self): 1761 recalls = math_ops.div_no_nan( 1762 self.true_positives, self.true_positives + self.false_negatives) 1763 precisions = math_ops.div_no_nan( 1764 self.true_positives, self.true_positives + self.false_positives) 1765 return self._find_max_under_constraint( 1766 recalls, precisions, math_ops.greater_equal) 1767 1768 def get_config(self): 1769 config = {'num_thresholds': self.num_thresholds, 'recall': self.recall} 1770 base_config = super(PrecisionAtRecall, self).get_config() 1771 return dict(list(base_config.items()) + list(config.items())) 1772 1773 1774@keras_export('keras.metrics.RecallAtPrecision') 1775class RecallAtPrecision(SensitivitySpecificityBase): 1776 """Computes best recall where precision is >= specified value. 1777 1778 For a given score-label-distribution the required precision might not 1779 be achievable, in this case 0.0 is returned as recall. 1780 1781 This metric creates four local variables, `true_positives`, `true_negatives`, 1782 `false_positives` and `false_negatives` that are used to compute the 1783 recall at the given precision. The threshold for the given precision 1784 value is computed and used to evaluate the corresponding recall. 1785 1786 If `sample_weight` is `None`, weights default to 1. 1787 Use `sample_weight` of 0 to mask values. 1788 1789 Args: 1790 precision: A scalar value in range `[0, 1]`. 1791 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1792 use for matching the given precision. 1793 name: (Optional) string name of the metric instance. 1794 dtype: (Optional) data type of the metric result. 1795 1796 Standalone usage: 1797 1798 >>> m = tf.keras.metrics.RecallAtPrecision(0.8) 1799 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 1800 >>> m.result().numpy() 1801 0.5 1802 1803 >>> m.reset_states() 1804 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 1805 ... sample_weight=[1, 0, 0, 1]) 1806 >>> m.result().numpy() 1807 1.0 1808 1809 Usage with `compile()` API: 1810 1811 ```python 1812 model.compile( 1813 optimizer='sgd', 1814 loss='mse', 1815 metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)]) 1816 ``` 1817 """ 1818 1819 def __init__(self, precision, num_thresholds=200, name=None, dtype=None): 1820 if precision < 0 or precision > 1: 1821 raise ValueError('`precision` must be in the range [0, 1].') 1822 self.precision = precision 1823 self.num_thresholds = num_thresholds 1824 super(RecallAtPrecision, self).__init__( 1825 value=precision, 1826 num_thresholds=num_thresholds, 1827 name=name, 1828 dtype=dtype) 1829 1830 def result(self): 1831 precisions = math_ops.div_no_nan( 1832 self.true_positives, self.true_positives + self.false_positives) 1833 recalls = math_ops.div_no_nan( 1834 self.true_positives, self.true_positives + self.false_negatives) 1835 return self._find_max_under_constraint( 1836 precisions, recalls, math_ops.greater_equal) 1837 1838 def get_config(self): 1839 config = {'num_thresholds': self.num_thresholds, 1840 'precision': self.precision} 1841 base_config = super(RecallAtPrecision, self).get_config() 1842 return dict(list(base_config.items()) + list(config.items())) 1843 1844 1845@keras_export('keras.metrics.AUC') 1846class AUC(Metric): 1847 """Approximates the AUC (Area under the curve) of the ROC or PR curves. 1848 1849 The AUC (Area under the curve) of the ROC (Receiver operating 1850 characteristic; default) or PR (Precision Recall) curves are quality measures 1851 of binary classifiers. Unlike the accuracy, and like cross-entropy 1852 losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. 1853 1854 This classes approximates AUCs using a Riemann sum: During the metric 1855 accumulation phrase, predictions are accumulated within predefined buckets 1856 by value. The AUC is then computed by interpolating per-bucket averages. These 1857 buckets define the evaluated operational points. 1858 1859 This metric creates four local variables, `true_positives`, `true_negatives`, 1860 `false_positives` and `false_negatives` that are used to compute the AUC. 1861 To discretize the AUC curve, a linearly spaced set of thresholds is used to 1862 compute pairs of recall and precision values. The area under the ROC-curve is 1863 therefore computed using the height of the recall values by the false positive 1864 rate, while the area under the PR-curve is the computed using the height of 1865 the precision values by the recall. 1866 1867 This value is ultimately returned as `auc`, an idempotent operation that 1868 computes the area under a discretized curve of precision versus recall values 1869 (computed using the aforementioned variables). The `num_thresholds` variable 1870 controls the degree of discretization with larger numbers of thresholds more 1871 closely approximating the true AUC. The quality of the approximation may vary 1872 dramatically depending on `num_thresholds`. The `thresholds` parameter can be 1873 used to manually specify thresholds which split the predictions more evenly. 1874 1875 For a best approximation of the real AUC, `predictions` should be distributed 1876 approximately uniformly in the range [0, 1] (if `from_logits=False`). The 1877 quality of the AUC approximation may be poor if this is not the case. Setting 1878 `summation_method` to 'minoring' or 'majoring' can help quantify the error in 1879 the approximation by providing lower or upper bound estimate of the AUC. 1880 1881 If `sample_weight` is `None`, weights default to 1. 1882 Use `sample_weight` of 0 to mask values. 1883 1884 Args: 1885 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1886 use when discretizing the roc curve. Values must be > 1. 1887 curve: (Optional) Specifies the name of the curve to be computed, 'ROC' 1888 [default] or 'PR' for the Precision-Recall-curve. 1889 summation_method: (Optional) Specifies the [Riemann summation method]( 1890 https://en.wikipedia.org/wiki/Riemann_sum) used. 1891 'interpolation' (default) applies mid-point summation scheme for `ROC`. 1892 For PR-AUC, interpolates (true/false) positives but not the ratio that 1893 is precision (see Davis & Goadrich 2006 for details); 1894 'minoring' applies left summation 1895 for increasing intervals and right summation for decreasing intervals; 1896 'majoring' does the opposite. 1897 name: (Optional) string name of the metric instance. 1898 dtype: (Optional) data type of the metric result. 1899 thresholds: (Optional) A list of floating point values to use as the 1900 thresholds for discretizing the curve. If set, the `num_thresholds` 1901 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 1902 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will 1903 be automatically included with these to correctly handle predictions 1904 equal to exactly 0 or 1. 1905 multi_label: boolean indicating whether multilabel data should be 1906 treated as such, wherein AUC is computed separately for each label and 1907 then averaged across labels, or (when False) if the data should be 1908 flattened into a single label before AUC computation. In the latter 1909 case, when multilabel data is passed to AUC, each label-prediction pair 1910 is treated as an individual data point. Should be set to False for 1911 multi-class data. 1912 num_labels: (Optional) The number of labels, used when `multi_label' is 1913 True. If `num_labels` is not specified, then state variables get created 1914 on the first call to `update_state`. 1915 label_weights: (Optional) list, array, or tensor of non-negative weights 1916 used to compute AUCs for multilabel data. When `multi_label` is True, 1917 the weights are applied to the individual label AUCs when they are 1918 averaged to produce the multi-label AUC. When it's False, they are used 1919 to weight the individual label predictions in computing the confusion 1920 matrix on the flattened data. Note that this is unlike class_weights in 1921 that class_weights weights the example depending on the value of its 1922 label, whereas label_weights depends only on the index of that label 1923 before flattening; therefore `label_weights` should not be used for 1924 multi-class data. 1925 from_logits: boolean indicating whether the predictions (`y_pred` in 1926 `update_state`) are probabilities or sigmoid logits. As a rule of thumb, 1927 when using a keras loss, the `from_logits` constructor argument of the 1928 loss should match the AUC `from_logits` constructor argument. 1929 1930 Standalone usage: 1931 1932 >>> m = tf.keras.metrics.AUC(num_thresholds=3) 1933 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 1934 >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] 1935 >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] 1936 >>> # recall = [1, 0.5, 0], fp_rate = [1, 0, 0] 1937 >>> # auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75 1938 >>> m.result().numpy() 1939 0.75 1940 1941 >>> m.reset_states() 1942 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 1943 ... sample_weight=[1, 0, 0, 1]) 1944 >>> m.result().numpy() 1945 1.0 1946 1947 Usage with `compile()` API: 1948 1949 ```python 1950 # Reports the AUC of a model outputing a probability. 1951 model.compile(optimizer='sgd', 1952 loss=tf.keras.losses.BinaryCrossentropy(), 1953 metrics=[tf.keras.metrics.AUC()]) 1954 1955 # Reports the AUC of a model outputing a logit. 1956 model.compile(optimizer='sgd', 1957 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 1958 metrics=[tf.keras.metrics.AUC(from_logits=True)]) 1959 ``` 1960 """ 1961 1962 def __init__(self, 1963 num_thresholds=200, 1964 curve='ROC', 1965 summation_method='interpolation', 1966 name=None, 1967 dtype=None, 1968 thresholds=None, 1969 multi_label=False, 1970 num_labels=None, 1971 label_weights=None, 1972 from_logits=False): 1973 # Validate configurations. 1974 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( 1975 metrics_utils.AUCCurve): 1976 raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( 1977 curve, list(metrics_utils.AUCCurve))) 1978 if isinstance( 1979 summation_method, 1980 metrics_utils.AUCSummationMethod) and summation_method not in list( 1981 metrics_utils.AUCSummationMethod): 1982 raise ValueError( 1983 'Invalid summation method: "{}". Valid options are: "{}"'.format( 1984 summation_method, list(metrics_utils.AUCSummationMethod))) 1985 1986 # Update properties. 1987 if thresholds is not None: 1988 # If specified, use the supplied thresholds. 1989 self.num_thresholds = len(thresholds) + 2 1990 thresholds = sorted(thresholds) 1991 else: 1992 if num_thresholds <= 1: 1993 raise ValueError('`num_thresholds` must be > 1.') 1994 1995 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 1996 # (0, 1). 1997 self.num_thresholds = num_thresholds 1998 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1999 for i in range(num_thresholds - 2)] 2000 2001 # Add an endpoint "threshold" below zero and above one for either 2002 # threshold method to account for floating point imprecisions. 2003 self._thresholds = np.array([0.0 - K.epsilon()] + thresholds + 2004 [1.0 + K.epsilon()]) 2005 2006 if isinstance(curve, metrics_utils.AUCCurve): 2007 self.curve = curve 2008 else: 2009 self.curve = metrics_utils.AUCCurve.from_str(curve) 2010 if isinstance(summation_method, metrics_utils.AUCSummationMethod): 2011 self.summation_method = summation_method 2012 else: 2013 self.summation_method = metrics_utils.AUCSummationMethod.from_str( 2014 summation_method) 2015 super(AUC, self).__init__(name=name, dtype=dtype) 2016 2017 # Handle multilabel arguments. 2018 self.multi_label = multi_label 2019 if label_weights is not None: 2020 label_weights = constant_op.constant(label_weights, dtype=self.dtype) 2021 checks = [ 2022 check_ops.assert_non_negative( 2023 label_weights, 2024 message='All values of `label_weights` must be non-negative.') 2025 ] 2026 with ops.control_dependencies(checks): 2027 self.label_weights = label_weights 2028 2029 else: 2030 self.label_weights = None 2031 2032 self._from_logits = from_logits 2033 2034 self._built = False 2035 if self.multi_label: 2036 if num_labels: 2037 shape = tensor_shape.TensorShape([None, num_labels]) 2038 self._build(shape) 2039 else: 2040 if num_labels: 2041 raise ValueError( 2042 '`num_labels` is needed only when `multi_label` is True.') 2043 self._build(None) 2044 2045 @property 2046 def thresholds(self): 2047 """The thresholds used for evaluating AUC.""" 2048 return list(self._thresholds) 2049 2050 def _build(self, shape): 2051 """Initialize TP, FP, TN, and FN tensors, given the shape of the data.""" 2052 if self.multi_label: 2053 if shape.ndims != 2: 2054 raise ValueError('`y_true` must have rank=2 when `multi_label` is ' 2055 'True. Found rank %s.' % shape.ndims) 2056 self._num_labels = shape[1] 2057 variable_shape = tensor_shape.TensorShape( 2058 [tensor_shape.Dimension(self.num_thresholds), self._num_labels]) 2059 2060 else: 2061 variable_shape = tensor_shape.TensorShape( 2062 [tensor_shape.Dimension(self.num_thresholds)]) 2063 self._build_input_shape = shape 2064 # Create metric variables 2065 self.true_positives = self.add_weight( 2066 'true_positives', 2067 shape=variable_shape, 2068 initializer=init_ops.zeros_initializer) 2069 self.true_negatives = self.add_weight( 2070 'true_negatives', 2071 shape=variable_shape, 2072 initializer=init_ops.zeros_initializer) 2073 self.false_positives = self.add_weight( 2074 'false_positives', 2075 shape=variable_shape, 2076 initializer=init_ops.zeros_initializer) 2077 self.false_negatives = self.add_weight( 2078 'false_negatives', 2079 shape=variable_shape, 2080 initializer=init_ops.zeros_initializer) 2081 2082 if self.multi_label: 2083 with ops.init_scope(): 2084 # This should only be necessary for handling v1 behavior. In v2, AUC 2085 # should be initialized outside of any tf.functions, and therefore in 2086 # eager mode. 2087 if not context.executing_eagerly(): 2088 K._initialize_variables(K._get_session()) # pylint: disable=protected-access 2089 2090 self._built = True 2091 2092 def update_state(self, y_true, y_pred, sample_weight=None): 2093 """Accumulates confusion matrix statistics. 2094 2095 Args: 2096 y_true: The ground truth values. 2097 y_pred: The predicted values. 2098 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2099 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2100 be broadcastable to `y_true`. 2101 2102 Returns: 2103 Update op. 2104 """ 2105 deps = [] 2106 if not self._built: 2107 self._build(tensor_shape.TensorShape(y_pred.shape)) 2108 2109 if self.multi_label or (self.label_weights is not None): 2110 # y_true should have shape (number of examples, number of labels). 2111 shapes = [ 2112 (y_true, ('N', 'L')) 2113 ] 2114 if self.multi_label: 2115 # TP, TN, FP, and FN should all have shape 2116 # (number of thresholds, number of labels). 2117 shapes.extend([(self.true_positives, ('T', 'L')), 2118 (self.true_negatives, ('T', 'L')), 2119 (self.false_positives, ('T', 'L')), 2120 (self.false_negatives, ('T', 'L'))]) 2121 if self.label_weights is not None: 2122 # label_weights should be of length equal to the number of labels. 2123 shapes.append((self.label_weights, ('L',))) 2124 deps = [ 2125 check_ops.assert_shapes( 2126 shapes, message='Number of labels is not consistent.') 2127 ] 2128 2129 # Only forward label_weights to update_confusion_matrix_variables when 2130 # multi_label is False. Otherwise the averaging of individual label AUCs is 2131 # handled in AUC.result 2132 label_weights = None if self.multi_label else self.label_weights 2133 2134 if self._from_logits: 2135 y_pred = activations.sigmoid(y_pred) 2136 2137 with ops.control_dependencies(deps): 2138 return metrics_utils.update_confusion_matrix_variables( 2139 { 2140 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: 2141 self.true_positives, 2142 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: 2143 self.true_negatives, 2144 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: 2145 self.false_positives, 2146 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: 2147 self.false_negatives, 2148 }, 2149 y_true, 2150 y_pred, 2151 self._thresholds, 2152 sample_weight=sample_weight, 2153 multi_label=self.multi_label, 2154 label_weights=label_weights) 2155 2156 def interpolate_pr_auc(self): 2157 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. 2158 2159 https://www.biostat.wisc.edu/~page/rocpr.pdf 2160 2161 Note here we derive & use a closed formula not present in the paper 2162 as follows: 2163 2164 Precision = TP / (TP + FP) = TP / P 2165 2166 Modeling all of TP (true positive), FP (false positive) and their sum 2167 P = TP + FP (predicted positive) as varying linearly within each interval 2168 [A, B] between successive thresholds, we get 2169 2170 Precision slope = dTP / dP 2171 = (TP_B - TP_A) / (P_B - P_A) 2172 = (TP - TP_A) / (P - P_A) 2173 Precision = (TP_A + slope * (P - P_A)) / P 2174 2175 The area within the interval is (slope / total_pos_weight) times 2176 2177 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 2178 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 2179 2180 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 2181 2182 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 2183 2184 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 2185 2186 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 2187 2188 where dTP == TP_B - TP_A. 2189 2190 Note that when P_A == 0 the above calculation simplifies into 2191 2192 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 2193 2194 which is really equivalent to imputing constant precision throughout the 2195 first bucket having >0 true positives. 2196 2197 Returns: 2198 pr_auc: an approximation of the area under the P-R curve. 2199 """ 2200 dtp = self.true_positives[:self.num_thresholds - 2201 1] - self.true_positives[1:] 2202 p = self.true_positives + self.false_positives 2203 dp = p[:self.num_thresholds - 1] - p[1:] 2204 prec_slope = math_ops.div_no_nan( 2205 dtp, math_ops.maximum(dp, 0), name='prec_slope') 2206 intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:]) 2207 2208 safe_p_ratio = array_ops.where( 2209 math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0), 2210 math_ops.div_no_nan( 2211 p[:self.num_thresholds - 1], 2212 math_ops.maximum(p[1:], 0), 2213 name='recall_relative_ratio'), 2214 array_ops.ones_like(p[1:])) 2215 2216 pr_auc_increment = math_ops.div_no_nan( 2217 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 2218 math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 0), 2219 name='pr_auc_increment') 2220 2221 if self.multi_label: 2222 by_label_auc = math_ops.reduce_sum( 2223 pr_auc_increment, name=self.name + '_by_label', axis=0) 2224 if self.label_weights is None: 2225 # Evenly weighted average of the label AUCs. 2226 return math_ops.reduce_mean(by_label_auc, name=self.name) 2227 else: 2228 # Weighted average of the label AUCs. 2229 return math_ops.div_no_nan( 2230 math_ops.reduce_sum( 2231 math_ops.multiply(by_label_auc, self.label_weights)), 2232 math_ops.reduce_sum(self.label_weights), 2233 name=self.name) 2234 else: 2235 return math_ops.reduce_sum(pr_auc_increment, name='interpolate_pr_auc') 2236 2237 def result(self): 2238 if (self.curve == metrics_utils.AUCCurve.PR and 2239 self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION 2240 ): 2241 # This use case is different and is handled separately. 2242 return self.interpolate_pr_auc() 2243 2244 # Set `x` and `y` values for the curves based on `curve` config. 2245 recall = math_ops.div_no_nan(self.true_positives, 2246 self.true_positives + self.false_negatives) 2247 if self.curve == metrics_utils.AUCCurve.ROC: 2248 fp_rate = math_ops.div_no_nan(self.false_positives, 2249 self.false_positives + self.true_negatives) 2250 x = fp_rate 2251 y = recall 2252 else: # curve == 'PR'. 2253 precision = math_ops.div_no_nan( 2254 self.true_positives, self.true_positives + self.false_positives) 2255 x = recall 2256 y = precision 2257 2258 # Find the rectangle heights based on `summation_method`. 2259 if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION: 2260 # Note: the case ('PR', 'interpolation') has been handled above. 2261 heights = (y[:self.num_thresholds - 1] + y[1:]) / 2. 2262 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: 2263 heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:]) 2264 else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: 2265 heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:]) 2266 2267 # Sum up the areas of all the rectangles. 2268 if self.multi_label: 2269 riemann_terms = math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], 2270 heights) 2271 by_label_auc = math_ops.reduce_sum( 2272 riemann_terms, name=self.name + '_by_label', axis=0) 2273 2274 if self.label_weights is None: 2275 # Unweighted average of the label AUCs. 2276 return math_ops.reduce_mean(by_label_auc, name=self.name) 2277 else: 2278 # Weighted average of the label AUCs. 2279 return math_ops.div_no_nan( 2280 math_ops.reduce_sum( 2281 math_ops.multiply(by_label_auc, self.label_weights)), 2282 math_ops.reduce_sum(self.label_weights), 2283 name=self.name) 2284 else: 2285 return math_ops.reduce_sum( 2286 math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights), 2287 name=self.name) 2288 2289 def reset_states(self): 2290 if self.multi_label: 2291 K.batch_set_value([(v, np.zeros((self.num_thresholds, self._num_labels))) 2292 for v in self.variables]) 2293 else: 2294 K.batch_set_value([ 2295 (v, np.zeros((self.num_thresholds,))) for v in self.variables 2296 ]) 2297 2298 def get_config(self): 2299 if is_tensor_or_variable(self.label_weights): 2300 label_weights = K.eval(self.label_weights) 2301 else: 2302 label_weights = self.label_weights 2303 config = { 2304 'num_thresholds': self.num_thresholds, 2305 'curve': self.curve.value, 2306 'summation_method': self.summation_method.value, 2307 # We remove the endpoint thresholds as an inverse of how the thresholds 2308 # were initialized. This ensures that a metric initialized from this 2309 # config has the same thresholds. 2310 'thresholds': self.thresholds[1:-1], 2311 'multi_label': self.multi_label, 2312 'label_weights': label_weights 2313 } 2314 base_config = super(AUC, self).get_config() 2315 return dict(list(base_config.items()) + list(config.items())) 2316 2317 2318@keras_export('keras.metrics.CosineSimilarity') 2319class CosineSimilarity(MeanMetricWrapper): 2320 """Computes the cosine similarity between the labels and predictions. 2321 2322 `cosine similarity = (a . b) / ||a|| ||b||` 2323 2324 See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity). 2325 2326 This metric keeps the average cosine similarity between `predictions` and 2327 `labels` over a stream of data. 2328 2329 Args: 2330 name: (Optional) string name of the metric instance. 2331 dtype: (Optional) data type of the metric result. 2332 axis: (Optional) Defaults to -1. The dimension along which the cosine 2333 similarity is computed. 2334 2335 Standalone usage: 2336 2337 >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] 2338 >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] 2339 >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] 2340 >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) 2341 >>> # = ((0. + 0.) + (0.5 + 0.5)) / 2 2342 >>> m = tf.keras.metrics.CosineSimilarity(axis=1) 2343 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) 2344 >>> m.result().numpy() 2345 0.49999997 2346 2347 >>> m.reset_states() 2348 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]], 2349 ... sample_weight=[0.3, 0.7]) 2350 >>> m.result().numpy() 2351 0.6999999 2352 2353 Usage with `compile()` API: 2354 2355 ```python 2356 model.compile( 2357 optimizer='sgd', 2358 loss='mse', 2359 metrics=[tf.keras.metrics.CosineSimilarity(axis=1)]) 2360 ``` 2361 """ 2362 2363 def __init__(self, name='cosine_similarity', dtype=None, axis=-1): 2364 super(CosineSimilarity, self).__init__( 2365 cosine_similarity, name, dtype=dtype, axis=axis) 2366 2367 2368@keras_export('keras.metrics.MeanAbsoluteError') 2369class MeanAbsoluteError(MeanMetricWrapper): 2370 """Computes the mean absolute error between the labels and predictions. 2371 2372 Args: 2373 name: (Optional) string name of the metric instance. 2374 dtype: (Optional) data type of the metric result. 2375 2376 Standalone usage: 2377 2378 >>> m = tf.keras.metrics.MeanAbsoluteError() 2379 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2380 >>> m.result().numpy() 2381 0.25 2382 2383 >>> m.reset_states() 2384 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2385 ... sample_weight=[1, 0]) 2386 >>> m.result().numpy() 2387 0.5 2388 2389 Usage with `compile()` API: 2390 2391 ```python 2392 model.compile( 2393 optimizer='sgd', 2394 loss='mse', 2395 metrics=[tf.keras.metrics.MeanAbsoluteError()]) 2396 ``` 2397 """ 2398 2399 def __init__(self, name='mean_absolute_error', dtype=None): 2400 super(MeanAbsoluteError, self).__init__( 2401 mean_absolute_error, name, dtype=dtype) 2402 2403 2404@keras_export('keras.metrics.MeanAbsolutePercentageError') 2405class MeanAbsolutePercentageError(MeanMetricWrapper): 2406 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 2407 2408 Args: 2409 name: (Optional) string name of the metric instance. 2410 dtype: (Optional) data type of the metric result. 2411 2412 Standalone usage: 2413 2414 >>> m = tf.keras.metrics.MeanAbsolutePercentageError() 2415 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2416 >>> m.result().numpy() 2417 250000000.0 2418 2419 >>> m.reset_states() 2420 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2421 ... sample_weight=[1, 0]) 2422 >>> m.result().numpy() 2423 500000000.0 2424 2425 Usage with `compile()` API: 2426 2427 ```python 2428 model.compile( 2429 optimizer='sgd', 2430 loss='mse', 2431 metrics=[tf.keras.metrics.MeanAbsolutePercentageError()]) 2432 ``` 2433 """ 2434 2435 def __init__(self, name='mean_absolute_percentage_error', dtype=None): 2436 super(MeanAbsolutePercentageError, self).__init__( 2437 mean_absolute_percentage_error, name, dtype=dtype) 2438 2439 2440@keras_export('keras.metrics.MeanSquaredError') 2441class MeanSquaredError(MeanMetricWrapper): 2442 """Computes the mean squared error between `y_true` and `y_pred`. 2443 2444 Args: 2445 name: (Optional) string name of the metric instance. 2446 dtype: (Optional) data type of the metric result. 2447 2448 Standalone usage: 2449 2450 >>> m = tf.keras.metrics.MeanSquaredError() 2451 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2452 >>> m.result().numpy() 2453 0.25 2454 2455 >>> m.reset_states() 2456 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2457 ... sample_weight=[1, 0]) 2458 >>> m.result().numpy() 2459 0.5 2460 2461 Usage with `compile()` API: 2462 2463 ```python 2464 model.compile( 2465 optimizer='sgd', 2466 loss='mse', 2467 metrics=[tf.keras.metrics.MeanSquaredError()]) 2468 ``` 2469 """ 2470 2471 def __init__(self, name='mean_squared_error', dtype=None): 2472 super(MeanSquaredError, self).__init__( 2473 mean_squared_error, name, dtype=dtype) 2474 2475 2476@keras_export('keras.metrics.MeanSquaredLogarithmicError') 2477class MeanSquaredLogarithmicError(MeanMetricWrapper): 2478 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 2479 2480 Args: 2481 name: (Optional) string name of the metric instance. 2482 dtype: (Optional) data type of the metric result. 2483 2484 Standalone usage: 2485 2486 >>> m = tf.keras.metrics.MeanSquaredLogarithmicError() 2487 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2488 >>> m.result().numpy() 2489 0.12011322 2490 2491 >>> m.reset_states() 2492 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2493 ... sample_weight=[1, 0]) 2494 >>> m.result().numpy() 2495 0.24022643 2496 2497 Usage with `compile()` API: 2498 2499 ```python 2500 model.compile( 2501 optimizer='sgd', 2502 loss='mse', 2503 metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()]) 2504 ``` 2505 """ 2506 2507 def __init__(self, name='mean_squared_logarithmic_error', dtype=None): 2508 super(MeanSquaredLogarithmicError, self).__init__( 2509 mean_squared_logarithmic_error, name, dtype=dtype) 2510 2511 2512@keras_export('keras.metrics.Hinge') 2513class Hinge(MeanMetricWrapper): 2514 """Computes the hinge metric between `y_true` and `y_pred`. 2515 2516 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 2517 provided we will convert them to -1 or 1. 2518 2519 Args: 2520 name: (Optional) string name of the metric instance. 2521 dtype: (Optional) data type of the metric result. 2522 2523 Standalone usage: 2524 2525 >>> m = tf.keras.metrics.Hinge() 2526 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2527 >>> m.result().numpy() 2528 1.3 2529 2530 >>> m.reset_states() 2531 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2532 ... sample_weight=[1, 0]) 2533 >>> m.result().numpy() 2534 1.1 2535 2536 Usage with `compile()` API: 2537 2538 ```python 2539 model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()]) 2540 ``` 2541 """ 2542 2543 def __init__(self, name='hinge', dtype=None): 2544 super(Hinge, self).__init__(hinge, name, dtype=dtype) 2545 2546 2547@keras_export('keras.metrics.SquaredHinge') 2548class SquaredHinge(MeanMetricWrapper): 2549 """Computes the squared hinge metric between `y_true` and `y_pred`. 2550 2551 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 2552 provided we will convert them to -1 or 1. 2553 2554 Args: 2555 name: (Optional) string name of the metric instance. 2556 dtype: (Optional) data type of the metric result. 2557 2558 Standalone usage: 2559 2560 >>> m = tf.keras.metrics.SquaredHinge() 2561 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2562 >>> m.result().numpy() 2563 1.86 2564 2565 >>> m.reset_states() 2566 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2567 ... sample_weight=[1, 0]) 2568 >>> m.result().numpy() 2569 1.46 2570 2571 Usage with `compile()` API: 2572 2573 ```python 2574 model.compile( 2575 optimizer='sgd', 2576 loss='mse', 2577 metrics=[tf.keras.metrics.SquaredHinge()]) 2578 ``` 2579 """ 2580 2581 def __init__(self, name='squared_hinge', dtype=None): 2582 super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype) 2583 2584 2585@keras_export('keras.metrics.CategoricalHinge') 2586class CategoricalHinge(MeanMetricWrapper): 2587 """Computes the categorical hinge metric between `y_true` and `y_pred`. 2588 2589 Args: 2590 name: (Optional) string name of the metric instance. 2591 dtype: (Optional) data type of the metric result. 2592 2593 Standalone usage: 2594 2595 >>> m = tf.keras.metrics.CategoricalHinge() 2596 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2597 >>> m.result().numpy() 2598 1.4000001 2599 2600 >>> m.reset_states() 2601 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2602 ... sample_weight=[1, 0]) 2603 >>> m.result().numpy() 2604 1.2 2605 2606 Usage with `compile()` API: 2607 2608 ```python 2609 model.compile( 2610 optimizer='sgd', 2611 loss='mse', 2612 metrics=[tf.keras.metrics.CategoricalHinge()]) 2613 ``` 2614 """ 2615 2616 def __init__(self, name='categorical_hinge', dtype=None): 2617 super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype) 2618 2619 2620@keras_export('keras.metrics.RootMeanSquaredError') 2621class RootMeanSquaredError(Mean): 2622 """Computes root mean squared error metric between `y_true` and `y_pred`. 2623 2624 Standalone usage: 2625 2626 >>> m = tf.keras.metrics.RootMeanSquaredError() 2627 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2628 >>> m.result().numpy() 2629 0.5 2630 2631 >>> m.reset_states() 2632 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2633 ... sample_weight=[1, 0]) 2634 >>> m.result().numpy() 2635 0.70710677 2636 2637 Usage with `compile()` API: 2638 2639 ```python 2640 model.compile( 2641 optimizer='sgd', 2642 loss='mse', 2643 metrics=[tf.keras.metrics.RootMeanSquaredError()]) 2644 ``` 2645 """ 2646 2647 def __init__(self, name='root_mean_squared_error', dtype=None): 2648 super(RootMeanSquaredError, self).__init__(name, dtype=dtype) 2649 2650 def update_state(self, y_true, y_pred, sample_weight=None): 2651 """Accumulates root mean squared error statistics. 2652 2653 Args: 2654 y_true: The ground truth values. 2655 y_pred: The predicted values. 2656 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2657 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2658 be broadcastable to `y_true`. 2659 2660 Returns: 2661 Update op. 2662 """ 2663 y_true = math_ops.cast(y_true, self._dtype) 2664 y_pred = math_ops.cast(y_pred, self._dtype) 2665 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 2666 y_pred, y_true) 2667 error_sq = math_ops.squared_difference(y_pred, y_true) 2668 return super(RootMeanSquaredError, self).update_state( 2669 error_sq, sample_weight=sample_weight) 2670 2671 def result(self): 2672 return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count)) 2673 2674 2675@keras_export('keras.metrics.LogCoshError') 2676class LogCoshError(MeanMetricWrapper): 2677 """Computes the logarithm of the hyperbolic cosine of the prediction error. 2678 2679 `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true) 2680 2681 Args: 2682 name: (Optional) string name of the metric instance. 2683 dtype: (Optional) data type of the metric result. 2684 2685 Standalone usage: 2686 2687 >>> m = tf.keras.metrics.LogCoshError() 2688 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2689 >>> m.result().numpy() 2690 0.10844523 2691 2692 >>> m.reset_states() 2693 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2694 ... sample_weight=[1, 0]) 2695 >>> m.result().numpy() 2696 0.21689045 2697 2698 Usage with `compile()` API: 2699 2700 ```python 2701 model.compile(optimizer='sgd', 2702 loss='mse', 2703 metrics=[tf.keras.metrics.LogCoshError()]) 2704 ``` 2705 """ 2706 2707 def __init__(self, name='logcosh', dtype=None): 2708 super(LogCoshError, self).__init__(logcosh, name, dtype=dtype) 2709 2710 2711@keras_export('keras.metrics.Poisson') 2712class Poisson(MeanMetricWrapper): 2713 """Computes the Poisson metric between `y_true` and `y_pred`. 2714 2715 `metric = y_pred - y_true * log(y_pred)` 2716 2717 Args: 2718 name: (Optional) string name of the metric instance. 2719 dtype: (Optional) data type of the metric result. 2720 2721 Standalone usage: 2722 2723 >>> m = tf.keras.metrics.Poisson() 2724 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2725 >>> m.result().numpy() 2726 0.49999997 2727 2728 >>> m.reset_states() 2729 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2730 ... sample_weight=[1, 0]) 2731 >>> m.result().numpy() 2732 0.99999994 2733 2734 Usage with `compile()` API: 2735 2736 ```python 2737 model.compile(optimizer='sgd', 2738 loss='mse', 2739 metrics=[tf.keras.metrics.Poisson()]) 2740 ``` 2741 """ 2742 2743 def __init__(self, name='poisson', dtype=None): 2744 super(Poisson, self).__init__(poisson, name, dtype=dtype) 2745 2746 2747@keras_export('keras.metrics.KLDivergence') 2748class KLDivergence(MeanMetricWrapper): 2749 """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`. 2750 2751 `metric = y_true * log(y_true / y_pred)` 2752 2753 Args: 2754 name: (Optional) string name of the metric instance. 2755 dtype: (Optional) data type of the metric result. 2756 2757 Standalone usage: 2758 2759 >>> m = tf.keras.metrics.KLDivergence() 2760 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2761 >>> m.result().numpy() 2762 0.45814306 2763 2764 >>> m.reset_states() 2765 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2766 ... sample_weight=[1, 0]) 2767 >>> m.result().numpy() 2768 0.9162892 2769 2770 Usage with `compile()` API: 2771 2772 ```python 2773 model.compile(optimizer='sgd', 2774 loss='mse', 2775 metrics=[tf.keras.metrics.KLDivergence()]) 2776 ``` 2777 """ 2778 2779 def __init__(self, name='kullback_leibler_divergence', dtype=None): 2780 super(KLDivergence, self).__init__( 2781 kullback_leibler_divergence, name, dtype=dtype) 2782 2783 2784@keras_export('keras.metrics.MeanIoU') 2785class MeanIoU(Metric): 2786 """Computes the mean Intersection-Over-Union metric. 2787 2788 Mean Intersection-Over-Union is a common evaluation metric for semantic image 2789 segmentation, which first computes the IOU for each semantic class and then 2790 computes the average over classes. IOU is defined as follows: 2791 IOU = true_positive / (true_positive + false_positive + false_negative). 2792 The predictions are accumulated in a confusion matrix, weighted by 2793 `sample_weight` and the metric is then calculated from it. 2794 2795 If `sample_weight` is `None`, weights default to 1. 2796 Use `sample_weight` of 0 to mask values. 2797 2798 Args: 2799 num_classes: The possible number of labels the prediction task can have. 2800 This value must be provided, since a confusion matrix of dimension = 2801 [num_classes, num_classes] will be allocated. 2802 name: (Optional) string name of the metric instance. 2803 dtype: (Optional) data type of the metric result. 2804 2805 Standalone usage: 2806 2807 >>> # cm = [[1, 1], 2808 >>> # [1, 1]] 2809 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] 2810 >>> # iou = true_positives / (sum_row + sum_col - true_positives)) 2811 >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 2812 >>> m = tf.keras.metrics.MeanIoU(num_classes=2) 2813 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) 2814 >>> m.result().numpy() 2815 0.33333334 2816 2817 >>> m.reset_states() 2818 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], 2819 ... sample_weight=[0.3, 0.3, 0.3, 0.1]) 2820 >>> m.result().numpy() 2821 0.23809525 2822 2823 Usage with `compile()` API: 2824 2825 ```python 2826 model.compile( 2827 optimizer='sgd', 2828 loss='mse', 2829 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) 2830 ``` 2831 """ 2832 2833 def __init__(self, num_classes, name=None, dtype=None): 2834 super(MeanIoU, self).__init__(name=name, dtype=dtype) 2835 self.num_classes = num_classes 2836 2837 # Variable to accumulate the predictions in the confusion matrix. 2838 self.total_cm = self.add_weight( 2839 'total_confusion_matrix', 2840 shape=(num_classes, num_classes), 2841 initializer=init_ops.zeros_initializer) 2842 2843 def update_state(self, y_true, y_pred, sample_weight=None): 2844 """Accumulates the confusion matrix statistics. 2845 2846 Args: 2847 y_true: The ground truth values. 2848 y_pred: The predicted values. 2849 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2850 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2851 be broadcastable to `y_true`. 2852 2853 Returns: 2854 Update op. 2855 """ 2856 2857 y_true = math_ops.cast(y_true, self._dtype) 2858 y_pred = math_ops.cast(y_pred, self._dtype) 2859 2860 # Flatten the input if its rank > 1. 2861 if y_pred.shape.ndims > 1: 2862 y_pred = array_ops.reshape(y_pred, [-1]) 2863 2864 if y_true.shape.ndims > 1: 2865 y_true = array_ops.reshape(y_true, [-1]) 2866 2867 if sample_weight is not None: 2868 sample_weight = math_ops.cast(sample_weight, self._dtype) 2869 if sample_weight.shape.ndims > 1: 2870 sample_weight = array_ops.reshape(sample_weight, [-1]) 2871 2872 # Accumulate the prediction to current confusion matrix. 2873 current_cm = confusion_matrix.confusion_matrix( 2874 y_true, 2875 y_pred, 2876 self.num_classes, 2877 weights=sample_weight, 2878 dtype=self._dtype) 2879 return self.total_cm.assign_add(current_cm) 2880 2881 def result(self): 2882 """Compute the mean intersection-over-union via the confusion matrix.""" 2883 sum_over_row = math_ops.cast( 2884 math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype) 2885 sum_over_col = math_ops.cast( 2886 math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype) 2887 true_positives = math_ops.cast( 2888 array_ops.tensor_diag_part(self.total_cm), dtype=self._dtype) 2889 2890 # sum_over_row + sum_over_col = 2891 # 2 * true_positives + false_positives + false_negatives. 2892 denominator = sum_over_row + sum_over_col - true_positives 2893 2894 # The mean is only computed over classes that appear in the 2895 # label or prediction tensor. If the denominator is 0, we need to 2896 # ignore the class. 2897 num_valid_entries = math_ops.reduce_sum( 2898 math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype)) 2899 2900 iou = math_ops.div_no_nan(true_positives, denominator) 2901 2902 return math_ops.div_no_nan( 2903 math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries) 2904 2905 def reset_states(self): 2906 K.set_value(self.total_cm, np.zeros((self.num_classes, self.num_classes))) 2907 2908 def get_config(self): 2909 config = {'num_classes': self.num_classes} 2910 base_config = super(MeanIoU, self).get_config() 2911 return dict(list(base_config.items()) + list(config.items())) 2912 2913 2914@keras_export('keras.metrics.MeanTensor') 2915class MeanTensor(Metric): 2916 """Computes the element-wise (weighted) mean of the given tensors. 2917 2918 `MeanTensor` returns a tensor with the same shape of the input tensors. The 2919 mean value is updated by keeping local variables `total` and `count`. The 2920 `total` tracks the sum of the weighted values, and `count` stores the sum of 2921 the weighted counts. 2922 2923 Args: 2924 name: (Optional) string name of the metric instance. 2925 dtype: (Optional) data type of the metric result. 2926 shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor 2927 of type int32. If not specified, the shape is inferred from the values at 2928 the first call of update_state. 2929 2930 Standalone usage: 2931 2932 >>> m = tf.keras.metrics.MeanTensor() 2933 >>> m.update_state([0, 1, 2, 3]) 2934 >>> m.update_state([4, 5, 6, 7]) 2935 >>> m.result().numpy() 2936 array([2., 3., 4., 5.], dtype=float32) 2937 2938 >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1]) 2939 >>> m.result().numpy() 2940 array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32) 2941 2942 >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4)) 2943 >>> m.result().numpy() 2944 array([[0., 0., 0., 0.]]) 2945 >>> m.update_state([[0, 1, 2, 3]]) 2946 >>> m.update_state([[4, 5, 6, 7]]) 2947 >>> m.result().numpy() 2948 array([[2., 3., 4., 5.]]) 2949 """ 2950 2951 def __init__(self, name='mean_tensor', dtype=None, shape=None): 2952 super(MeanTensor, self).__init__(name=name, dtype=dtype) 2953 self._shape = None 2954 self._total = None 2955 self._count = None 2956 self._built = False 2957 if shape is not None: 2958 self._build(shape) 2959 2960 def _build(self, shape): 2961 self._shape = tensor_shape.TensorShape(shape) 2962 self._build_input_shape = self._shape 2963 # Create new state variables 2964 self._total = self.add_weight( 2965 'total', shape=shape, initializer=init_ops.zeros_initializer) 2966 self._count = self.add_weight( 2967 'count', shape=shape, initializer=init_ops.zeros_initializer) 2968 with ops.init_scope(): 2969 if not context.executing_eagerly(): 2970 K._initialize_variables(K._get_session()) # pylint: disable=protected-access 2971 self._built = True 2972 2973 @property 2974 def total(self): 2975 return self._total if self._built else None 2976 2977 @property 2978 def count(self): 2979 return self._count if self._built else None 2980 2981 def update_state(self, values, sample_weight=None): 2982 """Accumulates statistics for computing the element-wise mean. 2983 2984 Args: 2985 values: Per-example value. 2986 sample_weight: Optional weighting of each example. Defaults to 1. 2987 2988 Returns: 2989 Update op. 2990 """ 2991 values = math_ops.cast(values, self._dtype) 2992 if not self._built: 2993 self._build(values.shape) 2994 elif values.shape != self._shape: 2995 raise ValueError('MeanTensor input values must always have the same ' 2996 'shape. Expected shape (set during the first call): {}. ' 2997 'Got: {}'.format(self._shape, values.shape)) 2998 2999 num_values = array_ops.ones_like(values) 3000 if sample_weight is not None: 3001 sample_weight = math_ops.cast(sample_weight, self._dtype) 3002 3003 # Update dimensions of weights to match with values if possible. 3004 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( 3005 values, sample_weight=sample_weight) 3006 try: 3007 # Broadcast weights if possible. 3008 sample_weight = weights_broadcast_ops.broadcast_weights( 3009 sample_weight, values) 3010 except ValueError: 3011 # Reduce values to same ndim as weight array 3012 ndim = K.ndim(values) 3013 weight_ndim = K.ndim(sample_weight) 3014 values = math_ops.reduce_mean( 3015 values, axis=list(range(weight_ndim, ndim))) 3016 3017 num_values = math_ops.multiply(num_values, sample_weight) 3018 values = math_ops.multiply(values, sample_weight) 3019 3020 update_total_op = self._total.assign_add(values) 3021 with ops.control_dependencies([update_total_op]): 3022 return self._count.assign_add(num_values) 3023 3024 def result(self): 3025 if not self._built: 3026 raise ValueError( 3027 'MeanTensor does not have any result yet. Please call the MeanTensor ' 3028 'instance or use `.update_state(value)` before retrieving the result.' 3029 ) 3030 return math_ops.div_no_nan(self.total, self.count) 3031 3032 def reset_states(self): 3033 if self._built: 3034 K.batch_set_value( 3035 [(v, np.zeros(self._shape.as_list())) for v in self.variables]) 3036 3037 3038@keras_export('keras.metrics.BinaryCrossentropy') 3039class BinaryCrossentropy(MeanMetricWrapper): 3040 """Computes the crossentropy metric between the labels and predictions. 3041 3042 This is the crossentropy metric class to be used when there are only two 3043 label classes (0 and 1). 3044 3045 Args: 3046 name: (Optional) string name of the metric instance. 3047 dtype: (Optional) data type of the metric result. 3048 from_logits: (Optional )Whether output is expected to be a logits tensor. 3049 By default, we consider that output encodes a probability distribution. 3050 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 3051 smoothed, meaning the confidence on label values are relaxed. 3052 e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for 3053 label `0` and `0.9` for label `1`". 3054 3055 Standalone usage: 3056 3057 >>> m = tf.keras.metrics.BinaryCrossentropy() 3058 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 3059 >>> m.result().numpy() 3060 0.81492424 3061 3062 >>> m.reset_states() 3063 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 3064 ... sample_weight=[1, 0]) 3065 >>> m.result().numpy() 3066 0.9162905 3067 3068 Usage with `compile()` API: 3069 3070 ```python 3071 model.compile( 3072 optimizer='sgd', 3073 loss='mse', 3074 metrics=[tf.keras.metrics.BinaryCrossentropy()]) 3075 ``` 3076 """ 3077 3078 def __init__(self, 3079 name='binary_crossentropy', 3080 dtype=None, 3081 from_logits=False, 3082 label_smoothing=0): 3083 super(BinaryCrossentropy, self).__init__( 3084 binary_crossentropy, 3085 name, 3086 dtype=dtype, 3087 from_logits=from_logits, 3088 label_smoothing=label_smoothing) 3089 3090 3091@keras_export('keras.metrics.CategoricalCrossentropy') 3092class CategoricalCrossentropy(MeanMetricWrapper): 3093 """Computes the crossentropy metric between the labels and predictions. 3094 3095 This is the crossentropy metric class to be used when there are multiple 3096 label classes (2 or more). Here we assume that labels are given as a `one_hot` 3097 representation. eg., When labels values are [2, 0, 1], 3098 `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]. 3099 3100 Args: 3101 name: (Optional) string name of the metric instance. 3102 dtype: (Optional) data type of the metric result. 3103 from_logits: (Optional) Whether output is expected to be a logits tensor. 3104 By default, we consider that output encodes a probability distribution. 3105 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 3106 smoothed, meaning the confidence on label values are relaxed. e.g. 3107 `label_smoothing=0.2` means that we will use a value of `0.1` for label 3108 `0` and `0.9` for label `1`" 3109 3110 Standalone usage: 3111 3112 >>> # EPSILON = 1e-7, y = y_true, y` = y_pred 3113 >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) 3114 >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 3115 >>> # xent = -sum(y * log(y'), axis = -1) 3116 >>> # = -((log 0.95), (log 0.1)) 3117 >>> # = [0.051, 2.302] 3118 >>> # Reduced xent = (0.051 + 2.302) / 2 3119 >>> m = tf.keras.metrics.CategoricalCrossentropy() 3120 >>> m.update_state([[0, 1, 0], [0, 0, 1]], 3121 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 3122 >>> m.result().numpy() 3123 1.1769392 3124 3125 >>> m.reset_states() 3126 >>> m.update_state([[0, 1, 0], [0, 0, 1]], 3127 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], 3128 ... sample_weight=tf.constant([0.3, 0.7])) 3129 >>> m.result().numpy() 3130 1.6271976 3131 3132 Usage with `compile()` API: 3133 3134 ```python 3135 model.compile( 3136 optimizer='sgd', 3137 loss='mse', 3138 metrics=[tf.keras.metrics.CategoricalCrossentropy()]) 3139 ``` 3140 """ 3141 3142 def __init__(self, 3143 name='categorical_crossentropy', 3144 dtype=None, 3145 from_logits=False, 3146 label_smoothing=0): 3147 super(CategoricalCrossentropy, self).__init__( 3148 categorical_crossentropy, 3149 name, 3150 dtype=dtype, 3151 from_logits=from_logits, 3152 label_smoothing=label_smoothing) 3153 3154 3155@keras_export('keras.metrics.SparseCategoricalCrossentropy') 3156class SparseCategoricalCrossentropy(MeanMetricWrapper): 3157 """Computes the crossentropy metric between the labels and predictions. 3158 3159 Use this crossentropy metric when there are two or more label classes. 3160 We expect labels to be provided as integers. If you want to provide labels 3161 using `one-hot` representation, please use `CategoricalCrossentropy` metric. 3162 There should be `# classes` floating point values per feature for `y_pred` 3163 and a single floating point value per feature for `y_true`. 3164 3165 In the snippet below, there is a single floating point value per example for 3166 `y_true` and `# classes` floating pointing values per example for `y_pred`. 3167 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is 3168 `[batch_size, num_classes]`. 3169 3170 Args: 3171 name: (Optional) string name of the metric instance. 3172 dtype: (Optional) data type of the metric result. 3173 from_logits: (Optional) Whether output is expected to be a logits tensor. 3174 By default, we consider that output encodes a probability distribution. 3175 axis: (Optional) Defaults to -1. The dimension along which the metric is 3176 computed. 3177 3178 Standalone usage: 3179 3180 >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] 3181 >>> # logits = log(y_pred) 3182 >>> # softmax = exp(logits) / sum(exp(logits), axis=-1) 3183 >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 3184 >>> # xent = -sum(y * log(softmax), 1) 3185 >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181], 3186 >>> # [-2.3026, -0.2231, -2.3026]] 3187 >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]] 3188 >>> # xent = [0.0513, 2.3026] 3189 >>> # Reduced xent = (0.0513 + 2.3026) / 2 3190 >>> m = tf.keras.metrics.SparseCategoricalCrossentropy() 3191 >>> m.update_state([1, 2], 3192 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 3193 >>> m.result().numpy() 3194 1.1769392 3195 3196 >>> m.reset_states() 3197 >>> m.update_state([1, 2], 3198 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], 3199 ... sample_weight=tf.constant([0.3, 0.7])) 3200 >>> m.result().numpy() 3201 1.6271976 3202 3203 Usage with `compile()` API: 3204 3205 ```python 3206 model.compile( 3207 optimizer='sgd', 3208 loss='mse', 3209 metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()]) 3210 ``` 3211 """ 3212 3213 def __init__(self, 3214 name='sparse_categorical_crossentropy', 3215 dtype=None, 3216 from_logits=False, 3217 axis=-1): 3218 super(SparseCategoricalCrossentropy, self).__init__( 3219 sparse_categorical_crossentropy, 3220 name, 3221 dtype=dtype, 3222 from_logits=from_logits, 3223 axis=axis) 3224 3225 3226class SumOverBatchSize(Reduce): 3227 """Computes the weighted sum over batch size of the given values. 3228 3229 For example, if values is [1, 3, 5, 7] then the metric value is 4. 3230 If the weights were specified as [1, 1, 0, 0] then the value would be 1. 3231 3232 This metric creates two variables, `total` and `count` that are used to 3233 compute the average of `values`. This average is ultimately returned as sum 3234 over batch size which is an idempotent operation that simply divides `total` 3235 by `count`. 3236 3237 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 3238 to mask values. 3239 """ 3240 3241 def __init__(self, name='sum_over_batch_size', dtype=None): 3242 super(SumOverBatchSize, self).__init__( 3243 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 3244 name=name, 3245 dtype=dtype) 3246 3247 3248class SumOverBatchSizeMetricWrapper(SumOverBatchSize): 3249 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric.""" 3250 3251 def __init__(self, fn, name=None, dtype=None, **kwargs): 3252 """Creates a `SumOverBatchSizeMetricWrapper` instance. 3253 3254 Args: 3255 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 3256 **kwargs)`. 3257 name: (Optional) string name of the metric instance. 3258 dtype: (Optional) data type of the metric result. 3259 **kwargs: The keyword arguments that are passed on to `fn`. 3260 """ 3261 super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype) 3262 self._fn = fn 3263 self._fn_kwargs = kwargs 3264 3265 def update_state(self, y_true, y_pred, sample_weight=None): 3266 y_true = math_ops.cast(y_true, self._dtype) 3267 y_pred = math_ops.cast(y_pred, self._dtype) 3268 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 3269 y_pred, y_true) 3270 3271 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) 3272 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 3273 return super(SumOverBatchSizeMetricWrapper, self).update_state( 3274 matches, sample_weight=sample_weight) 3275 3276 def get_config(self): 3277 config = {} 3278 for k, v in six.iteritems(self._fn_kwargs): 3279 config[k] = K.eval(v) if is_tensor_or_variable(v) else v 3280 base_config = super(SumOverBatchSizeMetricWrapper, self).get_config() 3281 return dict(list(base_config.items()) + list(config.items())) 3282 3283 3284def accuracy(y_true, y_pred): 3285 [y_pred, y_true], _ = \ 3286 metrics_utils.ragged_assert_compatible_and_get_flat_values( 3287 [y_pred, y_true]) 3288 y_pred.shape.assert_is_compatible_with(y_true.shape) 3289 if y_true.dtype != y_pred.dtype: 3290 y_pred = math_ops.cast(y_pred, y_true.dtype) 3291 return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx()) 3292 3293 3294@keras_export('keras.metrics.binary_accuracy') 3295@dispatch.add_dispatch_support 3296def binary_accuracy(y_true, y_pred, threshold=0.5): 3297 """Calculates how often predictions match binary labels. 3298 3299 Standalone usage: 3300 >>> y_true = [[1], [1], [0], [0]] 3301 >>> y_pred = [[1], [1], [0], [0]] 3302 >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred) 3303 >>> assert m.shape == (4,) 3304 >>> m.numpy() 3305 array([1., 1., 1., 1.], dtype=float32) 3306 3307 Args: 3308 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 3309 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 3310 threshold: (Optional) Float representing the threshold for deciding whether 3311 prediction values are 1 or 0. 3312 3313 Returns: 3314 Binary accuracy values. shape = `[batch_size, d0, .. dN-1]` 3315 """ 3316 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 3317 threshold = math_ops.cast(threshold, y_pred.dtype) 3318 y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) 3319 return K.mean(math_ops.equal(y_true, y_pred), axis=-1) 3320 3321 3322@keras_export('keras.metrics.categorical_accuracy') 3323@dispatch.add_dispatch_support 3324def categorical_accuracy(y_true, y_pred): 3325 """Calculates how often predictions match one-hot labels. 3326 3327 Standalone usage: 3328 >>> y_true = [[0, 0, 1], [0, 1, 0]] 3329 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3330 >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 3331 >>> assert m.shape == (2,) 3332 >>> m.numpy() 3333 array([0., 1.], dtype=float32) 3334 3335 You can provide logits of classes as `y_pred`, since argmax of 3336 logits and probabilities are same. 3337 3338 Args: 3339 y_true: One-hot ground truth values. 3340 y_pred: The prediction values. 3341 3342 Returns: 3343 Categorical accuracy values. 3344 """ 3345 return math_ops.cast( 3346 math_ops.equal( 3347 math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)), 3348 K.floatx()) 3349 3350 3351@keras_export('keras.metrics.sparse_categorical_accuracy') 3352@dispatch.add_dispatch_support 3353def sparse_categorical_accuracy(y_true, y_pred): 3354 """Calculates how often predictions match integer labels. 3355 3356 Standalone usage: 3357 >>> y_true = [2, 1] 3358 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3359 >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred) 3360 >>> assert m.shape == (2,) 3361 >>> m.numpy() 3362 array([0., 1.], dtype=float32) 3363 3364 You can provide logits of classes as `y_pred`, since argmax of 3365 logits and probabilities are same. 3366 3367 Args: 3368 y_true: Integer ground truth values. 3369 y_pred: The prediction values. 3370 3371 Returns: 3372 Sparse categorical accuracy values. 3373 """ 3374 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 3375 y_true = ops.convert_to_tensor_v2_with_dispatch(y_true) 3376 y_pred_rank = y_pred.shape.ndims 3377 y_true_rank = y_true.shape.ndims 3378 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) 3379 if (y_true_rank is not None) and (y_pred_rank is not None) and (len( 3380 K.int_shape(y_true)) == len(K.int_shape(y_pred))): 3381 y_true = array_ops.squeeze(y_true, [-1]) 3382 y_pred = math_ops.argmax(y_pred, axis=-1) 3383 3384 # If the predicted output and actual output types don't match, force cast them 3385 # to match. 3386 if K.dtype(y_pred) != K.dtype(y_true): 3387 y_pred = math_ops.cast(y_pred, K.dtype(y_true)) 3388 3389 return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx()) 3390 3391 3392@keras_export('keras.metrics.top_k_categorical_accuracy') 3393@dispatch.add_dispatch_support 3394def top_k_categorical_accuracy(y_true, y_pred, k=5): 3395 """Computes how often targets are in the top `K` predictions. 3396 3397 Standalone usage: 3398 >>> y_true = [[0, 0, 1], [0, 1, 0]] 3399 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3400 >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) 3401 >>> assert m.shape == (2,) 3402 >>> m.numpy() 3403 array([1., 1.], dtype=float32) 3404 3405 Args: 3406 y_true: The ground truth values. 3407 y_pred: The prediction values. 3408 k: (Optional) Number of top elements to look at for computing accuracy. 3409 Defaults to 5. 3410 3411 Returns: 3412 Top K categorical accuracy value. 3413 """ 3414 return math_ops.cast( 3415 nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), K.floatx()) 3416 3417 3418@keras_export('keras.metrics.sparse_top_k_categorical_accuracy') 3419@dispatch.add_dispatch_support 3420def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): 3421 """Computes how often integer targets are in the top `K` predictions. 3422 3423 Standalone usage: 3424 >>> y_true = [2, 1] 3425 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3426 >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy( 3427 ... y_true, y_pred, k=3) 3428 >>> assert m.shape == (2,) 3429 >>> m.numpy() 3430 array([1., 1.], dtype=float32) 3431 3432 Args: 3433 y_true: tensor of true targets. 3434 y_pred: tensor of predicted targets. 3435 k: (Optional) Number of top elements to look at for computing accuracy. 3436 Defaults to 5. 3437 3438 Returns: 3439 Sparse top K categorical accuracy value. 3440 """ 3441 y_pred_rank = ops.convert_to_tensor_v2_with_dispatch(y_pred).shape.ndims 3442 y_true_rank = ops.convert_to_tensor_v2_with_dispatch(y_true).shape.ndims 3443 # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) 3444 if (y_true_rank is not None) and (y_pred_rank is not None): 3445 if y_pred_rank > 2: 3446 y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) 3447 if y_true_rank > 1: 3448 y_true = array_ops.reshape(y_true, [-1]) 3449 3450 return math_ops.cast( 3451 nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), K.floatx()) 3452 3453 3454def cosine_proximity(y_true, y_pred, axis=-1): 3455 """Computes the cosine similarity between labels and predictions. 3456 3457 Args: 3458 y_true: The ground truth values. 3459 y_pred: The prediction values. 3460 axis: (Optional) Defaults to -1. The dimension along which the cosine 3461 similarity is computed. 3462 3463 Returns: 3464 Cosine similarity value. 3465 """ 3466 y_true = nn.l2_normalize(y_true, axis=axis) 3467 y_pred = nn.l2_normalize(y_pred, axis=axis) 3468 return math_ops.reduce_sum(y_true * y_pred, axis=axis) 3469 3470# Aliases 3471 3472acc = ACC = accuracy 3473bce = BCE = binary_crossentropy 3474mse = MSE = mean_squared_error 3475mae = MAE = mean_absolute_error 3476mape = MAPE = mean_absolute_percentage_error 3477msle = MSLE = mean_squared_logarithmic_error 3478cosine_similarity = cosine_proximity 3479log_cosh = logcosh 3480 3481 3482def clone_metric(metric): 3483 """Returns a clone of the metric if stateful, otherwise returns it as is.""" 3484 if isinstance(metric, Metric): 3485 with ops.init_scope(): 3486 return metric.__class__.from_config(metric.get_config()) 3487 return metric 3488 3489 3490def clone_metrics(metrics): 3491 """Clones the given metric list/dict.""" 3492 return nest.map_structure(clone_metric, metrics) 3493 3494 3495@keras_export('keras.metrics.serialize') 3496def serialize(metric): 3497 """Serializes metric function or `Metric` instance. 3498 3499 Args: 3500 metric: A Keras `Metric` instance or a metric function. 3501 3502 Returns: 3503 Metric configuration dictionary. 3504 """ 3505 return serialize_keras_object(metric) 3506 3507 3508@keras_export('keras.metrics.deserialize') 3509def deserialize(config, custom_objects=None): 3510 """Deserializes a serialized metric class/function instance. 3511 3512 Args: 3513 config: Metric configuration. 3514 custom_objects: Optional dictionary mapping names (strings) to custom 3515 objects (classes and functions) to be considered during deserialization. 3516 3517 Returns: 3518 A Keras `Metric` instance or a metric function. 3519 """ 3520 return deserialize_keras_object( 3521 config, 3522 module_objects=globals(), 3523 custom_objects=custom_objects, 3524 printable_module_name='metric function') 3525 3526 3527@keras_export('keras.metrics.get') 3528def get(identifier): 3529 """Retrieves a Keras metric as a `function`/`Metric` class instance. 3530 3531 The `identifier` may be the string name of a metric function or class. 3532 3533 >>> metric = tf.keras.metrics.get("categorical_crossentropy") 3534 >>> type(metric) 3535 <class 'function'> 3536 >>> metric = tf.keras.metrics.get("CategoricalCrossentropy") 3537 >>> type(metric) 3538 <class '...tensorflow.python.keras.metrics.CategoricalCrossentropy'> 3539 3540 You can also specify `config` of the metric to this function by passing dict 3541 containing `class_name` and `config` as an identifier. Also note that the 3542 `class_name` must map to a `Metric` class 3543 3544 >>> identifier = {"class_name": "CategoricalCrossentropy", 3545 ... "config": {"from_logits": True}} 3546 >>> metric = tf.keras.metrics.get(identifier) 3547 >>> type(metric) 3548 <class '...tensorflow.python.keras.metrics.CategoricalCrossentropy'> 3549 3550 Args: 3551 identifier: A metric identifier. One of None or string name of a metric 3552 function/class or metric configuration dictionary or a metric function or 3553 a metric class instance 3554 3555 Returns: 3556 A Keras metric as a `function`/ `Metric` class instance. 3557 3558 Raises: 3559 ValueError: If `identifier` cannot be interpreted. 3560 """ 3561 if isinstance(identifier, dict): 3562 return deserialize(identifier) 3563 elif isinstance(identifier, six.string_types): 3564 return deserialize(str(identifier)) 3565 elif callable(identifier): 3566 return identifier 3567 else: 3568 raise ValueError( 3569 'Could not interpret metric function identifier: {}'.format(identifier)) 3570 3571 3572def is_built_in(cls): 3573 return cls.__module__ == Metric.__module__ 3574