• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unused-import
16# pylint: disable=g-classes-have-attributes
17"""Built-in metrics.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24import math
25import types
26
27import numpy as np
28import six
29
30from tensorflow.python.autograph.core import ag_ctx
31from tensorflow.python.autograph.impl import api as autograph
32from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
33from tensorflow.python.eager import context
34from tensorflow.python.eager import def_function
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.keras import activations
41from tensorflow.python.keras import backend as K
42from tensorflow.python.keras.engine import base_layer
43from tensorflow.python.keras.engine import base_layer_utils
44from tensorflow.python.keras.engine import keras_tensor
45from tensorflow.python.keras.losses import binary_crossentropy
46from tensorflow.python.keras.losses import categorical_crossentropy
47from tensorflow.python.keras.losses import categorical_hinge
48from tensorflow.python.keras.losses import hinge
49from tensorflow.python.keras.losses import kullback_leibler_divergence
50from tensorflow.python.keras.losses import logcosh
51from tensorflow.python.keras.losses import mean_absolute_error
52from tensorflow.python.keras.losses import mean_absolute_percentage_error
53from tensorflow.python.keras.losses import mean_squared_error
54from tensorflow.python.keras.losses import mean_squared_logarithmic_error
55from tensorflow.python.keras.losses import poisson
56from tensorflow.python.keras.losses import sparse_categorical_crossentropy
57from tensorflow.python.keras.losses import squared_hinge
58from tensorflow.python.keras.saving.saved_model import metric_serialization
59from tensorflow.python.keras.utils import losses_utils
60from tensorflow.python.keras.utils import metrics_utils
61from tensorflow.python.keras.utils import tf_inspect
62from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
63from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
64from tensorflow.python.keras.utils.generic_utils import to_list
65from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
66from tensorflow.python.ops import array_ops
67from tensorflow.python.ops import check_ops
68from tensorflow.python.ops import confusion_matrix
69from tensorflow.python.ops import control_flow_ops
70from tensorflow.python.ops import init_ops
71from tensorflow.python.ops import math_ops
72from tensorflow.python.ops import nn
73from tensorflow.python.ops import variables as tf_variables
74from tensorflow.python.ops import weights_broadcast_ops
75from tensorflow.python.training.tracking import base as trackable
76from tensorflow.python.util import dispatch
77from tensorflow.python.util import nest
78from tensorflow.python.util.tf_export import keras_export
79from tensorflow.tools.docs import doc_controls
80
81
82@keras_export('keras.metrics.Metric')
83@six.add_metaclass(abc.ABCMeta)
84class Metric(base_layer.Layer):
85  """Encapsulates metric logic and state.
86
87  Args:
88    name: (Optional) string name of the metric instance.
89    dtype: (Optional) data type of the metric result.
90    **kwargs: Additional layer keywords arguments.
91
92  Standalone usage:
93
94  ```python
95  m = SomeMetric(...)
96  for input in ...:
97    m.update_state(input)
98  print('Final result: ', m.result().numpy())
99  ```
100
101  Usage with `compile()` API:
102
103  ```python
104  model = tf.keras.Sequential()
105  model.add(tf.keras.layers.Dense(64, activation='relu'))
106  model.add(tf.keras.layers.Dense(64, activation='relu'))
107  model.add(tf.keras.layers.Dense(10, activation='softmax'))
108
109  model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),
110                loss=tf.keras.losses.CategoricalCrossentropy(),
111                metrics=[tf.keras.metrics.CategoricalAccuracy()])
112
113  data = np.random.random((1000, 32))
114  labels = np.random.random((1000, 10))
115
116  dataset = tf.data.Dataset.from_tensor_slices((data, labels))
117  dataset = dataset.batch(32)
118
119  model.fit(dataset, epochs=10)
120  ```
121
122  To be implemented by subclasses:
123  * `__init__()`: All state variables should be created in this method by
124    calling `self.add_weight()` like: `self.var = self.add_weight(...)`
125  * `update_state()`: Has all updates to the state variables like:
126    self.var.assign_add(...).
127  * `result()`: Computes and returns a value for the metric
128    from the state variables.
129
130  Example subclass implementation:
131
132  ```python
133  class BinaryTruePositives(tf.keras.metrics.Metric):
134
135    def __init__(self, name='binary_true_positives', **kwargs):
136      super(BinaryTruePositives, self).__init__(name=name, **kwargs)
137      self.true_positives = self.add_weight(name='tp', initializer='zeros')
138
139    def update_state(self, y_true, y_pred, sample_weight=None):
140      y_true = tf.cast(y_true, tf.bool)
141      y_pred = tf.cast(y_pred, tf.bool)
142
143      values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
144      values = tf.cast(values, self.dtype)
145      if sample_weight is not None:
146        sample_weight = tf.cast(sample_weight, self.dtype)
147        sample_weight = tf.broadcast_to(sample_weight, values.shape)
148        values = tf.multiply(values, sample_weight)
149      self.true_positives.assign_add(tf.reduce_sum(values))
150
151    def result(self):
152      return self.true_positives
153  ```
154  """
155
156  def __init__(self, name=None, dtype=None, **kwargs):
157    super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
158    self.stateful = True  # All metric layers are stateful.
159    self.built = True
160    if not base_layer_utils.v2_dtype_behavior_enabled():
161      # We only do this when the V2 behavior is not enabled, as when it is
162      # enabled, the dtype already defaults to floatx.
163      self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name
164
165  def __new__(cls, *args, **kwargs):
166    obj = super(Metric, cls).__new__(cls)
167
168    # If `update_state` is not in eager/tf.function and it is not from a
169    # built-in metric, wrap it in `tf.function`. This is so that users writing
170    # custom metrics in v1 need not worry about control dependencies and
171    # return ops.
172    if (base_layer_utils.is_in_eager_or_tf_function() or
173        is_built_in(cls)):
174      obj_update_state = obj.update_state
175
176      def update_state_fn(*args, **kwargs):
177        control_status = ag_ctx.control_status_ctx()
178        ag_update_state = autograph.tf_convert(obj_update_state, control_status)
179        return ag_update_state(*args, **kwargs)
180    else:
181      if isinstance(obj.update_state, def_function.Function):
182        update_state_fn = obj.update_state
183      else:
184        update_state_fn = def_function.function(obj.update_state)
185
186    obj.update_state = types.MethodType(
187        metrics_utils.update_state_wrapper(update_state_fn), obj)
188
189    obj_result = obj.result
190
191    def result_fn(*args, **kwargs):
192      control_status = ag_ctx.control_status_ctx()
193      ag_result = autograph.tf_convert(obj_result, control_status)
194      return ag_result(*args, **kwargs)
195
196    obj.result = types.MethodType(metrics_utils.result_wrapper(result_fn), obj)
197
198    return obj
199
200  def __call__(self, *args, **kwargs):
201    """Accumulates statistics and then computes metric result value.
202
203    Args:
204      *args:
205      **kwargs: A mini-batch of inputs to the Metric,
206        passed on to `update_state()`.
207
208    Returns:
209      The metric value tensor.
210    """
211
212    def replica_local_fn(*args, **kwargs):
213      """Updates the state of the metric in a replica-local context."""
214      if any(
215          isinstance(arg, keras_tensor.KerasTensor)
216          for arg in nest.flatten((args, kwargs))):
217        update_op = None
218      else:
219        update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
220      update_ops = []
221      if update_op is not None:
222        update_ops.append(update_op)
223      with ops.control_dependencies(update_ops):
224        result_t = self.result()  # pylint: disable=not-callable
225
226        # We are adding the metric object as metadata on the result tensor.
227        # This is required when we want to use a metric with `add_metric` API on
228        # a Model/Layer in graph mode. This metric instance will later be used
229        # to reset variable state after each epoch of training.
230        # Example:
231        #   model = Model()
232        #   mean = Mean()
233        #   model.add_metric(mean(values), name='mean')
234        result_t._metric_obj = self  # pylint: disable=protected-access
235        return result_t
236
237    from tensorflow.python.keras.distribute import distributed_training_utils  # pylint:disable=g-import-not-at-top
238    return distributed_training_utils.call_replica_local_fn(
239        replica_local_fn, *args, **kwargs)
240
241  @property
242  def dtype(self):
243    return self._dtype
244
245  def get_config(self):
246    """Returns the serializable config of the metric."""
247    return {'name': self.name, 'dtype': self.dtype}
248
249  def reset_states(self):
250    """Resets all of the metric state variables.
251
252    This function is called between epochs/steps,
253    when a metric is evaluated during training.
254    """
255    K.batch_set_value([(v, 0) for v in self.variables])
256
257  @abc.abstractmethod
258  def update_state(self, *args, **kwargs):
259    """Accumulates statistics for the metric.
260
261    Note: This function is executed as a graph function in graph mode.
262    This means:
263      a) Operations on the same resource are executed in textual order.
264         This should make it easier to do things like add the updated
265         value of a variable to another, for example.
266      b) You don't need to worry about collecting the update ops to execute.
267         All update ops added to the graph by this function will be executed.
268      As a result, code should generally work the same way with graph or
269      eager execution.
270
271    Args:
272      *args:
273      **kwargs: A mini-batch of inputs to the Metric.
274    """
275    raise NotImplementedError('Must be implemented in subclasses.')
276
277  @abc.abstractmethod
278  def result(self):
279    """Computes and returns the metric value tensor.
280
281    Result computation is an idempotent operation that simply calculates the
282    metric value using the state variables.
283    """
284    raise NotImplementedError('Must be implemented in subclasses.')
285
286  ### For use by subclasses ###
287  @doc_controls.for_subclass_implementers
288  def add_weight(self,
289                 name,
290                 shape=(),
291                 aggregation=tf_variables.VariableAggregation.SUM,
292                 synchronization=tf_variables.VariableSynchronization.ON_READ,
293                 initializer=None,
294                 dtype=None):
295    """Adds state variable. Only for use by subclasses."""
296    if distribute_ctx.has_strategy():
297      strategy = distribute_ctx.get_strategy()
298    else:
299      strategy = None
300
301    # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU.
302    if K.is_tpu_strategy(strategy):
303      synchronization = tf_variables.VariableSynchronization.ON_WRITE
304
305    with ops.init_scope():
306      return super(Metric, self).add_weight(
307          name=name,
308          shape=shape,
309          dtype=self._dtype if dtype is None else dtype,
310          trainable=False,
311          initializer=initializer,
312          collections=[],
313          synchronization=synchronization,
314          aggregation=aggregation)
315
316  ### End: For use by subclasses ###
317
318  @property
319  def trainable_weights(self):
320    # Overridden from Layer class to track submetric weights.
321    if self.trainable:
322      trainable_weights = self._trainable_weights
323      for m in self._metrics:
324        trainable_weights += m.trainable_weights
325      return self._dedup_weights(trainable_weights)
326    else:
327      return []
328
329  @property
330  def non_trainable_weights(self):
331    # Overridden from Layer class to track submetric weights.
332    if self.trainable:
333      non_trainable_weights = self._non_trainable_weights
334      for m in self._metrics:
335        non_trainable_weights += m.non_trainable_weights
336    else:
337      non_trainable_weights = (
338          self._non_trainable_weights + self._trainable_weights)
339      for m in self._metrics:
340        non_trainable_weights += m.weights
341    return self._dedup_weights(non_trainable_weights)
342
343  @property
344  def _trackable_saved_model_saver(self):
345    return metric_serialization.MetricSavedModelSaver(self)
346
347
348class Reduce(Metric):
349  """Encapsulates metrics that perform a reduce operation on the values.
350
351  Args:
352    reduction: a `tf.keras.metrics.Reduction` enum value.
353    name: string name of the metric instance.
354    dtype: (Optional) data type of the metric result.
355  """
356
357  def __init__(self, reduction, name, dtype=None):
358    super(Reduce, self).__init__(name=name, dtype=dtype)
359    self.reduction = reduction
360    self.total = self.add_weight(
361        'total', initializer=init_ops.zeros_initializer)
362    if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
363                     metrics_utils.Reduction.WEIGHTED_MEAN]:
364      self.count = self.add_weight(
365          'count', initializer=init_ops.zeros_initializer)
366
367  def update_state(self, values, sample_weight=None):
368    """Accumulates statistics for computing the metric.
369
370    Args:
371      values: Per-example value.
372      sample_weight: Optional weighting of each example. Defaults to 1.
373
374    Returns:
375      Update op.
376    """
377    [values], sample_weight = \
378        metrics_utils.ragged_assert_compatible_and_get_flat_values(
379            [values], sample_weight)
380    values = math_ops.cast(values, self._dtype)
381    if sample_weight is not None:
382      sample_weight = math_ops.cast(sample_weight, self._dtype)
383      # Update dimensions of weights to match with values if possible.
384      values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
385          values, sample_weight=sample_weight)
386      try:
387        # Broadcast weights if possible.
388        sample_weight = weights_broadcast_ops.broadcast_weights(
389            sample_weight, values)
390      except ValueError:
391        # Reduce values to same ndim as weight array
392        ndim = K.ndim(values)
393        weight_ndim = K.ndim(sample_weight)
394        if self.reduction == metrics_utils.Reduction.SUM:
395          values = math_ops.reduce_sum(
396              values, axis=list(range(weight_ndim, ndim)))
397        else:
398          values = math_ops.reduce_mean(
399              values, axis=list(range(weight_ndim, ndim)))
400      values = math_ops.multiply(values, sample_weight)
401
402    value_sum = math_ops.reduce_sum(values)
403    with ops.control_dependencies([value_sum]):
404      update_total_op = self.total.assign_add(value_sum)
405
406    # Exit early if the reduction doesn't have a denominator.
407    if self.reduction == metrics_utils.Reduction.SUM:
408      return update_total_op
409
410    # Update `count` for reductions that require a denominator.
411    if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
412      num_values = math_ops.cast(array_ops.size(values), self._dtype)
413    elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
414      if sample_weight is None:
415        num_values = math_ops.cast(array_ops.size(values), self._dtype)
416      else:
417        num_values = math_ops.reduce_sum(sample_weight)
418    else:
419      raise NotImplementedError(
420          'reduction [%s] not implemented' % self.reduction)
421
422    with ops.control_dependencies([update_total_op]):
423      return self.count.assign_add(num_values)
424
425  def result(self):
426    if self.reduction == metrics_utils.Reduction.SUM:
427      return array_ops.identity(self.total)
428    elif self.reduction in [
429        metrics_utils.Reduction.WEIGHTED_MEAN,
430        metrics_utils.Reduction.SUM_OVER_BATCH_SIZE
431    ]:
432      return math_ops.div_no_nan(self.total, self.count)
433    else:
434      raise NotImplementedError(
435          'reduction [%s] not implemented' % self.reduction)
436
437
438@keras_export('keras.metrics.Sum')
439class Sum(Reduce):
440  """Computes the (weighted) sum of the given values.
441
442  For example, if values is [1, 3, 5, 7] then the sum is 16.
443  If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
444
445  This metric creates one variable, `total`, that is used to compute the sum of
446  `values`. This is ultimately returned as `sum`.
447
448  If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
449  to mask values.
450
451  Args:
452    name: (Optional) string name of the metric instance.
453    dtype: (Optional) data type of the metric result.
454
455  Standalone usage:
456
457  >>> m = tf.keras.metrics.Sum()
458  >>> m.update_state([1, 3, 5, 7])
459  >>> m.result().numpy()
460  16.0
461
462  Usage with `compile()` API:
463
464  ```python
465  model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs))
466  model.compile(optimizer='sgd', loss='mse')
467  ```
468  """
469
470  def __init__(self, name='sum', dtype=None):
471    super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
472                              name=name, dtype=dtype)
473
474
475@keras_export('keras.metrics.Mean')
476class Mean(Reduce):
477  """Computes the (weighted) mean of the given values.
478
479  For example, if values is [1, 3, 5, 7] then the mean is 4.
480  If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
481
482  This metric creates two variables, `total` and `count` that are used to
483  compute the average of `values`. This average is ultimately returned as `mean`
484  which is an idempotent operation that simply divides `total` by `count`.
485
486  If `sample_weight` is `None`, weights default to 1.
487  Use `sample_weight` of 0 to mask values.
488
489  Args:
490    name: (Optional) string name of the metric instance.
491    dtype: (Optional) data type of the metric result.
492
493  Standalone usage:
494
495  >>> m = tf.keras.metrics.Mean()
496  >>> m.update_state([1, 3, 5, 7])
497  >>> m.result().numpy()
498  4.0
499  >>> m.reset_states()
500  >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])
501  >>> m.result().numpy()
502  2.0
503
504  Usage with `compile()` API:
505
506  ```python
507  model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs))
508  model.compile(optimizer='sgd', loss='mse')
509  ```
510  """
511
512  def __init__(self, name='mean', dtype=None):
513    super(Mean, self).__init__(
514        reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype)
515
516
517@keras_export('keras.metrics.MeanRelativeError')
518class MeanRelativeError(Mean):
519  """Computes the mean relative error by normalizing with the given values.
520
521  This metric creates two local variables, `total` and `count` that are used to
522  compute the mean relative error. This is weighted by `sample_weight`, and
523  it is ultimately returned as `mean_relative_error`:
524  an idempotent operation that simply divides `total` by `count`.
525
526  If `sample_weight` is `None`, weights default to 1.
527  Use `sample_weight` of 0 to mask values.
528
529  Args:
530    normalizer: The normalizer values with same shape as predictions.
531    name: (Optional) string name of the metric instance.
532    dtype: (Optional) data type of the metric result.
533
534  Standalone usage:
535
536  >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3])
537  >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8])
538
539  >>> # metric = mean(|y_pred - y_true| / normalizer)
540  >>> #        = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3])
541  >>> #        = 5/4 = 1.25
542  >>> m.result().numpy()
543  1.25
544
545  Usage with `compile()` API:
546
547  ```python
548  model.compile(
549    optimizer='sgd',
550    loss='mse',
551    metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])])
552  ```
553  """
554
555  def __init__(self, normalizer, name=None, dtype=None):
556    super(MeanRelativeError, self).__init__(name=name, dtype=dtype)
557    normalizer = math_ops.cast(normalizer, self._dtype)
558    self.normalizer = normalizer
559
560  def update_state(self, y_true, y_pred, sample_weight=None):
561    """Accumulates metric statistics.
562
563    Args:
564      y_true: The ground truth values.
565      y_pred: The predicted values.
566      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
567        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
568        be broadcastable to `y_true`.
569
570    Returns:
571      Update op.
572    """
573    y_true = math_ops.cast(y_true, self._dtype)
574    y_pred = math_ops.cast(y_pred, self._dtype)
575    [y_pred, y_true], sample_weight = \
576        metrics_utils.ragged_assert_compatible_and_get_flat_values(
577            [y_pred, y_true], sample_weight)
578    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
579        y_pred, y_true)
580
581    y_pred, self.normalizer = losses_utils.remove_squeezable_dimensions(
582        y_pred, self.normalizer)
583    y_pred.shape.assert_is_compatible_with(y_true.shape)
584    relative_errors = math_ops.div_no_nan(
585        math_ops.abs(y_true - y_pred), self.normalizer)
586
587    return super(MeanRelativeError, self).update_state(
588        relative_errors, sample_weight=sample_weight)
589
590  def get_config(self):
591    n = self.normalizer
592    config = {'normalizer': K.eval(n) if is_tensor_or_variable(n) else n}
593    base_config = super(MeanRelativeError, self).get_config()
594    return dict(list(base_config.items()) + list(config.items()))
595
596
597class MeanMetricWrapper(Mean):
598  """Wraps a stateless metric function with the Mean metric.
599
600  Args:
601    fn: The metric function to wrap, with signature `fn(y_true, y_pred,
602      **kwargs)`.
603    name: (Optional) string name of the metric instance.
604    dtype: (Optional) data type of the metric result.
605    **kwargs: The keyword arguments that are passed on to `fn`.
606  """
607
608  def __init__(self, fn, name=None, dtype=None, **kwargs):
609    super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
610    self._fn = fn
611    self._fn_kwargs = kwargs
612
613  def update_state(self, y_true, y_pred, sample_weight=None):
614    """Accumulates metric statistics.
615
616    `y_true` and `y_pred` should have the same shape.
617
618    Args:
619      y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
620      y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
621      sample_weight: Optional `sample_weight` acts as a
622        coefficient for the metric. If a scalar is provided, then the metric is
623        simply scaled by the given value. If `sample_weight` is a tensor of size
624        `[batch_size]`, then the metric for each sample of the batch is rescaled
625        by the corresponding element in the `sample_weight` vector. If the shape
626        of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
627        to this shape), then each metric element of `y_pred` is scaled by the
628        corresponding value of `sample_weight`. (Note on `dN-1`: all metric
629        functions reduce by 1 dimension, usually the last axis (-1)).
630
631    Returns:
632      Update op.
633    """
634    y_true = math_ops.cast(y_true, self._dtype)
635    y_pred = math_ops.cast(y_pred, self._dtype)
636    [y_true, y_pred], sample_weight = \
637        metrics_utils.ragged_assert_compatible_and_get_flat_values(
638            [y_true, y_pred], sample_weight)
639    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
640        y_pred, y_true)
641
642    ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
643    matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
644    return super(MeanMetricWrapper, self).update_state(
645        matches, sample_weight=sample_weight)
646
647  def get_config(self):
648    config = {}
649
650    if type(self) is MeanMetricWrapper:  # pylint: disable=unidiomatic-typecheck
651      # Only include function argument when the object is a MeanMetricWrapper
652      # and not a subclass.
653      config['fn'] = self._fn
654
655    for k, v in six.iteritems(self._fn_kwargs):
656      config[k] = K.eval(v) if is_tensor_or_variable(v) else v
657    base_config = super(MeanMetricWrapper, self).get_config()
658    return dict(list(base_config.items()) + list(config.items()))
659
660  @classmethod
661  def from_config(cls, config):
662    # Note that while MeanMetricWrapper itself isn't public, objects of this
663    # class may be created and added to the model by calling model.compile.
664    fn = config.pop('fn', None)
665    if cls is MeanMetricWrapper:
666      return cls(get(fn), **config)
667    return super(MeanMetricWrapper, cls).from_config(config)
668
669
670@keras_export('keras.metrics.Accuracy')
671class Accuracy(MeanMetricWrapper):
672  """Calculates how often predictions equal labels.
673
674  This metric creates two local variables, `total` and `count` that are used to
675  compute the frequency with which `y_pred` matches `y_true`. This frequency is
676  ultimately returned as `binary accuracy`: an idempotent operation that simply
677  divides `total` by `count`.
678
679  If `sample_weight` is `None`, weights default to 1.
680  Use `sample_weight` of 0 to mask values.
681
682  Args:
683    name: (Optional) string name of the metric instance.
684    dtype: (Optional) data type of the metric result.
685
686  Standalone usage:
687
688  >>> m = tf.keras.metrics.Accuracy()
689  >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])
690  >>> m.result().numpy()
691  0.75
692
693  >>> m.reset_states()
694  >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]],
695  ...                sample_weight=[1, 1, 0, 0])
696  >>> m.result().numpy()
697  0.5
698
699  Usage with `compile()` API:
700
701  ```python
702  model.compile(optimizer='sgd',
703                loss='mse',
704                metrics=[tf.keras.metrics.Accuracy()])
705  ```
706  """
707
708  def __init__(self, name='accuracy', dtype=None):
709    super(Accuracy, self).__init__(accuracy, name, dtype=dtype)
710
711
712@keras_export('keras.metrics.BinaryAccuracy')
713class BinaryAccuracy(MeanMetricWrapper):
714  """Calculates how often predictions match binary labels.
715
716  This metric creates two local variables, `total` and `count` that are used to
717  compute the frequency with which `y_pred` matches `y_true`. This frequency is
718  ultimately returned as `binary accuracy`: an idempotent operation that simply
719  divides `total` by `count`.
720
721  If `sample_weight` is `None`, weights default to 1.
722  Use `sample_weight` of 0 to mask values.
723
724  Args:
725    name: (Optional) string name of the metric instance.
726    dtype: (Optional) data type of the metric result.
727    threshold: (Optional) Float representing the threshold for deciding
728    whether prediction values are 1 or 0.
729
730  Standalone usage:
731
732  >>> m = tf.keras.metrics.BinaryAccuracy()
733  >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])
734  >>> m.result().numpy()
735  0.75
736
737  >>> m.reset_states()
738  >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],
739  ...                sample_weight=[1, 0, 0, 1])
740  >>> m.result().numpy()
741  0.5
742
743  Usage with `compile()` API:
744
745  ```python
746  model.compile(optimizer='sgd',
747                loss='mse',
748                metrics=[tf.keras.metrics.BinaryAccuracy()])
749  ```
750  """
751
752  def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
753    super(BinaryAccuracy, self).__init__(
754        binary_accuracy, name, dtype=dtype, threshold=threshold)
755
756
757@keras_export('keras.metrics.CategoricalAccuracy')
758class CategoricalAccuracy(MeanMetricWrapper):
759  """Calculates how often predictions match one-hot labels.
760
761  You can provide logits of classes as `y_pred`, since argmax of
762  logits and probabilities are same.
763
764  This metric creates two local variables, `total` and `count` that are used to
765  compute the frequency with which `y_pred` matches `y_true`. This frequency is
766  ultimately returned as `categorical accuracy`: an idempotent operation that
767  simply divides `total` by `count`.
768
769  `y_pred` and `y_true` should be passed in as vectors of probabilities, rather
770  than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.
771
772  If `sample_weight` is `None`, weights default to 1.
773  Use `sample_weight` of 0 to mask values.
774
775  Args:
776    name: (Optional) string name of the metric instance.
777    dtype: (Optional) data type of the metric result.
778
779  Standalone usage:
780
781  >>> m = tf.keras.metrics.CategoricalAccuracy()
782  >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
783  ...                 [0.05, 0.95, 0]])
784  >>> m.result().numpy()
785  0.5
786
787  >>> m.reset_states()
788  >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
789  ...                 [0.05, 0.95, 0]],
790  ...                sample_weight=[0.7, 0.3])
791  >>> m.result().numpy()
792  0.3
793
794  Usage with `compile()` API:
795
796  ```python
797  model.compile(
798    optimizer='sgd',
799    loss='mse',
800    metrics=[tf.keras.metrics.CategoricalAccuracy()])
801  ```
802  """
803
804  def __init__(self, name='categorical_accuracy', dtype=None):
805    super(CategoricalAccuracy, self).__init__(
806        categorical_accuracy, name, dtype=dtype)
807
808
809@keras_export('keras.metrics.SparseCategoricalAccuracy')
810class SparseCategoricalAccuracy(MeanMetricWrapper):
811  """Calculates how often predictions match integer labels.
812
813  ```python
814  acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))
815  ```
816
817  You can provide logits of classes as `y_pred`, since argmax of
818  logits and probabilities are same.
819
820  This metric creates two local variables, `total` and `count` that are used to
821  compute the frequency with which `y_pred` matches `y_true`. This frequency is
822  ultimately returned as `sparse categorical accuracy`: an idempotent operation
823  that simply divides `total` by `count`.
824
825  If `sample_weight` is `None`, weights default to 1.
826  Use `sample_weight` of 0 to mask values.
827
828  Args:
829    name: (Optional) string name of the metric instance.
830    dtype: (Optional) data type of the metric result.
831
832  Standalone usage:
833
834  >>> m = tf.keras.metrics.SparseCategoricalAccuracy()
835  >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
836  >>> m.result().numpy()
837  0.5
838
839  >>> m.reset_states()
840  >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]],
841  ...                sample_weight=[0.7, 0.3])
842  >>> m.result().numpy()
843  0.3
844
845  Usage with `compile()` API:
846
847  ```python
848  model.compile(
849      optimizer='sgd',
850      loss='mse',
851      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
852  ```
853  """
854
855  def __init__(self, name='sparse_categorical_accuracy', dtype=None):
856    super(SparseCategoricalAccuracy, self).__init__(
857        sparse_categorical_accuracy, name, dtype=dtype)
858
859
860@keras_export('keras.metrics.TopKCategoricalAccuracy')
861class TopKCategoricalAccuracy(MeanMetricWrapper):
862  """Computes how often targets are in the top `K` predictions.
863
864  Args:
865    k: (Optional) Number of top elements to look at for computing accuracy.
866      Defaults to 5.
867    name: (Optional) string name of the metric instance.
868    dtype: (Optional) data type of the metric result.
869
870  Standalone usage:
871
872  >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
873  >>> m.update_state([[0, 0, 1], [0, 1, 0]],
874  ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
875  >>> m.result().numpy()
876  0.5
877
878  >>> m.reset_states()
879  >>> m.update_state([[0, 0, 1], [0, 1, 0]],
880  ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
881  ...                sample_weight=[0.7, 0.3])
882  >>> m.result().numpy()
883  0.3
884
885  Usage with `compile()` API:
886
887  ```python
888  model.compile(optimizer='sgd',
889                loss='mse',
890                metrics=[tf.keras.metrics.TopKCategoricalAccuracy()])
891  ```
892  """
893
894  def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
895    super(TopKCategoricalAccuracy, self).__init__(
896        top_k_categorical_accuracy, name, dtype=dtype, k=k)
897
898
899@keras_export('keras.metrics.SparseTopKCategoricalAccuracy')
900class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
901  """Computes how often integer targets are in the top `K` predictions.
902
903  Args:
904    k: (Optional) Number of top elements to look at for computing accuracy.
905      Defaults to 5.
906    name: (Optional) string name of the metric instance.
907    dtype: (Optional) data type of the metric result.
908
909  Standalone usage:
910
911  >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
912  >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
913  >>> m.result().numpy()
914  0.5
915
916  >>> m.reset_states()
917  >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
918  ...                sample_weight=[0.7, 0.3])
919  >>> m.result().numpy()
920  0.3
921
922  Usage with `compile()` API:
923
924  ```python
925  model.compile(
926    optimizer='sgd',
927    loss='mse',
928    metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()])
929  ```
930  """
931
932  def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
933    super(SparseTopKCategoricalAccuracy, self).__init__(
934        sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)
935
936
937class _ConfusionMatrixConditionCount(Metric):
938  """Calculates the number of the given confusion matrix condition.
939
940  Args:
941    confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
942    thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple
943      of float threshold values in [0, 1]. A threshold is compared with
944      prediction values to determine the truth value of predictions (i.e., above
945      the threshold is `true`, below is `false`). One metric value is generated
946      for each threshold value.
947    name: (Optional) string name of the metric instance.
948    dtype: (Optional) data type of the metric result.
949  """
950
951  def __init__(self,
952               confusion_matrix_cond,
953               thresholds=None,
954               name=None,
955               dtype=None):
956    super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
957    self._confusion_matrix_cond = confusion_matrix_cond
958    self.init_thresholds = thresholds
959    self.thresholds = metrics_utils.parse_init_thresholds(
960        thresholds, default_threshold=0.5)
961    self.accumulator = self.add_weight(
962        'accumulator',
963        shape=(len(self.thresholds),),
964        initializer=init_ops.zeros_initializer)
965
966  def update_state(self, y_true, y_pred, sample_weight=None):
967    """Accumulates the metric statistics.
968
969    Args:
970      y_true: The ground truth values.
971      y_pred: The predicted values.
972      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
973        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
974        be broadcastable to `y_true`.
975
976    Returns:
977      Update op.
978    """
979    return metrics_utils.update_confusion_matrix_variables(
980        {self._confusion_matrix_cond: self.accumulator},
981        y_true,
982        y_pred,
983        thresholds=self.thresholds,
984        sample_weight=sample_weight)
985
986  def result(self):
987    if len(self.thresholds) == 1:
988      result = self.accumulator[0]
989    else:
990      result = self.accumulator
991    return ops.convert_to_tensor_v2_with_dispatch(result)
992
993  def reset_states(self):
994    num_thresholds = len(to_list(self.thresholds))
995    K.batch_set_value(
996        [(v, np.zeros((num_thresholds,))) for v in self.variables])
997
998  def get_config(self):
999    config = {'thresholds': self.init_thresholds}
1000    base_config = super(_ConfusionMatrixConditionCount, self).get_config()
1001    return dict(list(base_config.items()) + list(config.items()))
1002
1003
1004@keras_export('keras.metrics.FalsePositives')
1005class FalsePositives(_ConfusionMatrixConditionCount):
1006  """Calculates the number of false positives.
1007
1008  If `sample_weight` is given, calculates the sum of the weights of
1009  false positives. This metric creates one local variable, `accumulator`
1010  that is used to keep track of the number of false positives.
1011
1012  If `sample_weight` is `None`, weights default to 1.
1013  Use `sample_weight` of 0 to mask values.
1014
1015  Args:
1016    thresholds: (Optional) Defaults to 0.5. A float value or a python
1017      list/tuple of float threshold values in [0, 1]. A threshold is compared
1018      with prediction values to determine the truth value of predictions
1019      (i.e., above the threshold is `true`, below is `false`). One metric
1020      value is generated for each threshold value.
1021    name: (Optional) string name of the metric instance.
1022    dtype: (Optional) data type of the metric result.
1023
1024  Standalone usage:
1025
1026  >>> m = tf.keras.metrics.FalsePositives()
1027  >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
1028  >>> m.result().numpy()
1029  2.0
1030
1031  >>> m.reset_states()
1032  >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1033  >>> m.result().numpy()
1034  1.0
1035
1036  Usage with `compile()` API:
1037
1038  ```python
1039  model.compile(optimizer='sgd',
1040                loss='mse',
1041                metrics=[tf.keras.metrics.FalsePositives()])
1042  ```
1043  """
1044
1045  def __init__(self, thresholds=None, name=None, dtype=None):
1046    super(FalsePositives, self).__init__(
1047        confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
1048        thresholds=thresholds,
1049        name=name,
1050        dtype=dtype)
1051
1052
1053@keras_export('keras.metrics.FalseNegatives')
1054class FalseNegatives(_ConfusionMatrixConditionCount):
1055  """Calculates the number of false negatives.
1056
1057  If `sample_weight` is given, calculates the sum of the weights of
1058  false negatives. This metric creates one local variable, `accumulator`
1059  that is used to keep track of the number of false negatives.
1060
1061  If `sample_weight` is `None`, weights default to 1.
1062  Use `sample_weight` of 0 to mask values.
1063
1064  Args:
1065    thresholds: (Optional) Defaults to 0.5. A float value or a python
1066      list/tuple of float threshold values in [0, 1]. A threshold is compared
1067      with prediction values to determine the truth value of predictions
1068      (i.e., above the threshold is `true`, below is `false`). One metric
1069      value is generated for each threshold value.
1070    name: (Optional) string name of the metric instance.
1071    dtype: (Optional) data type of the metric result.
1072
1073  Standalone usage:
1074
1075  >>> m = tf.keras.metrics.FalseNegatives()
1076  >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
1077  >>> m.result().numpy()
1078  2.0
1079
1080  >>> m.reset_states()
1081  >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])
1082  >>> m.result().numpy()
1083  1.0
1084
1085  Usage with `compile()` API:
1086
1087  ```python
1088  model.compile(optimizer='sgd',
1089                loss='mse',
1090                metrics=[tf.keras.metrics.FalseNegatives()])
1091  ```
1092  """
1093
1094  def __init__(self, thresholds=None, name=None, dtype=None):
1095    super(FalseNegatives, self).__init__(
1096        confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
1097        thresholds=thresholds,
1098        name=name,
1099        dtype=dtype)
1100
1101
1102@keras_export('keras.metrics.TrueNegatives')
1103class TrueNegatives(_ConfusionMatrixConditionCount):
1104  """Calculates the number of true negatives.
1105
1106  If `sample_weight` is given, calculates the sum of the weights of
1107  true negatives. This metric creates one local variable, `accumulator`
1108  that is used to keep track of the number of true negatives.
1109
1110  If `sample_weight` is `None`, weights default to 1.
1111  Use `sample_weight` of 0 to mask values.
1112
1113  Args:
1114    thresholds: (Optional) Defaults to 0.5. A float value or a python
1115      list/tuple of float threshold values in [0, 1]. A threshold is compared
1116      with prediction values to determine the truth value of predictions
1117      (i.e., above the threshold is `true`, below is `false`). One metric
1118      value is generated for each threshold value.
1119    name: (Optional) string name of the metric instance.
1120    dtype: (Optional) data type of the metric result.
1121
1122  Standalone usage:
1123
1124  >>> m = tf.keras.metrics.TrueNegatives()
1125  >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
1126  >>> m.result().numpy()
1127  2.0
1128
1129  >>> m.reset_states()
1130  >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])
1131  >>> m.result().numpy()
1132  1.0
1133
1134  Usage with `compile()` API:
1135
1136  ```python
1137  model.compile(optimizer='sgd',
1138                loss='mse',
1139                metrics=[tf.keras.metrics.TrueNegatives()])
1140  ```
1141  """
1142
1143  def __init__(self, thresholds=None, name=None, dtype=None):
1144    super(TrueNegatives, self).__init__(
1145        confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
1146        thresholds=thresholds,
1147        name=name,
1148        dtype=dtype)
1149
1150
1151@keras_export('keras.metrics.TruePositives')
1152class TruePositives(_ConfusionMatrixConditionCount):
1153  """Calculates the number of true positives.
1154
1155  If `sample_weight` is given, calculates the sum of the weights of
1156  true positives. This metric creates one local variable, `true_positives`
1157  that is used to keep track of the number of true positives.
1158
1159  If `sample_weight` is `None`, weights default to 1.
1160  Use `sample_weight` of 0 to mask values.
1161
1162  Args:
1163    thresholds: (Optional) Defaults to 0.5. A float value or a python
1164      list/tuple of float threshold values in [0, 1]. A threshold is compared
1165      with prediction values to determine the truth value of predictions
1166      (i.e., above the threshold is `true`, below is `false`). One metric
1167      value is generated for each threshold value.
1168    name: (Optional) string name of the metric instance.
1169    dtype: (Optional) data type of the metric result.
1170
1171  Standalone usage:
1172
1173  >>> m = tf.keras.metrics.TruePositives()
1174  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1175  >>> m.result().numpy()
1176  2.0
1177
1178  >>> m.reset_states()
1179  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1180  >>> m.result().numpy()
1181  1.0
1182
1183  Usage with `compile()` API:
1184
1185  ```python
1186  model.compile(optimizer='sgd',
1187                loss='mse',
1188                metrics=[tf.keras.metrics.TruePositives()])
1189  ```
1190  """
1191
1192  def __init__(self, thresholds=None, name=None, dtype=None):
1193    super(TruePositives, self).__init__(
1194        confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
1195        thresholds=thresholds,
1196        name=name,
1197        dtype=dtype)
1198
1199
1200@keras_export('keras.metrics.Precision')
1201class Precision(Metric):
1202  """Computes the precision of the predictions with respect to the labels.
1203
1204  The metric creates two local variables, `true_positives` and `false_positives`
1205  that are used to compute the precision. This value is ultimately returned as
1206  `precision`, an idempotent operation that simply divides `true_positives`
1207  by the sum of `true_positives` and `false_positives`.
1208
1209  If `sample_weight` is `None`, weights default to 1.
1210  Use `sample_weight` of 0 to mask values.
1211
1212  If `top_k` is set, we'll calculate precision as how often on average a class
1213  among the top-k classes with the highest predicted values of a batch entry is
1214  correct and can be found in the label for that entry.
1215
1216  If `class_id` is specified, we calculate precision by considering only the
1217  entries in the batch for which `class_id` is above the threshold and/or in the
1218  top-k highest predictions, and computing the fraction of them for which
1219  `class_id` is indeed a correct label.
1220
1221  Args:
1222    thresholds: (Optional) A float value or a python list/tuple of float
1223      threshold values in [0, 1]. A threshold is compared with prediction
1224      values to determine the truth value of predictions (i.e., above the
1225      threshold is `true`, below is `false`). One metric value is generated
1226      for each threshold value. If neither thresholds nor top_k are set, the
1227      default is to calculate precision with `thresholds=0.5`.
1228    top_k: (Optional) Unset by default. An int value specifying the top-k
1229      predictions to consider when calculating precision.
1230    class_id: (Optional) Integer class ID for which we want binary metrics.
1231      This must be in the half-open interval `[0, num_classes)`, where
1232      `num_classes` is the last dimension of predictions.
1233    name: (Optional) string name of the metric instance.
1234    dtype: (Optional) data type of the metric result.
1235
1236  Standalone usage:
1237
1238  >>> m = tf.keras.metrics.Precision()
1239  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1240  >>> m.result().numpy()
1241  0.6666667
1242
1243  >>> m.reset_states()
1244  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1245  >>> m.result().numpy()
1246  1.0
1247
1248  >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2]
1249  >>> m = tf.keras.metrics.Precision(top_k=2)
1250  >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
1251  >>> m.result().numpy()
1252  0.0
1253
1254  >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4]
1255  >>> m = tf.keras.metrics.Precision(top_k=4)
1256  >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
1257  >>> m.result().numpy()
1258  0.5
1259
1260  Usage with `compile()` API:
1261
1262  ```python
1263  model.compile(optimizer='sgd',
1264                loss='mse',
1265                metrics=[tf.keras.metrics.Precision()])
1266  ```
1267  """
1268
1269  def __init__(self,
1270               thresholds=None,
1271               top_k=None,
1272               class_id=None,
1273               name=None,
1274               dtype=None):
1275    super(Precision, self).__init__(name=name, dtype=dtype)
1276    self.init_thresholds = thresholds
1277    self.top_k = top_k
1278    self.class_id = class_id
1279
1280    default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1281    self.thresholds = metrics_utils.parse_init_thresholds(
1282        thresholds, default_threshold=default_threshold)
1283    self.true_positives = self.add_weight(
1284        'true_positives',
1285        shape=(len(self.thresholds),),
1286        initializer=init_ops.zeros_initializer)
1287    self.false_positives = self.add_weight(
1288        'false_positives',
1289        shape=(len(self.thresholds),),
1290        initializer=init_ops.zeros_initializer)
1291
1292  def update_state(self, y_true, y_pred, sample_weight=None):
1293    """Accumulates true positive and false positive statistics.
1294
1295    Args:
1296      y_true: The ground truth values, with the same dimensions as `y_pred`.
1297        Will be cast to `bool`.
1298      y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1299      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1300        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1301        be broadcastable to `y_true`.
1302
1303    Returns:
1304      Update op.
1305    """
1306    return metrics_utils.update_confusion_matrix_variables(
1307        {
1308            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1309            metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives
1310        },
1311        y_true,
1312        y_pred,
1313        thresholds=self.thresholds,
1314        top_k=self.top_k,
1315        class_id=self.class_id,
1316        sample_weight=sample_weight)
1317
1318  def result(self):
1319    result = math_ops.div_no_nan(self.true_positives,
1320                                 self.true_positives + self.false_positives)
1321    return result[0] if len(self.thresholds) == 1 else result
1322
1323  def reset_states(self):
1324    num_thresholds = len(to_list(self.thresholds))
1325    K.batch_set_value(
1326        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1327
1328  def get_config(self):
1329    config = {
1330        'thresholds': self.init_thresholds,
1331        'top_k': self.top_k,
1332        'class_id': self.class_id
1333    }
1334    base_config = super(Precision, self).get_config()
1335    return dict(list(base_config.items()) + list(config.items()))
1336
1337
1338@keras_export('keras.metrics.Recall')
1339class Recall(Metric):
1340  """Computes the recall of the predictions with respect to the labels.
1341
1342  This metric creates two local variables, `true_positives` and
1343  `false_negatives`, that are used to compute the recall. This value is
1344  ultimately returned as `recall`, an idempotent operation that simply divides
1345  `true_positives` by the sum of `true_positives` and `false_negatives`.
1346
1347  If `sample_weight` is `None`, weights default to 1.
1348  Use `sample_weight` of 0 to mask values.
1349
1350  If `top_k` is set, recall will be computed as how often on average a class
1351  among the labels of a batch entry is in the top-k predictions.
1352
1353  If `class_id` is specified, we calculate recall by considering only the
1354  entries in the batch for which `class_id` is in the label, and computing the
1355  fraction of them for which `class_id` is above the threshold and/or in the
1356  top-k predictions.
1357
1358  Args:
1359    thresholds: (Optional) A float value or a python list/tuple of float
1360      threshold values in [0, 1]. A threshold is compared with prediction
1361      values to determine the truth value of predictions (i.e., above the
1362      threshold is `true`, below is `false`). One metric value is generated
1363      for each threshold value. If neither thresholds nor top_k are set, the
1364      default is to calculate recall with `thresholds=0.5`.
1365    top_k: (Optional) Unset by default. An int value specifying the top-k
1366      predictions to consider when calculating recall.
1367    class_id: (Optional) Integer class ID for which we want binary metrics.
1368      This must be in the half-open interval `[0, num_classes)`, where
1369      `num_classes` is the last dimension of predictions.
1370    name: (Optional) string name of the metric instance.
1371    dtype: (Optional) data type of the metric result.
1372
1373  Standalone usage:
1374
1375  >>> m = tf.keras.metrics.Recall()
1376  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1377  >>> m.result().numpy()
1378  0.6666667
1379
1380  >>> m.reset_states()
1381  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1382  >>> m.result().numpy()
1383  1.0
1384
1385  Usage with `compile()` API:
1386
1387  ```python
1388  model.compile(optimizer='sgd',
1389                loss='mse',
1390                metrics=[tf.keras.metrics.Recall()])
1391  ```
1392  """
1393
1394  def __init__(self,
1395               thresholds=None,
1396               top_k=None,
1397               class_id=None,
1398               name=None,
1399               dtype=None):
1400    super(Recall, self).__init__(name=name, dtype=dtype)
1401    self.init_thresholds = thresholds
1402    self.top_k = top_k
1403    self.class_id = class_id
1404
1405    default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1406    self.thresholds = metrics_utils.parse_init_thresholds(
1407        thresholds, default_threshold=default_threshold)
1408    self.true_positives = self.add_weight(
1409        'true_positives',
1410        shape=(len(self.thresholds),),
1411        initializer=init_ops.zeros_initializer)
1412    self.false_negatives = self.add_weight(
1413        'false_negatives',
1414        shape=(len(self.thresholds),),
1415        initializer=init_ops.zeros_initializer)
1416
1417  def update_state(self, y_true, y_pred, sample_weight=None):
1418    """Accumulates true positive and false negative statistics.
1419
1420    Args:
1421      y_true: The ground truth values, with the same dimensions as `y_pred`.
1422        Will be cast to `bool`.
1423      y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1424      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1425        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1426        be broadcastable to `y_true`.
1427
1428    Returns:
1429      Update op.
1430    """
1431    return metrics_utils.update_confusion_matrix_variables(
1432        {
1433            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1434            metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives
1435        },
1436        y_true,
1437        y_pred,
1438        thresholds=self.thresholds,
1439        top_k=self.top_k,
1440        class_id=self.class_id,
1441        sample_weight=sample_weight)
1442
1443  def result(self):
1444    result = math_ops.div_no_nan(self.true_positives,
1445                                 self.true_positives + self.false_negatives)
1446    return result[0] if len(self.thresholds) == 1 else result
1447
1448  def reset_states(self):
1449    num_thresholds = len(to_list(self.thresholds))
1450    K.batch_set_value(
1451        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1452
1453  def get_config(self):
1454    config = {
1455        'thresholds': self.init_thresholds,
1456        'top_k': self.top_k,
1457        'class_id': self.class_id
1458    }
1459    base_config = super(Recall, self).get_config()
1460    return dict(list(base_config.items()) + list(config.items()))
1461
1462
1463@six.add_metaclass(abc.ABCMeta)
1464class SensitivitySpecificityBase(Metric):
1465  """Abstract base class for computing sensitivity and specificity.
1466
1467  For additional information about specificity and sensitivity, see
1468  [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1469  """
1470
1471  def __init__(self, value, num_thresholds=200, name=None, dtype=None):
1472    super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
1473    if num_thresholds <= 0:
1474      raise ValueError('`num_thresholds` must be > 0.')
1475    self.value = value
1476    self.true_positives = self.add_weight(
1477        'true_positives',
1478        shape=(num_thresholds,),
1479        initializer=init_ops.zeros_initializer)
1480    self.true_negatives = self.add_weight(
1481        'true_negatives',
1482        shape=(num_thresholds,),
1483        initializer=init_ops.zeros_initializer)
1484    self.false_positives = self.add_weight(
1485        'false_positives',
1486        shape=(num_thresholds,),
1487        initializer=init_ops.zeros_initializer)
1488    self.false_negatives = self.add_weight(
1489        'false_negatives',
1490        shape=(num_thresholds,),
1491        initializer=init_ops.zeros_initializer)
1492
1493    # Compute `num_thresholds` thresholds in [0, 1]
1494    if num_thresholds == 1:
1495      self.thresholds = [0.5]
1496    else:
1497      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1498                    for i in range(num_thresholds - 2)]
1499      self.thresholds = [0.0] + thresholds + [1.0]
1500
1501  def update_state(self, y_true, y_pred, sample_weight=None):
1502    """Accumulates confusion matrix statistics.
1503
1504    Args:
1505      y_true: The ground truth values.
1506      y_pred: The predicted values.
1507      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1508        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1509        be broadcastable to `y_true`.
1510
1511    Returns:
1512      Update op.
1513    """
1514    return metrics_utils.update_confusion_matrix_variables(
1515        {
1516            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1517            metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
1518            metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
1519            metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
1520        },
1521        y_true,
1522        y_pred,
1523        thresholds=self.thresholds,
1524        sample_weight=sample_weight)
1525
1526  def reset_states(self):
1527    num_thresholds = len(self.thresholds)
1528    K.batch_set_value(
1529        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1530
1531  def _find_max_under_constraint(self, constrained, dependent, predicate):
1532    """Returns the maximum of dependent_statistic that satisfies the constraint.
1533
1534    Args:
1535      constrained: Over these values the constraint
1536        is specified. A rank-1 tensor.
1537      dependent: From these values the maximum that satiesfies the
1538        constraint is selected. Values in this tensor and in
1539        `constrained` are linked by having the same threshold at each
1540        position, hence this tensor must have the same shape.
1541      predicate: A binary boolean functor to be applied to arguments
1542      `constrained` and `self.value`, e.g. `tf.greater`.
1543
1544    Returns maximal dependent value, if no value satiesfies the constraint 0.0.
1545    """
1546    feasible = array_ops.where(predicate(constrained, self.value))
1547    feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
1548
1549    def get_max():
1550      return math_ops.reduce_max(array_ops.gather(dependent, feasible))
1551
1552    return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0)
1553
1554
1555@keras_export('keras.metrics.SensitivityAtSpecificity')
1556class SensitivityAtSpecificity(SensitivitySpecificityBase):
1557  """Computes best sensitivity where specificity is >= specified value.
1558
1559  the sensitivity at a given specificity.
1560
1561  `Sensitivity` measures the proportion of actual positives that are correctly
1562  identified as such (tp / (tp + fn)).
1563  `Specificity` measures the proportion of actual negatives that are correctly
1564  identified as such (tn / (tn + fp)).
1565
1566  This metric creates four local variables, `true_positives`, `true_negatives`,
1567  `false_positives` and `false_negatives` that are used to compute the
1568  sensitivity at the given specificity. The threshold for the given specificity
1569  value is computed and used to evaluate the corresponding sensitivity.
1570
1571  If `sample_weight` is `None`, weights default to 1.
1572  Use `sample_weight` of 0 to mask values.
1573
1574  For additional information about specificity and sensitivity, see
1575  [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1576
1577  Args:
1578    specificity: A scalar value in range `[0, 1]`.
1579    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1580      use for matching the given specificity.
1581    name: (Optional) string name of the metric instance.
1582    dtype: (Optional) data type of the metric result.
1583
1584  Standalone usage:
1585
1586  >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
1587  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1588  >>> m.result().numpy()
1589  0.5
1590
1591  >>> m.reset_states()
1592  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1593  ...                sample_weight=[1, 1, 2, 2, 1])
1594  >>> m.result().numpy()
1595  0.333333
1596
1597  Usage with `compile()` API:
1598
1599  ```python
1600  model.compile(
1601      optimizer='sgd',
1602      loss='mse',
1603      metrics=[tf.keras.metrics.SensitivityAtSpecificity()])
1604  ```
1605  """
1606
1607  def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
1608    if specificity < 0 or specificity > 1:
1609      raise ValueError('`specificity` must be in the range [0, 1].')
1610    self.specificity = specificity
1611    self.num_thresholds = num_thresholds
1612    super(SensitivityAtSpecificity, self).__init__(
1613        specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
1614
1615  def result(self):
1616    specificities = math_ops.div_no_nan(
1617        self.true_negatives, self.true_negatives + self.false_positives)
1618    sensitivities = math_ops.div_no_nan(
1619        self.true_positives, self.true_positives + self.false_negatives)
1620    return self._find_max_under_constraint(
1621        specificities, sensitivities, math_ops.greater_equal)
1622
1623  def get_config(self):
1624    config = {
1625        'num_thresholds': self.num_thresholds,
1626        'specificity': self.specificity
1627    }
1628    base_config = super(SensitivityAtSpecificity, self).get_config()
1629    return dict(list(base_config.items()) + list(config.items()))
1630
1631
1632@keras_export('keras.metrics.SpecificityAtSensitivity')
1633class SpecificityAtSensitivity(SensitivitySpecificityBase):
1634  """Computes best specificity where sensitivity is >= specified value.
1635
1636  `Sensitivity` measures the proportion of actual positives that are correctly
1637  identified as such (tp / (tp + fn)).
1638  `Specificity` measures the proportion of actual negatives that are correctly
1639  identified as such (tn / (tn + fp)).
1640
1641  This metric creates four local variables, `true_positives`, `true_negatives`,
1642  `false_positives` and `false_negatives` that are used to compute the
1643  specificity at the given sensitivity. The threshold for the given sensitivity
1644  value is computed and used to evaluate the corresponding specificity.
1645
1646  If `sample_weight` is `None`, weights default to 1.
1647  Use `sample_weight` of 0 to mask values.
1648
1649  For additional information about specificity and sensitivity, see
1650  [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1651
1652  Args:
1653    sensitivity: A scalar value in range `[0, 1]`.
1654    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1655      use for matching the given sensitivity.
1656    name: (Optional) string name of the metric instance.
1657    dtype: (Optional) data type of the metric result.
1658
1659  Standalone usage:
1660
1661  >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
1662  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1663  >>> m.result().numpy()
1664  0.66666667
1665
1666  >>> m.reset_states()
1667  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1668  ...                sample_weight=[1, 1, 2, 2, 2])
1669  >>> m.result().numpy()
1670  0.5
1671
1672  Usage with `compile()` API:
1673
1674  ```python
1675  model.compile(
1676      optimizer='sgd',
1677      loss='mse',
1678      metrics=[tf.keras.metrics.SpecificityAtSensitivity()])
1679  ```
1680  """
1681
1682  def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None):
1683    if sensitivity < 0 or sensitivity > 1:
1684      raise ValueError('`sensitivity` must be in the range [0, 1].')
1685    self.sensitivity = sensitivity
1686    self.num_thresholds = num_thresholds
1687    super(SpecificityAtSensitivity, self).__init__(
1688        sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
1689
1690  def result(self):
1691    sensitivities = math_ops.div_no_nan(
1692        self.true_positives, self.true_positives + self.false_negatives)
1693    specificities = math_ops.div_no_nan(
1694        self.true_negatives, self.true_negatives + self.false_positives)
1695    return self._find_max_under_constraint(
1696        sensitivities, specificities, math_ops.greater_equal)
1697
1698  def get_config(self):
1699    config = {
1700        'num_thresholds': self.num_thresholds,
1701        'sensitivity': self.sensitivity
1702    }
1703    base_config = super(SpecificityAtSensitivity, self).get_config()
1704    return dict(list(base_config.items()) + list(config.items()))
1705
1706
1707@keras_export('keras.metrics.PrecisionAtRecall')
1708class PrecisionAtRecall(SensitivitySpecificityBase):
1709  """Computes best precision where recall is >= specified value.
1710
1711  This metric creates four local variables, `true_positives`, `true_negatives`,
1712  `false_positives` and `false_negatives` that are used to compute the
1713  precision at the given recall. The threshold for the given recall
1714  value is computed and used to evaluate the corresponding precision.
1715
1716  If `sample_weight` is `None`, weights default to 1.
1717  Use `sample_weight` of 0 to mask values.
1718
1719  Args:
1720    recall: A scalar value in range `[0, 1]`.
1721    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1722      use for matching the given recall.
1723    name: (Optional) string name of the metric instance.
1724    dtype: (Optional) data type of the metric result.
1725
1726  Standalone usage:
1727
1728  >>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
1729  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1730  >>> m.result().numpy()
1731  0.5
1732
1733  >>> m.reset_states()
1734  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1735  ...                sample_weight=[2, 2, 2, 1, 1])
1736  >>> m.result().numpy()
1737  0.33333333
1738
1739  Usage with `compile()` API:
1740
1741  ```python
1742  model.compile(
1743      optimizer='sgd',
1744      loss='mse',
1745      metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)])
1746  ```
1747  """
1748
1749  def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
1750    if recall < 0 or recall > 1:
1751      raise ValueError('`recall` must be in the range [0, 1].')
1752    self.recall = recall
1753    self.num_thresholds = num_thresholds
1754    super(PrecisionAtRecall, self).__init__(
1755        value=recall,
1756        num_thresholds=num_thresholds,
1757        name=name,
1758        dtype=dtype)
1759
1760  def result(self):
1761    recalls = math_ops.div_no_nan(
1762        self.true_positives, self.true_positives + self.false_negatives)
1763    precisions = math_ops.div_no_nan(
1764        self.true_positives, self.true_positives + self.false_positives)
1765    return self._find_max_under_constraint(
1766        recalls, precisions, math_ops.greater_equal)
1767
1768  def get_config(self):
1769    config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
1770    base_config = super(PrecisionAtRecall, self).get_config()
1771    return dict(list(base_config.items()) + list(config.items()))
1772
1773
1774@keras_export('keras.metrics.RecallAtPrecision')
1775class RecallAtPrecision(SensitivitySpecificityBase):
1776  """Computes best recall where precision is >= specified value.
1777
1778  For a given score-label-distribution the required precision might not
1779  be achievable, in this case 0.0 is returned as recall.
1780
1781  This metric creates four local variables, `true_positives`, `true_negatives`,
1782  `false_positives` and `false_negatives` that are used to compute the
1783  recall at the given precision. The threshold for the given precision
1784  value is computed and used to evaluate the corresponding recall.
1785
1786  If `sample_weight` is `None`, weights default to 1.
1787  Use `sample_weight` of 0 to mask values.
1788
1789  Args:
1790    precision: A scalar value in range `[0, 1]`.
1791    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1792      use for matching the given precision.
1793    name: (Optional) string name of the metric instance.
1794    dtype: (Optional) data type of the metric result.
1795
1796  Standalone usage:
1797
1798  >>> m = tf.keras.metrics.RecallAtPrecision(0.8)
1799  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1800  >>> m.result().numpy()
1801  0.5
1802
1803  >>> m.reset_states()
1804  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1805  ...                sample_weight=[1, 0, 0, 1])
1806  >>> m.result().numpy()
1807  1.0
1808
1809  Usage with `compile()` API:
1810
1811  ```python
1812  model.compile(
1813      optimizer='sgd',
1814      loss='mse',
1815      metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)])
1816  ```
1817  """
1818
1819  def __init__(self, precision, num_thresholds=200, name=None, dtype=None):
1820    if precision < 0 or precision > 1:
1821      raise ValueError('`precision` must be in the range [0, 1].')
1822    self.precision = precision
1823    self.num_thresholds = num_thresholds
1824    super(RecallAtPrecision, self).__init__(
1825        value=precision,
1826        num_thresholds=num_thresholds,
1827        name=name,
1828        dtype=dtype)
1829
1830  def result(self):
1831    precisions = math_ops.div_no_nan(
1832        self.true_positives, self.true_positives + self.false_positives)
1833    recalls = math_ops.div_no_nan(
1834        self.true_positives, self.true_positives + self.false_negatives)
1835    return self._find_max_under_constraint(
1836        precisions, recalls, math_ops.greater_equal)
1837
1838  def get_config(self):
1839    config = {'num_thresholds': self.num_thresholds,
1840              'precision': self.precision}
1841    base_config = super(RecallAtPrecision, self).get_config()
1842    return dict(list(base_config.items()) + list(config.items()))
1843
1844
1845@keras_export('keras.metrics.AUC')
1846class AUC(Metric):
1847  """Approximates the AUC (Area under the curve) of the ROC or PR curves.
1848
1849  The AUC (Area under the curve) of the ROC (Receiver operating
1850  characteristic; default) or PR (Precision Recall) curves are quality measures
1851  of binary classifiers. Unlike the accuracy, and like cross-entropy
1852  losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
1853
1854  This classes approximates AUCs using a Riemann sum: During the metric
1855  accumulation phrase, predictions are accumulated within predefined buckets
1856  by value. The AUC is then computed by interpolating per-bucket averages. These
1857  buckets define the evaluated operational points.
1858
1859  This metric creates four local variables, `true_positives`, `true_negatives`,
1860  `false_positives` and `false_negatives` that are used to compute the AUC.
1861  To discretize the AUC curve, a linearly spaced set of thresholds is used to
1862  compute pairs of recall and precision values. The area under the ROC-curve is
1863  therefore computed using the height of the recall values by the false positive
1864  rate, while the area under the PR-curve is the computed using the height of
1865  the precision values by the recall.
1866
1867  This value is ultimately returned as `auc`, an idempotent operation that
1868  computes the area under a discretized curve of precision versus recall values
1869  (computed using the aforementioned variables). The `num_thresholds` variable
1870  controls the degree of discretization with larger numbers of thresholds more
1871  closely approximating the true AUC. The quality of the approximation may vary
1872  dramatically depending on `num_thresholds`. The `thresholds` parameter can be
1873  used to manually specify thresholds which split the predictions more evenly.
1874
1875  For a best approximation of the real AUC, `predictions` should be distributed
1876  approximately uniformly in the range [0, 1] (if `from_logits=False`). The
1877  quality of the AUC approximation may be poor if this is not the case. Setting
1878  `summation_method` to 'minoring' or 'majoring' can help quantify the error in
1879  the approximation by providing lower or upper bound estimate of the AUC.
1880
1881  If `sample_weight` is `None`, weights default to 1.
1882  Use `sample_weight` of 0 to mask values.
1883
1884  Args:
1885    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1886      use when discretizing the roc curve. Values must be > 1.
1887    curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
1888      [default] or 'PR' for the Precision-Recall-curve.
1889    summation_method: (Optional) Specifies the [Riemann summation method](
1890        https://en.wikipedia.org/wiki/Riemann_sum) used.
1891        'interpolation' (default) applies mid-point summation scheme for `ROC`.
1892        For PR-AUC, interpolates (true/false) positives but not the ratio that
1893        is precision (see Davis & Goadrich 2006 for details);
1894        'minoring' applies left summation
1895        for increasing intervals and right summation for decreasing intervals;
1896        'majoring' does the opposite.
1897    name: (Optional) string name of the metric instance.
1898    dtype: (Optional) data type of the metric result.
1899    thresholds: (Optional) A list of floating point values to use as the
1900      thresholds for discretizing the curve. If set, the `num_thresholds`
1901      parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
1902      equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
1903      be automatically included with these to correctly handle predictions
1904      equal to exactly 0 or 1.
1905    multi_label: boolean indicating whether multilabel data should be
1906      treated as such, wherein AUC is computed separately for each label and
1907      then averaged across labels, or (when False) if the data should be
1908      flattened into a single label before AUC computation. In the latter
1909      case, when multilabel data is passed to AUC, each label-prediction pair
1910      is treated as an individual data point. Should be set to False for
1911      multi-class data.
1912    num_labels: (Optional) The number of labels, used when `multi_label' is
1913      True. If `num_labels` is not specified, then state variables get created
1914      on the first call to `update_state`.
1915    label_weights: (Optional) list, array, or tensor of non-negative weights
1916      used to compute AUCs for multilabel data. When `multi_label` is True,
1917      the weights are applied to the individual label AUCs when they are
1918      averaged to produce the multi-label AUC. When it's False, they are used
1919      to weight the individual label predictions in computing the confusion
1920      matrix on the flattened data. Note that this is unlike class_weights in
1921      that class_weights weights the example depending on the value of its
1922      label, whereas label_weights depends only on the index of that label
1923      before flattening; therefore `label_weights` should not be used for
1924      multi-class data.
1925    from_logits: boolean indicating whether the predictions (`y_pred` in
1926      `update_state`) are probabilities or sigmoid logits. As a rule of thumb,
1927      when using a keras loss, the `from_logits` constructor argument of the
1928      loss should match the AUC `from_logits` constructor argument.
1929
1930  Standalone usage:
1931
1932  >>> m = tf.keras.metrics.AUC(num_thresholds=3)
1933  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1934  >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
1935  >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
1936  >>> # recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
1937  >>> # auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
1938  >>> m.result().numpy()
1939  0.75
1940
1941  >>> m.reset_states()
1942  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1943  ...                sample_weight=[1, 0, 0, 1])
1944  >>> m.result().numpy()
1945  1.0
1946
1947  Usage with `compile()` API:
1948
1949  ```python
1950  # Reports the AUC of a model outputing a probability.
1951  model.compile(optimizer='sgd',
1952                loss=tf.keras.losses.BinaryCrossentropy(),
1953                metrics=[tf.keras.metrics.AUC()])
1954
1955  # Reports the AUC of a model outputing a logit.
1956  model.compile(optimizer='sgd',
1957                loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
1958                metrics=[tf.keras.metrics.AUC(from_logits=True)])
1959  ```
1960  """
1961
1962  def __init__(self,
1963               num_thresholds=200,
1964               curve='ROC',
1965               summation_method='interpolation',
1966               name=None,
1967               dtype=None,
1968               thresholds=None,
1969               multi_label=False,
1970               num_labels=None,
1971               label_weights=None,
1972               from_logits=False):
1973    # Validate configurations.
1974    if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
1975        metrics_utils.AUCCurve):
1976      raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
1977          curve, list(metrics_utils.AUCCurve)))
1978    if isinstance(
1979        summation_method,
1980        metrics_utils.AUCSummationMethod) and summation_method not in list(
1981            metrics_utils.AUCSummationMethod):
1982      raise ValueError(
1983          'Invalid summation method: "{}". Valid options are: "{}"'.format(
1984              summation_method, list(metrics_utils.AUCSummationMethod)))
1985
1986    # Update properties.
1987    if thresholds is not None:
1988      # If specified, use the supplied thresholds.
1989      self.num_thresholds = len(thresholds) + 2
1990      thresholds = sorted(thresholds)
1991    else:
1992      if num_thresholds <= 1:
1993        raise ValueError('`num_thresholds` must be > 1.')
1994
1995      # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
1996      # (0, 1).
1997      self.num_thresholds = num_thresholds
1998      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1999                    for i in range(num_thresholds - 2)]
2000
2001    # Add an endpoint "threshold" below zero and above one for either
2002    # threshold method to account for floating point imprecisions.
2003    self._thresholds = np.array([0.0 - K.epsilon()] + thresholds +
2004                                [1.0 + K.epsilon()])
2005
2006    if isinstance(curve, metrics_utils.AUCCurve):
2007      self.curve = curve
2008    else:
2009      self.curve = metrics_utils.AUCCurve.from_str(curve)
2010    if isinstance(summation_method, metrics_utils.AUCSummationMethod):
2011      self.summation_method = summation_method
2012    else:
2013      self.summation_method = metrics_utils.AUCSummationMethod.from_str(
2014          summation_method)
2015    super(AUC, self).__init__(name=name, dtype=dtype)
2016
2017    # Handle multilabel arguments.
2018    self.multi_label = multi_label
2019    if label_weights is not None:
2020      label_weights = constant_op.constant(label_weights, dtype=self.dtype)
2021      checks = [
2022          check_ops.assert_non_negative(
2023              label_weights,
2024              message='All values of `label_weights` must be non-negative.')
2025      ]
2026      with ops.control_dependencies(checks):
2027        self.label_weights = label_weights
2028
2029    else:
2030      self.label_weights = None
2031
2032    self._from_logits = from_logits
2033
2034    self._built = False
2035    if self.multi_label:
2036      if num_labels:
2037        shape = tensor_shape.TensorShape([None, num_labels])
2038        self._build(shape)
2039    else:
2040      if num_labels:
2041        raise ValueError(
2042            '`num_labels` is needed only when `multi_label` is True.')
2043      self._build(None)
2044
2045  @property
2046  def thresholds(self):
2047    """The thresholds used for evaluating AUC."""
2048    return list(self._thresholds)
2049
2050  def _build(self, shape):
2051    """Initialize TP, FP, TN, and FN tensors, given the shape of the data."""
2052    if self.multi_label:
2053      if shape.ndims != 2:
2054        raise ValueError('`y_true` must have rank=2 when `multi_label` is '
2055                         'True. Found rank %s.' % shape.ndims)
2056      self._num_labels = shape[1]
2057      variable_shape = tensor_shape.TensorShape(
2058          [tensor_shape.Dimension(self.num_thresholds), self._num_labels])
2059
2060    else:
2061      variable_shape = tensor_shape.TensorShape(
2062          [tensor_shape.Dimension(self.num_thresholds)])
2063    self._build_input_shape = shape
2064    # Create metric variables
2065    self.true_positives = self.add_weight(
2066        'true_positives',
2067        shape=variable_shape,
2068        initializer=init_ops.zeros_initializer)
2069    self.true_negatives = self.add_weight(
2070        'true_negatives',
2071        shape=variable_shape,
2072        initializer=init_ops.zeros_initializer)
2073    self.false_positives = self.add_weight(
2074        'false_positives',
2075        shape=variable_shape,
2076        initializer=init_ops.zeros_initializer)
2077    self.false_negatives = self.add_weight(
2078        'false_negatives',
2079        shape=variable_shape,
2080        initializer=init_ops.zeros_initializer)
2081
2082    if self.multi_label:
2083      with ops.init_scope():
2084        # This should only be necessary for handling v1 behavior. In v2, AUC
2085        # should be initialized outside of any tf.functions, and therefore in
2086        # eager mode.
2087        if not context.executing_eagerly():
2088          K._initialize_variables(K._get_session())  # pylint: disable=protected-access
2089
2090    self._built = True
2091
2092  def update_state(self, y_true, y_pred, sample_weight=None):
2093    """Accumulates confusion matrix statistics.
2094
2095    Args:
2096      y_true: The ground truth values.
2097      y_pred: The predicted values.
2098      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2099        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2100        be broadcastable to `y_true`.
2101
2102    Returns:
2103      Update op.
2104    """
2105    deps = []
2106    if not self._built:
2107      self._build(tensor_shape.TensorShape(y_pred.shape))
2108
2109    if self.multi_label or (self.label_weights is not None):
2110      # y_true should have shape (number of examples, number of labels).
2111      shapes = [
2112          (y_true, ('N', 'L'))
2113      ]
2114      if self.multi_label:
2115        # TP, TN, FP, and FN should all have shape
2116        # (number of thresholds, number of labels).
2117        shapes.extend([(self.true_positives, ('T', 'L')),
2118                       (self.true_negatives, ('T', 'L')),
2119                       (self.false_positives, ('T', 'L')),
2120                       (self.false_negatives, ('T', 'L'))])
2121      if self.label_weights is not None:
2122        # label_weights should be of length equal to the number of labels.
2123        shapes.append((self.label_weights, ('L',)))
2124      deps = [
2125          check_ops.assert_shapes(
2126              shapes, message='Number of labels is not consistent.')
2127      ]
2128
2129    # Only forward label_weights to update_confusion_matrix_variables when
2130    # multi_label is False. Otherwise the averaging of individual label AUCs is
2131    # handled in AUC.result
2132    label_weights = None if self.multi_label else self.label_weights
2133
2134    if self._from_logits:
2135      y_pred = activations.sigmoid(y_pred)
2136
2137    with ops.control_dependencies(deps):
2138      return metrics_utils.update_confusion_matrix_variables(
2139          {
2140              metrics_utils.ConfusionMatrix.TRUE_POSITIVES:
2141                  self.true_positives,
2142              metrics_utils.ConfusionMatrix.TRUE_NEGATIVES:
2143                  self.true_negatives,
2144              metrics_utils.ConfusionMatrix.FALSE_POSITIVES:
2145                  self.false_positives,
2146              metrics_utils.ConfusionMatrix.FALSE_NEGATIVES:
2147                  self.false_negatives,
2148          },
2149          y_true,
2150          y_pred,
2151          self._thresholds,
2152          sample_weight=sample_weight,
2153          multi_label=self.multi_label,
2154          label_weights=label_weights)
2155
2156  def interpolate_pr_auc(self):
2157    """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
2158
2159    https://www.biostat.wisc.edu/~page/rocpr.pdf
2160
2161    Note here we derive & use a closed formula not present in the paper
2162    as follows:
2163
2164      Precision = TP / (TP + FP) = TP / P
2165
2166    Modeling all of TP (true positive), FP (false positive) and their sum
2167    P = TP + FP (predicted positive) as varying linearly within each interval
2168    [A, B] between successive thresholds, we get
2169
2170      Precision slope = dTP / dP
2171                      = (TP_B - TP_A) / (P_B - P_A)
2172                      = (TP - TP_A) / (P - P_A)
2173      Precision = (TP_A + slope * (P - P_A)) / P
2174
2175    The area within the interval is (slope / total_pos_weight) times
2176
2177      int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
2178      int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
2179
2180    where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
2181
2182      int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
2183
2184    Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
2185
2186      slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
2187
2188    where dTP == TP_B - TP_A.
2189
2190    Note that when P_A == 0 the above calculation simplifies into
2191
2192      int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
2193
2194    which is really equivalent to imputing constant precision throughout the
2195    first bucket having >0 true positives.
2196
2197    Returns:
2198      pr_auc: an approximation of the area under the P-R curve.
2199    """
2200    dtp = self.true_positives[:self.num_thresholds -
2201                              1] - self.true_positives[1:]
2202    p = self.true_positives + self.false_positives
2203    dp = p[:self.num_thresholds - 1] - p[1:]
2204    prec_slope = math_ops.div_no_nan(
2205        dtp, math_ops.maximum(dp, 0), name='prec_slope')
2206    intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:])
2207
2208    safe_p_ratio = array_ops.where(
2209        math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0),
2210        math_ops.div_no_nan(
2211            p[:self.num_thresholds - 1],
2212            math_ops.maximum(p[1:], 0),
2213            name='recall_relative_ratio'),
2214        array_ops.ones_like(p[1:]))
2215
2216    pr_auc_increment = math_ops.div_no_nan(
2217        prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
2218        math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 0),
2219        name='pr_auc_increment')
2220
2221    if self.multi_label:
2222      by_label_auc = math_ops.reduce_sum(
2223          pr_auc_increment, name=self.name + '_by_label', axis=0)
2224      if self.label_weights is None:
2225        # Evenly weighted average of the label AUCs.
2226        return math_ops.reduce_mean(by_label_auc, name=self.name)
2227      else:
2228        # Weighted average of the label AUCs.
2229        return math_ops.div_no_nan(
2230            math_ops.reduce_sum(
2231                math_ops.multiply(by_label_auc, self.label_weights)),
2232            math_ops.reduce_sum(self.label_weights),
2233            name=self.name)
2234    else:
2235      return math_ops.reduce_sum(pr_auc_increment, name='interpolate_pr_auc')
2236
2237  def result(self):
2238    if (self.curve == metrics_utils.AUCCurve.PR and
2239        self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION
2240       ):
2241      # This use case is different and is handled separately.
2242      return self.interpolate_pr_auc()
2243
2244    # Set `x` and `y` values for the curves based on `curve` config.
2245    recall = math_ops.div_no_nan(self.true_positives,
2246                                 self.true_positives + self.false_negatives)
2247    if self.curve == metrics_utils.AUCCurve.ROC:
2248      fp_rate = math_ops.div_no_nan(self.false_positives,
2249                                    self.false_positives + self.true_negatives)
2250      x = fp_rate
2251      y = recall
2252    else:  # curve == 'PR'.
2253      precision = math_ops.div_no_nan(
2254          self.true_positives, self.true_positives + self.false_positives)
2255      x = recall
2256      y = precision
2257
2258    # Find the rectangle heights based on `summation_method`.
2259    if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
2260      # Note: the case ('PR', 'interpolation') has been handled above.
2261      heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
2262    elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
2263      heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:])
2264    else:  # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
2265      heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:])
2266
2267    # Sum up the areas of all the rectangles.
2268    if self.multi_label:
2269      riemann_terms = math_ops.multiply(x[:self.num_thresholds - 1] - x[1:],
2270                                        heights)
2271      by_label_auc = math_ops.reduce_sum(
2272          riemann_terms, name=self.name + '_by_label', axis=0)
2273
2274      if self.label_weights is None:
2275        # Unweighted average of the label AUCs.
2276        return math_ops.reduce_mean(by_label_auc, name=self.name)
2277      else:
2278        # Weighted average of the label AUCs.
2279        return math_ops.div_no_nan(
2280            math_ops.reduce_sum(
2281                math_ops.multiply(by_label_auc, self.label_weights)),
2282            math_ops.reduce_sum(self.label_weights),
2283            name=self.name)
2284    else:
2285      return math_ops.reduce_sum(
2286          math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights),
2287          name=self.name)
2288
2289  def reset_states(self):
2290    if self.multi_label:
2291      K.batch_set_value([(v, np.zeros((self.num_thresholds, self._num_labels)))
2292                         for v in self.variables])
2293    else:
2294      K.batch_set_value([
2295          (v, np.zeros((self.num_thresholds,))) for v in self.variables
2296      ])
2297
2298  def get_config(self):
2299    if is_tensor_or_variable(self.label_weights):
2300      label_weights = K.eval(self.label_weights)
2301    else:
2302      label_weights = self.label_weights
2303    config = {
2304        'num_thresholds': self.num_thresholds,
2305        'curve': self.curve.value,
2306        'summation_method': self.summation_method.value,
2307        # We remove the endpoint thresholds as an inverse of how the thresholds
2308        # were initialized. This ensures that a metric initialized from this
2309        # config has the same thresholds.
2310        'thresholds': self.thresholds[1:-1],
2311        'multi_label': self.multi_label,
2312        'label_weights': label_weights
2313    }
2314    base_config = super(AUC, self).get_config()
2315    return dict(list(base_config.items()) + list(config.items()))
2316
2317
2318@keras_export('keras.metrics.CosineSimilarity')
2319class CosineSimilarity(MeanMetricWrapper):
2320  """Computes the cosine similarity between the labels and predictions.
2321
2322  `cosine similarity = (a . b) / ||a|| ||b||`
2323
2324  See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity).
2325
2326  This metric keeps the average cosine similarity between `predictions` and
2327  `labels` over a stream of data.
2328
2329  Args:
2330    name: (Optional) string name of the metric instance.
2331    dtype: (Optional) data type of the metric result.
2332    axis: (Optional) Defaults to -1. The dimension along which the cosine
2333      similarity is computed.
2334
2335  Standalone usage:
2336
2337  >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]]
2338  >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]]
2339  >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
2340  >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
2341  >>> #        = ((0. + 0.) +  (0.5 + 0.5)) / 2
2342  >>> m = tf.keras.metrics.CosineSimilarity(axis=1)
2343  >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])
2344  >>> m.result().numpy()
2345  0.49999997
2346
2347  >>> m.reset_states()
2348  >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]],
2349  ...                sample_weight=[0.3, 0.7])
2350  >>> m.result().numpy()
2351  0.6999999
2352
2353  Usage with `compile()` API:
2354
2355  ```python
2356  model.compile(
2357      optimizer='sgd',
2358      loss='mse',
2359      metrics=[tf.keras.metrics.CosineSimilarity(axis=1)])
2360  ```
2361  """
2362
2363  def __init__(self, name='cosine_similarity', dtype=None, axis=-1):
2364    super(CosineSimilarity, self).__init__(
2365        cosine_similarity, name, dtype=dtype, axis=axis)
2366
2367
2368@keras_export('keras.metrics.MeanAbsoluteError')
2369class MeanAbsoluteError(MeanMetricWrapper):
2370  """Computes the mean absolute error between the labels and predictions.
2371
2372  Args:
2373    name: (Optional) string name of the metric instance.
2374    dtype: (Optional) data type of the metric result.
2375
2376  Standalone usage:
2377
2378  >>> m = tf.keras.metrics.MeanAbsoluteError()
2379  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2380  >>> m.result().numpy()
2381  0.25
2382
2383  >>> m.reset_states()
2384  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2385  ...                sample_weight=[1, 0])
2386  >>> m.result().numpy()
2387  0.5
2388
2389  Usage with `compile()` API:
2390
2391  ```python
2392  model.compile(
2393      optimizer='sgd',
2394      loss='mse',
2395      metrics=[tf.keras.metrics.MeanAbsoluteError()])
2396  ```
2397  """
2398
2399  def __init__(self, name='mean_absolute_error', dtype=None):
2400    super(MeanAbsoluteError, self).__init__(
2401        mean_absolute_error, name, dtype=dtype)
2402
2403
2404@keras_export('keras.metrics.MeanAbsolutePercentageError')
2405class MeanAbsolutePercentageError(MeanMetricWrapper):
2406  """Computes the mean absolute percentage error between `y_true` and `y_pred`.
2407
2408  Args:
2409    name: (Optional) string name of the metric instance.
2410    dtype: (Optional) data type of the metric result.
2411
2412  Standalone usage:
2413
2414  >>> m = tf.keras.metrics.MeanAbsolutePercentageError()
2415  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2416  >>> m.result().numpy()
2417  250000000.0
2418
2419  >>> m.reset_states()
2420  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2421  ...                sample_weight=[1, 0])
2422  >>> m.result().numpy()
2423  500000000.0
2424
2425  Usage with `compile()` API:
2426
2427  ```python
2428  model.compile(
2429      optimizer='sgd',
2430      loss='mse',
2431      metrics=[tf.keras.metrics.MeanAbsolutePercentageError()])
2432  ```
2433  """
2434
2435  def __init__(self, name='mean_absolute_percentage_error', dtype=None):
2436    super(MeanAbsolutePercentageError, self).__init__(
2437        mean_absolute_percentage_error, name, dtype=dtype)
2438
2439
2440@keras_export('keras.metrics.MeanSquaredError')
2441class MeanSquaredError(MeanMetricWrapper):
2442  """Computes the mean squared error between `y_true` and `y_pred`.
2443
2444  Args:
2445    name: (Optional) string name of the metric instance.
2446    dtype: (Optional) data type of the metric result.
2447
2448  Standalone usage:
2449
2450  >>> m = tf.keras.metrics.MeanSquaredError()
2451  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2452  >>> m.result().numpy()
2453  0.25
2454
2455  >>> m.reset_states()
2456  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2457  ...                sample_weight=[1, 0])
2458  >>> m.result().numpy()
2459  0.5
2460
2461  Usage with `compile()` API:
2462
2463  ```python
2464  model.compile(
2465      optimizer='sgd',
2466      loss='mse',
2467      metrics=[tf.keras.metrics.MeanSquaredError()])
2468  ```
2469  """
2470
2471  def __init__(self, name='mean_squared_error', dtype=None):
2472    super(MeanSquaredError, self).__init__(
2473        mean_squared_error, name, dtype=dtype)
2474
2475
2476@keras_export('keras.metrics.MeanSquaredLogarithmicError')
2477class MeanSquaredLogarithmicError(MeanMetricWrapper):
2478  """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
2479
2480  Args:
2481    name: (Optional) string name of the metric instance.
2482    dtype: (Optional) data type of the metric result.
2483
2484  Standalone usage:
2485
2486  >>> m = tf.keras.metrics.MeanSquaredLogarithmicError()
2487  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2488  >>> m.result().numpy()
2489  0.12011322
2490
2491  >>> m.reset_states()
2492  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2493  ...                sample_weight=[1, 0])
2494  >>> m.result().numpy()
2495  0.24022643
2496
2497  Usage with `compile()` API:
2498
2499  ```python
2500  model.compile(
2501      optimizer='sgd',
2502      loss='mse',
2503      metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()])
2504  ```
2505  """
2506
2507  def __init__(self, name='mean_squared_logarithmic_error', dtype=None):
2508    super(MeanSquaredLogarithmicError, self).__init__(
2509        mean_squared_logarithmic_error, name, dtype=dtype)
2510
2511
2512@keras_export('keras.metrics.Hinge')
2513class Hinge(MeanMetricWrapper):
2514  """Computes the hinge metric between `y_true` and `y_pred`.
2515
2516  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2517  provided we will convert them to -1 or 1.
2518
2519  Args:
2520    name: (Optional) string name of the metric instance.
2521    dtype: (Optional) data type of the metric result.
2522
2523  Standalone usage:
2524
2525  >>> m = tf.keras.metrics.Hinge()
2526  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2527  >>> m.result().numpy()
2528  1.3
2529
2530  >>> m.reset_states()
2531  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2532  ...                sample_weight=[1, 0])
2533  >>> m.result().numpy()
2534  1.1
2535
2536  Usage with `compile()` API:
2537
2538  ```python
2539  model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()])
2540  ```
2541  """
2542
2543  def __init__(self, name='hinge', dtype=None):
2544    super(Hinge, self).__init__(hinge, name, dtype=dtype)
2545
2546
2547@keras_export('keras.metrics.SquaredHinge')
2548class SquaredHinge(MeanMetricWrapper):
2549  """Computes the squared hinge metric between `y_true` and `y_pred`.
2550
2551  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2552  provided we will convert them to -1 or 1.
2553
2554  Args:
2555    name: (Optional) string name of the metric instance.
2556    dtype: (Optional) data type of the metric result.
2557
2558  Standalone usage:
2559
2560  >>> m = tf.keras.metrics.SquaredHinge()
2561  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2562  >>> m.result().numpy()
2563  1.86
2564
2565  >>> m.reset_states()
2566  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2567  ...                sample_weight=[1, 0])
2568  >>> m.result().numpy()
2569  1.46
2570
2571  Usage with `compile()` API:
2572
2573  ```python
2574  model.compile(
2575      optimizer='sgd',
2576      loss='mse',
2577      metrics=[tf.keras.metrics.SquaredHinge()])
2578  ```
2579  """
2580
2581  def __init__(self, name='squared_hinge', dtype=None):
2582    super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype)
2583
2584
2585@keras_export('keras.metrics.CategoricalHinge')
2586class CategoricalHinge(MeanMetricWrapper):
2587  """Computes the categorical hinge metric between `y_true` and `y_pred`.
2588
2589  Args:
2590    name: (Optional) string name of the metric instance.
2591    dtype: (Optional) data type of the metric result.
2592
2593  Standalone usage:
2594
2595  >>> m = tf.keras.metrics.CategoricalHinge()
2596  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2597  >>> m.result().numpy()
2598  1.4000001
2599
2600  >>> m.reset_states()
2601  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2602  ...                sample_weight=[1, 0])
2603  >>> m.result().numpy()
2604  1.2
2605
2606  Usage with `compile()` API:
2607
2608  ```python
2609  model.compile(
2610      optimizer='sgd',
2611      loss='mse',
2612      metrics=[tf.keras.metrics.CategoricalHinge()])
2613  ```
2614  """
2615
2616  def __init__(self, name='categorical_hinge', dtype=None):
2617    super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype)
2618
2619
2620@keras_export('keras.metrics.RootMeanSquaredError')
2621class RootMeanSquaredError(Mean):
2622  """Computes root mean squared error metric between `y_true` and `y_pred`.
2623
2624  Standalone usage:
2625
2626  >>> m = tf.keras.metrics.RootMeanSquaredError()
2627  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2628  >>> m.result().numpy()
2629  0.5
2630
2631  >>> m.reset_states()
2632  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2633  ...                sample_weight=[1, 0])
2634  >>> m.result().numpy()
2635  0.70710677
2636
2637  Usage with `compile()` API:
2638
2639  ```python
2640  model.compile(
2641      optimizer='sgd',
2642      loss='mse',
2643      metrics=[tf.keras.metrics.RootMeanSquaredError()])
2644  ```
2645  """
2646
2647  def __init__(self, name='root_mean_squared_error', dtype=None):
2648    super(RootMeanSquaredError, self).__init__(name, dtype=dtype)
2649
2650  def update_state(self, y_true, y_pred, sample_weight=None):
2651    """Accumulates root mean squared error statistics.
2652
2653    Args:
2654      y_true: The ground truth values.
2655      y_pred: The predicted values.
2656      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2657        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2658        be broadcastable to `y_true`.
2659
2660    Returns:
2661      Update op.
2662    """
2663    y_true = math_ops.cast(y_true, self._dtype)
2664    y_pred = math_ops.cast(y_pred, self._dtype)
2665    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
2666        y_pred, y_true)
2667    error_sq = math_ops.squared_difference(y_pred, y_true)
2668    return super(RootMeanSquaredError, self).update_state(
2669        error_sq, sample_weight=sample_weight)
2670
2671  def result(self):
2672    return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count))
2673
2674
2675@keras_export('keras.metrics.LogCoshError')
2676class LogCoshError(MeanMetricWrapper):
2677  """Computes the logarithm of the hyperbolic cosine of the prediction error.
2678
2679  `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true)
2680
2681  Args:
2682    name: (Optional) string name of the metric instance.
2683    dtype: (Optional) data type of the metric result.
2684
2685  Standalone usage:
2686
2687  >>> m = tf.keras.metrics.LogCoshError()
2688  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2689  >>> m.result().numpy()
2690  0.10844523
2691
2692  >>> m.reset_states()
2693  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2694  ...                sample_weight=[1, 0])
2695  >>> m.result().numpy()
2696  0.21689045
2697
2698  Usage with `compile()` API:
2699
2700  ```python
2701  model.compile(optimizer='sgd',
2702                loss='mse',
2703                metrics=[tf.keras.metrics.LogCoshError()])
2704  ```
2705  """
2706
2707  def __init__(self, name='logcosh', dtype=None):
2708    super(LogCoshError, self).__init__(logcosh, name, dtype=dtype)
2709
2710
2711@keras_export('keras.metrics.Poisson')
2712class Poisson(MeanMetricWrapper):
2713  """Computes the Poisson metric between `y_true` and `y_pred`.
2714
2715  `metric = y_pred - y_true * log(y_pred)`
2716
2717  Args:
2718    name: (Optional) string name of the metric instance.
2719    dtype: (Optional) data type of the metric result.
2720
2721  Standalone usage:
2722
2723  >>> m = tf.keras.metrics.Poisson()
2724  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2725  >>> m.result().numpy()
2726  0.49999997
2727
2728  >>> m.reset_states()
2729  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2730  ...                sample_weight=[1, 0])
2731  >>> m.result().numpy()
2732  0.99999994
2733
2734  Usage with `compile()` API:
2735
2736  ```python
2737  model.compile(optimizer='sgd',
2738                loss='mse',
2739                metrics=[tf.keras.metrics.Poisson()])
2740  ```
2741  """
2742
2743  def __init__(self, name='poisson', dtype=None):
2744    super(Poisson, self).__init__(poisson, name, dtype=dtype)
2745
2746
2747@keras_export('keras.metrics.KLDivergence')
2748class KLDivergence(MeanMetricWrapper):
2749  """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`.
2750
2751  `metric = y_true * log(y_true / y_pred)`
2752
2753  Args:
2754    name: (Optional) string name of the metric instance.
2755    dtype: (Optional) data type of the metric result.
2756
2757  Standalone usage:
2758
2759  >>> m = tf.keras.metrics.KLDivergence()
2760  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2761  >>> m.result().numpy()
2762  0.45814306
2763
2764  >>> m.reset_states()
2765  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2766  ...                sample_weight=[1, 0])
2767  >>> m.result().numpy()
2768  0.9162892
2769
2770  Usage with `compile()` API:
2771
2772  ```python
2773  model.compile(optimizer='sgd',
2774                loss='mse',
2775                metrics=[tf.keras.metrics.KLDivergence()])
2776  ```
2777  """
2778
2779  def __init__(self, name='kullback_leibler_divergence', dtype=None):
2780    super(KLDivergence, self).__init__(
2781        kullback_leibler_divergence, name, dtype=dtype)
2782
2783
2784@keras_export('keras.metrics.MeanIoU')
2785class MeanIoU(Metric):
2786  """Computes the mean Intersection-Over-Union metric.
2787
2788  Mean Intersection-Over-Union is a common evaluation metric for semantic image
2789  segmentation, which first computes the IOU for each semantic class and then
2790  computes the average over classes. IOU is defined as follows:
2791    IOU = true_positive / (true_positive + false_positive + false_negative).
2792  The predictions are accumulated in a confusion matrix, weighted by
2793  `sample_weight` and the metric is then calculated from it.
2794
2795  If `sample_weight` is `None`, weights default to 1.
2796  Use `sample_weight` of 0 to mask values.
2797
2798  Args:
2799    num_classes: The possible number of labels the prediction task can have.
2800      This value must be provided, since a confusion matrix of dimension =
2801      [num_classes, num_classes] will be allocated.
2802    name: (Optional) string name of the metric instance.
2803    dtype: (Optional) data type of the metric result.
2804
2805  Standalone usage:
2806
2807  >>> # cm = [[1, 1],
2808  >>> #        [1, 1]]
2809  >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
2810  >>> # iou = true_positives / (sum_row + sum_col - true_positives))
2811  >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
2812  >>> m = tf.keras.metrics.MeanIoU(num_classes=2)
2813  >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
2814  >>> m.result().numpy()
2815  0.33333334
2816
2817  >>> m.reset_states()
2818  >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
2819  ...                sample_weight=[0.3, 0.3, 0.3, 0.1])
2820  >>> m.result().numpy()
2821  0.23809525
2822
2823  Usage with `compile()` API:
2824
2825  ```python
2826  model.compile(
2827    optimizer='sgd',
2828    loss='mse',
2829    metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])
2830  ```
2831  """
2832
2833  def __init__(self, num_classes, name=None, dtype=None):
2834    super(MeanIoU, self).__init__(name=name, dtype=dtype)
2835    self.num_classes = num_classes
2836
2837    # Variable to accumulate the predictions in the confusion matrix.
2838    self.total_cm = self.add_weight(
2839        'total_confusion_matrix',
2840        shape=(num_classes, num_classes),
2841        initializer=init_ops.zeros_initializer)
2842
2843  def update_state(self, y_true, y_pred, sample_weight=None):
2844    """Accumulates the confusion matrix statistics.
2845
2846    Args:
2847      y_true: The ground truth values.
2848      y_pred: The predicted values.
2849      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2850        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2851        be broadcastable to `y_true`.
2852
2853    Returns:
2854      Update op.
2855    """
2856
2857    y_true = math_ops.cast(y_true, self._dtype)
2858    y_pred = math_ops.cast(y_pred, self._dtype)
2859
2860    # Flatten the input if its rank > 1.
2861    if y_pred.shape.ndims > 1:
2862      y_pred = array_ops.reshape(y_pred, [-1])
2863
2864    if y_true.shape.ndims > 1:
2865      y_true = array_ops.reshape(y_true, [-1])
2866
2867    if sample_weight is not None:
2868      sample_weight = math_ops.cast(sample_weight, self._dtype)
2869      if sample_weight.shape.ndims > 1:
2870        sample_weight = array_ops.reshape(sample_weight, [-1])
2871
2872    # Accumulate the prediction to current confusion matrix.
2873    current_cm = confusion_matrix.confusion_matrix(
2874        y_true,
2875        y_pred,
2876        self.num_classes,
2877        weights=sample_weight,
2878        dtype=self._dtype)
2879    return self.total_cm.assign_add(current_cm)
2880
2881  def result(self):
2882    """Compute the mean intersection-over-union via the confusion matrix."""
2883    sum_over_row = math_ops.cast(
2884        math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
2885    sum_over_col = math_ops.cast(
2886        math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
2887    true_positives = math_ops.cast(
2888        array_ops.tensor_diag_part(self.total_cm), dtype=self._dtype)
2889
2890    # sum_over_row + sum_over_col =
2891    #     2 * true_positives + false_positives + false_negatives.
2892    denominator = sum_over_row + sum_over_col - true_positives
2893
2894    # The mean is only computed over classes that appear in the
2895    # label or prediction tensor. If the denominator is 0, we need to
2896    # ignore the class.
2897    num_valid_entries = math_ops.reduce_sum(
2898        math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype))
2899
2900    iou = math_ops.div_no_nan(true_positives, denominator)
2901
2902    return math_ops.div_no_nan(
2903        math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries)
2904
2905  def reset_states(self):
2906    K.set_value(self.total_cm, np.zeros((self.num_classes, self.num_classes)))
2907
2908  def get_config(self):
2909    config = {'num_classes': self.num_classes}
2910    base_config = super(MeanIoU, self).get_config()
2911    return dict(list(base_config.items()) + list(config.items()))
2912
2913
2914@keras_export('keras.metrics.MeanTensor')
2915class MeanTensor(Metric):
2916  """Computes the element-wise (weighted) mean of the given tensors.
2917
2918  `MeanTensor` returns a tensor with the same shape of the input tensors. The
2919  mean value is updated by keeping local variables `total` and `count`. The
2920  `total` tracks the sum of the weighted values, and `count` stores the sum of
2921  the weighted counts.
2922
2923  Args:
2924    name: (Optional) string name of the metric instance.
2925    dtype: (Optional) data type of the metric result.
2926    shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor
2927      of type int32. If not specified, the shape is inferred from the values at
2928      the first call of update_state.
2929
2930  Standalone usage:
2931
2932  >>> m = tf.keras.metrics.MeanTensor()
2933  >>> m.update_state([0, 1, 2, 3])
2934  >>> m.update_state([4, 5, 6, 7])
2935  >>> m.result().numpy()
2936  array([2., 3., 4., 5.], dtype=float32)
2937
2938  >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1])
2939  >>> m.result().numpy()
2940  array([2.       , 3.6363635, 4.8      , 5.3333335], dtype=float32)
2941
2942  >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4))
2943  >>> m.result().numpy()
2944  array([[0., 0., 0., 0.]])
2945  >>> m.update_state([[0, 1, 2, 3]])
2946  >>> m.update_state([[4, 5, 6, 7]])
2947  >>> m.result().numpy()
2948  array([[2., 3., 4., 5.]])
2949  """
2950
2951  def __init__(self, name='mean_tensor', dtype=None, shape=None):
2952    super(MeanTensor, self).__init__(name=name, dtype=dtype)
2953    self._shape = None
2954    self._total = None
2955    self._count = None
2956    self._built = False
2957    if shape is not None:
2958      self._build(shape)
2959
2960  def _build(self, shape):
2961    self._shape = tensor_shape.TensorShape(shape)
2962    self._build_input_shape = self._shape
2963    # Create new state variables
2964    self._total = self.add_weight(
2965        'total', shape=shape, initializer=init_ops.zeros_initializer)
2966    self._count = self.add_weight(
2967        'count', shape=shape, initializer=init_ops.zeros_initializer)
2968    with ops.init_scope():
2969      if not context.executing_eagerly():
2970        K._initialize_variables(K._get_session())  # pylint: disable=protected-access
2971    self._built = True
2972
2973  @property
2974  def total(self):
2975    return self._total if self._built else None
2976
2977  @property
2978  def count(self):
2979    return self._count if self._built else None
2980
2981  def update_state(self, values, sample_weight=None):
2982    """Accumulates statistics for computing the element-wise mean.
2983
2984    Args:
2985      values: Per-example value.
2986      sample_weight: Optional weighting of each example. Defaults to 1.
2987
2988    Returns:
2989      Update op.
2990    """
2991    values = math_ops.cast(values, self._dtype)
2992    if not self._built:
2993      self._build(values.shape)
2994    elif values.shape != self._shape:
2995      raise ValueError('MeanTensor input values must always have the same '
2996                       'shape. Expected shape (set during the first call): {}. '
2997                       'Got: {}'.format(self._shape, values.shape))
2998
2999    num_values = array_ops.ones_like(values)
3000    if sample_weight is not None:
3001      sample_weight = math_ops.cast(sample_weight, self._dtype)
3002
3003      # Update dimensions of weights to match with values if possible.
3004      values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
3005          values, sample_weight=sample_weight)
3006      try:
3007        # Broadcast weights if possible.
3008        sample_weight = weights_broadcast_ops.broadcast_weights(
3009            sample_weight, values)
3010      except ValueError:
3011        # Reduce values to same ndim as weight array
3012        ndim = K.ndim(values)
3013        weight_ndim = K.ndim(sample_weight)
3014        values = math_ops.reduce_mean(
3015            values, axis=list(range(weight_ndim, ndim)))
3016
3017      num_values = math_ops.multiply(num_values, sample_weight)
3018      values = math_ops.multiply(values, sample_weight)
3019
3020    update_total_op = self._total.assign_add(values)
3021    with ops.control_dependencies([update_total_op]):
3022      return self._count.assign_add(num_values)
3023
3024  def result(self):
3025    if not self._built:
3026      raise ValueError(
3027          'MeanTensor does not have any result yet. Please call the MeanTensor '
3028          'instance or use `.update_state(value)` before retrieving the result.'
3029          )
3030    return math_ops.div_no_nan(self.total, self.count)
3031
3032  def reset_states(self):
3033    if self._built:
3034      K.batch_set_value(
3035          [(v, np.zeros(self._shape.as_list())) for v in self.variables])
3036
3037
3038@keras_export('keras.metrics.BinaryCrossentropy')
3039class BinaryCrossentropy(MeanMetricWrapper):
3040  """Computes the crossentropy metric between the labels and predictions.
3041
3042  This is the crossentropy metric class to be used when there are only two
3043  label classes (0 and 1).
3044
3045  Args:
3046    name: (Optional) string name of the metric instance.
3047    dtype: (Optional) data type of the metric result.
3048    from_logits: (Optional )Whether output is expected to be a logits tensor.
3049      By default, we consider that output encodes a probability distribution.
3050    label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
3051      smoothed, meaning the confidence on label values are relaxed.
3052      e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for
3053      label `0` and `0.9` for label `1`".
3054
3055  Standalone usage:
3056
3057  >>> m = tf.keras.metrics.BinaryCrossentropy()
3058  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
3059  >>> m.result().numpy()
3060  0.81492424
3061
3062  >>> m.reset_states()
3063  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
3064  ...                sample_weight=[1, 0])
3065  >>> m.result().numpy()
3066  0.9162905
3067
3068  Usage with `compile()` API:
3069
3070  ```python
3071  model.compile(
3072      optimizer='sgd',
3073      loss='mse',
3074      metrics=[tf.keras.metrics.BinaryCrossentropy()])
3075  ```
3076  """
3077
3078  def __init__(self,
3079               name='binary_crossentropy',
3080               dtype=None,
3081               from_logits=False,
3082               label_smoothing=0):
3083    super(BinaryCrossentropy, self).__init__(
3084        binary_crossentropy,
3085        name,
3086        dtype=dtype,
3087        from_logits=from_logits,
3088        label_smoothing=label_smoothing)
3089
3090
3091@keras_export('keras.metrics.CategoricalCrossentropy')
3092class CategoricalCrossentropy(MeanMetricWrapper):
3093  """Computes the crossentropy metric between the labels and predictions.
3094
3095  This is the crossentropy metric class to be used when there are multiple
3096  label classes (2 or more). Here we assume that labels are given as a `one_hot`
3097  representation. eg., When labels values are [2, 0, 1],
3098   `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]].
3099
3100  Args:
3101    name: (Optional) string name of the metric instance.
3102    dtype: (Optional) data type of the metric result.
3103    from_logits: (Optional) Whether output is expected to be a logits tensor.
3104      By default, we consider that output encodes a probability distribution.
3105    label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
3106      smoothed, meaning the confidence on label values are relaxed. e.g.
3107      `label_smoothing=0.2` means that we will use a value of `0.1` for label
3108      `0` and `0.9` for label `1`"
3109
3110  Standalone usage:
3111
3112  >>> # EPSILON = 1e-7, y = y_true, y` = y_pred
3113  >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)
3114  >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
3115  >>> # xent = -sum(y * log(y'), axis = -1)
3116  >>> #      = -((log 0.95), (log 0.1))
3117  >>> #      = [0.051, 2.302]
3118  >>> # Reduced xent = (0.051 + 2.302) / 2
3119  >>> m = tf.keras.metrics.CategoricalCrossentropy()
3120  >>> m.update_state([[0, 1, 0], [0, 0, 1]],
3121  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
3122  >>> m.result().numpy()
3123  1.1769392
3124
3125  >>> m.reset_states()
3126  >>> m.update_state([[0, 1, 0], [0, 0, 1]],
3127  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
3128  ...                sample_weight=tf.constant([0.3, 0.7]))
3129  >>> m.result().numpy()
3130  1.6271976
3131
3132  Usage with `compile()` API:
3133
3134  ```python
3135  model.compile(
3136    optimizer='sgd',
3137    loss='mse',
3138    metrics=[tf.keras.metrics.CategoricalCrossentropy()])
3139  ```
3140  """
3141
3142  def __init__(self,
3143               name='categorical_crossentropy',
3144               dtype=None,
3145               from_logits=False,
3146               label_smoothing=0):
3147    super(CategoricalCrossentropy, self).__init__(
3148        categorical_crossentropy,
3149        name,
3150        dtype=dtype,
3151        from_logits=from_logits,
3152        label_smoothing=label_smoothing)
3153
3154
3155@keras_export('keras.metrics.SparseCategoricalCrossentropy')
3156class SparseCategoricalCrossentropy(MeanMetricWrapper):
3157  """Computes the crossentropy metric between the labels and predictions.
3158
3159  Use this crossentropy metric when there are two or more label classes.
3160  We expect labels to be provided as integers. If you want to provide labels
3161  using `one-hot` representation, please use `CategoricalCrossentropy` metric.
3162  There should be `# classes` floating point values per feature for `y_pred`
3163  and a single floating point value per feature for `y_true`.
3164
3165  In the snippet below, there is a single floating point value per example for
3166  `y_true` and `# classes` floating pointing values per example for `y_pred`.
3167  The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
3168  `[batch_size, num_classes]`.
3169
3170  Args:
3171    name: (Optional) string name of the metric instance.
3172    dtype: (Optional) data type of the metric result.
3173    from_logits: (Optional) Whether output is expected to be a logits tensor.
3174      By default, we consider that output encodes a probability distribution.
3175    axis: (Optional) Defaults to -1. The dimension along which the metric is
3176      computed.
3177
3178  Standalone usage:
3179
3180  >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]
3181  >>> # logits = log(y_pred)
3182  >>> # softmax = exp(logits) / sum(exp(logits), axis=-1)
3183  >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
3184  >>> # xent = -sum(y * log(softmax), 1)
3185  >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181],
3186  >>> #                [-2.3026, -0.2231, -2.3026]]
3187  >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]]
3188  >>> # xent = [0.0513, 2.3026]
3189  >>> # Reduced xent = (0.0513 + 2.3026) / 2
3190  >>> m = tf.keras.metrics.SparseCategoricalCrossentropy()
3191  >>> m.update_state([1, 2],
3192  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
3193  >>> m.result().numpy()
3194  1.1769392
3195
3196  >>> m.reset_states()
3197  >>> m.update_state([1, 2],
3198  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
3199  ...                sample_weight=tf.constant([0.3, 0.7]))
3200  >>> m.result().numpy()
3201  1.6271976
3202
3203  Usage with `compile()` API:
3204
3205  ```python
3206  model.compile(
3207    optimizer='sgd',
3208    loss='mse',
3209    metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()])
3210  ```
3211  """
3212
3213  def __init__(self,
3214               name='sparse_categorical_crossentropy',
3215               dtype=None,
3216               from_logits=False,
3217               axis=-1):
3218    super(SparseCategoricalCrossentropy, self).__init__(
3219        sparse_categorical_crossentropy,
3220        name,
3221        dtype=dtype,
3222        from_logits=from_logits,
3223        axis=axis)
3224
3225
3226class SumOverBatchSize(Reduce):
3227  """Computes the weighted sum over batch size of the given values.
3228
3229  For example, if values is [1, 3, 5, 7] then the metric value is 4.
3230  If the weights were specified as [1, 1, 0, 0] then the value would be 1.
3231
3232  This metric creates two variables, `total` and `count` that are used to
3233  compute the average of `values`. This average is ultimately returned as sum
3234  over batch size which is an idempotent operation that simply divides `total`
3235  by `count`.
3236
3237  If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
3238  to mask values.
3239  """
3240
3241  def __init__(self, name='sum_over_batch_size', dtype=None):
3242    super(SumOverBatchSize, self).__init__(
3243        reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
3244        name=name,
3245        dtype=dtype)
3246
3247
3248class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
3249  """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric."""
3250
3251  def __init__(self, fn, name=None, dtype=None, **kwargs):
3252    """Creates a `SumOverBatchSizeMetricWrapper` instance.
3253
3254    Args:
3255      fn: The metric function to wrap, with signature `fn(y_true, y_pred,
3256        **kwargs)`.
3257      name: (Optional) string name of the metric instance.
3258      dtype: (Optional) data type of the metric result.
3259      **kwargs: The keyword arguments that are passed on to `fn`.
3260    """
3261    super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype)
3262    self._fn = fn
3263    self._fn_kwargs = kwargs
3264
3265  def update_state(self, y_true, y_pred, sample_weight=None):
3266    y_true = math_ops.cast(y_true, self._dtype)
3267    y_pred = math_ops.cast(y_pred, self._dtype)
3268    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
3269        y_pred, y_true)
3270
3271    ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
3272    matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
3273    return super(SumOverBatchSizeMetricWrapper, self).update_state(
3274        matches, sample_weight=sample_weight)
3275
3276  def get_config(self):
3277    config = {}
3278    for k, v in six.iteritems(self._fn_kwargs):
3279      config[k] = K.eval(v) if is_tensor_or_variable(v) else v
3280    base_config = super(SumOverBatchSizeMetricWrapper, self).get_config()
3281    return dict(list(base_config.items()) + list(config.items()))
3282
3283
3284def accuracy(y_true, y_pred):
3285  [y_pred, y_true], _ = \
3286      metrics_utils.ragged_assert_compatible_and_get_flat_values(
3287          [y_pred, y_true])
3288  y_pred.shape.assert_is_compatible_with(y_true.shape)
3289  if y_true.dtype != y_pred.dtype:
3290    y_pred = math_ops.cast(y_pred, y_true.dtype)
3291  return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
3292
3293
3294@keras_export('keras.metrics.binary_accuracy')
3295@dispatch.add_dispatch_support
3296def binary_accuracy(y_true, y_pred, threshold=0.5):
3297  """Calculates how often predictions match binary labels.
3298
3299  Standalone usage:
3300  >>> y_true = [[1], [1], [0], [0]]
3301  >>> y_pred = [[1], [1], [0], [0]]
3302  >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred)
3303  >>> assert m.shape == (4,)
3304  >>> m.numpy()
3305  array([1., 1., 1., 1.], dtype=float32)
3306
3307  Args:
3308    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
3309    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
3310    threshold: (Optional) Float representing the threshold for deciding whether
3311      prediction values are 1 or 0.
3312
3313  Returns:
3314    Binary accuracy values. shape = `[batch_size, d0, .. dN-1]`
3315  """
3316  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
3317  threshold = math_ops.cast(threshold, y_pred.dtype)
3318  y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
3319  return K.mean(math_ops.equal(y_true, y_pred), axis=-1)
3320
3321
3322@keras_export('keras.metrics.categorical_accuracy')
3323@dispatch.add_dispatch_support
3324def categorical_accuracy(y_true, y_pred):
3325  """Calculates how often predictions match one-hot labels.
3326
3327  Standalone usage:
3328  >>> y_true = [[0, 0, 1], [0, 1, 0]]
3329  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3330  >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred)
3331  >>> assert m.shape == (2,)
3332  >>> m.numpy()
3333  array([0., 1.], dtype=float32)
3334
3335  You can provide logits of classes as `y_pred`, since argmax of
3336  logits and probabilities are same.
3337
3338  Args:
3339    y_true: One-hot ground truth values.
3340    y_pred: The prediction values.
3341
3342  Returns:
3343    Categorical accuracy values.
3344  """
3345  return math_ops.cast(
3346      math_ops.equal(
3347          math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)),
3348      K.floatx())
3349
3350
3351@keras_export('keras.metrics.sparse_categorical_accuracy')
3352@dispatch.add_dispatch_support
3353def sparse_categorical_accuracy(y_true, y_pred):
3354  """Calculates how often predictions match integer labels.
3355
3356  Standalone usage:
3357  >>> y_true = [2, 1]
3358  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3359  >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
3360  >>> assert m.shape == (2,)
3361  >>> m.numpy()
3362  array([0., 1.], dtype=float32)
3363
3364  You can provide logits of classes as `y_pred`, since argmax of
3365  logits and probabilities are same.
3366
3367  Args:
3368    y_true: Integer ground truth values.
3369    y_pred: The prediction values.
3370
3371  Returns:
3372    Sparse categorical accuracy values.
3373  """
3374  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
3375  y_true = ops.convert_to_tensor_v2_with_dispatch(y_true)
3376  y_pred_rank = y_pred.shape.ndims
3377  y_true_rank = y_true.shape.ndims
3378  # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
3379  if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
3380      K.int_shape(y_true)) == len(K.int_shape(y_pred))):
3381    y_true = array_ops.squeeze(y_true, [-1])
3382  y_pred = math_ops.argmax(y_pred, axis=-1)
3383
3384  # If the predicted output and actual output types don't match, force cast them
3385  # to match.
3386  if K.dtype(y_pred) != K.dtype(y_true):
3387    y_pred = math_ops.cast(y_pred, K.dtype(y_true))
3388
3389  return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
3390
3391
3392@keras_export('keras.metrics.top_k_categorical_accuracy')
3393@dispatch.add_dispatch_support
3394def top_k_categorical_accuracy(y_true, y_pred, k=5):
3395  """Computes how often targets are in the top `K` predictions.
3396
3397  Standalone usage:
3398  >>> y_true = [[0, 0, 1], [0, 1, 0]]
3399  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3400  >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)
3401  >>> assert m.shape == (2,)
3402  >>> m.numpy()
3403  array([1., 1.], dtype=float32)
3404
3405  Args:
3406    y_true: The ground truth values.
3407    y_pred: The prediction values.
3408    k: (Optional) Number of top elements to look at for computing accuracy.
3409      Defaults to 5.
3410
3411  Returns:
3412    Top K categorical accuracy value.
3413  """
3414  return math_ops.cast(
3415      nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), K.floatx())
3416
3417
3418@keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
3419@dispatch.add_dispatch_support
3420def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
3421  """Computes how often integer targets are in the top `K` predictions.
3422
3423  Standalone usage:
3424  >>> y_true = [2, 1]
3425  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3426  >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy(
3427  ...     y_true, y_pred, k=3)
3428  >>> assert m.shape == (2,)
3429  >>> m.numpy()
3430  array([1., 1.], dtype=float32)
3431
3432  Args:
3433    y_true: tensor of true targets.
3434    y_pred: tensor of predicted targets.
3435    k: (Optional) Number of top elements to look at for computing accuracy.
3436      Defaults to 5.
3437
3438  Returns:
3439    Sparse top K categorical accuracy value.
3440  """
3441  y_pred_rank = ops.convert_to_tensor_v2_with_dispatch(y_pred).shape.ndims
3442  y_true_rank = ops.convert_to_tensor_v2_with_dispatch(y_true).shape.ndims
3443  # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
3444  if (y_true_rank is not None) and (y_pred_rank is not None):
3445    if y_pred_rank > 2:
3446      y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]])
3447    if y_true_rank > 1:
3448      y_true = array_ops.reshape(y_true, [-1])
3449
3450  return math_ops.cast(
3451      nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), K.floatx())
3452
3453
3454def cosine_proximity(y_true, y_pred, axis=-1):
3455  """Computes the cosine similarity between labels and predictions.
3456
3457  Args:
3458    y_true: The ground truth values.
3459    y_pred: The prediction values.
3460    axis: (Optional) Defaults to -1. The dimension along which the cosine
3461      similarity is computed.
3462
3463  Returns:
3464    Cosine similarity value.
3465  """
3466  y_true = nn.l2_normalize(y_true, axis=axis)
3467  y_pred = nn.l2_normalize(y_pred, axis=axis)
3468  return math_ops.reduce_sum(y_true * y_pred, axis=axis)
3469
3470# Aliases
3471
3472acc = ACC = accuracy
3473bce = BCE = binary_crossentropy
3474mse = MSE = mean_squared_error
3475mae = MAE = mean_absolute_error
3476mape = MAPE = mean_absolute_percentage_error
3477msle = MSLE = mean_squared_logarithmic_error
3478cosine_similarity = cosine_proximity
3479log_cosh = logcosh
3480
3481
3482def clone_metric(metric):
3483  """Returns a clone of the metric if stateful, otherwise returns it as is."""
3484  if isinstance(metric, Metric):
3485    with ops.init_scope():
3486      return metric.__class__.from_config(metric.get_config())
3487  return metric
3488
3489
3490def clone_metrics(metrics):
3491  """Clones the given metric list/dict."""
3492  return nest.map_structure(clone_metric, metrics)
3493
3494
3495@keras_export('keras.metrics.serialize')
3496def serialize(metric):
3497  """Serializes metric function or `Metric` instance.
3498
3499  Args:
3500    metric: A Keras `Metric` instance or a metric function.
3501
3502  Returns:
3503    Metric configuration dictionary.
3504  """
3505  return serialize_keras_object(metric)
3506
3507
3508@keras_export('keras.metrics.deserialize')
3509def deserialize(config, custom_objects=None):
3510  """Deserializes a serialized metric class/function instance.
3511
3512  Args:
3513    config: Metric configuration.
3514    custom_objects: Optional dictionary mapping names (strings) to custom
3515      objects (classes and functions) to be considered during deserialization.
3516
3517  Returns:
3518      A Keras `Metric` instance or a metric function.
3519  """
3520  return deserialize_keras_object(
3521      config,
3522      module_objects=globals(),
3523      custom_objects=custom_objects,
3524      printable_module_name='metric function')
3525
3526
3527@keras_export('keras.metrics.get')
3528def get(identifier):
3529  """Retrieves a Keras metric as a `function`/`Metric` class instance.
3530
3531  The `identifier` may be the string name of a metric function or class.
3532
3533  >>> metric = tf.keras.metrics.get("categorical_crossentropy")
3534  >>> type(metric)
3535  <class 'function'>
3536  >>> metric = tf.keras.metrics.get("CategoricalCrossentropy")
3537  >>> type(metric)
3538  <class '...tensorflow.python.keras.metrics.CategoricalCrossentropy'>
3539
3540  You can also specify `config` of the metric to this function by passing dict
3541  containing `class_name` and `config` as an identifier. Also note that the
3542  `class_name` must map to a `Metric` class
3543
3544  >>> identifier = {"class_name": "CategoricalCrossentropy",
3545  ...               "config": {"from_logits": True}}
3546  >>> metric = tf.keras.metrics.get(identifier)
3547  >>> type(metric)
3548  <class '...tensorflow.python.keras.metrics.CategoricalCrossentropy'>
3549
3550  Args:
3551    identifier: A metric identifier. One of None or string name of a metric
3552      function/class or metric configuration dictionary or a metric function or
3553      a metric class instance
3554
3555  Returns:
3556    A Keras metric as a `function`/ `Metric` class instance.
3557
3558  Raises:
3559    ValueError: If `identifier` cannot be interpreted.
3560  """
3561  if isinstance(identifier, dict):
3562    return deserialize(identifier)
3563  elif isinstance(identifier, six.string_types):
3564    return deserialize(str(identifier))
3565  elif callable(identifier):
3566    return identifier
3567  else:
3568    raise ValueError(
3569        'Could not interpret metric function identifier: {}'.format(identifier))
3570
3571
3572def is_built_in(cls):
3573  return cls.__module__ == Metric.__module__
3574