1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Confusion matrix related utilities.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import check_ops 21from tensorflow.python.ops import control_flow_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.util import deprecation 24from tensorflow.python.util import dispatch 25from tensorflow.python.util.tf_export import tf_export 26 27 28def remove_squeezable_dimensions( 29 labels, predictions, expected_rank_diff=0, name=None): 30 """Squeeze last dim if ranks differ from expected by exactly 1. 31 32 In the common case where we expect shapes to match, `expected_rank_diff` 33 defaults to 0, and we squeeze the last dimension of the larger rank if they 34 differ by 1. 35 36 But, for example, if `labels` contains class IDs and `predictions` contains 1 37 probability per class, we expect `predictions` to have 1 more dimension than 38 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze 39 `labels` if `rank(predictions) - rank(labels) == 0`, and 40 `predictions` if `rank(predictions) - rank(labels) == 2`. 41 42 This will use static shape if available. Otherwise, it will add graph 43 operations, which could result in a performance hit. 44 45 Args: 46 labels: Label values, a `Tensor` whose dimensions match `predictions`. 47 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 48 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. 49 name: Name of the op. 50 51 Returns: 52 Tuple of `labels` and `predictions`, possibly with last dim squeezed. 53 """ 54 with ops.name_scope(name, 'remove_squeezable_dimensions', 55 [labels, predictions]): 56 predictions = ops.convert_to_tensor(predictions) 57 labels = ops.convert_to_tensor(labels) 58 predictions_shape = predictions.get_shape() 59 predictions_rank = predictions_shape.ndims 60 labels_shape = labels.get_shape() 61 labels_rank = labels_shape.ndims 62 if (labels_rank is not None) and (predictions_rank is not None): 63 # Use static rank. 64 rank_diff = predictions_rank - labels_rank 65 if (rank_diff == expected_rank_diff + 1 and 66 predictions_shape.dims[-1].is_compatible_with(1)): 67 predictions = array_ops.squeeze(predictions, [-1]) 68 elif (rank_diff == expected_rank_diff - 1 and 69 labels_shape.dims[-1].is_compatible_with(1)): 70 labels = array_ops.squeeze(labels, [-1]) 71 return labels, predictions 72 73 # Use dynamic rank. 74 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) 75 if (predictions_rank is None) or ( 76 predictions_shape.dims[-1].is_compatible_with(1)): 77 predictions = control_flow_ops.cond( 78 math_ops.equal(expected_rank_diff + 1, rank_diff), 79 lambda: array_ops.squeeze(predictions, [-1]), 80 lambda: predictions) 81 if (labels_rank is None) or ( 82 labels_shape.dims[-1].is_compatible_with(1)): 83 labels = control_flow_ops.cond( 84 math_ops.equal(expected_rank_diff - 1, rank_diff), 85 lambda: array_ops.squeeze(labels, [-1]), 86 lambda: labels) 87 return labels, predictions 88 89 90@tf_export('math.confusion_matrix', v1=[]) 91@dispatch.add_dispatch_support 92def confusion_matrix(labels, 93 predictions, 94 num_classes=None, 95 weights=None, 96 dtype=dtypes.int32, 97 name=None): 98 """Computes the confusion matrix from predictions and labels. 99 100 The matrix columns represent the prediction labels and the rows represent the 101 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, 102 where `n` is the number of valid labels for a given classification task. Both 103 prediction and labels must be 1-D arrays of the same shape in order for this 104 function to work. 105 106 If `num_classes` is `None`, then `num_classes` will be set to one plus the 107 maximum value in either predictions or labels. Class labels are expected to 108 start at 0. For example, if `num_classes` is 3, then the possible labels 109 would be `[0, 1, 2]`. 110 111 If `weights` is not `None`, then each prediction contributes its 112 corresponding weight to the total value of the confusion matrix cell. 113 114 For example: 115 116 ```python 117 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> 118 [[0 0 0 0 0] 119 [0 0 1 0 0] 120 [0 0 1 0 0] 121 [0 0 0 0 0] 122 [0 0 0 0 1]] 123 ``` 124 125 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, 126 resulting in a 5x5 confusion matrix. 127 128 Args: 129 labels: 1-D `Tensor` of real labels for the classification task. 130 predictions: 1-D `Tensor` of predictions for a given classification. 131 num_classes: The possible number of labels the classification task can 132 have. If this value is not provided, it will be calculated 133 using both predictions and labels array. 134 weights: An optional `Tensor` whose shape matches `predictions`. 135 dtype: Data type of the confusion matrix. 136 name: Scope name. 137 138 Returns: 139 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion 140 matrix, where `n` is the number of possible labels in the classification 141 task. 142 143 Raises: 144 ValueError: If both predictions and labels are not 1-D vectors and have 145 mismatched shapes, or if `weights` is not `None` and its shape doesn't 146 match `predictions`. 147 """ 148 with ops.name_scope(name, 'confusion_matrix', 149 (predictions, labels, num_classes, weights)) as name: 150 labels, predictions = remove_squeezable_dimensions( 151 ops.convert_to_tensor(labels, name='labels'), 152 ops.convert_to_tensor( 153 predictions, name='predictions')) 154 predictions = math_ops.cast(predictions, dtypes.int64) 155 labels = math_ops.cast(labels, dtypes.int64) 156 157 # Sanity checks - underflow or overflow can cause memory corruption. 158 labels = control_flow_ops.with_dependencies( 159 [check_ops.assert_non_negative( 160 labels, message='`labels` contains negative values')], 161 labels) 162 predictions = control_flow_ops.with_dependencies( 163 [check_ops.assert_non_negative( 164 predictions, message='`predictions` contains negative values')], 165 predictions) 166 167 if num_classes is None: 168 num_classes = math_ops.maximum(math_ops.reduce_max(predictions), 169 math_ops.reduce_max(labels)) + 1 170 else: 171 num_classes_int64 = math_ops.cast(num_classes, dtypes.int64) 172 labels = control_flow_ops.with_dependencies( 173 [check_ops.assert_less( 174 labels, num_classes_int64, message='`labels` out of bound')], 175 labels) 176 predictions = control_flow_ops.with_dependencies( 177 [check_ops.assert_less( 178 predictions, num_classes_int64, 179 message='`predictions` out of bound')], 180 predictions) 181 182 if weights is not None: 183 weights = ops.convert_to_tensor(weights, name='weights') 184 predictions.get_shape().assert_is_compatible_with(weights.get_shape()) 185 weights = math_ops.cast(weights, dtype) 186 187 shape = array_ops.stack([num_classes, num_classes]) 188 indices = array_ops.stack([labels, predictions], axis=1) 189 values = (array_ops.ones_like(predictions, dtype) 190 if weights is None else weights) 191 return array_ops.scatter_nd( 192 indices=indices, 193 updates=values, 194 shape=math_ops.cast(shape, dtypes.int64)) 195 196 197@tf_export(v1=['math.confusion_matrix', 'confusion_matrix']) 198@dispatch.add_dispatch_support 199@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix') 200def confusion_matrix_v1(labels, 201 predictions, 202 num_classes=None, 203 dtype=dtypes.int32, 204 name=None, 205 weights=None): 206 """Computes the confusion matrix from predictions and labels. 207 208 The matrix columns represent the prediction labels and the rows represent the 209 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, 210 where `n` is the number of valid labels for a given classification task. Both 211 prediction and labels must be 1-D arrays of the same shape in order for this 212 function to work. 213 214 If `num_classes` is `None`, then `num_classes` will be set to one plus the 215 maximum value in either predictions or labels. Class labels are expected to 216 start at 0. For example, if `num_classes` is 3, then the possible labels 217 would be `[0, 1, 2]`. 218 219 If `weights` is not `None`, then each prediction contributes its 220 corresponding weight to the total value of the confusion matrix cell. 221 222 For example: 223 224 ```python 225 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> 226 [[0 0 0 0 0] 227 [0 0 1 0 0] 228 [0 0 1 0 0] 229 [0 0 0 0 0] 230 [0 0 0 0 1]] 231 ``` 232 233 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, 234 resulting in a 5x5 confusion matrix. 235 236 Args: 237 labels: 1-D `Tensor` of real labels for the classification task. 238 predictions: 1-D `Tensor` of predictions for a given classification. 239 num_classes: The possible number of labels the classification task can have. 240 If this value is not provided, it will be calculated using both 241 predictions and labels array. 242 dtype: Data type of the confusion matrix. 243 name: Scope name. 244 weights: An optional `Tensor` whose shape matches `predictions`. 245 246 Returns: 247 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion 248 matrix, where `n` is the number of possible labels in the classification 249 task. 250 251 Raises: 252 ValueError: If both predictions and labels are not 1-D vectors and have 253 mismatched shapes, or if `weights` is not `None` and its shape doesn't 254 match `predictions`. 255 """ 256 return confusion_matrix(labels, predictions, num_classes, weights, dtype, 257 name) 258