• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unused-import
16"""Built-in metrics.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import abc
23import types
24import numpy as np
25import six
26
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.keras import backend as K
33from tensorflow.python.keras.engine.base_layer import Layer
34from tensorflow.python.keras.losses import binary_crossentropy
35from tensorflow.python.keras.losses import categorical_crossentropy
36from tensorflow.python.keras.losses import categorical_hinge
37from tensorflow.python.keras.losses import cosine_similarity
38from tensorflow.python.keras.losses import hinge
39from tensorflow.python.keras.losses import kullback_leibler_divergence
40from tensorflow.python.keras.losses import logcosh
41from tensorflow.python.keras.losses import mean_absolute_error
42from tensorflow.python.keras.losses import mean_absolute_percentage_error
43from tensorflow.python.keras.losses import mean_squared_error
44from tensorflow.python.keras.losses import mean_squared_logarithmic_error
45from tensorflow.python.keras.losses import poisson
46from tensorflow.python.keras.losses import sparse_categorical_crossentropy
47from tensorflow.python.keras.losses import squared_hinge
48from tensorflow.python.keras.utils import metrics_utils
49from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
50from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
51from tensorflow.python.keras.utils.generic_utils import to_list
52from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
53from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
54from tensorflow.python.ops import array_ops
55from tensorflow.python.ops import confusion_matrix
56from tensorflow.python.ops import init_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import nn
59from tensorflow.python.ops import variables as tf_variables
60from tensorflow.python.ops import weights_broadcast_ops
61from tensorflow.python.util.tf_export import keras_export
62from tensorflow.tools.docs import doc_controls
63
64
65@keras_export('keras.metrics.Metric')
66@six.add_metaclass(abc.ABCMeta)
67class Metric(Layer):
68  """Encapsulates metric logic and state.
69
70  Usage:
71
72  ```python
73  m = SomeMetric(...)
74  for input in ...:
75    m.update_state(input)
76  print('Final result: ', m.result().numpy())
77  ```
78
79  Usage with tf.keras API:
80
81  ```python
82  model = tf.keras.Sequential()
83  model.add(tf.keras.layers.Dense(64, activation='relu'))
84  model.add(tf.keras.layers.Dense(64, activation='relu'))
85  model.add(tf.keras.layers.Dense(10, activation='softmax'))
86
87  model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
88                loss=tf.keras.losses.categorical_crossentropy,
89                metrics=[tf.keras.metrics.CategoricalAccuracy()])
90
91  data = np.random.random((1000, 32))
92  labels = np.random.random((1000, 10))
93
94  dataset = tf.data.Dataset.from_tensor_slices((data, labels))
95  dataset = dataset.batch(32)
96  dataset = dataset.repeat()
97
98  model.fit(dataset, epochs=10, steps_per_epoch=30)
99  ```
100
101  To be implemented by subclasses:
102  * `__init__()`: All state variables should be created in this method by
103    calling `self.add_weight()` like: `self.var = self.add_weight(...)`
104  * `update_state()`: Has all updates to the state variables like:
105    self.var.assign_add(...).
106  * `result()`: Computes and returns a value for the metric
107    from the state variables.
108
109  Example subclass implementation:
110
111  ```
112  class BinaryTruePositives(tf.keras.metrics.Metric):
113
114    def __init__(self, name='binary_true_positives', **kwargs):
115      super(BinaryTruePositives, self).__init__(name=name, **kwargs)
116      self.true_positives = self.add_weight(name='tp', initializer='zeros')
117
118    def update_state(self, y_true, y_pred, sample_weight=None):
119      y_true = tf.cast(y_true, tf.bool)
120      y_pred = tf.cast(y_pred, tf.bool)
121
122      values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
123      values = tf.cast(values, self.dtype)
124      if sample_weight is not None:
125        sample_weight = tf.cast(sample_weight, self.dtype)
126        sample_weight = tf.broadcast_weights(sample_weight, values)
127        values = tf.multiply(values, sample_weight)
128      self.true_positives.assign_add(tf.reduce_sum(values))
129
130    def result(self):
131      return self.true_positives
132  ```
133  """
134
135  def __init__(self, name=None, dtype=None, **kwargs):
136    super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
137    self.stateful = True  # All metric layers are stateful.
138    self.built = True
139    self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name
140
141  def __new__(cls, *args, **kwargs):
142    obj = super(Metric, cls).__new__(cls)
143
144    # TODO(psv): We are excluding wrapping `update_state` of built-in metrics
145    # with function here because of b/121302287. With this, built-in metrics
146    # will continue to work with TPUs and custom metrics will not, however
147    # users writing custom metrics need not worry about control dependencies
148    # and returning ops.
149    if cls.__module__ == Metric.__module__:
150      update_state_fn = obj.update_state
151    else:
152      update_state_fn = def_function.function(obj.update_state)
153
154    obj.update_state = types.MethodType(
155        metrics_utils.update_state_wrapper(update_state_fn), obj)
156    obj.result = types.MethodType(metrics_utils.result_wrapper(obj.result), obj)
157    return obj
158
159  def __call__(self, *args, **kwargs):
160    """Accumulates statistics and then computes metric result value.
161
162    Args:
163      *args:
164      **kwargs: A mini-batch of inputs to the Metric,
165        passed on to `update_state()`.
166
167    Returns:
168      The metric value tensor.
169    """
170    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
171    with ops.control_dependencies([update_op]):
172      result_t = self.result()  # pylint: disable=not-callable
173
174      # We are adding the metric object as metadata on the result tensor.
175      # This is required when we want to use a metric with `add_metric` API on
176      # a Model/Layer in graph mode. This metric instance will later be used
177      # to reset variable state after each epoch of training.
178      # Example:
179      #   model = Model()
180      #   mean = Mean()
181      #   model.add_metric(mean(values), name='mean')
182      if not context.executing_eagerly():
183        result_t._metric_obj = self  # pylint: disable=protected-access
184      return result_t
185
186  @property
187  def dtype(self):
188    return self._dtype
189
190  def get_config(self):
191    """Returns the serializable config of the metric."""
192    return {'name': self.name, 'dtype': self.dtype}
193
194  def reset_states(self):
195    """Resets all of the metric state variables.
196
197    This function is called between epochs/steps,
198    when a metric is evaluated during training.
199    """
200    K.batch_set_value([(v, 0) for v in self.variables])
201
202  @abc.abstractmethod
203  def update_state(self, *args, **kwargs):
204    """Accumulates statistics for the metric.
205
206    Note: This function is executed as a graph function in graph mode.
207    This means:
208      a) Operations on the same resource are executed in textual order.
209         This should make it easier to do things like add the updated
210         value of a variable to another, for example.
211      b) You don't need to worry about collecting the update ops to execute.
212         All update ops added to the graph by this function will be executed.
213      As a result, code should generally work the same way with graph or
214      eager execution.
215
216    Please use `tf.config.experimental_run_functions_eagerly(True)` to execute
217    this function eagerly for debugging or profiling.
218
219    Args:
220      *args:
221      **kwargs: A mini-batch of inputs to the Metric.
222    """
223    NotImplementedError('Must be implemented in subclasses.')
224
225  @abc.abstractmethod
226  def result(self):
227    """Computes and returns the metric value tensor.
228
229    Result computation is an idempotent operation that simply calculates the
230    metric value using the state variables.
231    """
232    NotImplementedError('Must be implemented in subclasses.')
233
234  ### For use by subclasses ###
235  @doc_controls.for_subclass_implementers
236  def add_weight(self,
237                 name,
238                 shape=(),
239                 aggregation=tf_variables.VariableAggregation.SUM,
240                 synchronization=tf_variables.VariableSynchronization.ON_READ,
241                 initializer=None,
242                 dtype=None):
243    """Adds state variable. Only for use by subclasses."""
244    return super(Metric, self).add_weight(
245        name=name,
246        shape=shape,
247        dtype=self._dtype if dtype is None else dtype,
248        trainable=False,
249        initializer=initializer,
250        collections=[],
251        synchronization=synchronization,
252        aggregation=aggregation)
253
254  ### End: For use by subclasses ###
255
256
257class Reduce(Metric):
258  """Encapsulates metrics that perform a reduce operation on the values."""
259
260  def __init__(self, reduction, name, dtype=None):
261    """Creates a `Reduce` instance.
262
263    Args:
264      reduction: a `tf.keras.metrics.Reduction` enum value.
265      name: string name of the metric instance.
266      dtype: (Optional) data type of the metric result.
267    """
268    super(Reduce, self).__init__(name=name, dtype=dtype)
269    self.reduction = reduction
270    self.total = self.add_weight(
271        'total', initializer=init_ops.zeros_initializer)
272    if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
273                     metrics_utils.Reduction.WEIGHTED_MEAN]:
274      self.count = self.add_weight(
275          'count', initializer=init_ops.zeros_initializer)
276
277  def update_state(self, values, sample_weight=None):
278    """Accumulates statistics for computing the reduction metric.
279
280    For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE,
281    then the value of `result()` is 4. If the `sample_weight` is specified as
282    [1, 1, 0, 0] then value of `result()` would be 2.
283
284    Args:
285      values: Per-example value.
286      sample_weight: Optional weighting of each example. Defaults to 1.
287
288    Returns:
289      Update op.
290    """
291    values = math_ops.cast(values, self._dtype)
292    if sample_weight is not None:
293      sample_weight = math_ops.cast(sample_weight, self._dtype)
294      # Update dimensions of weights to match with values if possible.
295      values, _, sample_weight = squeeze_or_expand_dimensions(
296          values, None, sample_weight)
297      try:
298        # Broadcast weights if possible.
299        sample_weight = weights_broadcast_ops.broadcast_weights(
300            sample_weight, values)
301      except ValueError:
302        # Reduce values to same ndim as weight array
303        ndim = K.ndim(values)
304        weight_ndim = K.ndim(sample_weight)
305        if self.reduction == metrics_utils.Reduction.SUM:
306          values = math_ops.reduce_sum(
307              values, axis=list(range(weight_ndim, ndim)))
308        else:
309          values = math_ops.reduce_mean(
310              values, axis=list(range(weight_ndim, ndim)))
311      values = math_ops.multiply(values, sample_weight)
312
313    value_sum = math_ops.reduce_sum(values)
314    with ops.control_dependencies([value_sum]):
315      update_total_op = self.total.assign_add(value_sum)
316
317    # Exit early if the reduction doesn't have a denominator.
318    if self.reduction == metrics_utils.Reduction.SUM:
319      return update_total_op
320
321    # Update `count` for reductions that require a denominator.
322    if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
323      num_values = math_ops.cast(array_ops.size(values), self._dtype)
324    elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
325      if sample_weight is None:
326        num_values = math_ops.cast(array_ops.size(values), self._dtype)
327      else:
328        num_values = math_ops.reduce_sum(sample_weight)
329    else:
330      raise NotImplementedError(
331          'reduction [%s] not implemented' % self.reduction)
332
333    with ops.control_dependencies([update_total_op]):
334      return self.count.assign_add(num_values)
335
336  def result(self):
337    if self.reduction == metrics_utils.Reduction.SUM:
338      return array_ops.identity(self.total)
339    elif self.reduction in [
340        metrics_utils.Reduction.WEIGHTED_MEAN,
341        metrics_utils.Reduction.SUM_OVER_BATCH_SIZE
342    ]:
343      return math_ops.div_no_nan(self.total, self.count)
344    else:
345      raise NotImplementedError(
346          'reduction [%s] not implemented' % self.reduction)
347
348
349@keras_export('keras.metrics.Sum')
350class Sum(Reduce):
351  """Computes the (weighted) sum of the given values.
352
353  For example, if values is [1, 3, 5, 7] then the sum is 16.
354  If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
355
356  This metric creates one variable, `total`, that is used to compute the sum of
357  `values`. This is ultimately returned as `sum`.
358
359  If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
360  to mask values.
361
362  Usage:
363
364  ```python
365  m = tf.keras.metrics.Sum()
366  m.update_state([1, 3, 5, 7])
367  print('Final result: ', m.result().numpy())  # Final result: 16.0
368  ```
369
370  Usage with tf.keras API:
371
372  ```python
373  model = tf.keras.Model(inputs, outputs)
374  model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs))
375  model.compile('sgd', loss='mse')
376  ```
377  """
378
379  def __init__(self, name='sum', dtype=None):
380    """Creates a `Sum` instance.
381
382    Args:
383      name: (Optional) string name of the metric instance.
384      dtype: (Optional) data type of the metric result.
385    """
386    super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
387                              name=name, dtype=dtype)
388
389
390@keras_export('keras.metrics.Mean')
391class Mean(Reduce):
392  """Computes the (weighted) mean of the given values.
393
394  For example, if values is [1, 3, 5, 7] then the mean is 4.
395  If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
396
397  This metric creates two variables, `total` and `count` that are used to
398  compute the average of `values`. This average is ultimately returned as `mean`
399  which is an idempotent operation that simply divides `total` by `count`.
400
401  If `sample_weight` is `None`, weights default to 1.
402  Use `sample_weight` of 0 to mask values.
403
404  Usage:
405
406  ```python
407  m = tf.keras.metrics.Mean()
408  m.update_state([1, 3, 5, 7])
409  print('Final result: ', m.result().numpy())  # Final result: 4.0
410  ```
411
412  Usage with tf.keras API:
413
414  ```python
415  model = tf.keras.Model(inputs, outputs)
416  model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs))
417  model.compile('sgd', loss='mse')
418  ```
419  """
420
421  def __init__(self, name='mean', dtype=None):
422    """Creates a `Mean` instance.
423
424    Args:
425      name: (Optional) string name of the metric instance.
426      dtype: (Optional) data type of the metric result.
427    """
428    super(Mean, self).__init__(
429        reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype)
430
431
432@keras_export('keras.metrics.MeanRelativeError')
433class MeanRelativeError(Mean):
434  """Computes the mean relative error by normalizing with the given values.
435
436  This metric creates two local variables, `total` and `count` that are used to
437  compute the mean relative absolute error. This average is weighted by
438  `sample_weight`, and it is ultimately returned as `mean_relative_error`:
439  an idempotent operation that simply divides `total` by `count`.
440
441  If `sample_weight` is `None`, weights default to 1.
442  Use `sample_weight` of 0 to mask values.
443
444  Usage:
445
446  ```python
447  m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3])
448  m.update_state([1, 3, 2, 3], [2, 4, 6, 8])
449
450  # metric = mean(|y_pred - y_true| / normalizer)
451  #        = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3])
452  #        = 5/4 = 1.25
453  print('Final result: ', m.result().numpy())  # Final result: 1.25
454  ```
455
456  Usage with tf.keras API:
457
458  ```python
459  model = tf.keras.Model(inputs, outputs)
460  model.compile(
461    'sgd',
462    loss='mse',
463    metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])])
464  ```
465  """
466
467  def __init__(self, normalizer, name=None, dtype=None):
468    """Creates a `MeanRelativeError` instance.
469
470    Args:
471      normalizer: The normalizer values with same shape as predictions.
472      name: (Optional) string name of the metric instance.
473      dtype: (Optional) data type of the metric result.
474    """
475    super(MeanRelativeError, self).__init__(name=name, dtype=dtype)
476    normalizer = math_ops.cast(normalizer, self._dtype)
477    self.normalizer = normalizer
478
479  def update_state(self, y_true, y_pred, sample_weight=None):
480    """Accumulates metric statistics.
481
482    Args:
483      y_true: The ground truth values.
484      y_pred: The predicted values.
485      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
486        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
487        be broadcastable to `y_true`.
488
489    Returns:
490      Update op.
491    """
492    y_true = math_ops.cast(y_true, self._dtype)
493    y_pred = math_ops.cast(y_pred, self._dtype)
494    y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
495        y_pred, y_true, sample_weight)
496
497    y_pred, self.normalizer = confusion_matrix.remove_squeezable_dimensions(
498        y_pred, self.normalizer)
499    y_pred.shape.assert_is_compatible_with(y_true.shape)
500    relative_errors = math_ops.div_no_nan(
501        math_ops.abs(y_true - y_pred), self.normalizer)
502
503    return super(MeanRelativeError, self).update_state(
504        relative_errors, sample_weight=sample_weight)
505
506  def get_config(self):
507    n = self.normalizer
508    config = {'normalizer': K.eval(n) if is_tensor_or_variable(n) else n}
509    base_config = super(MeanRelativeError, self).get_config()
510    return dict(list(base_config.items()) + list(config.items()))
511
512
513class MeanMetricWrapper(Mean):
514  """Wraps a stateless metric function with the Mean metric."""
515
516  def __init__(self, fn, name=None, dtype=None, **kwargs):
517    """Creates a `MeanMetricWrapper` instance.
518
519    Args:
520      fn: The metric function to wrap, with signature
521        `fn(y_true, y_pred, **kwargs)`.
522      name: (Optional) string name of the metric instance.
523      dtype: (Optional) data type of the metric result.
524      **kwargs: The keyword arguments that are passed on to `fn`.
525    """
526    super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
527    self._fn = fn
528    self._fn_kwargs = kwargs
529
530  def update_state(self, y_true, y_pred, sample_weight=None):
531    """Accumulates metric statistics.
532
533    `y_true` and `y_pred` should have the same shape.
534
535    Args:
536      y_true: The ground truth values.
537      y_pred: The predicted values.
538      sample_weight: Optional weighting of each example. Defaults to 1. Can be
539        a `Tensor` whose rank is either 0, or the same rank as `y_true`,
540        and must be broadcastable to `y_true`.
541
542    Returns:
543      Update op.
544    """
545    y_true = math_ops.cast(y_true, self._dtype)
546    y_pred = math_ops.cast(y_pred, self._dtype)
547    y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
548        y_pred, y_true, sample_weight)
549
550    matches = self._fn(y_true, y_pred, **self._fn_kwargs)
551    return super(MeanMetricWrapper, self).update_state(
552        matches, sample_weight=sample_weight)
553
554  def get_config(self):
555    config = {}
556    for k, v in six.iteritems(self._fn_kwargs):
557      config[k] = K.eval(v) if is_tensor_or_variable(v) else v
558    base_config = super(MeanMetricWrapper, self).get_config()
559    return dict(list(base_config.items()) + list(config.items()))
560
561
562@keras_export('keras.metrics.Accuracy')
563class Accuracy(MeanMetricWrapper):
564  """Calculates how often predictions matches labels.
565
566  For example, if `y_true` is [1, 2, 3, 4] and `y_pred` is [0, 2, 3, 4]
567  then the accuracy is 3/4 or .75.  If the weights were specified as
568  [1, 1, 0, 0] then the accuracy would be 1/2 or .5.
569
570  This metric creates two local variables, `total` and `count` that are used to
571  compute the frequency with which `y_pred` matches `y_true`. This frequency is
572  ultimately returned as `binary accuracy`: an idempotent operation that simply
573  divides `total` by `count`.
574
575  If `sample_weight` is `None`, weights default to 1.
576  Use `sample_weight` of 0 to mask values.
577
578  Usage:
579
580  ```python
581  m = tf.keras.metrics.Accuracy()
582  m.update_state([1, 2, 3, 4], [0, 2, 3, 4])
583  print('Final result: ', m.result().numpy())  # Final result: 0.75
584  ```
585
586  Usage with tf.keras API:
587
588  ```python
589  model = tf.keras.Model(inputs, outputs)
590  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Accuracy()])
591  ```
592  """
593
594  def __init__(self, name='accuracy', dtype=None):
595    super(Accuracy, self).__init__(accuracy, name, dtype=dtype)
596
597
598@keras_export('keras.metrics.BinaryAccuracy')
599class BinaryAccuracy(MeanMetricWrapper):
600  """Calculates how often predictions matches labels.
601
602  For example, if `y_true` is [1, 1, 0, 0] and `y_pred` is [0.98, 1, 0, 0.6]
603  then the binary accuracy is 3/4 or .75.  If the weights were specified as
604  [1, 0, 0, 1] then the binary accuracy would be 1/2 or .5.
605
606  This metric creates two local variables, `total` and `count` that are used to
607  compute the frequency with which `y_pred` matches `y_true`. This frequency is
608  ultimately returned as `binary accuracy`: an idempotent operation that simply
609  divides `total` by `count`.
610
611  If `sample_weight` is `None`, weights default to 1.
612  Use `sample_weight` of 0 to mask values.
613
614  Usage:
615
616  ```python
617  m = tf.keras.metrics.BinaryAccuracy()
618  m.update_state([1, 1, 0, 0], [0.98, 1, 0, 0.6])
619  print('Final result: ', m.result().numpy())  # Final result: 0.75
620  ```
621
622  Usage with tf.keras API:
623
624  ```python
625  model = tf.keras.Model(inputs, outputs)
626  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.BinaryAccuracy()])
627  ```
628  """
629
630  def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
631    """Creates a `BinaryAccuracy` instance.
632
633    Args:
634      name: (Optional) string name of the metric instance.
635      dtype: (Optional) data type of the metric result.
636      threshold: (Optional) Float representing the threshold for deciding
637      whether prediction values are 1 or 0.
638    """
639    super(BinaryAccuracy, self).__init__(
640        binary_accuracy, name, dtype=dtype, threshold=threshold)
641
642
643@keras_export('keras.metrics.CategoricalAccuracy')
644class CategoricalAccuracy(MeanMetricWrapper):
645  """Calculates how often predictions matches labels.
646
647  For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is
648  [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
649  If the weights were specified as [0.7, 0.3] then the categorical accuracy
650  would be .3. You can provide logits of classes as `y_pred`, since argmax of
651  logits and probabilities are same.
652
653  This metric creates two local variables, `total` and `count` that are used to
654  compute the frequency with which `y_pred` matches `y_true`. This frequency is
655  ultimately returned as `categorical accuracy`: an idempotent operation that
656  simply divides `total` by `count`.
657
658  `y_pred` and `y_true` should be passed in as vectors of probabilities, rather
659  than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.
660
661  If `sample_weight` is `None`, weights default to 1.
662  Use `sample_weight` of 0 to mask values.
663
664  Usage:
665
666  ```python
667  m = tf.keras.metrics.CategoricalAccuracy()
668  m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
669  print('Final result: ', m.result().numpy())  # Final result: 0.5
670  ```
671
672  Usage with tf.keras API:
673
674  ```python
675  model = tf.keras.Model(inputs, outputs)
676  model.compile(
677    'sgd',
678    loss='mse',
679    metrics=[tf.keras.metrics.CategoricalAccuracy()])
680  ```
681  """
682
683  def __init__(self, name='categorical_accuracy', dtype=None):
684    """Creates a `CategoricalAccuracy` instance.
685
686    Args:
687      name: (Optional) string name of the metric instance.
688      dtype: (Optional) data type of the metric result.
689    """
690    super(CategoricalAccuracy, self).__init__(
691        categorical_accuracy, name, dtype=dtype)
692
693
694@keras_export('keras.metrics.SparseCategoricalAccuracy')
695class SparseCategoricalAccuracy(MeanMetricWrapper):
696  """Calculates how often predictions matches integer labels.
697
698  For example, if `y_true` is [[2], [1]] and `y_pred` is
699  [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
700  If the weights were specified as [0.7, 0.3] then the categorical accuracy
701  would be .3. You can provide logits of classes as `y_pred`, since argmax of
702  logits and probabilities are same.
703
704  This metric creates two local variables, `total` and `count` that are used to
705  compute the frequency with which `y_pred` matches `y_true`. This frequency is
706  ultimately returned as `sparse categorical accuracy`: an idempotent operation
707  that simply divides `total` by `count`.
708
709  If `sample_weight` is `None`, weights default to 1.
710  Use `sample_weight` of 0 to mask values.
711
712  Usage:
713
714  ```python
715  m = tf.keras.metrics.SparseCategoricalAccuracy()
716  m.update_state([[2], [1]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
717  print('Final result: ', m.result().numpy())  # Final result: 0.5
718  ```
719
720  Usage with tf.keras API:
721
722  ```python
723  model = tf.keras.Model(inputs, outputs)
724  model.compile(
725      'sgd',
726      loss='mse',
727      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
728  ```
729  """
730
731  def __init__(self, name='sparse_categorical_accuracy', dtype=None):
732    super(SparseCategoricalAccuracy, self).__init__(
733        sparse_categorical_accuracy, name, dtype=dtype)
734
735
736@keras_export('keras.metrics.TopKCategoricalAccuracy')
737class TopKCategoricalAccuracy(MeanMetricWrapper):
738  """Computes how often targets are in the top `K` predictions.
739
740  Usage:
741
742  ```python
743  m = tf.keras.metrics.TopKCategoricalAccuracy()
744  m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
745  print('Final result: ', m.result().numpy())  # Final result: 1.0
746  ```
747
748  Usage with tf.keras API:
749
750  ```python
751  model = tf.keras.Model(inputs, outputs)
752  model.compile('sgd', metrics=[tf.keras.metrics.TopKCategoricalAccuracy()])
753  ```
754  """
755
756  def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
757    """Creates a `TopKCategoricalAccuracy` instance.
758
759    Args:
760      k: (Optional) Number of top elements to look at for computing accuracy.
761        Defaults to 5.
762      name: (Optional) string name of the metric instance.
763      dtype: (Optional) data type of the metric result.
764    """
765    super(TopKCategoricalAccuracy, self).__init__(
766        top_k_categorical_accuracy, name, dtype=dtype, k=k)
767
768
769@keras_export('keras.metrics.SparseTopKCategoricalAccuracy')
770class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
771  """Computes how often integer targets are in the top `K` predictions.
772
773  Usage:
774
775  ```python
776  m = tf.keras.metrics.SparseTopKCategoricalAccuracy()
777  m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
778  print('Final result: ', m.result().numpy())  # Final result: 1.0
779  ```
780
781  Usage with tf.keras API:
782
783  ```python
784  model = tf.keras.Model(inputs, outputs)
785  model.compile(
786    'sgd',
787    metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()])
788  ```
789  """
790
791  def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
792    """Creates a `SparseTopKCategoricalAccuracy` instance.
793
794    Args:
795      k: (Optional) Number of top elements to look at for computing accuracy.
796        Defaults to 5.
797      name: (Optional) string name of the metric instance.
798      dtype: (Optional) data type of the metric result.
799    """
800    super(SparseTopKCategoricalAccuracy, self).__init__(
801        sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)
802
803
804class _ConfusionMatrixConditionCount(Metric):
805  """Calculates the number of the given confusion matrix condition."""
806
807  def __init__(self,
808               confusion_matrix_cond,
809               thresholds=None,
810               name=None,
811               dtype=None):
812    """Creates a `_ConfusionMatrixConditionCount` instance.
813
814    Args:
815      confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
816      thresholds: (Optional) Defaults to 0.5. A float value or a python
817        list/tuple of float threshold values in [0, 1]. A threshold is compared
818        with prediction values to determine the truth value of predictions
819        (i.e., above the threshold is `true`, below is `false`). One metric
820        value is generated for each threshold value.
821      name: (Optional) string name of the metric instance.
822      dtype: (Optional) data type of the metric result.
823    """
824    super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
825    self._confusion_matrix_cond = confusion_matrix_cond
826    self.init_thresholds = thresholds
827    self.thresholds = metrics_utils.parse_init_thresholds(
828        thresholds, default_threshold=0.5)
829    self.accumulator = self.add_weight(
830        'accumulator',
831        shape=(len(self.thresholds),),
832        initializer=init_ops.zeros_initializer)
833
834  def update_state(self, y_true, y_pred, sample_weight=None):
835    """Accumulates the given confusion matrix condition statistics.
836
837    Args:
838      y_true: The ground truth values.
839      y_pred: The predicted values.
840      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
841        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
842        be broadcastable to `y_true`.
843
844    Returns:
845      Update op.
846    """
847    return metrics_utils.update_confusion_matrix_variables(
848        {self._confusion_matrix_cond: self.accumulator},
849        y_true,
850        y_pred,
851        thresholds=self.thresholds,
852        sample_weight=sample_weight)
853
854  def result(self):
855    if len(self.thresholds) == 1:
856      result = self.accumulator[0]
857    else:
858      result = self.accumulator
859    return ops.convert_to_tensor(result)
860
861  def reset_states(self):
862    num_thresholds = len(to_list(self.thresholds))
863    K.batch_set_value(
864        [(v, np.zeros((num_thresholds,))) for v in self.variables])
865
866  def get_config(self):
867    config = {'thresholds': self.init_thresholds}
868    base_config = super(_ConfusionMatrixConditionCount, self).get_config()
869    return dict(list(base_config.items()) + list(config.items()))
870
871
872@keras_export('keras.metrics.FalsePositives')
873class FalsePositives(_ConfusionMatrixConditionCount):
874  """Calculates the number of false positives.
875
876  For example, if `y_true` is [0, 1, 0, 0] and `y_pred` is [0, 0, 1, 1]
877  then the false positives value is 2.  If the weights were specified as
878  [0, 0, 1, 0] then the false positives value would be 1.
879
880  If `sample_weight` is given, calculates the sum of the weights of
881  false positives. This metric creates one local variable, `accumulator`
882  that is used to keep track of the number of false positives.
883
884  If `sample_weight` is `None`, weights default to 1.
885  Use `sample_weight` of 0 to mask values.
886
887  Usage:
888
889  ```python
890  m = tf.keras.metrics.FalsePositives()
891  m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
892  print('Final result: ', m.result().numpy())  # Final result: 2
893  ```
894
895  Usage with tf.keras API:
896
897  ```python
898  model = tf.keras.Model(inputs, outputs)
899  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.FalsePositives()])
900  ```
901  """
902
903  def __init__(self, thresholds=None, name=None, dtype=None):
904    """Creates a `FalsePositives` instance.
905
906    Args:
907      thresholds: (Optional) Defaults to 0.5. A float value or a python
908        list/tuple of float threshold values in [0, 1]. A threshold is compared
909        with prediction values to determine the truth value of predictions
910        (i.e., above the threshold is `true`, below is `false`). One metric
911        value is generated for each threshold value.
912      name: (Optional) string name of the metric instance.
913      dtype: (Optional) data type of the metric result.
914    """
915    super(FalsePositives, self).__init__(
916        confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
917        thresholds=thresholds,
918        name=name,
919        dtype=dtype)
920
921
922@keras_export('keras.metrics.FalseNegatives')
923class FalseNegatives(_ConfusionMatrixConditionCount):
924  """Calculates the number of false negatives.
925
926  For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [0, 1, 0, 0]
927  then the false negatives value is 2.  If the weights were specified as
928  [0, 0, 1, 0] then the false negatives value would be 1.
929
930  If `sample_weight` is given, calculates the sum of the weights of
931  false negatives. This metric creates one local variable, `accumulator`
932  that is used to keep track of the number of false negatives.
933
934  If `sample_weight` is `None`, weights default to 1.
935  Use `sample_weight` of 0 to mask values.
936
937  Usage:
938
939  ```python
940  m = tf.keras.metrics.FalseNegatives()
941  m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
942  print('Final result: ', m.result().numpy())  # Final result: 2
943  ```
944
945  Usage with tf.keras API:
946
947  ```python
948  model = tf.keras.Model(inputs, outputs)
949  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.FalseNegatives()])
950  ```
951  """
952
953  def __init__(self, thresholds=None, name=None, dtype=None):
954    """Creates a `FalseNegatives` instance.
955
956    Args:
957      thresholds: (Optional) Defaults to 0.5. A float value or a python
958        list/tuple of float threshold values in [0, 1]. A threshold is compared
959        with prediction values to determine the truth value of predictions
960        (i.e., above the threshold is `true`, below is `false`). One metric
961        value is generated for each threshold value.
962      name: (Optional) string name of the metric instance.
963      dtype: (Optional) data type of the metric result.
964    """
965    super(FalseNegatives, self).__init__(
966        confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
967        thresholds=thresholds,
968        name=name,
969        dtype=dtype)
970
971
972@keras_export('keras.metrics.TrueNegatives')
973class TrueNegatives(_ConfusionMatrixConditionCount):
974  """Calculates the number of true negatives.
975
976  For example, if `y_true` is [0, 1, 0, 0] and `y_pred` is [1, 1, 0, 0]
977  then the true negatives value is 2.  If the weights were specified as
978  [0, 0, 1, 0] then the true negatives value would be 1.
979
980  If `sample_weight` is given, calculates the sum of the weights of
981  true negatives. This metric creates one local variable, `accumulator`
982  that is used to keep track of the number of true negatives.
983
984  If `sample_weight` is `None`, weights default to 1.
985  Use `sample_weight` of 0 to mask values.
986
987  Usage:
988
989  ```python
990  m = tf.keras.metrics.TrueNegatives()
991  m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
992  print('Final result: ', m.result().numpy())  # Final result: 2
993  ```
994
995  Usage with tf.keras API:
996
997  ```python
998  model = tf.keras.Model(inputs, outputs)
999  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.TrueNegatives()])
1000  ```
1001  """
1002
1003  def __init__(self, thresholds=None, name=None, dtype=None):
1004    """Creates a `TrueNegatives` instance.
1005
1006    Args:
1007      thresholds: (Optional) Defaults to 0.5. A float value or a python
1008        list/tuple of float threshold values in [0, 1]. A threshold is compared
1009        with prediction values to determine the truth value of predictions
1010        (i.e., above the threshold is `true`, below is `false`). One metric
1011        value is generated for each threshold value.
1012      name: (Optional) string name of the metric instance.
1013      dtype: (Optional) data type of the metric result.
1014    """
1015    super(TrueNegatives, self).__init__(
1016        confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
1017        thresholds=thresholds,
1018        name=name,
1019        dtype=dtype)
1020
1021
1022@keras_export('keras.metrics.TruePositives')
1023class TruePositives(_ConfusionMatrixConditionCount):
1024  """Calculates the number of true positives.
1025
1026  For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1]
1027  then the true positives value is 2.  If the weights were specified as
1028  [0, 0, 1, 0] then the true positives value would be 1.
1029
1030  If `sample_weight` is given, calculates the sum of the weights of
1031  true positives. This metric creates one local variable, `true_positives`
1032  that is used to keep track of the number of true positives.
1033
1034  If `sample_weight` is `None`, weights default to 1.
1035  Use `sample_weight` of 0 to mask values.
1036
1037  Usage:
1038
1039  ```python
1040  m = tf.keras.metrics.TruePositives()
1041  m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1042  print('Final result: ', m.result().numpy())  # Final result: 2
1043  ```
1044
1045  Usage with tf.keras API:
1046
1047  ```python
1048  model = tf.keras.Model(inputs, outputs)
1049  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.TruePositives()])
1050  ```
1051  """
1052
1053  def __init__(self, thresholds=None, name=None, dtype=None):
1054    """Creates a `TruePositives` instance.
1055
1056    Args:
1057      thresholds: (Optional) Defaults to 0.5. A float value or a python
1058        list/tuple of float threshold values in [0, 1]. A threshold is compared
1059        with prediction values to determine the truth value of predictions
1060        (i.e., above the threshold is `true`, below is `false`). One metric
1061        value is generated for each threshold value.
1062      name: (Optional) string name of the metric instance.
1063      dtype: (Optional) data type of the metric result.
1064    """
1065    super(TruePositives, self).__init__(
1066        confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
1067        thresholds=thresholds,
1068        name=name,
1069        dtype=dtype)
1070
1071
1072@keras_export('keras.metrics.Precision')
1073class Precision(Metric):
1074  """Computes the precision of the predictions with respect to the labels.
1075
1076  For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1]
1077  then the precision value is 2/(2+1) ie. 0.66. If the weights were specified as
1078  [0, 0, 1, 0] then the precision value would be 1.
1079
1080  The metric creates two local variables, `true_positives` and `false_positives`
1081  that are used to compute the precision. This value is ultimately returned as
1082  `precision`, an idempotent operation that simply divides `true_positives`
1083  by the sum of `true_positives` and `false_positives`.
1084
1085  If `sample_weight` is `None`, weights default to 1.
1086  Use `sample_weight` of 0 to mask values.
1087
1088  If `top_k` is set, we'll calculate precision as how often on average a class
1089  among the top-k classes with the highest predicted values of a batch entry is
1090  correct and can be found in the label for that entry.
1091
1092  If `class_id` is specified, we calculate precision by considering only the
1093  entries in the batch for which `class_id` is above the threshold and/or in the
1094  top-k highest predictions, and computing the fraction of them for which
1095  `class_id` is indeed a correct label.
1096
1097  Usage:
1098
1099  ```python
1100  m = tf.keras.metrics.Precision()
1101  m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1102  print('Final result: ', m.result().numpy())  # Final result: 0.66
1103  ```
1104
1105  Usage with tf.keras API:
1106
1107  ```python
1108  model = tf.keras.Model(inputs, outputs)
1109  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Precision()])
1110  ```
1111  """
1112
1113  def __init__(self,
1114               thresholds=None,
1115               top_k=None,
1116               class_id=None,
1117               name=None,
1118               dtype=None):
1119    """Creates a `Precision` instance.
1120
1121    Args:
1122      thresholds: (Optional) A float value or a python list/tuple of float
1123        threshold values in [0, 1]. A threshold is compared with prediction
1124        values to determine the truth value of predictions (i.e., above the
1125        threshold is `true`, below is `false`). One metric value is generated
1126        for each threshold value. If neither thresholds nor top_k are set, the
1127        default is to calculate precision with `thresholds=0.5`.
1128      top_k: (Optional) Unset by default. An int value specifying the top-k
1129        predictions to consider when calculating precision.
1130      class_id: (Optional) Integer class ID for which we want binary metrics.
1131        This must be in the half-open interval `[0, num_classes)`, where
1132        `num_classes` is the last dimension of predictions.
1133      name: (Optional) string name of the metric instance.
1134      dtype: (Optional) data type of the metric result.
1135    """
1136    super(Precision, self).__init__(name=name, dtype=dtype)
1137    self.init_thresholds = thresholds
1138    self.top_k = top_k
1139    self.class_id = class_id
1140
1141    default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1142    self.thresholds = metrics_utils.parse_init_thresholds(
1143        thresholds, default_threshold=default_threshold)
1144    self.true_positives = self.add_weight(
1145        'true_positives',
1146        shape=(len(self.thresholds),),
1147        initializer=init_ops.zeros_initializer)
1148    self.false_positives = self.add_weight(
1149        'false_positives',
1150        shape=(len(self.thresholds),),
1151        initializer=init_ops.zeros_initializer)
1152
1153  def update_state(self, y_true, y_pred, sample_weight=None):
1154    """Accumulates true positive and false positive statistics.
1155
1156    Args:
1157      y_true: The ground truth values, with the same dimensions as `y_pred`.
1158        Will be cast to `bool`.
1159      y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1160      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1161        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1162        be broadcastable to `y_true`.
1163
1164    Returns:
1165      Update op.
1166    """
1167    return metrics_utils.update_confusion_matrix_variables(
1168        {
1169            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1170            metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives
1171        },
1172        y_true,
1173        y_pred,
1174        thresholds=self.thresholds,
1175        top_k=self.top_k,
1176        class_id=self.class_id,
1177        sample_weight=sample_weight)
1178
1179  def result(self):
1180    result = math_ops.div_no_nan(self.true_positives,
1181                                 self.true_positives + self.false_positives)
1182    return result[0] if len(self.thresholds) == 1 else result
1183
1184  def reset_states(self):
1185    num_thresholds = len(to_list(self.thresholds))
1186    K.batch_set_value(
1187        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1188
1189  def get_config(self):
1190    config = {
1191        'thresholds': self.init_thresholds,
1192        'top_k': self.top_k,
1193        'class_id': self.class_id
1194    }
1195    base_config = super(Precision, self).get_config()
1196    return dict(list(base_config.items()) + list(config.items()))
1197
1198
1199@keras_export('keras.metrics.Recall')
1200class Recall(Metric):
1201  """Computes the recall of the predictions with respect to the labels.
1202
1203  For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1]
1204  then the recall value is 2/(2+1) ie. 0.66. If the weights were specified as
1205  [0, 0, 1, 0] then the recall value would be 1.
1206
1207  This metric creates two local variables, `true_positives` and
1208  `false_negatives`, that are used to compute the recall. This value is
1209  ultimately returned as `recall`, an idempotent operation that simply divides
1210  `true_positives` by the sum of `true_positives` and `false_negatives`.
1211
1212  If `sample_weight` is `None`, weights default to 1.
1213  Use `sample_weight` of 0 to mask values.
1214
1215  If `top_k` is set, recall will be computed as how often on average a class
1216  among the labels of a batch entry is in the top-k predictions.
1217
1218  If `class_id` is specified, we calculate recall by considering only the
1219  entries in the batch for which `class_id` is in the label, and computing the
1220  fraction of them for which `class_id` is above the threshold and/or in the
1221  top-k predictions.
1222
1223  Usage:
1224
1225  ```python
1226  m = tf.keras.metrics.Recall()
1227  m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1228  print('Final result: ', m.result().numpy())  # Final result: 0.66
1229  ```
1230
1231  Usage with tf.keras API:
1232
1233  ```python
1234  model = tf.keras.Model(inputs, outputs)
1235  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Recall()])
1236  ```
1237  """
1238
1239  def __init__(self,
1240               thresholds=None,
1241               top_k=None,
1242               class_id=None,
1243               name=None,
1244               dtype=None):
1245    """Creates a `Recall` instance.
1246
1247    Args:
1248      thresholds: (Optional) A float value or a python list/tuple of float
1249        threshold values in [0, 1]. A threshold is compared with prediction
1250        values to determine the truth value of predictions (i.e., above the
1251        threshold is `true`, below is `false`). One metric value is generated
1252        for each threshold value. If neither thresholds nor top_k are set, the
1253        default is to calculate recall with `thresholds=0.5`.
1254      top_k: (Optional) Unset by default. An int value specifying the top-k
1255        predictions to consider when calculating recall.
1256      class_id: (Optional) Integer class ID for which we want binary metrics.
1257        This must be in the half-open interval `[0, num_classes)`, where
1258        `num_classes` is the last dimension of predictions.
1259      name: (Optional) string name of the metric instance.
1260      dtype: (Optional) data type of the metric result.
1261    """
1262    super(Recall, self).__init__(name=name, dtype=dtype)
1263    self.init_thresholds = thresholds
1264    self.top_k = top_k
1265    self.class_id = class_id
1266
1267    default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1268    self.thresholds = metrics_utils.parse_init_thresholds(
1269        thresholds, default_threshold=default_threshold)
1270    self.true_positives = self.add_weight(
1271        'true_positives',
1272        shape=(len(self.thresholds),),
1273        initializer=init_ops.zeros_initializer)
1274    self.false_negatives = self.add_weight(
1275        'false_negatives',
1276        shape=(len(self.thresholds),),
1277        initializer=init_ops.zeros_initializer)
1278
1279  def update_state(self, y_true, y_pred, sample_weight=None):
1280    """Accumulates true positive and false negative statistics.
1281
1282    Args:
1283      y_true: The ground truth values, with the same dimensions as `y_pred`.
1284        Will be cast to `bool`.
1285      y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1286      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1287        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1288        be broadcastable to `y_true`.
1289
1290    Returns:
1291      Update op.
1292    """
1293    return metrics_utils.update_confusion_matrix_variables(
1294        {
1295            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1296            metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives
1297        },
1298        y_true,
1299        y_pred,
1300        thresholds=self.thresholds,
1301        top_k=self.top_k,
1302        class_id=self.class_id,
1303        sample_weight=sample_weight)
1304
1305  def result(self):
1306    result = math_ops.div_no_nan(self.true_positives,
1307                                 self.true_positives + self.false_negatives)
1308    return result[0] if len(self.thresholds) == 1 else result
1309
1310  def reset_states(self):
1311    num_thresholds = len(to_list(self.thresholds))
1312    K.batch_set_value(
1313        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1314
1315  def get_config(self):
1316    config = {
1317        'thresholds': self.init_thresholds,
1318        'top_k': self.top_k,
1319        'class_id': self.class_id
1320    }
1321    base_config = super(Recall, self).get_config()
1322    return dict(list(base_config.items()) + list(config.items()))
1323
1324
1325@six.add_metaclass(abc.ABCMeta)
1326class SensitivitySpecificityBase(Metric):
1327  """Abstract base class for computing sensitivity and specificity.
1328
1329  For additional information about specificity and sensitivity, see the
1330  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1331  """
1332
1333  def __init__(self, value, num_thresholds=200, name=None, dtype=None):
1334    super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
1335    if num_thresholds <= 0:
1336      raise ValueError('`num_thresholds` must be > 0.')
1337    self.value = value
1338    self.true_positives = self.add_weight(
1339        'true_positives',
1340        shape=(num_thresholds,),
1341        initializer=init_ops.zeros_initializer)
1342    self.true_negatives = self.add_weight(
1343        'true_negatives',
1344        shape=(num_thresholds,),
1345        initializer=init_ops.zeros_initializer)
1346    self.false_positives = self.add_weight(
1347        'false_positives',
1348        shape=(num_thresholds,),
1349        initializer=init_ops.zeros_initializer)
1350    self.false_negatives = self.add_weight(
1351        'false_negatives',
1352        shape=(num_thresholds,),
1353        initializer=init_ops.zeros_initializer)
1354
1355    # Compute `num_thresholds` thresholds in [0, 1]
1356    if num_thresholds == 1:
1357      self.thresholds = [0.5]
1358    else:
1359      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1360                    for i in range(num_thresholds - 2)]
1361      self.thresholds = [0.0] + thresholds + [1.0]
1362
1363  def update_state(self, y_true, y_pred, sample_weight=None):
1364    """Accumulates confusion matrix statistics.
1365
1366    Args:
1367      y_true: The ground truth values.
1368      y_pred: The predicted values.
1369      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1370        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1371        be broadcastable to `y_true`.
1372
1373    Returns:
1374      Update op.
1375    """
1376    return metrics_utils.update_confusion_matrix_variables(
1377        {
1378            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1379            metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
1380            metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
1381            metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
1382        },
1383        y_true,
1384        y_pred,
1385        thresholds=self.thresholds,
1386        sample_weight=sample_weight)
1387
1388  def reset_states(self):
1389    num_thresholds = len(self.thresholds)
1390    K.batch_set_value(
1391        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1392
1393
1394@keras_export('keras.metrics.SensitivityAtSpecificity')
1395class SensitivityAtSpecificity(SensitivitySpecificityBase):
1396  """Computes the sensitivity at a given specificity.
1397
1398  `Sensitivity` measures the proportion of actual positives that are correctly
1399  identified as such (tp / (tp + fn)).
1400  `Specificity` measures the proportion of actual negatives that are correctly
1401  identified as such (tn / (tn + fp)).
1402
1403  This metric creates four local variables, `true_positives`, `true_negatives`,
1404  `false_positives` and `false_negatives` that are used to compute the
1405  sensitivity at the given specificity. The threshold for the given specificity
1406  value is computed and used to evaluate the corresponding sensitivity.
1407
1408  If `sample_weight` is `None`, weights default to 1.
1409  Use `sample_weight` of 0 to mask values.
1410
1411  For additional information about specificity and sensitivity, see the
1412  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1413
1414  Usage:
1415
1416  ```python
1417  m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1)
1418  m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1419  print('Final result: ', m.result().numpy())  # Final result: 0.5
1420  ```
1421
1422  Usage with tf.keras API:
1423
1424  ```python
1425  model = tf.keras.Model(inputs, outputs)
1426  model.compile(
1427      'sgd',
1428      loss='mse',
1429      metrics=[tf.keras.metrics.SensitivityAtSpecificity()])
1430  ```
1431  """
1432
1433  def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
1434    """Creates a `SensitivityAtSpecificity` instance.
1435
1436    Args:
1437      specificity: A scalar value in range `[0, 1]`.
1438      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1439        use for matching the given specificity.
1440      name: (Optional) string name of the metric instance.
1441      dtype: (Optional) data type of the metric result.
1442    """
1443    if specificity < 0 or specificity > 1:
1444      raise ValueError('`specificity` must be in the range [0, 1].')
1445    self.specificity = specificity
1446    self.num_thresholds = num_thresholds
1447    super(SensitivityAtSpecificity, self).__init__(
1448        specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
1449
1450  def result(self):
1451    # Calculate specificities at all the thresholds.
1452    specificities = math_ops.div_no_nan(
1453        self.true_negatives, self.true_negatives + self.false_positives)
1454
1455    # Find the index of the threshold where the specificity is closest to the
1456    # given specificity.
1457    min_index = math_ops.argmin(
1458        math_ops.abs(specificities - self.value), axis=0)
1459    min_index = math_ops.cast(min_index, dtypes.int32)
1460
1461    # Compute sensitivity at that index.
1462    return math_ops.div_no_nan(
1463        self.true_positives[min_index],
1464        self.true_positives[min_index] + self.false_negatives[min_index])
1465
1466  def get_config(self):
1467    config = {
1468        'num_thresholds': self.num_thresholds,
1469        'specificity': self.specificity
1470    }
1471    base_config = super(SensitivityAtSpecificity, self).get_config()
1472    return dict(list(base_config.items()) + list(config.items()))
1473
1474
1475@keras_export('keras.metrics.SpecificityAtSensitivity')
1476class SpecificityAtSensitivity(SensitivitySpecificityBase):
1477  """Computes the specificity at a given sensitivity.
1478
1479  `Sensitivity` measures the proportion of actual positives that are correctly
1480  identified as such (tp / (tp + fn)).
1481  `Specificity` measures the proportion of actual negatives that are correctly
1482  identified as such (tn / (tn + fp)).
1483
1484  This metric creates four local variables, `true_positives`, `true_negatives`,
1485  `false_positives` and `false_negatives` that are used to compute the
1486  specificity at the given sensitivity. The threshold for the given sensitivity
1487  value is computed and used to evaluate the corresponding specificity.
1488
1489  If `sample_weight` is `None`, weights default to 1.
1490  Use `sample_weight` of 0 to mask values.
1491
1492  For additional information about specificity and sensitivity, see the
1493  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1494
1495  Usage:
1496
1497  ```python
1498  m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1)
1499  m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1500  print('Final result: ', m.result().numpy())  # Final result: 1.0
1501  ```
1502
1503  Usage with tf.keras API:
1504
1505  ```python
1506  model = tf.keras.Model(inputs, outputs)
1507  model.compile(
1508      'sgd',
1509      loss='mse',
1510      metrics=[tf.keras.metrics.SpecificityAtSensitivity()])
1511  ```
1512  """
1513
1514  def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None):
1515    """Creates a `SpecificityAtSensitivity` instance.
1516
1517    Args:
1518      sensitivity: A scalar value in range `[0, 1]`.
1519      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1520        use for matching the given specificity.
1521      name: (Optional) string name of the metric instance.
1522      dtype: (Optional) data type of the metric result.
1523    """
1524    if sensitivity < 0 or sensitivity > 1:
1525      raise ValueError('`sensitivity` must be in the range [0, 1].')
1526    self.sensitivity = sensitivity
1527    self.num_thresholds = num_thresholds
1528    super(SpecificityAtSensitivity, self).__init__(
1529        sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
1530
1531  def result(self):
1532    # Calculate sensitivities at all the thresholds.
1533    sensitivities = math_ops.div_no_nan(
1534        self.true_positives, self.true_positives + self.false_negatives)
1535
1536    # Find the index of the threshold where the sensitivity is closest to the
1537    # given specificity.
1538    min_index = math_ops.argmin(
1539        math_ops.abs(sensitivities - self.value), axis=0)
1540    min_index = math_ops.cast(min_index, dtypes.int32)
1541
1542    # Compute specificity at that index.
1543    return math_ops.div_no_nan(
1544        self.true_negatives[min_index],
1545        self.true_negatives[min_index] + self.false_positives[min_index])
1546
1547  def get_config(self):
1548    config = {
1549        'num_thresholds': self.num_thresholds,
1550        'sensitivity': self.sensitivity
1551    }
1552    base_config = super(SpecificityAtSensitivity, self).get_config()
1553    return dict(list(base_config.items()) + list(config.items()))
1554
1555
1556@keras_export('keras.metrics.AUC')
1557class AUC(Metric):
1558  """Computes the approximate AUC (Area under the curve) via a Riemann sum.
1559
1560  This metric creates four local variables, `true_positives`, `true_negatives`,
1561  `false_positives` and `false_negatives` that are used to compute the AUC.
1562  To discretize the AUC curve, a linearly spaced set of thresholds is used to
1563  compute pairs of recall and precision values. The area under the ROC-curve is
1564  therefore computed using the height of the recall values by the false positive
1565  rate, while the area under the PR-curve is the computed using the height of
1566  the precision values by the recall.
1567
1568  This value is ultimately returned as `auc`, an idempotent operation that
1569  computes the area under a discretized curve of precision versus recall values
1570  (computed using the aforementioned variables). The `num_thresholds` variable
1571  controls the degree of discretization with larger numbers of thresholds more
1572  closely approximating the true AUC. The quality of the approximation may vary
1573  dramatically depending on `num_thresholds`.
1574
1575  For best results, `predictions` should be distributed approximately uniformly
1576  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
1577  approximation may be poor if this is not the case. Setting `summation_method`
1578  to 'minoring' or 'majoring' can help quantify the error in the approximation
1579  by providing lower or upper bound estimate of the AUC.
1580
1581  If `sample_weight` is `None`, weights default to 1.
1582  Use `sample_weight` of 0 to mask values.
1583
1584  Usage:
1585
1586  ```python
1587  m = tf.keras.metrics.AUC(num_thresholds=3)
1588  m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1589
1590  # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
1591  # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
1592  # recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
1593  # auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
1594
1595  print('Final result: ', m.result().numpy())  # Final result: 0.75
1596  ```
1597
1598  Usage with tf.keras API:
1599
1600  ```python
1601  model = tf.keras.Model(inputs, outputs)
1602  model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.AUC()])
1603  ```
1604  """
1605
1606  def __init__(self,
1607               num_thresholds=200,
1608               curve='ROC',
1609               summation_method='interpolation',
1610               name=None,
1611               dtype=None):
1612    """Creates an `AUC` instance.
1613
1614    Args:
1615      num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1616        use when discretizing the roc curve. Values must be > 1.
1617      curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
1618        [default] or 'PR' for the Precision-Recall-curve.
1619      summation_method: (Optional) Specifies the Riemann summation method used
1620        (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default],
1621          applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates
1622          (true/false) positives but not the ratio that is precision (see Davis
1623          & Goadrich 2006 for details); 'minoring' that applies left summation
1624          for increasing intervals and right summation for decreasing intervals;
1625          'majoring' that does the opposite.
1626      name: (Optional) string name of the metric instance.
1627      dtype: (Optional) data type of the metric result.
1628    """
1629    # Validate configurations.
1630    if num_thresholds <= 1:
1631      raise ValueError('`num_thresholds` must be > 1.')
1632    if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
1633        metrics_utils.AUCCurve):
1634      raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
1635          curve, list(metrics_utils.AUCCurve)))
1636    if isinstance(
1637        summation_method,
1638        metrics_utils.AUCSummationMethod) and summation_method not in list(
1639            metrics_utils.AUCSummationMethod):
1640      raise ValueError(
1641          'Invalid summation method: "{}". Valid options are: "{}"'.format(
1642              summation_method, list(metrics_utils.AUCSummationMethod)))
1643
1644    # Update properties.
1645    self.num_thresholds = num_thresholds
1646    if isinstance(curve, metrics_utils.AUCCurve):
1647      self.curve = curve
1648    else:
1649      self.curve = metrics_utils.AUCCurve.from_str(curve)
1650    if isinstance(summation_method, metrics_utils.AUCSummationMethod):
1651      self.summation_method = summation_method
1652    else:
1653      self.summation_method = metrics_utils.AUCSummationMethod.from_str(
1654          summation_method)
1655    super(AUC, self).__init__(name=name, dtype=dtype)
1656
1657    # Create metric variables
1658    self.true_positives = self.add_weight(
1659        'true_positives',
1660        shape=(num_thresholds,),
1661        initializer=init_ops.zeros_initializer)
1662    self.true_negatives = self.add_weight(
1663        'true_negatives',
1664        shape=(num_thresholds,),
1665        initializer=init_ops.zeros_initializer)
1666    self.false_positives = self.add_weight(
1667        'false_positives',
1668        shape=(num_thresholds,),
1669        initializer=init_ops.zeros_initializer)
1670    self.false_negatives = self.add_weight(
1671        'false_negatives',
1672        shape=(num_thresholds,),
1673        initializer=init_ops.zeros_initializer)
1674
1675    # Compute `num_thresholds` thresholds in [0, 1]
1676    thresholds = [
1677        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
1678    ]
1679    self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
1680    # epsilon - to account for floating point imprecisions.
1681
1682  def update_state(self, y_true, y_pred, sample_weight=None):
1683    """Accumulates confusion matrix statistics.
1684
1685    Args:
1686      y_true: The ground truth values.
1687      y_pred: The predicted values.
1688      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1689        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1690        be broadcastable to `y_true`.
1691
1692    Returns:
1693      Update op.
1694    """
1695    return metrics_utils.update_confusion_matrix_variables({
1696        metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1697        metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
1698        metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
1699        metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
1700    }, y_true, y_pred, self.thresholds, sample_weight=sample_weight)
1701
1702  def interpolate_pr_auc(self):
1703    """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
1704
1705    https://www.biostat.wisc.edu/~page/rocpr.pdf
1706
1707    Note here we derive & use a closed formula not present in the paper
1708    as follows:
1709
1710      Precision = TP / (TP + FP) = TP / P
1711
1712    Modeling all of TP (true positive), FP (false positive) and their sum
1713    P = TP + FP (predicted positive) as varying linearly within each interval
1714    [A, B] between successive thresholds, we get
1715
1716      Precision slope = dTP / dP
1717                      = (TP_B - TP_A) / (P_B - P_A)
1718                      = (TP - TP_A) / (P - P_A)
1719      Precision = (TP_A + slope * (P - P_A)) / P
1720
1721    The area within the interval is (slope / total_pos_weight) times
1722
1723      int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
1724      int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
1725
1726    where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
1727
1728      int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
1729
1730    Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
1731
1732      slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
1733
1734    where dTP == TP_B - TP_A.
1735
1736    Note that when P_A == 0 the above calculation simplifies into
1737
1738      int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
1739
1740    which is really equivalent to imputing constant precision throughout the
1741    first bucket having >0 true positives.
1742
1743    Returns:
1744      pr_auc: an approximation of the area under the P-R curve.
1745    """
1746    dtp = self.true_positives[:self.num_thresholds -
1747                              1] - self.true_positives[1:]
1748    p = self.true_positives + self.false_positives
1749    dp = p[:self.num_thresholds - 1] - p[1:]
1750
1751    prec_slope = math_ops.div_no_nan(
1752        dtp, math_ops.maximum(dp, 0), name='prec_slope')
1753    intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:])
1754
1755    safe_p_ratio = array_ops.where(
1756        math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0),
1757        math_ops.div_no_nan(
1758            p[:self.num_thresholds - 1],
1759            math_ops.maximum(p[1:], 0),
1760            name='recall_relative_ratio'),
1761        array_ops.ones_like(p[1:]))
1762
1763    return math_ops.reduce_sum(
1764        math_ops.div_no_nan(
1765            prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
1766            math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:],
1767                             0),
1768            name='pr_auc_increment'),
1769        name='interpolate_pr_auc')
1770
1771  def result(self):
1772    if (self.curve == metrics_utils.AUCCurve.PR and
1773        self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION
1774       ):
1775      # This use case is different and is handled separately.
1776      return self.interpolate_pr_auc()
1777
1778    # Set `x` and `y` values for the curves based on `curve` config.
1779    recall = math_ops.div_no_nan(self.true_positives,
1780                                 self.true_positives + self.false_negatives)
1781    if self.curve == metrics_utils.AUCCurve.ROC:
1782      fp_rate = math_ops.div_no_nan(self.false_positives,
1783                                    self.false_positives + self.true_negatives)
1784      x = fp_rate
1785      y = recall
1786    else:  # curve == 'PR'.
1787      precision = math_ops.div_no_nan(
1788          self.true_positives, self.true_positives + self.false_positives)
1789      x = recall
1790      y = precision
1791
1792    # Find the rectangle heights based on `summation_method`.
1793    if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
1794      # Note: the case ('PR', 'interpolation') has been handled above.
1795      heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
1796    elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
1797      heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:])
1798    else:  # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
1799      heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:])
1800
1801    # Sum up the areas of all the rectangles.
1802    return math_ops.reduce_sum(
1803        math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights),
1804        name=self.name)
1805
1806  def reset_states(self):
1807    num_thresholds = len(self.thresholds)
1808    K.batch_set_value(
1809        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1810
1811  def get_config(self):
1812    config = {
1813        'num_thresholds': self.num_thresholds,
1814        'curve': self.curve.value,
1815        'summation_method': self.summation_method.value,
1816    }
1817    base_config = super(AUC, self).get_config()
1818    return dict(list(base_config.items()) + list(config.items()))
1819
1820
1821@keras_export('keras.metrics.CosineSimilarity')
1822class CosineSimilarity(MeanMetricWrapper):
1823  """Computes the cosine similarity between the labels and predictions.
1824
1825  cosine similarity = (a . b) / ||a|| ||b||
1826  (https://en.wikipedia.org/wiki/Cosine_similarity)
1827
1828  For example, if `y_true` is [0, 1, 1], and `y_pred` is [1, 0, 1], the cosine
1829  similarity is 0.5.
1830
1831  This metric keeps the average cosine similarity between `predictions` and
1832  `labels` over a stream of data.
1833
1834  Usage:
1835  ```python
1836  m = tf.keras.metrics.CosineSimilarity(axis=1)
1837  m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])
1838  # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]]
1839  # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]]
1840  # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
1841  # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
1842         = ((0. + 0.) +  (0.5 + 0.5)) / 2
1843
1844  print('Final result: ', m.result().numpy())  # Final result: 0.5
1845  ```
1846
1847  Usage with tf.keras API:
1848
1849  ```python
1850  model = tf.keras.Model(inputs, outputs)
1851  model.compile(
1852      'sgd',
1853      loss='mse',
1854      metrics=[tf.keras.metrics.CosineSimilarity(axis=1)])
1855  ```
1856  """
1857
1858  def __init__(self, name='cosine_similarity', dtype=None, axis=-1):
1859    """Creates a `CosineSimilarity` instance.
1860
1861    Args:
1862      name: (Optional) string name of the metric instance.
1863      dtype: (Optional) data type of the metric result.
1864      axis: (Optional) Defaults to -1. The dimension along which the cosine
1865        similarity is computed.
1866    """
1867    super(CosineSimilarity, self).__init__(
1868        cosine_similarity, name, dtype=dtype, axis=axis)
1869
1870
1871@keras_export('keras.metrics.MeanAbsoluteError')
1872class MeanAbsoluteError(MeanMetricWrapper):
1873  """Computes the mean absolute error between the labels and predictions.
1874
1875  For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.]
1876  the mean absolute error is 3/4 (0.75).
1877
1878  Usage:
1879  ```python
1880  m = tf.keras.metrics.MeanAbsoluteError()
1881  m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.])
1882  print('Final result: ', m.result().numpy())  # Final result: 0.75
1883  ```
1884
1885  Usage with tf.keras API:
1886
1887  ```python
1888  model = tf.keras.Model(inputs, outputs)
1889  model.compile('sgd', metrics=[tf.keras.metrics.MeanAbsoluteError()])
1890  ```
1891  """
1892
1893  def __init__(self, name='mean_absolute_error', dtype=None):
1894    super(MeanAbsoluteError, self).__init__(
1895        mean_absolute_error, name, dtype=dtype)
1896
1897
1898@keras_export('keras.metrics.MeanAbsolutePercentageError')
1899class MeanAbsolutePercentageError(MeanMetricWrapper):
1900  """Computes the mean absolute percentage error between `y_true` and `y_pred`.
1901
1902  For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.]
1903  the mean absolute percentage error is 5e+08.
1904
1905  Usage:
1906
1907  ```python
1908  m = tf.keras.metrics.MeanAbsolutePercentageError()
1909  m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.])
1910  print('Final result: ', m.result().numpy())  # Final result: 5e+08
1911  ```
1912
1913  Usage with tf.keras API:
1914
1915  ```python
1916  model = tf.keras.Model(inputs, outputs)
1917  model.compile('sgd', metrics=[tf.keras.metrics.MeanAbsolutePercentageError()])
1918  ```
1919  """
1920
1921  def __init__(self, name='mean_absolute_percentage_error', dtype=None):
1922    super(MeanAbsolutePercentageError, self).__init__(
1923        mean_absolute_percentage_error, name, dtype=dtype)
1924
1925
1926@keras_export('keras.metrics.MeanSquaredError')
1927class MeanSquaredError(MeanMetricWrapper):
1928  """Computes the mean squared error between `y_true` and `y_pred`.
1929
1930  For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.]
1931  the mean squared error is 3/4 (0.75).
1932
1933  Usage:
1934
1935  ```python
1936  m = tf.keras.metrics.MeanSquaredError()
1937  m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.])
1938  print('Final result: ', m.result().numpy())  # Final result: 0.75
1939  ```
1940
1941  Usage with tf.keras API:
1942
1943  ```python
1944  model = tf.keras.Model(inputs, outputs)
1945  model.compile('sgd', metrics=[tf.keras.metrics.MeanSquaredError()])
1946  ```
1947  """
1948
1949  def __init__(self, name='mean_squared_error', dtype=None):
1950    super(MeanSquaredError, self).__init__(
1951        mean_squared_error, name, dtype=dtype)
1952
1953
1954@keras_export('keras.metrics.MeanSquaredLogarithmicError')
1955class MeanSquaredLogarithmicError(MeanMetricWrapper):
1956  """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
1957
1958  For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.]
1959  the mean squared logarithmic error is 0.36034.
1960
1961  Usage:
1962
1963  ```python
1964  m = tf.keras.metrics.MeanSquaredLogarithmicError()
1965  m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.])
1966  print('Final result: ', m.result().numpy())  # Final result: 0.36034
1967  ```
1968
1969  Usage with tf.keras API:
1970
1971  ```python
1972  model = tf.keras.Model(inputs, outputs)
1973  model.compile('sgd', metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()])
1974  ```
1975  """
1976
1977  def __init__(self, name='mean_squared_logarithmic_error', dtype=None):
1978    super(MeanSquaredLogarithmicError, self).__init__(
1979        mean_squared_logarithmic_error, name, dtype=dtype)
1980
1981
1982@keras_export('keras.metrics.Hinge')
1983class Hinge(MeanMetricWrapper):
1984  """Computes the hinge metric between `y_true` and `y_pred`.
1985
1986  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
1987  provided we will convert them to -1 or 1.
1988
1989  For example, if `y_true` is [-1., 1., 1.], and `y_pred` is [0.6, -0.7, -0.5]
1990  the hinge metric value is 1.6.
1991
1992  Usage:
1993
1994  ```python
1995  m = tf.keras.metrics.Hinge()
1996  m.update_state([-1., 1., 1.], [0.6, -0.7, -0.5])
1997
1998  # result = max(0, 1-y_true * y_pred) = [1.6 + 1.7 + 1.5] / 3
1999
2000  print('Final result: ', m.result().numpy())  # Final result: 1.6
2001  ```
2002
2003  Usage with tf.keras API:
2004
2005  ```python
2006  model = tf.keras.Model(inputs, outputs)
2007  model.compile('sgd', metrics=[tf.keras.metrics.Hinge()])
2008  ```
2009  """
2010
2011  def __init__(self, name='hinge', dtype=None):
2012    super(Hinge, self).__init__(hinge, name, dtype=dtype)
2013
2014
2015@keras_export('keras.metrics.SquaredHinge')
2016class SquaredHinge(MeanMetricWrapper):
2017  """Computes the squared hinge metric between `y_true` and `y_pred`.
2018
2019  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2020  provided we will convert them to -1 or 1.
2021
2022  For example, if `y_true` is [-1., 1., 1.], and `y_pred` is [0.6, -0.7, -0.5]
2023  the squared hinge metric value is 2.6.
2024
2025  Usage:
2026
2027  ```python
2028  m = tf.keras.metrics.SquaredHinge()
2029  m.update_state([-1., 1., 1.], [0.6, -0.7, -0.5])
2030
2031  # result = max(0, 1-y_true * y_pred) = [1.6^2 + 1.7^2 + 1.5^2] / 3
2032
2033  print('Final result: ', m.result().numpy())  # Final result: 2.6
2034  ```
2035
2036  Usage with tf.keras API:
2037
2038  ```python
2039  model = tf.keras.Model(inputs, outputs)
2040  model.compile('sgd', metrics=[tf.keras.metrics.SquaredHinge()])
2041  ```
2042  """
2043
2044  def __init__(self, name='squared_hinge', dtype=None):
2045    super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype)
2046
2047
2048@keras_export('keras.metrics.CategoricalHinge')
2049class CategoricalHinge(MeanMetricWrapper):
2050  """Computes the categorical hinge metric between `y_true` and `y_pred`.
2051
2052  For example, if `y_true` is [0., 1., 1.], and `y_pred` is [1., 0., 1.]
2053  the categorical hinge metric value is 1.0.
2054
2055  Usage:
2056
2057  ```python
2058  m = tf.keras.metrics.CategoricalHinge()
2059  m.update_state([0., 1., 1.], [1., 0., 1.])
2060  print('Final result: ', m.result().numpy())  # Final result: 1.0
2061  ```
2062
2063  Usage with tf.keras API:
2064
2065  ```python
2066  model = tf.keras.Model(inputs, outputs)
2067  model.compile('sgd', metrics=[tf.keras.metrics.CategoricalHinge()])
2068  ```
2069  """
2070
2071  def __init__(self, name='categorical_hinge', dtype=None):
2072    super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype)
2073
2074
2075@keras_export('keras.metrics.RootMeanSquaredError')
2076class RootMeanSquaredError(Mean):
2077  """Computes root mean squared error metric between `y_true` and `y_pred`.
2078
2079  Usage:
2080
2081  ```python
2082  m = tf.keras.metrics.RootMeanSquaredError()
2083  m.update_state([2., 4., 6.], [1., 3., 2.])
2084  print('Final result: ', m.result().numpy())  # Final result: 2.449
2085  ```
2086
2087  Usage with tf.keras API:
2088
2089  ```python
2090  model = tf.keras.Model(inputs, outputs)
2091  model.compile('sgd', metrics=[tf.keras.metrics.RootMeanSquaredError()])
2092  ```
2093  """
2094
2095  def __init__(self, name='root_mean_squared_error', dtype=None):
2096    super(RootMeanSquaredError, self).__init__(name, dtype=dtype)
2097
2098  def update_state(self, y_true, y_pred, sample_weight=None):
2099    """Accumulates root mean squared error statistics.
2100
2101    Args:
2102      y_true: The ground truth values.
2103      y_pred: The predicted values.
2104      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2105        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2106        be broadcastable to `y_true`.
2107
2108    Returns:
2109      Update op.
2110    """
2111    y_true = math_ops.cast(y_true, self._dtype)
2112    y_pred = math_ops.cast(y_pred, self._dtype)
2113    y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
2114        y_pred, y_true, sample_weight)
2115    error_sq = math_ops.squared_difference(y_pred, y_true)
2116    return super(RootMeanSquaredError, self).update_state(
2117        error_sq, sample_weight=sample_weight)
2118
2119  def result(self):
2120    return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count))
2121
2122
2123@keras_export('keras.metrics.LogCoshError')
2124class LogCoshError(MeanMetricWrapper):
2125  """Computes the logarithm of the hyperbolic cosine of the prediction error.
2126
2127  `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true)
2128
2129  Usage:
2130
2131  ```python
2132  m = tf.keras.metrics.LogCoshError()
2133  m.update_state([0., 1., 1.], [1., 0., 1.])
2134  print('Final result: ', m.result().numpy())  # Final result: 0.289
2135  ```
2136
2137  Usage with tf.keras API:
2138
2139  ```python
2140  model = tf.keras.Model(inputs, outputs)
2141  model.compile('sgd', metrics=[tf.keras.metrics.LogCoshError()])
2142  ```
2143  """
2144
2145  def __init__(self, name='logcosh', dtype=None):
2146    super(LogCoshError, self).__init__(logcosh, name, dtype=dtype)
2147
2148
2149@keras_export('keras.metrics.Poisson')
2150class Poisson(MeanMetricWrapper):
2151  """Computes the Poisson metric between `y_true` and `y_pred`.
2152
2153  `metric = y_pred - y_true * log(y_pred)`
2154
2155  Usage:
2156
2157  ```python
2158  m = tf.keras.metrics.Poisson()
2159  m.update_state([1, 9, 2], [4, 8, 12])
2160  print('Final result: ', m.result().numpy())  # Final result: -4.63
2161  ```
2162
2163  Usage with tf.keras API:
2164
2165  ```python
2166  model = tf.keras.Model(inputs, outputs)
2167  model.compile('sgd', metrics=[tf.keras.metrics.Poisson()])
2168  ```
2169  """
2170
2171  def __init__(self, name='poisson', dtype=None):
2172    super(Poisson, self).__init__(poisson, name, dtype=dtype)
2173
2174
2175@keras_export('keras.metrics.KLDivergence')
2176class KLDivergence(MeanMetricWrapper):
2177  """Computes Kullback Leibler divergence metric between `y_true` and `y_pred`.
2178
2179  `metric = y_true * log(y_true / y_pred)`
2180
2181  Usage:
2182
2183  ```python
2184  m = tf.keras.metrics.KLDivergence()
2185  m.update_state([.4, .9, .2], [.5, .8, .12])
2186  print('Final result: ', m.result().numpy())  # Final result: -0.043
2187  ```
2188
2189  Usage with tf.keras API:
2190
2191  ```python
2192  model = tf.keras.Model(inputs, outputs)
2193  model.compile('sgd', metrics=[tf.keras.metrics.KLDivergence()])
2194  ```
2195  """
2196
2197  def __init__(self, name='kullback_leibler_divergence', dtype=None):
2198    super(KLDivergence, self).__init__(
2199        kullback_leibler_divergence, name, dtype=dtype)
2200
2201
2202@keras_export('keras.metrics.MeanIoU')
2203class MeanIoU(Metric):
2204  """Computes the mean Intersection-Over-Union metric.
2205
2206  Mean Intersection-Over-Union is a common evaluation metric for semantic image
2207  segmentation, which first computes the IOU for each semantic class and then
2208  computes the average over classes. IOU is defined as follows:
2209    IOU = true_positive / (true_positive + false_positive + false_negative).
2210  The predictions are accumulated in a confusion matrix, weighted by
2211  `sample_weight` and the metric is then calculated from it.
2212
2213  If `sample_weight` is `None`, weights default to 1.
2214  Use `sample_weight` of 0 to mask values.
2215
2216  Usage:
2217
2218  ```python
2219  m = tf.keras.metrics.MeanIoU(num_classes=2)
2220  m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
2221
2222    # cm = [[1, 1],
2223            [1, 1]]
2224    # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
2225    # iou = true_positives / (sum_row + sum_col - true_positives))
2226    # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
2227  print('Final result: ', m.result().numpy())  # Final result: 0.33
2228  ```
2229
2230  Usage with tf.keras API:
2231
2232  ```python
2233  model = tf.keras.Model(inputs, outputs)
2234  model.compile(
2235    'sgd',
2236    loss='mse',
2237    metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])
2238  ```
2239  """
2240
2241  def __init__(self, num_classes, name=None, dtype=None):
2242    """Creates a `MeanIoU` instance.
2243
2244    Args:
2245      num_classes: The possible number of labels the prediction task can have.
2246        This value must be provided, since a confusion matrix of dimension =
2247        [num_classes, num_classes] will be allocated.
2248      name: (Optional) string name of the metric instance.
2249      dtype: (Optional) data type of the metric result.
2250    """
2251    super(MeanIoU, self).__init__(name=name, dtype=dtype)
2252    self.num_classes = num_classes
2253
2254    # Variable to accumulate the predictions in the confusion matrix. Setting
2255    # the type to be `float64` as required by confusion_matrix_ops.
2256    self.total_cm = self.add_weight(
2257        'total_confusion_matrix',
2258        shape=(num_classes, num_classes),
2259        initializer=init_ops.zeros_initializer,
2260        dtype=dtypes.float64)
2261
2262  def update_state(self, y_true, y_pred, sample_weight=None):
2263    """Accumulates the confusion matrix statistics.
2264
2265    Args:
2266      y_true: The ground truth values.
2267      y_pred: The predicted values.
2268      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2269        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2270        be broadcastable to `y_true`.
2271
2272    Returns:
2273      Update op.
2274    """
2275    # Flatten the input if its rank > 1.
2276    if y_pred.shape.ndims > 1:
2277      y_pred = array_ops.reshape(y_pred, [-1])
2278
2279    if y_true.shape.ndims > 1:
2280      y_true = array_ops.reshape(y_true, [-1])
2281
2282    if sample_weight is not None and sample_weight.shape.ndims > 1:
2283      sample_weight = array_ops.reshape(sample_weight, [-1])
2284
2285    # Accumulate the prediction to current confusion matrix.
2286    current_cm = confusion_matrix.confusion_matrix(
2287        y_true,
2288        y_pred,
2289        self.num_classes,
2290        weights=sample_weight,
2291        dtype=dtypes.float64)
2292    return self.total_cm.assign_add(current_cm)
2293
2294  def result(self):
2295    """Compute the mean intersection-over-union via the confusion matrix."""
2296    sum_over_row = math_ops.cast(
2297        math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
2298    sum_over_col = math_ops.cast(
2299        math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
2300    true_positives = math_ops.cast(
2301        array_ops.diag_part(self.total_cm), dtype=self._dtype)
2302
2303    # sum_over_row + sum_over_col =
2304    #     2 * true_positives + false_positives + false_negatives.
2305    denominator = sum_over_row + sum_over_col - true_positives
2306
2307    # The mean is only computed over classes that appear in the
2308    # label or prediction tensor. If the denominator is 0, we need to
2309    # ignore the class.
2310    num_valid_entries = math_ops.reduce_sum(
2311        math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype))
2312
2313    iou = math_ops.div_no_nan(true_positives, denominator)
2314
2315    return math_ops.div_no_nan(
2316        math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries)
2317
2318  def reset_states(self):
2319    K.set_value(self.total_cm, np.zeros((self.num_classes, self.num_classes)))
2320
2321  def get_config(self):
2322    config = {'num_classes': self.num_classes}
2323    base_config = super(MeanIoU, self).get_config()
2324    return dict(list(base_config.items()) + list(config.items()))
2325
2326
2327@keras_export('keras.metrics.MeanTensor')
2328class MeanTensor(Metric):
2329  """Computes the element-wise (weighted) mean of the given tensors.
2330
2331  `MeanTensor` returns a tensor with the same shape of the input tensors. The
2332  mean value is updated by keeping local variables `total` and `count`. The
2333  `total` tracks the sum of the weighted values, and `count` stores the sum of
2334  the weighted counts.
2335
2336  Usage:
2337
2338  ```python
2339  m = tf.keras.metrics.MeanTensor()
2340  m.update_state([0, 1, 2, 3])
2341  m.update_state([4, 5, 6, 7])
2342  print('Result: ', m.result().numpy())  # Result: [2, 3, 4, 5]
2343  m.update_state([12, 10, 8, 6], sample_weights= [0, 0.2, 0.5, 1])
2344  print('Result: ', m.result().numpy())  # Result: [2, 3.636, 4.8, 5.333]
2345  ```
2346  """
2347
2348  def __init__(self, name='mean_tensor', dtype=None):
2349    """Creates a `MeanTensor` instance.
2350
2351    Args:
2352      name: (Optional) string name of the metric instance.
2353      dtype: (Optional) data type of the metric result.
2354    """
2355    super(MeanTensor, self).__init__(name=name, dtype=dtype)
2356    self._shape = None
2357    self._total = None
2358    self._count = None
2359    self._built = False
2360
2361  def _build(self, shape):
2362    self._shape = tensor_shape.TensorShape(shape)
2363    # Create new state variables
2364    self._total = self.add_weight(
2365        'total', shape=shape, initializer=init_ops.zeros_initializer)
2366    self._count = self.add_weight(
2367        'count', shape=shape, initializer=init_ops.zeros_initializer)
2368    with ops.init_scope():
2369      if not context.executing_eagerly():
2370        K._initialize_variables(K._get_session())  # pylint: disable=protected-access
2371    self._built = True
2372
2373  @property
2374  def total(self):
2375    return self._total if self._built else None
2376
2377  @property
2378  def count(self):
2379    return self._count if self._built else None
2380
2381  def update_state(self, values, sample_weight=None):
2382    """Accumulates statistics for computing the element-wise mean.
2383
2384    Args:
2385      values: Per-example value.
2386      sample_weight: Optional weighting of each example. Defaults to 1.
2387
2388    Returns:
2389      Update op.
2390    """
2391    values = math_ops.cast(values, self._dtype)
2392    if not self._built:
2393      self._build(values.shape)
2394    elif values.shape != self._shape:
2395      raise ValueError('MeanTensor input values must always have the same '
2396                       'shape. Expected shape (set during the first call): {}. '
2397                       'Got: {}'.format(self._shape, values.get_shape()))
2398
2399    num_values = array_ops.ones_like(values)
2400    if sample_weight is not None:
2401      sample_weight = math_ops.cast(sample_weight, self._dtype)
2402
2403      # Update dimensions of weights to match with values if possible.
2404      values, _, sample_weight = squeeze_or_expand_dimensions(
2405          values, None, sample_weight)
2406      try:
2407        # Broadcast weights if possible.
2408        sample_weight = weights_broadcast_ops.broadcast_weights(
2409            sample_weight, values)
2410      except ValueError:
2411        # Reduce values to same ndim as weight array
2412        ndim = K.ndim(values)
2413        weight_ndim = K.ndim(sample_weight)
2414        values = math_ops.reduce_mean(
2415            values, axis=list(range(weight_ndim, ndim)))
2416
2417      num_values = math_ops.multiply(num_values, sample_weight)
2418      values = math_ops.multiply(values, sample_weight)
2419
2420    update_total_op = self._total.assign_add(values)
2421    with ops.control_dependencies([update_total_op]):
2422      return self._count.assign_add(num_values)
2423
2424  def result(self):
2425    if not self._built:
2426      raise ValueError(
2427          'MeanTensor does not have any result yet. Please call the MeanTensor '
2428          'instance or use `.update_state(value)` before retrieving the result.'
2429          )
2430    return math_ops.div_no_nan(self.total, self.count)
2431
2432  def reset_states(self):
2433    if self._built:
2434      K.batch_set_value(
2435          [(v, np.zeros(self._shape.as_list())) for v in self.variables])
2436
2437
2438@keras_export('keras.metrics.BinaryCrossentropy')
2439class BinaryCrossentropy(MeanMetricWrapper):
2440  """Computes the crossentropy metric between the labels and predictions.
2441
2442  This is the crossentropy metric class to be used when there are only two
2443  label classes (0 and 1).
2444
2445  Usage:
2446
2447  ```python
2448  m = tf.keras.metrics.BinaryCrossentropy()
2449  m.update_state([1., 0., 1., 0.], [1., 1., 1., 0.])
2450
2451  # EPSILON = 1e-7, y = y_true, y` = y_pred, Y_MAX = 0.9999999
2452  # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)
2453  # y` = [Y_MAX, Y_MAX, Y_MAX, EPSILON]
2454
2455  # Metric = -(y log(y` + EPSILON) + (1 - y) log(1 - y` + EPSILON))
2456  #        = [-log(Y_MAX + EPSILON), -log(1 - Y_MAX + EPSILON),
2457  #           -log(Y_MAX + EPSILON), -log(1)]
2458  #        = [(0 + 15.33) / 2, (0 + 0) / 2]
2459  # Reduced metric = 7.665 / 2
2460
2461  print('Final result: ', m.result().numpy())  # Final result: 3.833
2462  ```
2463
2464  Usage with tf.keras API:
2465
2466  ```python
2467  model = tf.keras.Model(inputs, outputs)
2468  model.compile(
2469      'sgd',
2470      loss='mse',
2471      metrics=[tf.keras.metrics.BinaryCrossentropy()])
2472  ```
2473  """
2474
2475  def __init__(self,
2476               name='binary_crossentropy',
2477               dtype=None,
2478               from_logits=False,
2479               label_smoothing=0):
2480    """Creates a `BinaryCrossentropy` instance.
2481
2482    Args:
2483      name: (Optional) string name of the metric instance.
2484      dtype: (Optional) data type of the metric result.
2485      from_logits: (Optional )Whether output is expected to be a logits tensor.
2486        By default, we consider that output encodes a probability distribution.
2487      label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
2488        smoothed, meaning the confidence on label values are relaxed.
2489        e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for
2490        label `0` and `0.9` for label `1`"
2491    """
2492
2493    super(BinaryCrossentropy, self).__init__(
2494        binary_crossentropy,
2495        name,
2496        dtype=dtype,
2497        from_logits=from_logits,
2498        label_smoothing=label_smoothing)
2499
2500
2501@keras_export('keras.metrics.CategoricalCrossentropy')
2502class CategoricalCrossentropy(MeanMetricWrapper):
2503  """Computes the crossentropy metric between the labels and predictions.
2504
2505  This is the crossentropy metric class to be used when there are multiple
2506  label classes (2 or more). Here we assume that labels are given as a `one_hot`
2507  representation. eg., When labels values are [2, 0, 1],
2508   `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]].
2509
2510  Usage:
2511
2512  ```python
2513  m = tf.keras.metrics.CategoricalCrossentropy()
2514  m.update_state([[0, 1, 0], [0, 0, 1]],
2515                 [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
2516
2517  # EPSILON = 1e-7, y = y_true, y` = y_pred
2518  # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)
2519  # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
2520
2521  # xent = -sum(y * log(y'), axis = -1)
2522  #      = -((log 0.95), (log 0.1))
2523  #      = [0.051, 2.302]
2524  # Reduced xent = (0.051 + 2.302) / 2
2525
2526  print('Final result: ', m.result().numpy())  # Final result: 1.176
2527  ```
2528
2529  Usage with tf.keras API:
2530
2531  ```python
2532  model = tf.keras.Model(inputs, outputs)
2533  model.compile(
2534    'sgd',
2535    loss='mse',
2536    metrics=[tf.keras.metrics.CategoricalCrossentropy()])
2537  ```
2538
2539  Args:
2540    name: (Optional) string name of the metric instance.
2541    dtype: (Optional) data type of the metric result.
2542    from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor.
2543      By default, we assume that `y_pred` encodes a probability distribution.
2544    label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
2545      meaning the confidence on label values are relaxed. e.g.
2546      `label_smoothing=0.2` means that we will use a value of `0.1` for label
2547      `0` and `0.9` for label `1`"
2548  """
2549
2550  def __init__(self,
2551               name='categorical_crossentropy',
2552               dtype=None,
2553               from_logits=False,
2554               label_smoothing=0):
2555
2556    super(CategoricalCrossentropy, self).__init__(
2557        categorical_crossentropy,
2558        name,
2559        dtype=dtype,
2560        from_logits=from_logits,
2561        label_smoothing=label_smoothing)
2562
2563
2564@keras_export('keras.metrics.SparseCategoricalCrossentropy')
2565class SparseCategoricalCrossentropy(MeanMetricWrapper):
2566  """Computes the crossentropy metric between the labels and predictions.
2567
2568  Use this crossentropy metric when there are two or more label classes.
2569  We expect labels to be provided as integers. If you want to provide labels
2570  using `one-hot` representation, please use `CategoricalCrossentropy` metric.
2571  There should be `# classes` floating point values per feature for `y_pred`
2572  and a single floating point value per feature for `y_true`.
2573
2574  In the snippet below, there is a single floating point value per example for
2575  `y_true` and `# classes` floating pointing values per example for `y_pred`.
2576  The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
2577  `[batch_size, num_classes]`.
2578
2579  Usage:
2580
2581  ```python
2582  m = tf.keras.metrics.SparseCategoricalCrossentropy()
2583  m.update_state(
2584    [1, 2],
2585    [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
2586
2587  # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]
2588  # logits = log(y_pred)
2589  # softmax = exp(logits) / sum(exp(logits), axis=-1)
2590  # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
2591
2592  # xent = -sum(y * log(softmax), 1)
2593  # log(softmax) = [[-2.9957, -0.0513, -16.1181], [-2.3026, -0.2231, -2.3026]]
2594  # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]]
2595
2596  # xent = [0.0513, 2.3026]
2597  # Reduced xent = (0.0513 + 2.3026) / 2
2598
2599  print('Final result: ', m.result().numpy())  # Final result: 1.176
2600  ```
2601
2602  Usage with tf.keras API:
2603
2604  ```python
2605  model = tf.keras.Model(inputs, outputs)
2606  model.compile(
2607    'sgd',
2608    loss='mse',
2609    metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()])
2610  ```
2611
2612  Args:
2613    name: (Optional) string name of the metric instance.
2614    dtype: (Optional) data type of the metric result.
2615    from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor.
2616      By default, we assume that `y_pred` encodes a probability distribution.
2617    axis: (Optional) Defaults to -1. The dimension along which the metric is
2618      computed.
2619  """
2620
2621  def __init__(self,
2622               name='sparse_categorical_crossentropy',
2623               dtype=None,
2624               from_logits=False,
2625               axis=-1):
2626
2627    super(SparseCategoricalCrossentropy, self).__init__(
2628        sparse_categorical_crossentropy,
2629        name,
2630        dtype=dtype,
2631        from_logits=from_logits,
2632        axis=axis)
2633
2634
2635class SumOverBatchSize(Reduce):
2636  """Computes the weighted sum over batch size of the given values.
2637
2638  For example, if values is [1, 3, 5, 7] then the metric value is 4.
2639  If the weights were specified as [1, 1, 0, 0] then the value would be 1.
2640
2641  This metric creates two variables, `total` and `count` that are used to
2642  compute the average of `values`. This average is ultimately returned as sum
2643  over batch size which is an idempotent operation that simply divides `total`
2644  by `count`.
2645
2646  If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
2647  to mask values.
2648  """
2649
2650  def __init__(self, name='sum_over_batch_size', dtype=None):
2651    super(SumOverBatchSize, self).__init__(
2652        reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
2653        name=name,
2654        dtype=dtype)
2655
2656
2657class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
2658  """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric."""
2659
2660  def __init__(self, fn, name=None, dtype=None, **kwargs):
2661    """Creates a `SumOverBatchSizeMetricWrapper` instance.
2662
2663    Args:
2664      fn: The metric function to wrap, with signature `fn(y_true, y_pred,
2665        **kwargs)`.
2666      name: (Optional) string name of the metric instance.
2667      dtype: (Optional) data type of the metric result.
2668      **kwargs: The keyword arguments that are passed on to `fn`.
2669    """
2670    super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype)
2671    self._fn = fn
2672    self._fn_kwargs = kwargs
2673
2674  def update_state(self, y_true, y_pred, sample_weight=None):
2675    y_true = math_ops.cast(y_true, self._dtype)
2676    y_pred = math_ops.cast(y_pred, self._dtype)
2677    y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
2678        y_pred, y_true, sample_weight)
2679
2680    matches = self._fn(y_true, y_pred, **self._fn_kwargs)
2681    return super(SumOverBatchSizeMetricWrapper, self).update_state(
2682        matches, sample_weight=sample_weight)
2683
2684  def get_config(self):
2685    config = {}
2686    for k, v in six.iteritems(self._fn_kwargs):
2687      config[k] = K.eval(v) if is_tensor_or_variable(v) else v
2688    base_config = super(SumOverBatchSizeMetricWrapper, self).get_config()
2689    return dict(list(base_config.items()) + list(config.items()))
2690
2691
2692def accuracy(y_true, y_pred):
2693  y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
2694  if y_true.dtype != y_pred.dtype:
2695    y_pred = math_ops.cast(y_pred, y_true.dtype)
2696  return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
2697
2698
2699@keras_export('keras.metrics.binary_accuracy')
2700def binary_accuracy(y_true, y_pred, threshold=0.5):
2701  threshold = math_ops.cast(threshold, y_pred.dtype)
2702  y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
2703  return K.mean(math_ops.equal(y_true, y_pred), axis=-1)
2704
2705
2706@keras_export('keras.metrics.categorical_accuracy')
2707def categorical_accuracy(y_true, y_pred):
2708  return math_ops.cast(
2709      math_ops.equal(
2710          math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)),
2711      K.floatx())
2712
2713
2714@keras_export('keras.metrics.sparse_categorical_accuracy')
2715def sparse_categorical_accuracy(y_true, y_pred):
2716  y_pred_rank = ops.convert_to_tensor(y_pred).get_shape().ndims
2717  y_true_rank = ops.convert_to_tensor(y_true).get_shape().ndims
2718  # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
2719  if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
2720      K.int_shape(y_true)) == len(K.int_shape(y_pred))):
2721    y_true = array_ops.squeeze(y_true, [-1])
2722  y_pred = math_ops.argmax(y_pred, axis=-1)
2723
2724  # If the predicted output and actual output types don't match, force cast them
2725  # to match.
2726  if K.dtype(y_pred) != K.dtype(y_true):
2727    y_pred = math_ops.cast(y_pred, K.dtype(y_true))
2728
2729  return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
2730
2731
2732@keras_export('keras.metrics.top_k_categorical_accuracy')
2733def top_k_categorical_accuracy(y_true, y_pred, k=5):
2734  return K.mean(
2735      nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), axis=-1)
2736
2737
2738@keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
2739def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
2740  y_pred_rank = ops.convert_to_tensor(y_pred).get_shape().ndims
2741  y_true_rank = ops.convert_to_tensor(y_true).get_shape().ndims
2742  # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
2743  if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
2744      K.int_shape(y_true)) == len(K.int_shape(y_pred))):
2745    y_true = array_ops.squeeze(y_true, [-1])
2746
2747  return K.mean(nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), axis=-1)
2748
2749# Aliases
2750
2751mse = MSE = mean_squared_error
2752mae = MAE = mean_absolute_error
2753mape = MAPE = mean_absolute_percentage_error
2754msle = MSLE = mean_squared_logarithmic_error
2755cosine_proximity = cosine_similarity
2756
2757
2758def clone_metric(metric):
2759  """Returns a clone of the metric if stateful, otherwise returns it as is."""
2760  if isinstance(metric, Metric):
2761    return metric.__class__.from_config(metric.get_config())
2762  return metric
2763
2764
2765def clone_metrics(metrics):
2766  """Clones the given metric list/dict."""
2767  if metrics is None:
2768    return None
2769  if isinstance(metrics, dict):
2770    return {key: clone_metric(value) for key, value in metrics.items()}
2771  return [clone_metric(metric) for metric in metrics]
2772
2773
2774@keras_export('keras.metrics.serialize')
2775def serialize(metric):
2776  return serialize_keras_object(metric)
2777
2778
2779@keras_export('keras.metrics.deserialize')
2780def deserialize(config, custom_objects=None):
2781  return deserialize_keras_object(
2782      config,
2783      module_objects=globals(),
2784      custom_objects=custom_objects,
2785      printable_module_name='metric function')
2786
2787
2788@keras_export('keras.metrics.get')
2789def get(identifier):
2790  if isinstance(identifier, dict):
2791    return deserialize(identifier)
2792  elif isinstance(identifier, six.string_types):
2793    return deserialize(str(identifier))
2794  elif callable(identifier):
2795    return identifier
2796  else:
2797    raise ValueError('Could not interpret '
2798                     'metric function identifier: %s' % identifier)
2799