1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Confusion matrix related utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import check_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import sparse_ops 29from tensorflow.python.util import deprecation 30from tensorflow.python.util.tf_export import tf_export 31 32 33def remove_squeezable_dimensions( 34 labels, predictions, expected_rank_diff=0, name=None): 35 """Squeeze last dim if ranks differ from expected by exactly 1. 36 37 In the common case where we expect shapes to match, `expected_rank_diff` 38 defaults to 0, and we squeeze the last dimension of the larger rank if they 39 differ by 1. 40 41 But, for example, if `labels` contains class IDs and `predictions` contains 1 42 probability per class, we expect `predictions` to have 1 more dimension than 43 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze 44 `labels` if `rank(predictions) - rank(labels) == 0`, and 45 `predictions` if `rank(predictions) - rank(labels) == 2`. 46 47 This will use static shape if available. Otherwise, it will add graph 48 operations, which could result in a performance hit. 49 50 Args: 51 labels: Label values, a `Tensor` whose dimensions match `predictions`. 52 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 53 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. 54 name: Name of the op. 55 56 Returns: 57 Tuple of `labels` and `predictions`, possibly with last dim squeezed. 58 """ 59 with ops.name_scope(name, 'remove_squeezable_dimensions', 60 [labels, predictions]): 61 predictions = ops.convert_to_tensor(predictions) 62 labels = ops.convert_to_tensor(labels) 63 predictions_shape = predictions.get_shape() 64 predictions_rank = predictions_shape.ndims 65 labels_shape = labels.get_shape() 66 labels_rank = labels_shape.ndims 67 if (labels_rank is not None) and (predictions_rank is not None): 68 # Use static rank. 69 rank_diff = predictions_rank - labels_rank 70 if (rank_diff == expected_rank_diff + 1 and 71 predictions_shape.dims[-1].is_compatible_with(1)): 72 predictions = array_ops.squeeze(predictions, [-1]) 73 elif (rank_diff == expected_rank_diff - 1 and 74 labels_shape.dims[-1].is_compatible_with(1)): 75 labels = array_ops.squeeze(labels, [-1]) 76 return labels, predictions 77 78 # Use dynamic rank. 79 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) 80 if (predictions_rank is None) or ( 81 predictions_shape.dims[-1].is_compatible_with(1)): 82 predictions = control_flow_ops.cond( 83 math_ops.equal(expected_rank_diff + 1, rank_diff), 84 lambda: array_ops.squeeze(predictions, [-1]), 85 lambda: predictions) 86 if (labels_rank is None) or ( 87 labels_shape.dims[-1].is_compatible_with(1)): 88 labels = control_flow_ops.cond( 89 math_ops.equal(expected_rank_diff - 1, rank_diff), 90 lambda: array_ops.squeeze(labels, [-1]), 91 lambda: labels) 92 return labels, predictions 93 94 95@tf_export('math.confusion_matrix', v1=[]) 96def confusion_matrix(labels, 97 predictions, 98 num_classes=None, 99 weights=None, 100 dtype=dtypes.int32, 101 name=None): 102 """Computes the confusion matrix from predictions and labels. 103 104 The matrix columns represent the prediction labels and the rows represent the 105 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, 106 where `n` is the number of valid labels for a given classification task. Both 107 prediction and labels must be 1-D arrays of the same shape in order for this 108 function to work. 109 110 If `num_classes` is `None`, then `num_classes` will be set to one plus the 111 maximum value in either predictions or labels. Class labels are expected to 112 start at 0. For example, if `num_classes` is 3, then the possible labels 113 would be `[0, 1, 2]`. 114 115 If `weights` is not `None`, then each prediction contributes its 116 corresponding weight to the total value of the confusion matrix cell. 117 118 For example: 119 120 ```python 121 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> 122 [[0 0 0 0 0] 123 [0 0 1 0 0] 124 [0 0 1 0 0] 125 [0 0 0 0 0] 126 [0 0 0 0 1]] 127 ``` 128 129 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, 130 resulting in a 5x5 confusion matrix. 131 132 Args: 133 labels: 1-D `Tensor` of real labels for the classification task. 134 predictions: 1-D `Tensor` of predictions for a given classification. 135 num_classes: The possible number of labels the classification task can 136 have. If this value is not provided, it will be calculated 137 using both predictions and labels array. 138 weights: An optional `Tensor` whose shape matches `predictions`. 139 dtype: Data type of the confusion matrix. 140 name: Scope name. 141 142 Returns: 143 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion 144 matrix, where `n` is the number of possible labels in the classification 145 task. 146 147 Raises: 148 ValueError: If both predictions and labels are not 1-D vectors and have 149 mismatched shapes, or if `weights` is not `None` and its shape doesn't 150 match `predictions`. 151 """ 152 with ops.name_scope(name, 'confusion_matrix', 153 (predictions, labels, num_classes, weights)) as name: 154 labels, predictions = remove_squeezable_dimensions( 155 ops.convert_to_tensor(labels, name='labels'), 156 ops.convert_to_tensor( 157 predictions, name='predictions')) 158 predictions = math_ops.cast(predictions, dtypes.int64) 159 labels = math_ops.cast(labels, dtypes.int64) 160 161 # Sanity checks - underflow or overflow can cause memory corruption. 162 labels = control_flow_ops.with_dependencies( 163 [check_ops.assert_non_negative( 164 labels, message='`labels` contains negative values')], 165 labels) 166 predictions = control_flow_ops.with_dependencies( 167 [check_ops.assert_non_negative( 168 predictions, message='`predictions` contains negative values')], 169 predictions) 170 171 if num_classes is None: 172 num_classes = math_ops.maximum(math_ops.reduce_max(predictions), 173 math_ops.reduce_max(labels)) + 1 174 else: 175 num_classes_int64 = math_ops.cast(num_classes, dtypes.int64) 176 labels = control_flow_ops.with_dependencies( 177 [check_ops.assert_less( 178 labels, num_classes_int64, message='`labels` out of bound')], 179 labels) 180 predictions = control_flow_ops.with_dependencies( 181 [check_ops.assert_less( 182 predictions, num_classes_int64, 183 message='`predictions` out of bound')], 184 predictions) 185 186 if weights is not None: 187 weights = ops.convert_to_tensor(weights, name='weights') 188 predictions.get_shape().assert_is_compatible_with(weights.get_shape()) 189 weights = math_ops.cast(weights, dtype) 190 191 shape = array_ops.stack([num_classes, num_classes]) 192 indices = array_ops.stack([labels, predictions], axis=1) 193 values = (array_ops.ones_like(predictions, dtype) 194 if weights is None else weights) 195 cm_sparse = sparse_tensor.SparseTensor( 196 indices=indices, 197 values=values, 198 dense_shape=math_ops.cast(shape, dtypes.int64)) 199 zero_matrix = array_ops.zeros(math_ops.cast(shape, dtypes.int32), dtype) 200 201 return sparse_ops.sparse_add(zero_matrix, cm_sparse) 202 203 204@tf_export(v1=['math.confusion_matrix', 'confusion_matrix']) 205@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix') 206def confusion_matrix_v1(labels, 207 predictions, 208 num_classes=None, 209 dtype=dtypes.int32, 210 name=None, 211 weights=None): 212 """Computes the confusion matrix from predictions and labels. 213 214 The matrix columns represent the prediction labels and the rows represent the 215 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, 216 where `n` is the number of valid labels for a given classification task. Both 217 prediction and labels must be 1-D arrays of the same shape in order for this 218 function to work. 219 220 If `num_classes` is `None`, then `num_classes` will be set to one plus the 221 maximum value in either predictions or labels. Class labels are expected to 222 start at 0. For example, if `num_classes` is 3, then the possible labels 223 would be `[0, 1, 2]`. 224 225 If `weights` is not `None`, then each prediction contributes its 226 corresponding weight to the total value of the confusion matrix cell. 227 228 For example: 229 230 ```python 231 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> 232 [[0 0 0 0 0] 233 [0 0 1 0 0] 234 [0 0 1 0 0] 235 [0 0 0 0 0] 236 [0 0 0 0 1]] 237 ``` 238 239 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, 240 resulting in a 5x5 confusion matrix. 241 242 Args: 243 labels: 1-D `Tensor` of real labels for the classification task. 244 predictions: 1-D `Tensor` of predictions for a given classification. 245 num_classes: The possible number of labels the classification task can have. 246 If this value is not provided, it will be calculated using both 247 predictions and labels array. 248 dtype: Data type of the confusion matrix. 249 name: Scope name. 250 weights: An optional `Tensor` whose shape matches `predictions`. 251 252 Returns: 253 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion 254 matrix, where `n` is the number of possible labels in the classification 255 task. 256 257 Raises: 258 ValueError: If both predictions and labels are not 1-D vectors and have 259 mismatched shapes, or if `weights` is not `None` and its shape doesn't 260 match `predictions`. 261 """ 262 return confusion_matrix(labels, predictions, num_classes, weights, dtype, 263 name) 264