1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unused-import 16"""Built-in metrics. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import abc 23import types 24import numpy as np 25import six 26 27from tensorflow.python.eager import context 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.keras import backend as K 33from tensorflow.python.keras.engine.base_layer import Layer 34from tensorflow.python.keras.losses import binary_crossentropy 35from tensorflow.python.keras.losses import categorical_crossentropy 36from tensorflow.python.keras.losses import categorical_hinge 37from tensorflow.python.keras.losses import cosine_similarity 38from tensorflow.python.keras.losses import hinge 39from tensorflow.python.keras.losses import kullback_leibler_divergence 40from tensorflow.python.keras.losses import logcosh 41from tensorflow.python.keras.losses import mean_absolute_error 42from tensorflow.python.keras.losses import mean_absolute_percentage_error 43from tensorflow.python.keras.losses import mean_squared_error 44from tensorflow.python.keras.losses import mean_squared_logarithmic_error 45from tensorflow.python.keras.losses import poisson 46from tensorflow.python.keras.losses import sparse_categorical_crossentropy 47from tensorflow.python.keras.losses import squared_hinge 48from tensorflow.python.keras.utils import metrics_utils 49from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 50from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 51from tensorflow.python.keras.utils.generic_utils import to_list 52from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions 53from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable 54from tensorflow.python.ops import array_ops 55from tensorflow.python.ops import confusion_matrix 56from tensorflow.python.ops import init_ops 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import nn 59from tensorflow.python.ops import variables as tf_variables 60from tensorflow.python.ops import weights_broadcast_ops 61from tensorflow.python.util.tf_export import keras_export 62from tensorflow.tools.docs import doc_controls 63 64 65@keras_export('keras.metrics.Metric') 66@six.add_metaclass(abc.ABCMeta) 67class Metric(Layer): 68 """Encapsulates metric logic and state. 69 70 Usage: 71 72 ```python 73 m = SomeMetric(...) 74 for input in ...: 75 m.update_state(input) 76 print('Final result: ', m.result().numpy()) 77 ``` 78 79 Usage with tf.keras API: 80 81 ```python 82 model = tf.keras.Sequential() 83 model.add(tf.keras.layers.Dense(64, activation='relu')) 84 model.add(tf.keras.layers.Dense(64, activation='relu')) 85 model.add(tf.keras.layers.Dense(10, activation='softmax')) 86 87 model.compile(optimizer=tf.train.RMSPropOptimizer(0.01), 88 loss=tf.keras.losses.categorical_crossentropy, 89 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 90 91 data = np.random.random((1000, 32)) 92 labels = np.random.random((1000, 10)) 93 94 dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 95 dataset = dataset.batch(32) 96 dataset = dataset.repeat() 97 98 model.fit(dataset, epochs=10, steps_per_epoch=30) 99 ``` 100 101 To be implemented by subclasses: 102 * `__init__()`: All state variables should be created in this method by 103 calling `self.add_weight()` like: `self.var = self.add_weight(...)` 104 * `update_state()`: Has all updates to the state variables like: 105 self.var.assign_add(...). 106 * `result()`: Computes and returns a value for the metric 107 from the state variables. 108 109 Example subclass implementation: 110 111 ``` 112 class BinaryTruePositives(tf.keras.metrics.Metric): 113 114 def __init__(self, name='binary_true_positives', **kwargs): 115 super(BinaryTruePositives, self).__init__(name=name, **kwargs) 116 self.true_positives = self.add_weight(name='tp', initializer='zeros') 117 118 def update_state(self, y_true, y_pred, sample_weight=None): 119 y_true = tf.cast(y_true, tf.bool) 120 y_pred = tf.cast(y_pred, tf.bool) 121 122 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) 123 values = tf.cast(values, self.dtype) 124 if sample_weight is not None: 125 sample_weight = tf.cast(sample_weight, self.dtype) 126 sample_weight = tf.broadcast_weights(sample_weight, values) 127 values = tf.multiply(values, sample_weight) 128 self.true_positives.assign_add(tf.reduce_sum(values)) 129 130 def result(self): 131 return self.true_positives 132 ``` 133 """ 134 135 def __init__(self, name=None, dtype=None, **kwargs): 136 super(Metric, self).__init__(name=name, dtype=dtype, **kwargs) 137 self.stateful = True # All metric layers are stateful. 138 self.built = True 139 self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name 140 141 def __new__(cls, *args, **kwargs): 142 obj = super(Metric, cls).__new__(cls) 143 144 # TODO(psv): We are excluding wrapping `update_state` of built-in metrics 145 # with function here because of b/121302287. With this, built-in metrics 146 # will continue to work with TPUs and custom metrics will not, however 147 # users writing custom metrics need not worry about control dependencies 148 # and returning ops. 149 if cls.__module__ == Metric.__module__: 150 update_state_fn = obj.update_state 151 else: 152 update_state_fn = def_function.function(obj.update_state) 153 154 obj.update_state = types.MethodType( 155 metrics_utils.update_state_wrapper(update_state_fn), obj) 156 obj.result = types.MethodType(metrics_utils.result_wrapper(obj.result), obj) 157 return obj 158 159 def __call__(self, *args, **kwargs): 160 """Accumulates statistics and then computes metric result value. 161 162 Args: 163 *args: 164 **kwargs: A mini-batch of inputs to the Metric, 165 passed on to `update_state()`. 166 167 Returns: 168 The metric value tensor. 169 """ 170 update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable 171 with ops.control_dependencies([update_op]): 172 result_t = self.result() # pylint: disable=not-callable 173 174 # We are adding the metric object as metadata on the result tensor. 175 # This is required when we want to use a metric with `add_metric` API on 176 # a Model/Layer in graph mode. This metric instance will later be used 177 # to reset variable state after each epoch of training. 178 # Example: 179 # model = Model() 180 # mean = Mean() 181 # model.add_metric(mean(values), name='mean') 182 if not context.executing_eagerly(): 183 result_t._metric_obj = self # pylint: disable=protected-access 184 return result_t 185 186 @property 187 def dtype(self): 188 return self._dtype 189 190 def get_config(self): 191 """Returns the serializable config of the metric.""" 192 return {'name': self.name, 'dtype': self.dtype} 193 194 def reset_states(self): 195 """Resets all of the metric state variables. 196 197 This function is called between epochs/steps, 198 when a metric is evaluated during training. 199 """ 200 K.batch_set_value([(v, 0) for v in self.variables]) 201 202 @abc.abstractmethod 203 def update_state(self, *args, **kwargs): 204 """Accumulates statistics for the metric. 205 206 Note: This function is executed as a graph function in graph mode. 207 This means: 208 a) Operations on the same resource are executed in textual order. 209 This should make it easier to do things like add the updated 210 value of a variable to another, for example. 211 b) You don't need to worry about collecting the update ops to execute. 212 All update ops added to the graph by this function will be executed. 213 As a result, code should generally work the same way with graph or 214 eager execution. 215 216 Please use `tf.config.experimental_run_functions_eagerly(True)` to execute 217 this function eagerly for debugging or profiling. 218 219 Args: 220 *args: 221 **kwargs: A mini-batch of inputs to the Metric. 222 """ 223 NotImplementedError('Must be implemented in subclasses.') 224 225 @abc.abstractmethod 226 def result(self): 227 """Computes and returns the metric value tensor. 228 229 Result computation is an idempotent operation that simply calculates the 230 metric value using the state variables. 231 """ 232 NotImplementedError('Must be implemented in subclasses.') 233 234 ### For use by subclasses ### 235 @doc_controls.for_subclass_implementers 236 def add_weight(self, 237 name, 238 shape=(), 239 aggregation=tf_variables.VariableAggregation.SUM, 240 synchronization=tf_variables.VariableSynchronization.ON_READ, 241 initializer=None, 242 dtype=None): 243 """Adds state variable. Only for use by subclasses.""" 244 return super(Metric, self).add_weight( 245 name=name, 246 shape=shape, 247 dtype=self._dtype if dtype is None else dtype, 248 trainable=False, 249 initializer=initializer, 250 collections=[], 251 synchronization=synchronization, 252 aggregation=aggregation) 253 254 ### End: For use by subclasses ### 255 256 257class Reduce(Metric): 258 """Encapsulates metrics that perform a reduce operation on the values.""" 259 260 def __init__(self, reduction, name, dtype=None): 261 """Creates a `Reduce` instance. 262 263 Args: 264 reduction: a `tf.keras.metrics.Reduction` enum value. 265 name: string name of the metric instance. 266 dtype: (Optional) data type of the metric result. 267 """ 268 super(Reduce, self).__init__(name=name, dtype=dtype) 269 self.reduction = reduction 270 self.total = self.add_weight( 271 'total', initializer=init_ops.zeros_initializer) 272 if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 273 metrics_utils.Reduction.WEIGHTED_MEAN]: 274 self.count = self.add_weight( 275 'count', initializer=init_ops.zeros_initializer) 276 277 def update_state(self, values, sample_weight=None): 278 """Accumulates statistics for computing the reduction metric. 279 280 For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, 281 then the value of `result()` is 4. If the `sample_weight` is specified as 282 [1, 1, 0, 0] then value of `result()` would be 2. 283 284 Args: 285 values: Per-example value. 286 sample_weight: Optional weighting of each example. Defaults to 1. 287 288 Returns: 289 Update op. 290 """ 291 values = math_ops.cast(values, self._dtype) 292 if sample_weight is not None: 293 sample_weight = math_ops.cast(sample_weight, self._dtype) 294 # Update dimensions of weights to match with values if possible. 295 values, _, sample_weight = squeeze_or_expand_dimensions( 296 values, None, sample_weight) 297 try: 298 # Broadcast weights if possible. 299 sample_weight = weights_broadcast_ops.broadcast_weights( 300 sample_weight, values) 301 except ValueError: 302 # Reduce values to same ndim as weight array 303 ndim = K.ndim(values) 304 weight_ndim = K.ndim(sample_weight) 305 if self.reduction == metrics_utils.Reduction.SUM: 306 values = math_ops.reduce_sum( 307 values, axis=list(range(weight_ndim, ndim))) 308 else: 309 values = math_ops.reduce_mean( 310 values, axis=list(range(weight_ndim, ndim))) 311 values = math_ops.multiply(values, sample_weight) 312 313 value_sum = math_ops.reduce_sum(values) 314 with ops.control_dependencies([value_sum]): 315 update_total_op = self.total.assign_add(value_sum) 316 317 # Exit early if the reduction doesn't have a denominator. 318 if self.reduction == metrics_utils.Reduction.SUM: 319 return update_total_op 320 321 # Update `count` for reductions that require a denominator. 322 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: 323 num_values = math_ops.cast(array_ops.size(values), self._dtype) 324 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: 325 if sample_weight is None: 326 num_values = math_ops.cast(array_ops.size(values), self._dtype) 327 else: 328 num_values = math_ops.reduce_sum(sample_weight) 329 else: 330 raise NotImplementedError( 331 'reduction [%s] not implemented' % self.reduction) 332 333 with ops.control_dependencies([update_total_op]): 334 return self.count.assign_add(num_values) 335 336 def result(self): 337 if self.reduction == metrics_utils.Reduction.SUM: 338 return array_ops.identity(self.total) 339 elif self.reduction in [ 340 metrics_utils.Reduction.WEIGHTED_MEAN, 341 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE 342 ]: 343 return math_ops.div_no_nan(self.total, self.count) 344 else: 345 raise NotImplementedError( 346 'reduction [%s] not implemented' % self.reduction) 347 348 349@keras_export('keras.metrics.Sum') 350class Sum(Reduce): 351 """Computes the (weighted) sum of the given values. 352 353 For example, if values is [1, 3, 5, 7] then the sum is 16. 354 If the weights were specified as [1, 1, 0, 0] then the sum would be 4. 355 356 This metric creates one variable, `total`, that is used to compute the sum of 357 `values`. This is ultimately returned as `sum`. 358 359 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 360 to mask values. 361 362 Usage: 363 364 ```python 365 m = tf.keras.metrics.Sum() 366 m.update_state([1, 3, 5, 7]) 367 print('Final result: ', m.result().numpy()) # Final result: 16.0 368 ``` 369 370 Usage with tf.keras API: 371 372 ```python 373 model = tf.keras.Model(inputs, outputs) 374 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs)) 375 model.compile('sgd', loss='mse') 376 ``` 377 """ 378 379 def __init__(self, name='sum', dtype=None): 380 """Creates a `Sum` instance. 381 382 Args: 383 name: (Optional) string name of the metric instance. 384 dtype: (Optional) data type of the metric result. 385 """ 386 super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM, 387 name=name, dtype=dtype) 388 389 390@keras_export('keras.metrics.Mean') 391class Mean(Reduce): 392 """Computes the (weighted) mean of the given values. 393 394 For example, if values is [1, 3, 5, 7] then the mean is 4. 395 If the weights were specified as [1, 1, 0, 0] then the mean would be 2. 396 397 This metric creates two variables, `total` and `count` that are used to 398 compute the average of `values`. This average is ultimately returned as `mean` 399 which is an idempotent operation that simply divides `total` by `count`. 400 401 If `sample_weight` is `None`, weights default to 1. 402 Use `sample_weight` of 0 to mask values. 403 404 Usage: 405 406 ```python 407 m = tf.keras.metrics.Mean() 408 m.update_state([1, 3, 5, 7]) 409 print('Final result: ', m.result().numpy()) # Final result: 4.0 410 ``` 411 412 Usage with tf.keras API: 413 414 ```python 415 model = tf.keras.Model(inputs, outputs) 416 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs)) 417 model.compile('sgd', loss='mse') 418 ``` 419 """ 420 421 def __init__(self, name='mean', dtype=None): 422 """Creates a `Mean` instance. 423 424 Args: 425 name: (Optional) string name of the metric instance. 426 dtype: (Optional) data type of the metric result. 427 """ 428 super(Mean, self).__init__( 429 reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype) 430 431 432@keras_export('keras.metrics.MeanRelativeError') 433class MeanRelativeError(Mean): 434 """Computes the mean relative error by normalizing with the given values. 435 436 This metric creates two local variables, `total` and `count` that are used to 437 compute the mean relative absolute error. This average is weighted by 438 `sample_weight`, and it is ultimately returned as `mean_relative_error`: 439 an idempotent operation that simply divides `total` by `count`. 440 441 If `sample_weight` is `None`, weights default to 1. 442 Use `sample_weight` of 0 to mask values. 443 444 Usage: 445 446 ```python 447 m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3]) 448 m.update_state([1, 3, 2, 3], [2, 4, 6, 8]) 449 450 # metric = mean(|y_pred - y_true| / normalizer) 451 # = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3]) 452 # = 5/4 = 1.25 453 print('Final result: ', m.result().numpy()) # Final result: 1.25 454 ``` 455 456 Usage with tf.keras API: 457 458 ```python 459 model = tf.keras.Model(inputs, outputs) 460 model.compile( 461 'sgd', 462 loss='mse', 463 metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])]) 464 ``` 465 """ 466 467 def __init__(self, normalizer, name=None, dtype=None): 468 """Creates a `MeanRelativeError` instance. 469 470 Args: 471 normalizer: The normalizer values with same shape as predictions. 472 name: (Optional) string name of the metric instance. 473 dtype: (Optional) data type of the metric result. 474 """ 475 super(MeanRelativeError, self).__init__(name=name, dtype=dtype) 476 normalizer = math_ops.cast(normalizer, self._dtype) 477 self.normalizer = normalizer 478 479 def update_state(self, y_true, y_pred, sample_weight=None): 480 """Accumulates metric statistics. 481 482 Args: 483 y_true: The ground truth values. 484 y_pred: The predicted values. 485 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 486 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 487 be broadcastable to `y_true`. 488 489 Returns: 490 Update op. 491 """ 492 y_true = math_ops.cast(y_true, self._dtype) 493 y_pred = math_ops.cast(y_pred, self._dtype) 494 y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( 495 y_pred, y_true, sample_weight) 496 497 y_pred, self.normalizer = confusion_matrix.remove_squeezable_dimensions( 498 y_pred, self.normalizer) 499 y_pred.shape.assert_is_compatible_with(y_true.shape) 500 relative_errors = math_ops.div_no_nan( 501 math_ops.abs(y_true - y_pred), self.normalizer) 502 503 return super(MeanRelativeError, self).update_state( 504 relative_errors, sample_weight=sample_weight) 505 506 def get_config(self): 507 n = self.normalizer 508 config = {'normalizer': K.eval(n) if is_tensor_or_variable(n) else n} 509 base_config = super(MeanRelativeError, self).get_config() 510 return dict(list(base_config.items()) + list(config.items())) 511 512 513class MeanMetricWrapper(Mean): 514 """Wraps a stateless metric function with the Mean metric.""" 515 516 def __init__(self, fn, name=None, dtype=None, **kwargs): 517 """Creates a `MeanMetricWrapper` instance. 518 519 Args: 520 fn: The metric function to wrap, with signature 521 `fn(y_true, y_pred, **kwargs)`. 522 name: (Optional) string name of the metric instance. 523 dtype: (Optional) data type of the metric result. 524 **kwargs: The keyword arguments that are passed on to `fn`. 525 """ 526 super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype) 527 self._fn = fn 528 self._fn_kwargs = kwargs 529 530 def update_state(self, y_true, y_pred, sample_weight=None): 531 """Accumulates metric statistics. 532 533 `y_true` and `y_pred` should have the same shape. 534 535 Args: 536 y_true: The ground truth values. 537 y_pred: The predicted values. 538 sample_weight: Optional weighting of each example. Defaults to 1. Can be 539 a `Tensor` whose rank is either 0, or the same rank as `y_true`, 540 and must be broadcastable to `y_true`. 541 542 Returns: 543 Update op. 544 """ 545 y_true = math_ops.cast(y_true, self._dtype) 546 y_pred = math_ops.cast(y_pred, self._dtype) 547 y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( 548 y_pred, y_true, sample_weight) 549 550 matches = self._fn(y_true, y_pred, **self._fn_kwargs) 551 return super(MeanMetricWrapper, self).update_state( 552 matches, sample_weight=sample_weight) 553 554 def get_config(self): 555 config = {} 556 for k, v in six.iteritems(self._fn_kwargs): 557 config[k] = K.eval(v) if is_tensor_or_variable(v) else v 558 base_config = super(MeanMetricWrapper, self).get_config() 559 return dict(list(base_config.items()) + list(config.items())) 560 561 562@keras_export('keras.metrics.Accuracy') 563class Accuracy(MeanMetricWrapper): 564 """Calculates how often predictions matches labels. 565 566 For example, if `y_true` is [1, 2, 3, 4] and `y_pred` is [0, 2, 3, 4] 567 then the accuracy is 3/4 or .75. If the weights were specified as 568 [1, 1, 0, 0] then the accuracy would be 1/2 or .5. 569 570 This metric creates two local variables, `total` and `count` that are used to 571 compute the frequency with which `y_pred` matches `y_true`. This frequency is 572 ultimately returned as `binary accuracy`: an idempotent operation that simply 573 divides `total` by `count`. 574 575 If `sample_weight` is `None`, weights default to 1. 576 Use `sample_weight` of 0 to mask values. 577 578 Usage: 579 580 ```python 581 m = tf.keras.metrics.Accuracy() 582 m.update_state([1, 2, 3, 4], [0, 2, 3, 4]) 583 print('Final result: ', m.result().numpy()) # Final result: 0.75 584 ``` 585 586 Usage with tf.keras API: 587 588 ```python 589 model = tf.keras.Model(inputs, outputs) 590 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Accuracy()]) 591 ``` 592 """ 593 594 def __init__(self, name='accuracy', dtype=None): 595 super(Accuracy, self).__init__(accuracy, name, dtype=dtype) 596 597 598@keras_export('keras.metrics.BinaryAccuracy') 599class BinaryAccuracy(MeanMetricWrapper): 600 """Calculates how often predictions matches labels. 601 602 For example, if `y_true` is [1, 1, 0, 0] and `y_pred` is [0.98, 1, 0, 0.6] 603 then the binary accuracy is 3/4 or .75. If the weights were specified as 604 [1, 0, 0, 1] then the binary accuracy would be 1/2 or .5. 605 606 This metric creates two local variables, `total` and `count` that are used to 607 compute the frequency with which `y_pred` matches `y_true`. This frequency is 608 ultimately returned as `binary accuracy`: an idempotent operation that simply 609 divides `total` by `count`. 610 611 If `sample_weight` is `None`, weights default to 1. 612 Use `sample_weight` of 0 to mask values. 613 614 Usage: 615 616 ```python 617 m = tf.keras.metrics.BinaryAccuracy() 618 m.update_state([1, 1, 0, 0], [0.98, 1, 0, 0.6]) 619 print('Final result: ', m.result().numpy()) # Final result: 0.75 620 ``` 621 622 Usage with tf.keras API: 623 624 ```python 625 model = tf.keras.Model(inputs, outputs) 626 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.BinaryAccuracy()]) 627 ``` 628 """ 629 630 def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): 631 """Creates a `BinaryAccuracy` instance. 632 633 Args: 634 name: (Optional) string name of the metric instance. 635 dtype: (Optional) data type of the metric result. 636 threshold: (Optional) Float representing the threshold for deciding 637 whether prediction values are 1 or 0. 638 """ 639 super(BinaryAccuracy, self).__init__( 640 binary_accuracy, name, dtype=dtype, threshold=threshold) 641 642 643@keras_export('keras.metrics.CategoricalAccuracy') 644class CategoricalAccuracy(MeanMetricWrapper): 645 """Calculates how often predictions matches labels. 646 647 For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is 648 [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5. 649 If the weights were specified as [0.7, 0.3] then the categorical accuracy 650 would be .3. You can provide logits of classes as `y_pred`, since argmax of 651 logits and probabilities are same. 652 653 This metric creates two local variables, `total` and `count` that are used to 654 compute the frequency with which `y_pred` matches `y_true`. This frequency is 655 ultimately returned as `categorical accuracy`: an idempotent operation that 656 simply divides `total` by `count`. 657 658 `y_pred` and `y_true` should be passed in as vectors of probabilities, rather 659 than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector. 660 661 If `sample_weight` is `None`, weights default to 1. 662 Use `sample_weight` of 0 to mask values. 663 664 Usage: 665 666 ```python 667 m = tf.keras.metrics.CategoricalAccuracy() 668 m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 669 print('Final result: ', m.result().numpy()) # Final result: 0.5 670 ``` 671 672 Usage with tf.keras API: 673 674 ```python 675 model = tf.keras.Model(inputs, outputs) 676 model.compile( 677 'sgd', 678 loss='mse', 679 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 680 ``` 681 """ 682 683 def __init__(self, name='categorical_accuracy', dtype=None): 684 """Creates a `CategoricalAccuracy` instance. 685 686 Args: 687 name: (Optional) string name of the metric instance. 688 dtype: (Optional) data type of the metric result. 689 """ 690 super(CategoricalAccuracy, self).__init__( 691 categorical_accuracy, name, dtype=dtype) 692 693 694@keras_export('keras.metrics.SparseCategoricalAccuracy') 695class SparseCategoricalAccuracy(MeanMetricWrapper): 696 """Calculates how often predictions matches integer labels. 697 698 For example, if `y_true` is [[2], [1]] and `y_pred` is 699 [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5. 700 If the weights were specified as [0.7, 0.3] then the categorical accuracy 701 would be .3. You can provide logits of classes as `y_pred`, since argmax of 702 logits and probabilities are same. 703 704 This metric creates two local variables, `total` and `count` that are used to 705 compute the frequency with which `y_pred` matches `y_true`. This frequency is 706 ultimately returned as `sparse categorical accuracy`: an idempotent operation 707 that simply divides `total` by `count`. 708 709 If `sample_weight` is `None`, weights default to 1. 710 Use `sample_weight` of 0 to mask values. 711 712 Usage: 713 714 ```python 715 m = tf.keras.metrics.SparseCategoricalAccuracy() 716 m.update_state([[2], [1]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 717 print('Final result: ', m.result().numpy()) # Final result: 0.5 718 ``` 719 720 Usage with tf.keras API: 721 722 ```python 723 model = tf.keras.Model(inputs, outputs) 724 model.compile( 725 'sgd', 726 loss='mse', 727 metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) 728 ``` 729 """ 730 731 def __init__(self, name='sparse_categorical_accuracy', dtype=None): 732 super(SparseCategoricalAccuracy, self).__init__( 733 sparse_categorical_accuracy, name, dtype=dtype) 734 735 736@keras_export('keras.metrics.TopKCategoricalAccuracy') 737class TopKCategoricalAccuracy(MeanMetricWrapper): 738 """Computes how often targets are in the top `K` predictions. 739 740 Usage: 741 742 ```python 743 m = tf.keras.metrics.TopKCategoricalAccuracy() 744 m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 745 print('Final result: ', m.result().numpy()) # Final result: 1.0 746 ``` 747 748 Usage with tf.keras API: 749 750 ```python 751 model = tf.keras.Model(inputs, outputs) 752 model.compile('sgd', metrics=[tf.keras.metrics.TopKCategoricalAccuracy()]) 753 ``` 754 """ 755 756 def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): 757 """Creates a `TopKCategoricalAccuracy` instance. 758 759 Args: 760 k: (Optional) Number of top elements to look at for computing accuracy. 761 Defaults to 5. 762 name: (Optional) string name of the metric instance. 763 dtype: (Optional) data type of the metric result. 764 """ 765 super(TopKCategoricalAccuracy, self).__init__( 766 top_k_categorical_accuracy, name, dtype=dtype, k=k) 767 768 769@keras_export('keras.metrics.SparseTopKCategoricalAccuracy') 770class SparseTopKCategoricalAccuracy(MeanMetricWrapper): 771 """Computes how often integer targets are in the top `K` predictions. 772 773 Usage: 774 775 ```python 776 m = tf.keras.metrics.SparseTopKCategoricalAccuracy() 777 m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 778 print('Final result: ', m.result().numpy()) # Final result: 1.0 779 ``` 780 781 Usage with tf.keras API: 782 783 ```python 784 model = tf.keras.Model(inputs, outputs) 785 model.compile( 786 'sgd', 787 metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()]) 788 ``` 789 """ 790 791 def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None): 792 """Creates a `SparseTopKCategoricalAccuracy` instance. 793 794 Args: 795 k: (Optional) Number of top elements to look at for computing accuracy. 796 Defaults to 5. 797 name: (Optional) string name of the metric instance. 798 dtype: (Optional) data type of the metric result. 799 """ 800 super(SparseTopKCategoricalAccuracy, self).__init__( 801 sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k) 802 803 804class _ConfusionMatrixConditionCount(Metric): 805 """Calculates the number of the given confusion matrix condition.""" 806 807 def __init__(self, 808 confusion_matrix_cond, 809 thresholds=None, 810 name=None, 811 dtype=None): 812 """Creates a `_ConfusionMatrixConditionCount` instance. 813 814 Args: 815 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. 816 thresholds: (Optional) Defaults to 0.5. A float value or a python 817 list/tuple of float threshold values in [0, 1]. A threshold is compared 818 with prediction values to determine the truth value of predictions 819 (i.e., above the threshold is `true`, below is `false`). One metric 820 value is generated for each threshold value. 821 name: (Optional) string name of the metric instance. 822 dtype: (Optional) data type of the metric result. 823 """ 824 super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype) 825 self._confusion_matrix_cond = confusion_matrix_cond 826 self.init_thresholds = thresholds 827 self.thresholds = metrics_utils.parse_init_thresholds( 828 thresholds, default_threshold=0.5) 829 self.accumulator = self.add_weight( 830 'accumulator', 831 shape=(len(self.thresholds),), 832 initializer=init_ops.zeros_initializer) 833 834 def update_state(self, y_true, y_pred, sample_weight=None): 835 """Accumulates the given confusion matrix condition statistics. 836 837 Args: 838 y_true: The ground truth values. 839 y_pred: The predicted values. 840 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 841 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 842 be broadcastable to `y_true`. 843 844 Returns: 845 Update op. 846 """ 847 return metrics_utils.update_confusion_matrix_variables( 848 {self._confusion_matrix_cond: self.accumulator}, 849 y_true, 850 y_pred, 851 thresholds=self.thresholds, 852 sample_weight=sample_weight) 853 854 def result(self): 855 if len(self.thresholds) == 1: 856 result = self.accumulator[0] 857 else: 858 result = self.accumulator 859 return ops.convert_to_tensor(result) 860 861 def reset_states(self): 862 num_thresholds = len(to_list(self.thresholds)) 863 K.batch_set_value( 864 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 865 866 def get_config(self): 867 config = {'thresholds': self.init_thresholds} 868 base_config = super(_ConfusionMatrixConditionCount, self).get_config() 869 return dict(list(base_config.items()) + list(config.items())) 870 871 872@keras_export('keras.metrics.FalsePositives') 873class FalsePositives(_ConfusionMatrixConditionCount): 874 """Calculates the number of false positives. 875 876 For example, if `y_true` is [0, 1, 0, 0] and `y_pred` is [0, 0, 1, 1] 877 then the false positives value is 2. If the weights were specified as 878 [0, 0, 1, 0] then the false positives value would be 1. 879 880 If `sample_weight` is given, calculates the sum of the weights of 881 false positives. This metric creates one local variable, `accumulator` 882 that is used to keep track of the number of false positives. 883 884 If `sample_weight` is `None`, weights default to 1. 885 Use `sample_weight` of 0 to mask values. 886 887 Usage: 888 889 ```python 890 m = tf.keras.metrics.FalsePositives() 891 m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) 892 print('Final result: ', m.result().numpy()) # Final result: 2 893 ``` 894 895 Usage with tf.keras API: 896 897 ```python 898 model = tf.keras.Model(inputs, outputs) 899 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.FalsePositives()]) 900 ``` 901 """ 902 903 def __init__(self, thresholds=None, name=None, dtype=None): 904 """Creates a `FalsePositives` instance. 905 906 Args: 907 thresholds: (Optional) Defaults to 0.5. A float value or a python 908 list/tuple of float threshold values in [0, 1]. A threshold is compared 909 with prediction values to determine the truth value of predictions 910 (i.e., above the threshold is `true`, below is `false`). One metric 911 value is generated for each threshold value. 912 name: (Optional) string name of the metric instance. 913 dtype: (Optional) data type of the metric result. 914 """ 915 super(FalsePositives, self).__init__( 916 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, 917 thresholds=thresholds, 918 name=name, 919 dtype=dtype) 920 921 922@keras_export('keras.metrics.FalseNegatives') 923class FalseNegatives(_ConfusionMatrixConditionCount): 924 """Calculates the number of false negatives. 925 926 For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [0, 1, 0, 0] 927 then the false negatives value is 2. If the weights were specified as 928 [0, 0, 1, 0] then the false negatives value would be 1. 929 930 If `sample_weight` is given, calculates the sum of the weights of 931 false negatives. This metric creates one local variable, `accumulator` 932 that is used to keep track of the number of false negatives. 933 934 If `sample_weight` is `None`, weights default to 1. 935 Use `sample_weight` of 0 to mask values. 936 937 Usage: 938 939 ```python 940 m = tf.keras.metrics.FalseNegatives() 941 m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) 942 print('Final result: ', m.result().numpy()) # Final result: 2 943 ``` 944 945 Usage with tf.keras API: 946 947 ```python 948 model = tf.keras.Model(inputs, outputs) 949 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.FalseNegatives()]) 950 ``` 951 """ 952 953 def __init__(self, thresholds=None, name=None, dtype=None): 954 """Creates a `FalseNegatives` instance. 955 956 Args: 957 thresholds: (Optional) Defaults to 0.5. A float value or a python 958 list/tuple of float threshold values in [0, 1]. A threshold is compared 959 with prediction values to determine the truth value of predictions 960 (i.e., above the threshold is `true`, below is `false`). One metric 961 value is generated for each threshold value. 962 name: (Optional) string name of the metric instance. 963 dtype: (Optional) data type of the metric result. 964 """ 965 super(FalseNegatives, self).__init__( 966 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, 967 thresholds=thresholds, 968 name=name, 969 dtype=dtype) 970 971 972@keras_export('keras.metrics.TrueNegatives') 973class TrueNegatives(_ConfusionMatrixConditionCount): 974 """Calculates the number of true negatives. 975 976 For example, if `y_true` is [0, 1, 0, 0] and `y_pred` is [1, 1, 0, 0] 977 then the true negatives value is 2. If the weights were specified as 978 [0, 0, 1, 0] then the true negatives value would be 1. 979 980 If `sample_weight` is given, calculates the sum of the weights of 981 true negatives. This metric creates one local variable, `accumulator` 982 that is used to keep track of the number of true negatives. 983 984 If `sample_weight` is `None`, weights default to 1. 985 Use `sample_weight` of 0 to mask values. 986 987 Usage: 988 989 ```python 990 m = tf.keras.metrics.TrueNegatives() 991 m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) 992 print('Final result: ', m.result().numpy()) # Final result: 2 993 ``` 994 995 Usage with tf.keras API: 996 997 ```python 998 model = tf.keras.Model(inputs, outputs) 999 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.TrueNegatives()]) 1000 ``` 1001 """ 1002 1003 def __init__(self, thresholds=None, name=None, dtype=None): 1004 """Creates a `TrueNegatives` instance. 1005 1006 Args: 1007 thresholds: (Optional) Defaults to 0.5. A float value or a python 1008 list/tuple of float threshold values in [0, 1]. A threshold is compared 1009 with prediction values to determine the truth value of predictions 1010 (i.e., above the threshold is `true`, below is `false`). One metric 1011 value is generated for each threshold value. 1012 name: (Optional) string name of the metric instance. 1013 dtype: (Optional) data type of the metric result. 1014 """ 1015 super(TrueNegatives, self).__init__( 1016 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, 1017 thresholds=thresholds, 1018 name=name, 1019 dtype=dtype) 1020 1021 1022@keras_export('keras.metrics.TruePositives') 1023class TruePositives(_ConfusionMatrixConditionCount): 1024 """Calculates the number of true positives. 1025 1026 For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1] 1027 then the true positives value is 2. If the weights were specified as 1028 [0, 0, 1, 0] then the true positives value would be 1. 1029 1030 If `sample_weight` is given, calculates the sum of the weights of 1031 true positives. This metric creates one local variable, `true_positives` 1032 that is used to keep track of the number of true positives. 1033 1034 If `sample_weight` is `None`, weights default to 1. 1035 Use `sample_weight` of 0 to mask values. 1036 1037 Usage: 1038 1039 ```python 1040 m = tf.keras.metrics.TruePositives() 1041 m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1042 print('Final result: ', m.result().numpy()) # Final result: 2 1043 ``` 1044 1045 Usage with tf.keras API: 1046 1047 ```python 1048 model = tf.keras.Model(inputs, outputs) 1049 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.TruePositives()]) 1050 ``` 1051 """ 1052 1053 def __init__(self, thresholds=None, name=None, dtype=None): 1054 """Creates a `TruePositives` instance. 1055 1056 Args: 1057 thresholds: (Optional) Defaults to 0.5. A float value or a python 1058 list/tuple of float threshold values in [0, 1]. A threshold is compared 1059 with prediction values to determine the truth value of predictions 1060 (i.e., above the threshold is `true`, below is `false`). One metric 1061 value is generated for each threshold value. 1062 name: (Optional) string name of the metric instance. 1063 dtype: (Optional) data type of the metric result. 1064 """ 1065 super(TruePositives, self).__init__( 1066 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, 1067 thresholds=thresholds, 1068 name=name, 1069 dtype=dtype) 1070 1071 1072@keras_export('keras.metrics.Precision') 1073class Precision(Metric): 1074 """Computes the precision of the predictions with respect to the labels. 1075 1076 For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1] 1077 then the precision value is 2/(2+1) ie. 0.66. If the weights were specified as 1078 [0, 0, 1, 0] then the precision value would be 1. 1079 1080 The metric creates two local variables, `true_positives` and `false_positives` 1081 that are used to compute the precision. This value is ultimately returned as 1082 `precision`, an idempotent operation that simply divides `true_positives` 1083 by the sum of `true_positives` and `false_positives`. 1084 1085 If `sample_weight` is `None`, weights default to 1. 1086 Use `sample_weight` of 0 to mask values. 1087 1088 If `top_k` is set, we'll calculate precision as how often on average a class 1089 among the top-k classes with the highest predicted values of a batch entry is 1090 correct and can be found in the label for that entry. 1091 1092 If `class_id` is specified, we calculate precision by considering only the 1093 entries in the batch for which `class_id` is above the threshold and/or in the 1094 top-k highest predictions, and computing the fraction of them for which 1095 `class_id` is indeed a correct label. 1096 1097 Usage: 1098 1099 ```python 1100 m = tf.keras.metrics.Precision() 1101 m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1102 print('Final result: ', m.result().numpy()) # Final result: 0.66 1103 ``` 1104 1105 Usage with tf.keras API: 1106 1107 ```python 1108 model = tf.keras.Model(inputs, outputs) 1109 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Precision()]) 1110 ``` 1111 """ 1112 1113 def __init__(self, 1114 thresholds=None, 1115 top_k=None, 1116 class_id=None, 1117 name=None, 1118 dtype=None): 1119 """Creates a `Precision` instance. 1120 1121 Args: 1122 thresholds: (Optional) A float value or a python list/tuple of float 1123 threshold values in [0, 1]. A threshold is compared with prediction 1124 values to determine the truth value of predictions (i.e., above the 1125 threshold is `true`, below is `false`). One metric value is generated 1126 for each threshold value. If neither thresholds nor top_k are set, the 1127 default is to calculate precision with `thresholds=0.5`. 1128 top_k: (Optional) Unset by default. An int value specifying the top-k 1129 predictions to consider when calculating precision. 1130 class_id: (Optional) Integer class ID for which we want binary metrics. 1131 This must be in the half-open interval `[0, num_classes)`, where 1132 `num_classes` is the last dimension of predictions. 1133 name: (Optional) string name of the metric instance. 1134 dtype: (Optional) data type of the metric result. 1135 """ 1136 super(Precision, self).__init__(name=name, dtype=dtype) 1137 self.init_thresholds = thresholds 1138 self.top_k = top_k 1139 self.class_id = class_id 1140 1141 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 1142 self.thresholds = metrics_utils.parse_init_thresholds( 1143 thresholds, default_threshold=default_threshold) 1144 self.true_positives = self.add_weight( 1145 'true_positives', 1146 shape=(len(self.thresholds),), 1147 initializer=init_ops.zeros_initializer) 1148 self.false_positives = self.add_weight( 1149 'false_positives', 1150 shape=(len(self.thresholds),), 1151 initializer=init_ops.zeros_initializer) 1152 1153 def update_state(self, y_true, y_pred, sample_weight=None): 1154 """Accumulates true positive and false positive statistics. 1155 1156 Args: 1157 y_true: The ground truth values, with the same dimensions as `y_pred`. 1158 Will be cast to `bool`. 1159 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 1160 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1161 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1162 be broadcastable to `y_true`. 1163 1164 Returns: 1165 Update op. 1166 """ 1167 return metrics_utils.update_confusion_matrix_variables( 1168 { 1169 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1170 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives 1171 }, 1172 y_true, 1173 y_pred, 1174 thresholds=self.thresholds, 1175 top_k=self.top_k, 1176 class_id=self.class_id, 1177 sample_weight=sample_weight) 1178 1179 def result(self): 1180 result = math_ops.div_no_nan(self.true_positives, 1181 self.true_positives + self.false_positives) 1182 return result[0] if len(self.thresholds) == 1 else result 1183 1184 def reset_states(self): 1185 num_thresholds = len(to_list(self.thresholds)) 1186 K.batch_set_value( 1187 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1188 1189 def get_config(self): 1190 config = { 1191 'thresholds': self.init_thresholds, 1192 'top_k': self.top_k, 1193 'class_id': self.class_id 1194 } 1195 base_config = super(Precision, self).get_config() 1196 return dict(list(base_config.items()) + list(config.items())) 1197 1198 1199@keras_export('keras.metrics.Recall') 1200class Recall(Metric): 1201 """Computes the recall of the predictions with respect to the labels. 1202 1203 For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1] 1204 then the recall value is 2/(2+1) ie. 0.66. If the weights were specified as 1205 [0, 0, 1, 0] then the recall value would be 1. 1206 1207 This metric creates two local variables, `true_positives` and 1208 `false_negatives`, that are used to compute the recall. This value is 1209 ultimately returned as `recall`, an idempotent operation that simply divides 1210 `true_positives` by the sum of `true_positives` and `false_negatives`. 1211 1212 If `sample_weight` is `None`, weights default to 1. 1213 Use `sample_weight` of 0 to mask values. 1214 1215 If `top_k` is set, recall will be computed as how often on average a class 1216 among the labels of a batch entry is in the top-k predictions. 1217 1218 If `class_id` is specified, we calculate recall by considering only the 1219 entries in the batch for which `class_id` is in the label, and computing the 1220 fraction of them for which `class_id` is above the threshold and/or in the 1221 top-k predictions. 1222 1223 Usage: 1224 1225 ```python 1226 m = tf.keras.metrics.Recall() 1227 m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1228 print('Final result: ', m.result().numpy()) # Final result: 0.66 1229 ``` 1230 1231 Usage with tf.keras API: 1232 1233 ```python 1234 model = tf.keras.Model(inputs, outputs) 1235 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Recall()]) 1236 ``` 1237 """ 1238 1239 def __init__(self, 1240 thresholds=None, 1241 top_k=None, 1242 class_id=None, 1243 name=None, 1244 dtype=None): 1245 """Creates a `Recall` instance. 1246 1247 Args: 1248 thresholds: (Optional) A float value or a python list/tuple of float 1249 threshold values in [0, 1]. A threshold is compared with prediction 1250 values to determine the truth value of predictions (i.e., above the 1251 threshold is `true`, below is `false`). One metric value is generated 1252 for each threshold value. If neither thresholds nor top_k are set, the 1253 default is to calculate recall with `thresholds=0.5`. 1254 top_k: (Optional) Unset by default. An int value specifying the top-k 1255 predictions to consider when calculating recall. 1256 class_id: (Optional) Integer class ID for which we want binary metrics. 1257 This must be in the half-open interval `[0, num_classes)`, where 1258 `num_classes` is the last dimension of predictions. 1259 name: (Optional) string name of the metric instance. 1260 dtype: (Optional) data type of the metric result. 1261 """ 1262 super(Recall, self).__init__(name=name, dtype=dtype) 1263 self.init_thresholds = thresholds 1264 self.top_k = top_k 1265 self.class_id = class_id 1266 1267 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 1268 self.thresholds = metrics_utils.parse_init_thresholds( 1269 thresholds, default_threshold=default_threshold) 1270 self.true_positives = self.add_weight( 1271 'true_positives', 1272 shape=(len(self.thresholds),), 1273 initializer=init_ops.zeros_initializer) 1274 self.false_negatives = self.add_weight( 1275 'false_negatives', 1276 shape=(len(self.thresholds),), 1277 initializer=init_ops.zeros_initializer) 1278 1279 def update_state(self, y_true, y_pred, sample_weight=None): 1280 """Accumulates true positive and false negative statistics. 1281 1282 Args: 1283 y_true: The ground truth values, with the same dimensions as `y_pred`. 1284 Will be cast to `bool`. 1285 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 1286 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1287 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1288 be broadcastable to `y_true`. 1289 1290 Returns: 1291 Update op. 1292 """ 1293 return metrics_utils.update_confusion_matrix_variables( 1294 { 1295 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1296 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives 1297 }, 1298 y_true, 1299 y_pred, 1300 thresholds=self.thresholds, 1301 top_k=self.top_k, 1302 class_id=self.class_id, 1303 sample_weight=sample_weight) 1304 1305 def result(self): 1306 result = math_ops.div_no_nan(self.true_positives, 1307 self.true_positives + self.false_negatives) 1308 return result[0] if len(self.thresholds) == 1 else result 1309 1310 def reset_states(self): 1311 num_thresholds = len(to_list(self.thresholds)) 1312 K.batch_set_value( 1313 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1314 1315 def get_config(self): 1316 config = { 1317 'thresholds': self.init_thresholds, 1318 'top_k': self.top_k, 1319 'class_id': self.class_id 1320 } 1321 base_config = super(Recall, self).get_config() 1322 return dict(list(base_config.items()) + list(config.items())) 1323 1324 1325@six.add_metaclass(abc.ABCMeta) 1326class SensitivitySpecificityBase(Metric): 1327 """Abstract base class for computing sensitivity and specificity. 1328 1329 For additional information about specificity and sensitivity, see the 1330 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1331 """ 1332 1333 def __init__(self, value, num_thresholds=200, name=None, dtype=None): 1334 super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) 1335 if num_thresholds <= 0: 1336 raise ValueError('`num_thresholds` must be > 0.') 1337 self.value = value 1338 self.true_positives = self.add_weight( 1339 'true_positives', 1340 shape=(num_thresholds,), 1341 initializer=init_ops.zeros_initializer) 1342 self.true_negatives = self.add_weight( 1343 'true_negatives', 1344 shape=(num_thresholds,), 1345 initializer=init_ops.zeros_initializer) 1346 self.false_positives = self.add_weight( 1347 'false_positives', 1348 shape=(num_thresholds,), 1349 initializer=init_ops.zeros_initializer) 1350 self.false_negatives = self.add_weight( 1351 'false_negatives', 1352 shape=(num_thresholds,), 1353 initializer=init_ops.zeros_initializer) 1354 1355 # Compute `num_thresholds` thresholds in [0, 1] 1356 if num_thresholds == 1: 1357 self.thresholds = [0.5] 1358 else: 1359 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1360 for i in range(num_thresholds - 2)] 1361 self.thresholds = [0.0] + thresholds + [1.0] 1362 1363 def update_state(self, y_true, y_pred, sample_weight=None): 1364 """Accumulates confusion matrix statistics. 1365 1366 Args: 1367 y_true: The ground truth values. 1368 y_pred: The predicted values. 1369 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1370 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1371 be broadcastable to `y_true`. 1372 1373 Returns: 1374 Update op. 1375 """ 1376 return metrics_utils.update_confusion_matrix_variables( 1377 { 1378 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1379 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, 1380 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, 1381 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, 1382 }, 1383 y_true, 1384 y_pred, 1385 thresholds=self.thresholds, 1386 sample_weight=sample_weight) 1387 1388 def reset_states(self): 1389 num_thresholds = len(self.thresholds) 1390 K.batch_set_value( 1391 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1392 1393 1394@keras_export('keras.metrics.SensitivityAtSpecificity') 1395class SensitivityAtSpecificity(SensitivitySpecificityBase): 1396 """Computes the sensitivity at a given specificity. 1397 1398 `Sensitivity` measures the proportion of actual positives that are correctly 1399 identified as such (tp / (tp + fn)). 1400 `Specificity` measures the proportion of actual negatives that are correctly 1401 identified as such (tn / (tn + fp)). 1402 1403 This metric creates four local variables, `true_positives`, `true_negatives`, 1404 `false_positives` and `false_negatives` that are used to compute the 1405 sensitivity at the given specificity. The threshold for the given specificity 1406 value is computed and used to evaluate the corresponding sensitivity. 1407 1408 If `sample_weight` is `None`, weights default to 1. 1409 Use `sample_weight` of 0 to mask values. 1410 1411 For additional information about specificity and sensitivity, see the 1412 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1413 1414 Usage: 1415 1416 ```python 1417 m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1) 1418 m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 1419 print('Final result: ', m.result().numpy()) # Final result: 0.5 1420 ``` 1421 1422 Usage with tf.keras API: 1423 1424 ```python 1425 model = tf.keras.Model(inputs, outputs) 1426 model.compile( 1427 'sgd', 1428 loss='mse', 1429 metrics=[tf.keras.metrics.SensitivityAtSpecificity()]) 1430 ``` 1431 """ 1432 1433 def __init__(self, specificity, num_thresholds=200, name=None, dtype=None): 1434 """Creates a `SensitivityAtSpecificity` instance. 1435 1436 Args: 1437 specificity: A scalar value in range `[0, 1]`. 1438 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1439 use for matching the given specificity. 1440 name: (Optional) string name of the metric instance. 1441 dtype: (Optional) data type of the metric result. 1442 """ 1443 if specificity < 0 or specificity > 1: 1444 raise ValueError('`specificity` must be in the range [0, 1].') 1445 self.specificity = specificity 1446 self.num_thresholds = num_thresholds 1447 super(SensitivityAtSpecificity, self).__init__( 1448 specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) 1449 1450 def result(self): 1451 # Calculate specificities at all the thresholds. 1452 specificities = math_ops.div_no_nan( 1453 self.true_negatives, self.true_negatives + self.false_positives) 1454 1455 # Find the index of the threshold where the specificity is closest to the 1456 # given specificity. 1457 min_index = math_ops.argmin( 1458 math_ops.abs(specificities - self.value), axis=0) 1459 min_index = math_ops.cast(min_index, dtypes.int32) 1460 1461 # Compute sensitivity at that index. 1462 return math_ops.div_no_nan( 1463 self.true_positives[min_index], 1464 self.true_positives[min_index] + self.false_negatives[min_index]) 1465 1466 def get_config(self): 1467 config = { 1468 'num_thresholds': self.num_thresholds, 1469 'specificity': self.specificity 1470 } 1471 base_config = super(SensitivityAtSpecificity, self).get_config() 1472 return dict(list(base_config.items()) + list(config.items())) 1473 1474 1475@keras_export('keras.metrics.SpecificityAtSensitivity') 1476class SpecificityAtSensitivity(SensitivitySpecificityBase): 1477 """Computes the specificity at a given sensitivity. 1478 1479 `Sensitivity` measures the proportion of actual positives that are correctly 1480 identified as such (tp / (tp + fn)). 1481 `Specificity` measures the proportion of actual negatives that are correctly 1482 identified as such (tn / (tn + fp)). 1483 1484 This metric creates four local variables, `true_positives`, `true_negatives`, 1485 `false_positives` and `false_negatives` that are used to compute the 1486 specificity at the given sensitivity. The threshold for the given sensitivity 1487 value is computed and used to evaluate the corresponding specificity. 1488 1489 If `sample_weight` is `None`, weights default to 1. 1490 Use `sample_weight` of 0 to mask values. 1491 1492 For additional information about specificity and sensitivity, see the 1493 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1494 1495 Usage: 1496 1497 ```python 1498 m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1) 1499 m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 1500 print('Final result: ', m.result().numpy()) # Final result: 1.0 1501 ``` 1502 1503 Usage with tf.keras API: 1504 1505 ```python 1506 model = tf.keras.Model(inputs, outputs) 1507 model.compile( 1508 'sgd', 1509 loss='mse', 1510 metrics=[tf.keras.metrics.SpecificityAtSensitivity()]) 1511 ``` 1512 """ 1513 1514 def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None): 1515 """Creates a `SpecificityAtSensitivity` instance. 1516 1517 Args: 1518 sensitivity: A scalar value in range `[0, 1]`. 1519 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1520 use for matching the given specificity. 1521 name: (Optional) string name of the metric instance. 1522 dtype: (Optional) data type of the metric result. 1523 """ 1524 if sensitivity < 0 or sensitivity > 1: 1525 raise ValueError('`sensitivity` must be in the range [0, 1].') 1526 self.sensitivity = sensitivity 1527 self.num_thresholds = num_thresholds 1528 super(SpecificityAtSensitivity, self).__init__( 1529 sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype) 1530 1531 def result(self): 1532 # Calculate sensitivities at all the thresholds. 1533 sensitivities = math_ops.div_no_nan( 1534 self.true_positives, self.true_positives + self.false_negatives) 1535 1536 # Find the index of the threshold where the sensitivity is closest to the 1537 # given specificity. 1538 min_index = math_ops.argmin( 1539 math_ops.abs(sensitivities - self.value), axis=0) 1540 min_index = math_ops.cast(min_index, dtypes.int32) 1541 1542 # Compute specificity at that index. 1543 return math_ops.div_no_nan( 1544 self.true_negatives[min_index], 1545 self.true_negatives[min_index] + self.false_positives[min_index]) 1546 1547 def get_config(self): 1548 config = { 1549 'num_thresholds': self.num_thresholds, 1550 'sensitivity': self.sensitivity 1551 } 1552 base_config = super(SpecificityAtSensitivity, self).get_config() 1553 return dict(list(base_config.items()) + list(config.items())) 1554 1555 1556@keras_export('keras.metrics.AUC') 1557class AUC(Metric): 1558 """Computes the approximate AUC (Area under the curve) via a Riemann sum. 1559 1560 This metric creates four local variables, `true_positives`, `true_negatives`, 1561 `false_positives` and `false_negatives` that are used to compute the AUC. 1562 To discretize the AUC curve, a linearly spaced set of thresholds is used to 1563 compute pairs of recall and precision values. The area under the ROC-curve is 1564 therefore computed using the height of the recall values by the false positive 1565 rate, while the area under the PR-curve is the computed using the height of 1566 the precision values by the recall. 1567 1568 This value is ultimately returned as `auc`, an idempotent operation that 1569 computes the area under a discretized curve of precision versus recall values 1570 (computed using the aforementioned variables). The `num_thresholds` variable 1571 controls the degree of discretization with larger numbers of thresholds more 1572 closely approximating the true AUC. The quality of the approximation may vary 1573 dramatically depending on `num_thresholds`. 1574 1575 For best results, `predictions` should be distributed approximately uniformly 1576 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 1577 approximation may be poor if this is not the case. Setting `summation_method` 1578 to 'minoring' or 'majoring' can help quantify the error in the approximation 1579 by providing lower or upper bound estimate of the AUC. 1580 1581 If `sample_weight` is `None`, weights default to 1. 1582 Use `sample_weight` of 0 to mask values. 1583 1584 Usage: 1585 1586 ```python 1587 m = tf.keras.metrics.AUC(num_thresholds=3) 1588 m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 1589 1590 # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] 1591 # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] 1592 # recall = [1, 0.5, 0], fp_rate = [1, 0, 0] 1593 # auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75 1594 1595 print('Final result: ', m.result().numpy()) # Final result: 0.75 1596 ``` 1597 1598 Usage with tf.keras API: 1599 1600 ```python 1601 model = tf.keras.Model(inputs, outputs) 1602 model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.AUC()]) 1603 ``` 1604 """ 1605 1606 def __init__(self, 1607 num_thresholds=200, 1608 curve='ROC', 1609 summation_method='interpolation', 1610 name=None, 1611 dtype=None): 1612 """Creates an `AUC` instance. 1613 1614 Args: 1615 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1616 use when discretizing the roc curve. Values must be > 1. 1617 curve: (Optional) Specifies the name of the curve to be computed, 'ROC' 1618 [default] or 'PR' for the Precision-Recall-curve. 1619 summation_method: (Optional) Specifies the Riemann summation method used 1620 (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default], 1621 applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates 1622 (true/false) positives but not the ratio that is precision (see Davis 1623 & Goadrich 2006 for details); 'minoring' that applies left summation 1624 for increasing intervals and right summation for decreasing intervals; 1625 'majoring' that does the opposite. 1626 name: (Optional) string name of the metric instance. 1627 dtype: (Optional) data type of the metric result. 1628 """ 1629 # Validate configurations. 1630 if num_thresholds <= 1: 1631 raise ValueError('`num_thresholds` must be > 1.') 1632 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( 1633 metrics_utils.AUCCurve): 1634 raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( 1635 curve, list(metrics_utils.AUCCurve))) 1636 if isinstance( 1637 summation_method, 1638 metrics_utils.AUCSummationMethod) and summation_method not in list( 1639 metrics_utils.AUCSummationMethod): 1640 raise ValueError( 1641 'Invalid summation method: "{}". Valid options are: "{}"'.format( 1642 summation_method, list(metrics_utils.AUCSummationMethod))) 1643 1644 # Update properties. 1645 self.num_thresholds = num_thresholds 1646 if isinstance(curve, metrics_utils.AUCCurve): 1647 self.curve = curve 1648 else: 1649 self.curve = metrics_utils.AUCCurve.from_str(curve) 1650 if isinstance(summation_method, metrics_utils.AUCSummationMethod): 1651 self.summation_method = summation_method 1652 else: 1653 self.summation_method = metrics_utils.AUCSummationMethod.from_str( 1654 summation_method) 1655 super(AUC, self).__init__(name=name, dtype=dtype) 1656 1657 # Create metric variables 1658 self.true_positives = self.add_weight( 1659 'true_positives', 1660 shape=(num_thresholds,), 1661 initializer=init_ops.zeros_initializer) 1662 self.true_negatives = self.add_weight( 1663 'true_negatives', 1664 shape=(num_thresholds,), 1665 initializer=init_ops.zeros_initializer) 1666 self.false_positives = self.add_weight( 1667 'false_positives', 1668 shape=(num_thresholds,), 1669 initializer=init_ops.zeros_initializer) 1670 self.false_negatives = self.add_weight( 1671 'false_negatives', 1672 shape=(num_thresholds,), 1673 initializer=init_ops.zeros_initializer) 1674 1675 # Compute `num_thresholds` thresholds in [0, 1] 1676 thresholds = [ 1677 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 1678 ] 1679 self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()] 1680 # epsilon - to account for floating point imprecisions. 1681 1682 def update_state(self, y_true, y_pred, sample_weight=None): 1683 """Accumulates confusion matrix statistics. 1684 1685 Args: 1686 y_true: The ground truth values. 1687 y_pred: The predicted values. 1688 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1689 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1690 be broadcastable to `y_true`. 1691 1692 Returns: 1693 Update op. 1694 """ 1695 return metrics_utils.update_confusion_matrix_variables({ 1696 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1697 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, 1698 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, 1699 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, 1700 }, y_true, y_pred, self.thresholds, sample_weight=sample_weight) 1701 1702 def interpolate_pr_auc(self): 1703 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. 1704 1705 https://www.biostat.wisc.edu/~page/rocpr.pdf 1706 1707 Note here we derive & use a closed formula not present in the paper 1708 as follows: 1709 1710 Precision = TP / (TP + FP) = TP / P 1711 1712 Modeling all of TP (true positive), FP (false positive) and their sum 1713 P = TP + FP (predicted positive) as varying linearly within each interval 1714 [A, B] between successive thresholds, we get 1715 1716 Precision slope = dTP / dP 1717 = (TP_B - TP_A) / (P_B - P_A) 1718 = (TP - TP_A) / (P - P_A) 1719 Precision = (TP_A + slope * (P - P_A)) / P 1720 1721 The area within the interval is (slope / total_pos_weight) times 1722 1723 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 1724 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 1725 1726 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 1727 1728 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 1729 1730 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 1731 1732 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 1733 1734 where dTP == TP_B - TP_A. 1735 1736 Note that when P_A == 0 the above calculation simplifies into 1737 1738 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 1739 1740 which is really equivalent to imputing constant precision throughout the 1741 first bucket having >0 true positives. 1742 1743 Returns: 1744 pr_auc: an approximation of the area under the P-R curve. 1745 """ 1746 dtp = self.true_positives[:self.num_thresholds - 1747 1] - self.true_positives[1:] 1748 p = self.true_positives + self.false_positives 1749 dp = p[:self.num_thresholds - 1] - p[1:] 1750 1751 prec_slope = math_ops.div_no_nan( 1752 dtp, math_ops.maximum(dp, 0), name='prec_slope') 1753 intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:]) 1754 1755 safe_p_ratio = array_ops.where( 1756 math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0), 1757 math_ops.div_no_nan( 1758 p[:self.num_thresholds - 1], 1759 math_ops.maximum(p[1:], 0), 1760 name='recall_relative_ratio'), 1761 array_ops.ones_like(p[1:])) 1762 1763 return math_ops.reduce_sum( 1764 math_ops.div_no_nan( 1765 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 1766 math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 1767 0), 1768 name='pr_auc_increment'), 1769 name='interpolate_pr_auc') 1770 1771 def result(self): 1772 if (self.curve == metrics_utils.AUCCurve.PR and 1773 self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION 1774 ): 1775 # This use case is different and is handled separately. 1776 return self.interpolate_pr_auc() 1777 1778 # Set `x` and `y` values for the curves based on `curve` config. 1779 recall = math_ops.div_no_nan(self.true_positives, 1780 self.true_positives + self.false_negatives) 1781 if self.curve == metrics_utils.AUCCurve.ROC: 1782 fp_rate = math_ops.div_no_nan(self.false_positives, 1783 self.false_positives + self.true_negatives) 1784 x = fp_rate 1785 y = recall 1786 else: # curve == 'PR'. 1787 precision = math_ops.div_no_nan( 1788 self.true_positives, self.true_positives + self.false_positives) 1789 x = recall 1790 y = precision 1791 1792 # Find the rectangle heights based on `summation_method`. 1793 if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION: 1794 # Note: the case ('PR', 'interpolation') has been handled above. 1795 heights = (y[:self.num_thresholds - 1] + y[1:]) / 2. 1796 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: 1797 heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:]) 1798 else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: 1799 heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:]) 1800 1801 # Sum up the areas of all the rectangles. 1802 return math_ops.reduce_sum( 1803 math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights), 1804 name=self.name) 1805 1806 def reset_states(self): 1807 num_thresholds = len(self.thresholds) 1808 K.batch_set_value( 1809 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1810 1811 def get_config(self): 1812 config = { 1813 'num_thresholds': self.num_thresholds, 1814 'curve': self.curve.value, 1815 'summation_method': self.summation_method.value, 1816 } 1817 base_config = super(AUC, self).get_config() 1818 return dict(list(base_config.items()) + list(config.items())) 1819 1820 1821@keras_export('keras.metrics.CosineSimilarity') 1822class CosineSimilarity(MeanMetricWrapper): 1823 """Computes the cosine similarity between the labels and predictions. 1824 1825 cosine similarity = (a . b) / ||a|| ||b|| 1826 (https://en.wikipedia.org/wiki/Cosine_similarity) 1827 1828 For example, if `y_true` is [0, 1, 1], and `y_pred` is [1, 0, 1], the cosine 1829 similarity is 0.5. 1830 1831 This metric keeps the average cosine similarity between `predictions` and 1832 `labels` over a stream of data. 1833 1834 Usage: 1835 ```python 1836 m = tf.keras.metrics.CosineSimilarity(axis=1) 1837 m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) 1838 # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] 1839 # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] 1840 # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] 1841 # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) 1842 = ((0. + 0.) + (0.5 + 0.5)) / 2 1843 1844 print('Final result: ', m.result().numpy()) # Final result: 0.5 1845 ``` 1846 1847 Usage with tf.keras API: 1848 1849 ```python 1850 model = tf.keras.Model(inputs, outputs) 1851 model.compile( 1852 'sgd', 1853 loss='mse', 1854 metrics=[tf.keras.metrics.CosineSimilarity(axis=1)]) 1855 ``` 1856 """ 1857 1858 def __init__(self, name='cosine_similarity', dtype=None, axis=-1): 1859 """Creates a `CosineSimilarity` instance. 1860 1861 Args: 1862 name: (Optional) string name of the metric instance. 1863 dtype: (Optional) data type of the metric result. 1864 axis: (Optional) Defaults to -1. The dimension along which the cosine 1865 similarity is computed. 1866 """ 1867 super(CosineSimilarity, self).__init__( 1868 cosine_similarity, name, dtype=dtype, axis=axis) 1869 1870 1871@keras_export('keras.metrics.MeanAbsoluteError') 1872class MeanAbsoluteError(MeanMetricWrapper): 1873 """Computes the mean absolute error between the labels and predictions. 1874 1875 For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.] 1876 the mean absolute error is 3/4 (0.75). 1877 1878 Usage: 1879 ```python 1880 m = tf.keras.metrics.MeanAbsoluteError() 1881 m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.]) 1882 print('Final result: ', m.result().numpy()) # Final result: 0.75 1883 ``` 1884 1885 Usage with tf.keras API: 1886 1887 ```python 1888 model = tf.keras.Model(inputs, outputs) 1889 model.compile('sgd', metrics=[tf.keras.metrics.MeanAbsoluteError()]) 1890 ``` 1891 """ 1892 1893 def __init__(self, name='mean_absolute_error', dtype=None): 1894 super(MeanAbsoluteError, self).__init__( 1895 mean_absolute_error, name, dtype=dtype) 1896 1897 1898@keras_export('keras.metrics.MeanAbsolutePercentageError') 1899class MeanAbsolutePercentageError(MeanMetricWrapper): 1900 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 1901 1902 For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.] 1903 the mean absolute percentage error is 5e+08. 1904 1905 Usage: 1906 1907 ```python 1908 m = tf.keras.metrics.MeanAbsolutePercentageError() 1909 m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.]) 1910 print('Final result: ', m.result().numpy()) # Final result: 5e+08 1911 ``` 1912 1913 Usage with tf.keras API: 1914 1915 ```python 1916 model = tf.keras.Model(inputs, outputs) 1917 model.compile('sgd', metrics=[tf.keras.metrics.MeanAbsolutePercentageError()]) 1918 ``` 1919 """ 1920 1921 def __init__(self, name='mean_absolute_percentage_error', dtype=None): 1922 super(MeanAbsolutePercentageError, self).__init__( 1923 mean_absolute_percentage_error, name, dtype=dtype) 1924 1925 1926@keras_export('keras.metrics.MeanSquaredError') 1927class MeanSquaredError(MeanMetricWrapper): 1928 """Computes the mean squared error between `y_true` and `y_pred`. 1929 1930 For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.] 1931 the mean squared error is 3/4 (0.75). 1932 1933 Usage: 1934 1935 ```python 1936 m = tf.keras.metrics.MeanSquaredError() 1937 m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.]) 1938 print('Final result: ', m.result().numpy()) # Final result: 0.75 1939 ``` 1940 1941 Usage with tf.keras API: 1942 1943 ```python 1944 model = tf.keras.Model(inputs, outputs) 1945 model.compile('sgd', metrics=[tf.keras.metrics.MeanSquaredError()]) 1946 ``` 1947 """ 1948 1949 def __init__(self, name='mean_squared_error', dtype=None): 1950 super(MeanSquaredError, self).__init__( 1951 mean_squared_error, name, dtype=dtype) 1952 1953 1954@keras_export('keras.metrics.MeanSquaredLogarithmicError') 1955class MeanSquaredLogarithmicError(MeanMetricWrapper): 1956 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 1957 1958 For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.] 1959 the mean squared logarithmic error is 0.36034. 1960 1961 Usage: 1962 1963 ```python 1964 m = tf.keras.metrics.MeanSquaredLogarithmicError() 1965 m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.]) 1966 print('Final result: ', m.result().numpy()) # Final result: 0.36034 1967 ``` 1968 1969 Usage with tf.keras API: 1970 1971 ```python 1972 model = tf.keras.Model(inputs, outputs) 1973 model.compile('sgd', metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()]) 1974 ``` 1975 """ 1976 1977 def __init__(self, name='mean_squared_logarithmic_error', dtype=None): 1978 super(MeanSquaredLogarithmicError, self).__init__( 1979 mean_squared_logarithmic_error, name, dtype=dtype) 1980 1981 1982@keras_export('keras.metrics.Hinge') 1983class Hinge(MeanMetricWrapper): 1984 """Computes the hinge metric between `y_true` and `y_pred`. 1985 1986 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 1987 provided we will convert them to -1 or 1. 1988 1989 For example, if `y_true` is [-1., 1., 1.], and `y_pred` is [0.6, -0.7, -0.5] 1990 the hinge metric value is 1.6. 1991 1992 Usage: 1993 1994 ```python 1995 m = tf.keras.metrics.Hinge() 1996 m.update_state([-1., 1., 1.], [0.6, -0.7, -0.5]) 1997 1998 # result = max(0, 1-y_true * y_pred) = [1.6 + 1.7 + 1.5] / 3 1999 2000 print('Final result: ', m.result().numpy()) # Final result: 1.6 2001 ``` 2002 2003 Usage with tf.keras API: 2004 2005 ```python 2006 model = tf.keras.Model(inputs, outputs) 2007 model.compile('sgd', metrics=[tf.keras.metrics.Hinge()]) 2008 ``` 2009 """ 2010 2011 def __init__(self, name='hinge', dtype=None): 2012 super(Hinge, self).__init__(hinge, name, dtype=dtype) 2013 2014 2015@keras_export('keras.metrics.SquaredHinge') 2016class SquaredHinge(MeanMetricWrapper): 2017 """Computes the squared hinge metric between `y_true` and `y_pred`. 2018 2019 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 2020 provided we will convert them to -1 or 1. 2021 2022 For example, if `y_true` is [-1., 1., 1.], and `y_pred` is [0.6, -0.7, -0.5] 2023 the squared hinge metric value is 2.6. 2024 2025 Usage: 2026 2027 ```python 2028 m = tf.keras.metrics.SquaredHinge() 2029 m.update_state([-1., 1., 1.], [0.6, -0.7, -0.5]) 2030 2031 # result = max(0, 1-y_true * y_pred) = [1.6^2 + 1.7^2 + 1.5^2] / 3 2032 2033 print('Final result: ', m.result().numpy()) # Final result: 2.6 2034 ``` 2035 2036 Usage with tf.keras API: 2037 2038 ```python 2039 model = tf.keras.Model(inputs, outputs) 2040 model.compile('sgd', metrics=[tf.keras.metrics.SquaredHinge()]) 2041 ``` 2042 """ 2043 2044 def __init__(self, name='squared_hinge', dtype=None): 2045 super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype) 2046 2047 2048@keras_export('keras.metrics.CategoricalHinge') 2049class CategoricalHinge(MeanMetricWrapper): 2050 """Computes the categorical hinge metric between `y_true` and `y_pred`. 2051 2052 For example, if `y_true` is [0., 1., 1.], and `y_pred` is [1., 0., 1.] 2053 the categorical hinge metric value is 1.0. 2054 2055 Usage: 2056 2057 ```python 2058 m = tf.keras.metrics.CategoricalHinge() 2059 m.update_state([0., 1., 1.], [1., 0., 1.]) 2060 print('Final result: ', m.result().numpy()) # Final result: 1.0 2061 ``` 2062 2063 Usage with tf.keras API: 2064 2065 ```python 2066 model = tf.keras.Model(inputs, outputs) 2067 model.compile('sgd', metrics=[tf.keras.metrics.CategoricalHinge()]) 2068 ``` 2069 """ 2070 2071 def __init__(self, name='categorical_hinge', dtype=None): 2072 super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype) 2073 2074 2075@keras_export('keras.metrics.RootMeanSquaredError') 2076class RootMeanSquaredError(Mean): 2077 """Computes root mean squared error metric between `y_true` and `y_pred`. 2078 2079 Usage: 2080 2081 ```python 2082 m = tf.keras.metrics.RootMeanSquaredError() 2083 m.update_state([2., 4., 6.], [1., 3., 2.]) 2084 print('Final result: ', m.result().numpy()) # Final result: 2.449 2085 ``` 2086 2087 Usage with tf.keras API: 2088 2089 ```python 2090 model = tf.keras.Model(inputs, outputs) 2091 model.compile('sgd', metrics=[tf.keras.metrics.RootMeanSquaredError()]) 2092 ``` 2093 """ 2094 2095 def __init__(self, name='root_mean_squared_error', dtype=None): 2096 super(RootMeanSquaredError, self).__init__(name, dtype=dtype) 2097 2098 def update_state(self, y_true, y_pred, sample_weight=None): 2099 """Accumulates root mean squared error statistics. 2100 2101 Args: 2102 y_true: The ground truth values. 2103 y_pred: The predicted values. 2104 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2105 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2106 be broadcastable to `y_true`. 2107 2108 Returns: 2109 Update op. 2110 """ 2111 y_true = math_ops.cast(y_true, self._dtype) 2112 y_pred = math_ops.cast(y_pred, self._dtype) 2113 y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( 2114 y_pred, y_true, sample_weight) 2115 error_sq = math_ops.squared_difference(y_pred, y_true) 2116 return super(RootMeanSquaredError, self).update_state( 2117 error_sq, sample_weight=sample_weight) 2118 2119 def result(self): 2120 return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count)) 2121 2122 2123@keras_export('keras.metrics.LogCoshError') 2124class LogCoshError(MeanMetricWrapper): 2125 """Computes the logarithm of the hyperbolic cosine of the prediction error. 2126 2127 `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true) 2128 2129 Usage: 2130 2131 ```python 2132 m = tf.keras.metrics.LogCoshError() 2133 m.update_state([0., 1., 1.], [1., 0., 1.]) 2134 print('Final result: ', m.result().numpy()) # Final result: 0.289 2135 ``` 2136 2137 Usage with tf.keras API: 2138 2139 ```python 2140 model = tf.keras.Model(inputs, outputs) 2141 model.compile('sgd', metrics=[tf.keras.metrics.LogCoshError()]) 2142 ``` 2143 """ 2144 2145 def __init__(self, name='logcosh', dtype=None): 2146 super(LogCoshError, self).__init__(logcosh, name, dtype=dtype) 2147 2148 2149@keras_export('keras.metrics.Poisson') 2150class Poisson(MeanMetricWrapper): 2151 """Computes the Poisson metric between `y_true` and `y_pred`. 2152 2153 `metric = y_pred - y_true * log(y_pred)` 2154 2155 Usage: 2156 2157 ```python 2158 m = tf.keras.metrics.Poisson() 2159 m.update_state([1, 9, 2], [4, 8, 12]) 2160 print('Final result: ', m.result().numpy()) # Final result: -4.63 2161 ``` 2162 2163 Usage with tf.keras API: 2164 2165 ```python 2166 model = tf.keras.Model(inputs, outputs) 2167 model.compile('sgd', metrics=[tf.keras.metrics.Poisson()]) 2168 ``` 2169 """ 2170 2171 def __init__(self, name='poisson', dtype=None): 2172 super(Poisson, self).__init__(poisson, name, dtype=dtype) 2173 2174 2175@keras_export('keras.metrics.KLDivergence') 2176class KLDivergence(MeanMetricWrapper): 2177 """Computes Kullback Leibler divergence metric between `y_true` and `y_pred`. 2178 2179 `metric = y_true * log(y_true / y_pred)` 2180 2181 Usage: 2182 2183 ```python 2184 m = tf.keras.metrics.KLDivergence() 2185 m.update_state([.4, .9, .2], [.5, .8, .12]) 2186 print('Final result: ', m.result().numpy()) # Final result: -0.043 2187 ``` 2188 2189 Usage with tf.keras API: 2190 2191 ```python 2192 model = tf.keras.Model(inputs, outputs) 2193 model.compile('sgd', metrics=[tf.keras.metrics.KLDivergence()]) 2194 ``` 2195 """ 2196 2197 def __init__(self, name='kullback_leibler_divergence', dtype=None): 2198 super(KLDivergence, self).__init__( 2199 kullback_leibler_divergence, name, dtype=dtype) 2200 2201 2202@keras_export('keras.metrics.MeanIoU') 2203class MeanIoU(Metric): 2204 """Computes the mean Intersection-Over-Union metric. 2205 2206 Mean Intersection-Over-Union is a common evaluation metric for semantic image 2207 segmentation, which first computes the IOU for each semantic class and then 2208 computes the average over classes. IOU is defined as follows: 2209 IOU = true_positive / (true_positive + false_positive + false_negative). 2210 The predictions are accumulated in a confusion matrix, weighted by 2211 `sample_weight` and the metric is then calculated from it. 2212 2213 If `sample_weight` is `None`, weights default to 1. 2214 Use `sample_weight` of 0 to mask values. 2215 2216 Usage: 2217 2218 ```python 2219 m = tf.keras.metrics.MeanIoU(num_classes=2) 2220 m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) 2221 2222 # cm = [[1, 1], 2223 [1, 1]] 2224 # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] 2225 # iou = true_positives / (sum_row + sum_col - true_positives)) 2226 # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 2227 print('Final result: ', m.result().numpy()) # Final result: 0.33 2228 ``` 2229 2230 Usage with tf.keras API: 2231 2232 ```python 2233 model = tf.keras.Model(inputs, outputs) 2234 model.compile( 2235 'sgd', 2236 loss='mse', 2237 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) 2238 ``` 2239 """ 2240 2241 def __init__(self, num_classes, name=None, dtype=None): 2242 """Creates a `MeanIoU` instance. 2243 2244 Args: 2245 num_classes: The possible number of labels the prediction task can have. 2246 This value must be provided, since a confusion matrix of dimension = 2247 [num_classes, num_classes] will be allocated. 2248 name: (Optional) string name of the metric instance. 2249 dtype: (Optional) data type of the metric result. 2250 """ 2251 super(MeanIoU, self).__init__(name=name, dtype=dtype) 2252 self.num_classes = num_classes 2253 2254 # Variable to accumulate the predictions in the confusion matrix. Setting 2255 # the type to be `float64` as required by confusion_matrix_ops. 2256 self.total_cm = self.add_weight( 2257 'total_confusion_matrix', 2258 shape=(num_classes, num_classes), 2259 initializer=init_ops.zeros_initializer, 2260 dtype=dtypes.float64) 2261 2262 def update_state(self, y_true, y_pred, sample_weight=None): 2263 """Accumulates the confusion matrix statistics. 2264 2265 Args: 2266 y_true: The ground truth values. 2267 y_pred: The predicted values. 2268 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2269 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2270 be broadcastable to `y_true`. 2271 2272 Returns: 2273 Update op. 2274 """ 2275 # Flatten the input if its rank > 1. 2276 if y_pred.shape.ndims > 1: 2277 y_pred = array_ops.reshape(y_pred, [-1]) 2278 2279 if y_true.shape.ndims > 1: 2280 y_true = array_ops.reshape(y_true, [-1]) 2281 2282 if sample_weight is not None and sample_weight.shape.ndims > 1: 2283 sample_weight = array_ops.reshape(sample_weight, [-1]) 2284 2285 # Accumulate the prediction to current confusion matrix. 2286 current_cm = confusion_matrix.confusion_matrix( 2287 y_true, 2288 y_pred, 2289 self.num_classes, 2290 weights=sample_weight, 2291 dtype=dtypes.float64) 2292 return self.total_cm.assign_add(current_cm) 2293 2294 def result(self): 2295 """Compute the mean intersection-over-union via the confusion matrix.""" 2296 sum_over_row = math_ops.cast( 2297 math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype) 2298 sum_over_col = math_ops.cast( 2299 math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype) 2300 true_positives = math_ops.cast( 2301 array_ops.diag_part(self.total_cm), dtype=self._dtype) 2302 2303 # sum_over_row + sum_over_col = 2304 # 2 * true_positives + false_positives + false_negatives. 2305 denominator = sum_over_row + sum_over_col - true_positives 2306 2307 # The mean is only computed over classes that appear in the 2308 # label or prediction tensor. If the denominator is 0, we need to 2309 # ignore the class. 2310 num_valid_entries = math_ops.reduce_sum( 2311 math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype)) 2312 2313 iou = math_ops.div_no_nan(true_positives, denominator) 2314 2315 return math_ops.div_no_nan( 2316 math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries) 2317 2318 def reset_states(self): 2319 K.set_value(self.total_cm, np.zeros((self.num_classes, self.num_classes))) 2320 2321 def get_config(self): 2322 config = {'num_classes': self.num_classes} 2323 base_config = super(MeanIoU, self).get_config() 2324 return dict(list(base_config.items()) + list(config.items())) 2325 2326 2327@keras_export('keras.metrics.MeanTensor') 2328class MeanTensor(Metric): 2329 """Computes the element-wise (weighted) mean of the given tensors. 2330 2331 `MeanTensor` returns a tensor with the same shape of the input tensors. The 2332 mean value is updated by keeping local variables `total` and `count`. The 2333 `total` tracks the sum of the weighted values, and `count` stores the sum of 2334 the weighted counts. 2335 2336 Usage: 2337 2338 ```python 2339 m = tf.keras.metrics.MeanTensor() 2340 m.update_state([0, 1, 2, 3]) 2341 m.update_state([4, 5, 6, 7]) 2342 print('Result: ', m.result().numpy()) # Result: [2, 3, 4, 5] 2343 m.update_state([12, 10, 8, 6], sample_weights= [0, 0.2, 0.5, 1]) 2344 print('Result: ', m.result().numpy()) # Result: [2, 3.636, 4.8, 5.333] 2345 ``` 2346 """ 2347 2348 def __init__(self, name='mean_tensor', dtype=None): 2349 """Creates a `MeanTensor` instance. 2350 2351 Args: 2352 name: (Optional) string name of the metric instance. 2353 dtype: (Optional) data type of the metric result. 2354 """ 2355 super(MeanTensor, self).__init__(name=name, dtype=dtype) 2356 self._shape = None 2357 self._total = None 2358 self._count = None 2359 self._built = False 2360 2361 def _build(self, shape): 2362 self._shape = tensor_shape.TensorShape(shape) 2363 # Create new state variables 2364 self._total = self.add_weight( 2365 'total', shape=shape, initializer=init_ops.zeros_initializer) 2366 self._count = self.add_weight( 2367 'count', shape=shape, initializer=init_ops.zeros_initializer) 2368 with ops.init_scope(): 2369 if not context.executing_eagerly(): 2370 K._initialize_variables(K._get_session()) # pylint: disable=protected-access 2371 self._built = True 2372 2373 @property 2374 def total(self): 2375 return self._total if self._built else None 2376 2377 @property 2378 def count(self): 2379 return self._count if self._built else None 2380 2381 def update_state(self, values, sample_weight=None): 2382 """Accumulates statistics for computing the element-wise mean. 2383 2384 Args: 2385 values: Per-example value. 2386 sample_weight: Optional weighting of each example. Defaults to 1. 2387 2388 Returns: 2389 Update op. 2390 """ 2391 values = math_ops.cast(values, self._dtype) 2392 if not self._built: 2393 self._build(values.shape) 2394 elif values.shape != self._shape: 2395 raise ValueError('MeanTensor input values must always have the same ' 2396 'shape. Expected shape (set during the first call): {}. ' 2397 'Got: {}'.format(self._shape, values.get_shape())) 2398 2399 num_values = array_ops.ones_like(values) 2400 if sample_weight is not None: 2401 sample_weight = math_ops.cast(sample_weight, self._dtype) 2402 2403 # Update dimensions of weights to match with values if possible. 2404 values, _, sample_weight = squeeze_or_expand_dimensions( 2405 values, None, sample_weight) 2406 try: 2407 # Broadcast weights if possible. 2408 sample_weight = weights_broadcast_ops.broadcast_weights( 2409 sample_weight, values) 2410 except ValueError: 2411 # Reduce values to same ndim as weight array 2412 ndim = K.ndim(values) 2413 weight_ndim = K.ndim(sample_weight) 2414 values = math_ops.reduce_mean( 2415 values, axis=list(range(weight_ndim, ndim))) 2416 2417 num_values = math_ops.multiply(num_values, sample_weight) 2418 values = math_ops.multiply(values, sample_weight) 2419 2420 update_total_op = self._total.assign_add(values) 2421 with ops.control_dependencies([update_total_op]): 2422 return self._count.assign_add(num_values) 2423 2424 def result(self): 2425 if not self._built: 2426 raise ValueError( 2427 'MeanTensor does not have any result yet. Please call the MeanTensor ' 2428 'instance or use `.update_state(value)` before retrieving the result.' 2429 ) 2430 return math_ops.div_no_nan(self.total, self.count) 2431 2432 def reset_states(self): 2433 if self._built: 2434 K.batch_set_value( 2435 [(v, np.zeros(self._shape.as_list())) for v in self.variables]) 2436 2437 2438@keras_export('keras.metrics.BinaryCrossentropy') 2439class BinaryCrossentropy(MeanMetricWrapper): 2440 """Computes the crossentropy metric between the labels and predictions. 2441 2442 This is the crossentropy metric class to be used when there are only two 2443 label classes (0 and 1). 2444 2445 Usage: 2446 2447 ```python 2448 m = tf.keras.metrics.BinaryCrossentropy() 2449 m.update_state([1., 0., 1., 0.], [1., 1., 1., 0.]) 2450 2451 # EPSILON = 1e-7, y = y_true, y` = y_pred, Y_MAX = 0.9999999 2452 # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) 2453 # y` = [Y_MAX, Y_MAX, Y_MAX, EPSILON] 2454 2455 # Metric = -(y log(y` + EPSILON) + (1 - y) log(1 - y` + EPSILON)) 2456 # = [-log(Y_MAX + EPSILON), -log(1 - Y_MAX + EPSILON), 2457 # -log(Y_MAX + EPSILON), -log(1)] 2458 # = [(0 + 15.33) / 2, (0 + 0) / 2] 2459 # Reduced metric = 7.665 / 2 2460 2461 print('Final result: ', m.result().numpy()) # Final result: 3.833 2462 ``` 2463 2464 Usage with tf.keras API: 2465 2466 ```python 2467 model = tf.keras.Model(inputs, outputs) 2468 model.compile( 2469 'sgd', 2470 loss='mse', 2471 metrics=[tf.keras.metrics.BinaryCrossentropy()]) 2472 ``` 2473 """ 2474 2475 def __init__(self, 2476 name='binary_crossentropy', 2477 dtype=None, 2478 from_logits=False, 2479 label_smoothing=0): 2480 """Creates a `BinaryCrossentropy` instance. 2481 2482 Args: 2483 name: (Optional) string name of the metric instance. 2484 dtype: (Optional) data type of the metric result. 2485 from_logits: (Optional )Whether output is expected to be a logits tensor. 2486 By default, we consider that output encodes a probability distribution. 2487 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 2488 smoothed, meaning the confidence on label values are relaxed. 2489 e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for 2490 label `0` and `0.9` for label `1`" 2491 """ 2492 2493 super(BinaryCrossentropy, self).__init__( 2494 binary_crossentropy, 2495 name, 2496 dtype=dtype, 2497 from_logits=from_logits, 2498 label_smoothing=label_smoothing) 2499 2500 2501@keras_export('keras.metrics.CategoricalCrossentropy') 2502class CategoricalCrossentropy(MeanMetricWrapper): 2503 """Computes the crossentropy metric between the labels and predictions. 2504 2505 This is the crossentropy metric class to be used when there are multiple 2506 label classes (2 or more). Here we assume that labels are given as a `one_hot` 2507 representation. eg., When labels values are [2, 0, 1], 2508 `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]. 2509 2510 Usage: 2511 2512 ```python 2513 m = tf.keras.metrics.CategoricalCrossentropy() 2514 m.update_state([[0, 1, 0], [0, 0, 1]], 2515 [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 2516 2517 # EPSILON = 1e-7, y = y_true, y` = y_pred 2518 # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) 2519 # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 2520 2521 # xent = -sum(y * log(y'), axis = -1) 2522 # = -((log 0.95), (log 0.1)) 2523 # = [0.051, 2.302] 2524 # Reduced xent = (0.051 + 2.302) / 2 2525 2526 print('Final result: ', m.result().numpy()) # Final result: 1.176 2527 ``` 2528 2529 Usage with tf.keras API: 2530 2531 ```python 2532 model = tf.keras.Model(inputs, outputs) 2533 model.compile( 2534 'sgd', 2535 loss='mse', 2536 metrics=[tf.keras.metrics.CategoricalCrossentropy()]) 2537 ``` 2538 2539 Args: 2540 name: (Optional) string name of the metric instance. 2541 dtype: (Optional) data type of the metric result. 2542 from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor. 2543 By default, we assume that `y_pred` encodes a probability distribution. 2544 label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, 2545 meaning the confidence on label values are relaxed. e.g. 2546 `label_smoothing=0.2` means that we will use a value of `0.1` for label 2547 `0` and `0.9` for label `1`" 2548 """ 2549 2550 def __init__(self, 2551 name='categorical_crossentropy', 2552 dtype=None, 2553 from_logits=False, 2554 label_smoothing=0): 2555 2556 super(CategoricalCrossentropy, self).__init__( 2557 categorical_crossentropy, 2558 name, 2559 dtype=dtype, 2560 from_logits=from_logits, 2561 label_smoothing=label_smoothing) 2562 2563 2564@keras_export('keras.metrics.SparseCategoricalCrossentropy') 2565class SparseCategoricalCrossentropy(MeanMetricWrapper): 2566 """Computes the crossentropy metric between the labels and predictions. 2567 2568 Use this crossentropy metric when there are two or more label classes. 2569 We expect labels to be provided as integers. If you want to provide labels 2570 using `one-hot` representation, please use `CategoricalCrossentropy` metric. 2571 There should be `# classes` floating point values per feature for `y_pred` 2572 and a single floating point value per feature for `y_true`. 2573 2574 In the snippet below, there is a single floating point value per example for 2575 `y_true` and `# classes` floating pointing values per example for `y_pred`. 2576 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is 2577 `[batch_size, num_classes]`. 2578 2579 Usage: 2580 2581 ```python 2582 m = tf.keras.metrics.SparseCategoricalCrossentropy() 2583 m.update_state( 2584 [1, 2], 2585 [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 2586 2587 # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] 2588 # logits = log(y_pred) 2589 # softmax = exp(logits) / sum(exp(logits), axis=-1) 2590 # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 2591 2592 # xent = -sum(y * log(softmax), 1) 2593 # log(softmax) = [[-2.9957, -0.0513, -16.1181], [-2.3026, -0.2231, -2.3026]] 2594 # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]] 2595 2596 # xent = [0.0513, 2.3026] 2597 # Reduced xent = (0.0513 + 2.3026) / 2 2598 2599 print('Final result: ', m.result().numpy()) # Final result: 1.176 2600 ``` 2601 2602 Usage with tf.keras API: 2603 2604 ```python 2605 model = tf.keras.Model(inputs, outputs) 2606 model.compile( 2607 'sgd', 2608 loss='mse', 2609 metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()]) 2610 ``` 2611 2612 Args: 2613 name: (Optional) string name of the metric instance. 2614 dtype: (Optional) data type of the metric result. 2615 from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor. 2616 By default, we assume that `y_pred` encodes a probability distribution. 2617 axis: (Optional) Defaults to -1. The dimension along which the metric is 2618 computed. 2619 """ 2620 2621 def __init__(self, 2622 name='sparse_categorical_crossentropy', 2623 dtype=None, 2624 from_logits=False, 2625 axis=-1): 2626 2627 super(SparseCategoricalCrossentropy, self).__init__( 2628 sparse_categorical_crossentropy, 2629 name, 2630 dtype=dtype, 2631 from_logits=from_logits, 2632 axis=axis) 2633 2634 2635class SumOverBatchSize(Reduce): 2636 """Computes the weighted sum over batch size of the given values. 2637 2638 For example, if values is [1, 3, 5, 7] then the metric value is 4. 2639 If the weights were specified as [1, 1, 0, 0] then the value would be 1. 2640 2641 This metric creates two variables, `total` and `count` that are used to 2642 compute the average of `values`. This average is ultimately returned as sum 2643 over batch size which is an idempotent operation that simply divides `total` 2644 by `count`. 2645 2646 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 2647 to mask values. 2648 """ 2649 2650 def __init__(self, name='sum_over_batch_size', dtype=None): 2651 super(SumOverBatchSize, self).__init__( 2652 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 2653 name=name, 2654 dtype=dtype) 2655 2656 2657class SumOverBatchSizeMetricWrapper(SumOverBatchSize): 2658 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric.""" 2659 2660 def __init__(self, fn, name=None, dtype=None, **kwargs): 2661 """Creates a `SumOverBatchSizeMetricWrapper` instance. 2662 2663 Args: 2664 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 2665 **kwargs)`. 2666 name: (Optional) string name of the metric instance. 2667 dtype: (Optional) data type of the metric result. 2668 **kwargs: The keyword arguments that are passed on to `fn`. 2669 """ 2670 super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype) 2671 self._fn = fn 2672 self._fn_kwargs = kwargs 2673 2674 def update_state(self, y_true, y_pred, sample_weight=None): 2675 y_true = math_ops.cast(y_true, self._dtype) 2676 y_pred = math_ops.cast(y_pred, self._dtype) 2677 y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( 2678 y_pred, y_true, sample_weight) 2679 2680 matches = self._fn(y_true, y_pred, **self._fn_kwargs) 2681 return super(SumOverBatchSizeMetricWrapper, self).update_state( 2682 matches, sample_weight=sample_weight) 2683 2684 def get_config(self): 2685 config = {} 2686 for k, v in six.iteritems(self._fn_kwargs): 2687 config[k] = K.eval(v) if is_tensor_or_variable(v) else v 2688 base_config = super(SumOverBatchSizeMetricWrapper, self).get_config() 2689 return dict(list(base_config.items()) + list(config.items())) 2690 2691 2692def accuracy(y_true, y_pred): 2693 y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) 2694 if y_true.dtype != y_pred.dtype: 2695 y_pred = math_ops.cast(y_pred, y_true.dtype) 2696 return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx()) 2697 2698 2699@keras_export('keras.metrics.binary_accuracy') 2700def binary_accuracy(y_true, y_pred, threshold=0.5): 2701 threshold = math_ops.cast(threshold, y_pred.dtype) 2702 y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) 2703 return K.mean(math_ops.equal(y_true, y_pred), axis=-1) 2704 2705 2706@keras_export('keras.metrics.categorical_accuracy') 2707def categorical_accuracy(y_true, y_pred): 2708 return math_ops.cast( 2709 math_ops.equal( 2710 math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)), 2711 K.floatx()) 2712 2713 2714@keras_export('keras.metrics.sparse_categorical_accuracy') 2715def sparse_categorical_accuracy(y_true, y_pred): 2716 y_pred_rank = ops.convert_to_tensor(y_pred).get_shape().ndims 2717 y_true_rank = ops.convert_to_tensor(y_true).get_shape().ndims 2718 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) 2719 if (y_true_rank is not None) and (y_pred_rank is not None) and (len( 2720 K.int_shape(y_true)) == len(K.int_shape(y_pred))): 2721 y_true = array_ops.squeeze(y_true, [-1]) 2722 y_pred = math_ops.argmax(y_pred, axis=-1) 2723 2724 # If the predicted output and actual output types don't match, force cast them 2725 # to match. 2726 if K.dtype(y_pred) != K.dtype(y_true): 2727 y_pred = math_ops.cast(y_pred, K.dtype(y_true)) 2728 2729 return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx()) 2730 2731 2732@keras_export('keras.metrics.top_k_categorical_accuracy') 2733def top_k_categorical_accuracy(y_true, y_pred, k=5): 2734 return K.mean( 2735 nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), axis=-1) 2736 2737 2738@keras_export('keras.metrics.sparse_top_k_categorical_accuracy') 2739def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): 2740 y_pred_rank = ops.convert_to_tensor(y_pred).get_shape().ndims 2741 y_true_rank = ops.convert_to_tensor(y_true).get_shape().ndims 2742 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) 2743 if (y_true_rank is not None) and (y_pred_rank is not None) and (len( 2744 K.int_shape(y_true)) == len(K.int_shape(y_pred))): 2745 y_true = array_ops.squeeze(y_true, [-1]) 2746 2747 return K.mean(nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), axis=-1) 2748 2749# Aliases 2750 2751mse = MSE = mean_squared_error 2752mae = MAE = mean_absolute_error 2753mape = MAPE = mean_absolute_percentage_error 2754msle = MSLE = mean_squared_logarithmic_error 2755cosine_proximity = cosine_similarity 2756 2757 2758def clone_metric(metric): 2759 """Returns a clone of the metric if stateful, otherwise returns it as is.""" 2760 if isinstance(metric, Metric): 2761 return metric.__class__.from_config(metric.get_config()) 2762 return metric 2763 2764 2765def clone_metrics(metrics): 2766 """Clones the given metric list/dict.""" 2767 if metrics is None: 2768 return None 2769 if isinstance(metrics, dict): 2770 return {key: clone_metric(value) for key, value in metrics.items()} 2771 return [clone_metric(metric) for metric in metrics] 2772 2773 2774@keras_export('keras.metrics.serialize') 2775def serialize(metric): 2776 return serialize_keras_object(metric) 2777 2778 2779@keras_export('keras.metrics.deserialize') 2780def deserialize(config, custom_objects=None): 2781 return deserialize_keras_object( 2782 config, 2783 module_objects=globals(), 2784 custom_objects=custom_objects, 2785 printable_module_name='metric function') 2786 2787 2788@keras_export('keras.metrics.get') 2789def get(identifier): 2790 if isinstance(identifier, dict): 2791 return deserialize(identifier) 2792 elif isinstance(identifier, six.string_types): 2793 return deserialize(str(identifier)) 2794 elif callable(identifier): 2795 return identifier 2796 else: 2797 raise ValueError('Could not interpret ' 2798 'metric function identifier: %s' % identifier) 2799