• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Confusion matrix related utilities."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import check_ops
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.util import deprecation
24from tensorflow.python.util import dispatch
25from tensorflow.python.util.tf_export import tf_export
26
27
28def remove_squeezable_dimensions(
29    labels, predictions, expected_rank_diff=0, name=None):
30  """Squeeze last dim if ranks differ from expected by exactly 1.
31
32  In the common case where we expect shapes to match, `expected_rank_diff`
33  defaults to 0, and we squeeze the last dimension of the larger rank if they
34  differ by 1.
35
36  But, for example, if `labels` contains class IDs and `predictions` contains 1
37  probability per class, we expect `predictions` to have 1 more dimension than
38  `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
39  `labels` if `rank(predictions) - rank(labels) == 0`, and
40  `predictions` if `rank(predictions) - rank(labels) == 2`.
41
42  This will use static shape if available. Otherwise, it will add graph
43  operations, which could result in a performance hit.
44
45  Args:
46    labels: Label values, a `Tensor` whose dimensions match `predictions`.
47    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
48    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
49    name: Name of the op.
50
51  Returns:
52    Tuple of `labels` and `predictions`, possibly with last dim squeezed.
53  """
54  with ops.name_scope(name, 'remove_squeezable_dimensions',
55                      [labels, predictions]):
56    predictions = ops.convert_to_tensor(predictions)
57    labels = ops.convert_to_tensor(labels)
58    predictions_shape = predictions.get_shape()
59    predictions_rank = predictions_shape.ndims
60    labels_shape = labels.get_shape()
61    labels_rank = labels_shape.ndims
62    if (labels_rank is not None) and (predictions_rank is not None):
63      # Use static rank.
64      rank_diff = predictions_rank - labels_rank
65      if (rank_diff == expected_rank_diff + 1 and
66          predictions_shape.dims[-1].is_compatible_with(1)):
67        predictions = array_ops.squeeze(predictions, [-1])
68      elif (rank_diff == expected_rank_diff - 1 and
69            labels_shape.dims[-1].is_compatible_with(1)):
70        labels = array_ops.squeeze(labels, [-1])
71      return labels, predictions
72
73    # Use dynamic rank.
74    rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
75    if (predictions_rank is None) or (
76        predictions_shape.dims[-1].is_compatible_with(1)):
77      predictions = control_flow_ops.cond(
78          math_ops.equal(expected_rank_diff + 1, rank_diff),
79          lambda: array_ops.squeeze(predictions, [-1]),
80          lambda: predictions)
81    if (labels_rank is None) or (
82        labels_shape.dims[-1].is_compatible_with(1)):
83      labels = control_flow_ops.cond(
84          math_ops.equal(expected_rank_diff - 1, rank_diff),
85          lambda: array_ops.squeeze(labels, [-1]),
86          lambda: labels)
87    return labels, predictions
88
89
90@tf_export('math.confusion_matrix', v1=[])
91@dispatch.add_dispatch_support
92def confusion_matrix(labels,
93                     predictions,
94                     num_classes=None,
95                     weights=None,
96                     dtype=dtypes.int32,
97                     name=None):
98  """Computes the confusion matrix from predictions and labels.
99
100  The matrix columns represent the prediction labels and the rows represent the
101  real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
102  where `n` is the number of valid labels for a given classification task. Both
103  prediction and labels must be 1-D arrays of the same shape in order for this
104  function to work.
105
106  If `num_classes` is `None`, then `num_classes` will be set to one plus the
107  maximum value in either predictions or labels. Class labels are expected to
108  start at 0. For example, if `num_classes` is 3, then the possible labels
109  would be `[0, 1, 2]`.
110
111  If `weights` is not `None`, then each prediction contributes its
112  corresponding weight to the total value of the confusion matrix cell.
113
114  For example:
115
116  ```python
117    tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
118        [[0 0 0 0 0]
119         [0 0 1 0 0]
120         [0 0 1 0 0]
121         [0 0 0 0 0]
122         [0 0 0 0 1]]
123  ```
124
125  Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
126  resulting in a 5x5 confusion matrix.
127
128  Args:
129    labels: 1-D `Tensor` of real labels for the classification task.
130    predictions: 1-D `Tensor` of predictions for a given classification.
131    num_classes: The possible number of labels the classification task can
132                 have. If this value is not provided, it will be calculated
133                 using both predictions and labels array.
134    weights: An optional `Tensor` whose shape matches `predictions`.
135    dtype: Data type of the confusion matrix.
136    name: Scope name.
137
138  Returns:
139    A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
140    matrix, where `n` is the number of possible labels in the classification
141    task.
142
143  Raises:
144    ValueError: If both predictions and labels are not 1-D vectors and have
145      mismatched shapes, or if `weights` is not `None` and its shape doesn't
146      match `predictions`.
147  """
148  with ops.name_scope(name, 'confusion_matrix',
149                      (predictions, labels, num_classes, weights)) as name:
150    labels, predictions = remove_squeezable_dimensions(
151        ops.convert_to_tensor(labels, name='labels'),
152        ops.convert_to_tensor(
153            predictions, name='predictions'))
154    predictions = math_ops.cast(predictions, dtypes.int64)
155    labels = math_ops.cast(labels, dtypes.int64)
156
157    # Sanity checks - underflow or overflow can cause memory corruption.
158    labels = control_flow_ops.with_dependencies(
159        [check_ops.assert_non_negative(
160            labels, message='`labels` contains negative values')],
161        labels)
162    predictions = control_flow_ops.with_dependencies(
163        [check_ops.assert_non_negative(
164            predictions, message='`predictions` contains negative values')],
165        predictions)
166
167    if num_classes is None:
168      num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
169                                     math_ops.reduce_max(labels)) + 1
170    else:
171      num_classes_int64 = math_ops.cast(num_classes, dtypes.int64)
172      labels = control_flow_ops.with_dependencies(
173          [check_ops.assert_less(
174              labels, num_classes_int64, message='`labels` out of bound')],
175          labels)
176      predictions = control_flow_ops.with_dependencies(
177          [check_ops.assert_less(
178              predictions, num_classes_int64,
179              message='`predictions` out of bound')],
180          predictions)
181
182    if weights is not None:
183      weights = ops.convert_to_tensor(weights, name='weights')
184      predictions.get_shape().assert_is_compatible_with(weights.get_shape())
185      weights = math_ops.cast(weights, dtype)
186
187    shape = array_ops.stack([num_classes, num_classes])
188    indices = array_ops.stack([labels, predictions], axis=1)
189    values = (array_ops.ones_like(predictions, dtype)
190              if weights is None else weights)
191    return array_ops.scatter_nd(
192        indices=indices,
193        updates=values,
194        shape=math_ops.cast(shape, dtypes.int64))
195
196
197@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
198@dispatch.add_dispatch_support
199@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
200def confusion_matrix_v1(labels,
201                        predictions,
202                        num_classes=None,
203                        dtype=dtypes.int32,
204                        name=None,
205                        weights=None):
206  """Computes the confusion matrix from predictions and labels.
207
208  The matrix columns represent the prediction labels and the rows represent the
209  real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
210  where `n` is the number of valid labels for a given classification task. Both
211  prediction and labels must be 1-D arrays of the same shape in order for this
212  function to work.
213
214  If `num_classes` is `None`, then `num_classes` will be set to one plus the
215  maximum value in either predictions or labels. Class labels are expected to
216  start at 0. For example, if `num_classes` is 3, then the possible labels
217  would be `[0, 1, 2]`.
218
219  If `weights` is not `None`, then each prediction contributes its
220  corresponding weight to the total value of the confusion matrix cell.
221
222  For example:
223
224  ```python
225    tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
226        [[0 0 0 0 0]
227         [0 0 1 0 0]
228         [0 0 1 0 0]
229         [0 0 0 0 0]
230         [0 0 0 0 1]]
231  ```
232
233  Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
234  resulting in a 5x5 confusion matrix.
235
236  Args:
237    labels: 1-D `Tensor` of real labels for the classification task.
238    predictions: 1-D `Tensor` of predictions for a given classification.
239    num_classes: The possible number of labels the classification task can have.
240      If this value is not provided, it will be calculated using both
241      predictions and labels array.
242    dtype: Data type of the confusion matrix.
243    name: Scope name.
244    weights: An optional `Tensor` whose shape matches `predictions`.
245
246  Returns:
247    A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
248    matrix, where `n` is the number of possible labels in the classification
249    task.
250
251  Raises:
252    ValueError: If both predictions and labels are not 1-D vectors and have
253      mismatched shapes, or if `weights` is not `None` and its shape doesn't
254      match `predictions`.
255  """
256  return confusion_matrix(labels, predictions, num_classes, weights, dtype,
257                          name)
258