1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for metrics.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import math 23 24import numpy as np 25from six.moves import xrange # pylint: disable=redefined-builtin 26 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes as dtypes_lib 29from tensorflow.python.framework import errors_impl 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import data_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import metrics 37from tensorflow.python.ops import random_ops 38from tensorflow.python.ops import variables 39import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import 40import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 41from tensorflow.python.platform import test 42 43NAN = float('nan') 44 45 46def _enqueue_vector(sess, queue, values, shape=None): 47 if not shape: 48 shape = (1, len(values)) 49 dtype = queue.dtypes[0] 50 sess.run( 51 queue.enqueue(constant_op.constant( 52 values, dtype=dtype, shape=shape))) 53 54 55def _binary_2d_label_to_2d_sparse_value(labels): 56 """Convert dense 2D binary indicator to sparse ID. 57 58 Only 1 values in `labels` are included in result. 59 60 Args: 61 labels: Dense 2D binary indicator, shape [batch_size, num_classes]. 62 63 Returns: 64 `SparseTensorValue` of shape [batch_size, num_classes], where num_classes 65 is the number of `1` values in each row of `labels`. Values are indices 66 of `1` values along the last dimension of `labels`. 67 """ 68 indices = [] 69 values = [] 70 batch = 0 71 for row in labels: 72 label = 0 73 xi = 0 74 for x in row: 75 if x == 1: 76 indices.append([batch, xi]) 77 values.append(label) 78 xi += 1 79 else: 80 assert x == 0 81 label += 1 82 batch += 1 83 shape = [len(labels), len(labels[0])] 84 return sparse_tensor.SparseTensorValue( 85 np.array(indices, np.int64), 86 np.array(values, np.int64), np.array(shape, np.int64)) 87 88 89def _binary_2d_label_to_1d_sparse_value(labels): 90 """Convert dense 2D binary indicator to sparse ID. 91 92 Only 1 values in `labels` are included in result. 93 94 Args: 95 labels: Dense 2D binary indicator, shape [batch_size, num_classes]. Each 96 row must contain exactly 1 `1` value. 97 98 Returns: 99 `SparseTensorValue` of shape [batch_size]. Values are indices of `1` values 100 along the last dimension of `labels`. 101 102 Raises: 103 ValueError: if there is not exactly 1 `1` value per row of `labels`. 104 """ 105 indices = [] 106 values = [] 107 batch = 0 108 for row in labels: 109 label = 0 110 xi = 0 111 for x in row: 112 if x == 1: 113 indices.append([batch]) 114 values.append(label) 115 xi += 1 116 else: 117 assert x == 0 118 label += 1 119 batch += 1 120 if indices != [[i] for i in range(len(labels))]: 121 raise ValueError('Expected 1 label/example, got %s.' % indices) 122 shape = [len(labels)] 123 return sparse_tensor.SparseTensorValue( 124 np.array(indices, np.int64), 125 np.array(values, np.int64), np.array(shape, np.int64)) 126 127 128def _binary_3d_label_to_sparse_value(labels): 129 """Convert dense 3D binary indicator tensor to sparse tensor. 130 131 Only 1 values in `labels` are included in result. 132 133 Args: 134 labels: Dense 2D binary indicator tensor. 135 136 Returns: 137 `SparseTensorValue` whose values are indices along the last dimension of 138 `labels`. 139 """ 140 indices = [] 141 values = [] 142 for d0, labels_d0 in enumerate(labels): 143 for d1, labels_d1 in enumerate(labels_d0): 144 d2 = 0 145 for class_id, label in enumerate(labels_d1): 146 if label == 1: 147 values.append(class_id) 148 indices.append([d0, d1, d2]) 149 d2 += 1 150 else: 151 assert label == 0 152 shape = [len(labels), len(labels[0]), len(labels[0][0])] 153 return sparse_tensor.SparseTensorValue( 154 np.array(indices, np.int64), 155 np.array(values, np.int64), np.array(shape, np.int64)) 156 157 158def _assert_nan(test_case, actual): 159 test_case.assertTrue(math.isnan(actual), 'Expected NAN, got %s.' % actual) 160 161 162def _assert_metric_variables(test_case, expected): 163 test_case.assertEqual( 164 set(expected), set(v.name for v in variables.local_variables())) 165 test_case.assertEqual( 166 set(expected), 167 set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) 168 169 170def _test_values(shape): 171 return np.reshape(np.cumsum(np.ones(shape)), newshape=shape) 172 173 174class MeanTest(test.TestCase): 175 176 def setUp(self): 177 ops.reset_default_graph() 178 179 @test_util.run_deprecated_v1 180 def testVars(self): 181 metrics.mean(array_ops.ones([4, 3])) 182 _assert_metric_variables(self, ('mean/count:0', 'mean/total:0')) 183 184 @test_util.run_deprecated_v1 185 def testMetricsCollection(self): 186 my_collection_name = '__metrics__' 187 mean, _ = metrics.mean( 188 array_ops.ones([4, 3]), metrics_collections=[my_collection_name]) 189 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 190 191 @test_util.run_deprecated_v1 192 def testUpdatesCollection(self): 193 my_collection_name = '__updates__' 194 _, update_op = metrics.mean( 195 array_ops.ones([4, 3]), updates_collections=[my_collection_name]) 196 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 197 198 @test_util.run_deprecated_v1 199 def testBasic(self): 200 with self.cached_session() as sess: 201 values_queue = data_flow_ops.FIFOQueue( 202 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 203 _enqueue_vector(sess, values_queue, [0, 1]) 204 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 205 _enqueue_vector(sess, values_queue, [6.5, 0]) 206 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 207 values = values_queue.dequeue() 208 209 mean, update_op = metrics.mean(values) 210 211 self.evaluate(variables.local_variables_initializer()) 212 for _ in range(4): 213 self.evaluate(update_op) 214 self.assertAlmostEqual(1.65, self.evaluate(mean), 5) 215 216 @test_util.run_deprecated_v1 217 def testUpdateOpsReturnsCurrentValue(self): 218 with self.cached_session() as sess: 219 values_queue = data_flow_ops.FIFOQueue( 220 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 221 _enqueue_vector(sess, values_queue, [0, 1]) 222 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 223 _enqueue_vector(sess, values_queue, [6.5, 0]) 224 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 225 values = values_queue.dequeue() 226 227 mean, update_op = metrics.mean(values) 228 229 self.evaluate(variables.local_variables_initializer()) 230 231 self.assertAlmostEqual(0.5, self.evaluate(update_op), 5) 232 self.assertAlmostEqual(1.475, self.evaluate(update_op), 5) 233 self.assertAlmostEqual(12.4 / 6.0, self.evaluate(update_op), 5) 234 self.assertAlmostEqual(1.65, self.evaluate(update_op), 5) 235 236 self.assertAlmostEqual(1.65, self.evaluate(mean), 5) 237 238 @test_util.run_deprecated_v1 239 def testUnweighted(self): 240 values = _test_values((3, 2, 4, 1)) 241 mean_results = ( 242 metrics.mean(values), 243 metrics.mean(values, weights=1.0), 244 metrics.mean(values, weights=np.ones((1, 1, 1))), 245 metrics.mean(values, weights=np.ones((1, 1, 1, 1))), 246 metrics.mean(values, weights=np.ones((1, 1, 1, 1, 1))), 247 metrics.mean(values, weights=np.ones((1, 1, 4))), 248 metrics.mean(values, weights=np.ones((1, 1, 4, 1))), 249 metrics.mean(values, weights=np.ones((1, 2, 1))), 250 metrics.mean(values, weights=np.ones((1, 2, 1, 1))), 251 metrics.mean(values, weights=np.ones((1, 2, 4))), 252 metrics.mean(values, weights=np.ones((1, 2, 4, 1))), 253 metrics.mean(values, weights=np.ones((3, 1, 1))), 254 metrics.mean(values, weights=np.ones((3, 1, 1, 1))), 255 metrics.mean(values, weights=np.ones((3, 1, 4))), 256 metrics.mean(values, weights=np.ones((3, 1, 4, 1))), 257 metrics.mean(values, weights=np.ones((3, 2, 1))), 258 metrics.mean(values, weights=np.ones((3, 2, 1, 1))), 259 metrics.mean(values, weights=np.ones((3, 2, 4))), 260 metrics.mean(values, weights=np.ones((3, 2, 4, 1))), 261 metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),) 262 expected = np.mean(values) 263 with self.cached_session(): 264 variables.local_variables_initializer().run() 265 for mean_result in mean_results: 266 mean, update_op = mean_result 267 self.assertAlmostEqual(expected, self.evaluate(update_op)) 268 self.assertAlmostEqual(expected, self.evaluate(mean)) 269 270 def _test_3d_weighted(self, values, weights): 271 expected = ( 272 np.sum(np.multiply(weights, values)) / 273 np.sum(np.multiply(weights, np.ones_like(values))) 274 ) 275 mean, update_op = metrics.mean(values, weights=weights) 276 with self.cached_session(): 277 variables.local_variables_initializer().run() 278 self.assertAlmostEqual(expected, self.evaluate(update_op), places=5) 279 self.assertAlmostEqual(expected, self.evaluate(mean), places=5) 280 281 @test_util.run_deprecated_v1 282 def test1x1x1Weighted(self): 283 self._test_3d_weighted( 284 _test_values((3, 2, 4)), 285 weights=np.asarray((5,)).reshape((1, 1, 1))) 286 287 @test_util.run_deprecated_v1 288 def test1x1xNWeighted(self): 289 self._test_3d_weighted( 290 _test_values((3, 2, 4)), 291 weights=np.asarray((5, 7, 11, 3)).reshape((1, 1, 4))) 292 293 @test_util.run_deprecated_v1 294 def test1xNx1Weighted(self): 295 self._test_3d_weighted( 296 _test_values((3, 2, 4)), 297 weights=np.asarray((5, 11)).reshape((1, 2, 1))) 298 299 @test_util.run_deprecated_v1 300 def test1xNxNWeighted(self): 301 self._test_3d_weighted( 302 _test_values((3, 2, 4)), 303 weights=np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4))) 304 305 @test_util.run_deprecated_v1 306 def testNx1x1Weighted(self): 307 self._test_3d_weighted( 308 _test_values((3, 2, 4)), 309 weights=np.asarray((5, 7, 11)).reshape((3, 1, 1))) 310 311 @test_util.run_deprecated_v1 312 def testNx1xNWeighted(self): 313 self._test_3d_weighted( 314 _test_values((3, 2, 4)), 315 weights=np.asarray(( 316 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4))) 317 318 @test_util.run_deprecated_v1 319 def testNxNxNWeighted(self): 320 self._test_3d_weighted( 321 _test_values((3, 2, 4)), 322 weights=np.asarray(( 323 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 324 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4))) 325 326 @test_util.run_deprecated_v1 327 def testInvalidWeights(self): 328 values_placeholder = array_ops.placeholder(dtype=dtypes_lib.float32) 329 values = _test_values((3, 2, 4, 1)) 330 invalid_weights = ( 331 (1,), 332 (1, 1), 333 (3, 2), 334 (2, 4, 1), 335 (4, 2, 4, 1), 336 (3, 3, 4, 1), 337 (3, 2, 5, 1), 338 (3, 2, 4, 2), 339 (1, 1, 1, 1, 1)) 340 expected_error_msg = 'weights can not be broadcast to values' 341 for invalid_weight in invalid_weights: 342 # Static shapes. 343 with self.assertRaisesRegex(ValueError, expected_error_msg): 344 metrics.mean(values, invalid_weight) 345 346 # Dynamic shapes. 347 with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg): 348 with self.cached_session(): 349 _, update_op = metrics.mean(values_placeholder, invalid_weight) 350 variables.local_variables_initializer().run() 351 update_op.eval(feed_dict={values_placeholder: values}) 352 353 354class MeanTensorTest(test.TestCase): 355 356 def setUp(self): 357 ops.reset_default_graph() 358 359 @test_util.run_deprecated_v1 360 def testVars(self): 361 metrics.mean_tensor(array_ops.ones([4, 3])) 362 _assert_metric_variables(self, 363 ('mean/total_tensor:0', 'mean/count_tensor:0')) 364 365 @test_util.run_deprecated_v1 366 def testMetricsCollection(self): 367 my_collection_name = '__metrics__' 368 mean, _ = metrics.mean_tensor( 369 array_ops.ones([4, 3]), metrics_collections=[my_collection_name]) 370 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 371 372 @test_util.run_deprecated_v1 373 def testUpdatesCollection(self): 374 my_collection_name = '__updates__' 375 _, update_op = metrics.mean_tensor( 376 array_ops.ones([4, 3]), updates_collections=[my_collection_name]) 377 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 378 379 @test_util.run_deprecated_v1 380 def testBasic(self): 381 with self.cached_session() as sess: 382 values_queue = data_flow_ops.FIFOQueue( 383 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 384 _enqueue_vector(sess, values_queue, [0, 1]) 385 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 386 _enqueue_vector(sess, values_queue, [6.5, 0]) 387 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 388 values = values_queue.dequeue() 389 390 mean, update_op = metrics.mean_tensor(values) 391 392 self.evaluate(variables.local_variables_initializer()) 393 for _ in range(4): 394 self.evaluate(update_op) 395 self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(mean)) 396 397 @test_util.run_deprecated_v1 398 def testMultiDimensional(self): 399 with self.cached_session() as sess: 400 values_queue = data_flow_ops.FIFOQueue( 401 2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2)) 402 _enqueue_vector( 403 sess, 404 values_queue, [[[1, 2], [1, 2]], [[1, 2], [1, 2]]], 405 shape=(2, 2, 2)) 406 _enqueue_vector( 407 sess, 408 values_queue, [[[1, 2], [1, 2]], [[3, 4], [9, 10]]], 409 shape=(2, 2, 2)) 410 values = values_queue.dequeue() 411 412 mean, update_op = metrics.mean_tensor(values) 413 414 self.evaluate(variables.local_variables_initializer()) 415 for _ in range(2): 416 self.evaluate(update_op) 417 self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], 418 self.evaluate(mean)) 419 420 @test_util.run_deprecated_v1 421 def testUpdateOpsReturnsCurrentValue(self): 422 with self.cached_session() as sess: 423 values_queue = data_flow_ops.FIFOQueue( 424 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 425 _enqueue_vector(sess, values_queue, [0, 1]) 426 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 427 _enqueue_vector(sess, values_queue, [6.5, 0]) 428 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 429 values = values_queue.dequeue() 430 431 mean, update_op = metrics.mean_tensor(values) 432 433 self.evaluate(variables.local_variables_initializer()) 434 435 self.assertAllClose([[0, 1]], self.evaluate(update_op), 5) 436 self.assertAllClose([[-2.1, 5.05]], self.evaluate(update_op), 5) 437 self.assertAllClose([[2.3 / 3., 10.1 / 3.]], self.evaluate(update_op), 5) 438 self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(update_op), 5) 439 440 self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(mean), 5) 441 442 @test_util.run_deprecated_v1 443 def testBinaryWeighted1d(self): 444 with self.cached_session() as sess: 445 # Create the queue that populates the values. 446 values_queue = data_flow_ops.FIFOQueue( 447 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 448 _enqueue_vector(sess, values_queue, [0, 1]) 449 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 450 _enqueue_vector(sess, values_queue, [6.5, 0]) 451 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 452 values = values_queue.dequeue() 453 454 # Create the queue that populates the weights. 455 weights_queue = data_flow_ops.FIFOQueue( 456 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 457 _enqueue_vector(sess, weights_queue, [[1]]) 458 _enqueue_vector(sess, weights_queue, [[0]]) 459 _enqueue_vector(sess, weights_queue, [[1]]) 460 _enqueue_vector(sess, weights_queue, [[0]]) 461 weights = weights_queue.dequeue() 462 463 mean, update_op = metrics.mean_tensor(values, weights) 464 465 self.evaluate(variables.local_variables_initializer()) 466 for _ in range(4): 467 self.evaluate(update_op) 468 self.assertAllClose([[3.25, 0.5]], self.evaluate(mean), 5) 469 470 @test_util.run_deprecated_v1 471 def testWeighted1d(self): 472 with self.cached_session() as sess: 473 # Create the queue that populates the values. 474 values_queue = data_flow_ops.FIFOQueue( 475 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 476 _enqueue_vector(sess, values_queue, [0, 1]) 477 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 478 _enqueue_vector(sess, values_queue, [6.5, 0]) 479 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 480 values = values_queue.dequeue() 481 482 # Create the queue that populates the weights. 483 weights_queue = data_flow_ops.FIFOQueue( 484 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 485 _enqueue_vector(sess, weights_queue, [[0.0025]]) 486 _enqueue_vector(sess, weights_queue, [[0.005]]) 487 _enqueue_vector(sess, weights_queue, [[0.01]]) 488 _enqueue_vector(sess, weights_queue, [[0.0075]]) 489 weights = weights_queue.dequeue() 490 491 mean, update_op = metrics.mean_tensor(values, weights) 492 493 self.evaluate(variables.local_variables_initializer()) 494 for _ in range(4): 495 self.evaluate(update_op) 496 self.assertAllClose([[0.8, 3.52]], self.evaluate(mean), 5) 497 498 @test_util.run_deprecated_v1 499 def testWeighted2d_1(self): 500 with self.cached_session() as sess: 501 # Create the queue that populates the values. 502 values_queue = data_flow_ops.FIFOQueue( 503 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 504 _enqueue_vector(sess, values_queue, [0, 1]) 505 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 506 _enqueue_vector(sess, values_queue, [6.5, 0]) 507 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 508 values = values_queue.dequeue() 509 510 # Create the queue that populates the weights. 511 weights_queue = data_flow_ops.FIFOQueue( 512 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 513 _enqueue_vector(sess, weights_queue, [1, 1]) 514 _enqueue_vector(sess, weights_queue, [1, 0]) 515 _enqueue_vector(sess, weights_queue, [0, 1]) 516 _enqueue_vector(sess, weights_queue, [0, 0]) 517 weights = weights_queue.dequeue() 518 519 mean, update_op = metrics.mean_tensor(values, weights) 520 521 self.evaluate(variables.local_variables_initializer()) 522 for _ in range(4): 523 self.evaluate(update_op) 524 self.assertAllClose([[-2.1, 0.5]], self.evaluate(mean), 5) 525 526 @test_util.run_deprecated_v1 527 def testWeighted2d_2(self): 528 with self.cached_session() as sess: 529 # Create the queue that populates the values. 530 values_queue = data_flow_ops.FIFOQueue( 531 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 532 _enqueue_vector(sess, values_queue, [0, 1]) 533 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 534 _enqueue_vector(sess, values_queue, [6.5, 0]) 535 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 536 values = values_queue.dequeue() 537 538 # Create the queue that populates the weights. 539 weights_queue = data_flow_ops.FIFOQueue( 540 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 541 _enqueue_vector(sess, weights_queue, [0, 1]) 542 _enqueue_vector(sess, weights_queue, [0, 0]) 543 _enqueue_vector(sess, weights_queue, [0, 1]) 544 _enqueue_vector(sess, weights_queue, [0, 0]) 545 weights = weights_queue.dequeue() 546 547 mean, update_op = metrics.mean_tensor(values, weights) 548 549 self.evaluate(variables.local_variables_initializer()) 550 for _ in range(4): 551 self.evaluate(update_op) 552 self.assertAllClose([[0, 0.5]], self.evaluate(mean), 5) 553 554 555class AccuracyTest(test.TestCase): 556 557 def setUp(self): 558 ops.reset_default_graph() 559 560 @test_util.run_deprecated_v1 561 def testVars(self): 562 metrics.accuracy( 563 predictions=array_ops.ones((10, 1)), 564 labels=array_ops.ones((10, 1)), 565 name='my_accuracy') 566 _assert_metric_variables(self, 567 ('my_accuracy/count:0', 'my_accuracy/total:0')) 568 569 @test_util.run_deprecated_v1 570 def testMetricsCollection(self): 571 my_collection_name = '__metrics__' 572 mean, _ = metrics.accuracy( 573 predictions=array_ops.ones((10, 1)), 574 labels=array_ops.ones((10, 1)), 575 metrics_collections=[my_collection_name]) 576 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 577 578 @test_util.run_deprecated_v1 579 def testUpdatesCollection(self): 580 my_collection_name = '__updates__' 581 _, update_op = metrics.accuracy( 582 predictions=array_ops.ones((10, 1)), 583 labels=array_ops.ones((10, 1)), 584 updates_collections=[my_collection_name]) 585 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 586 587 @test_util.run_deprecated_v1 588 def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self): 589 predictions = array_ops.ones((10, 3)) 590 labels = array_ops.ones((10, 4)) 591 with self.assertRaises(ValueError): 592 metrics.accuracy(labels, predictions) 593 594 @test_util.run_deprecated_v1 595 def testPredictionsAndWeightsOfDifferentSizeRaisesValueError(self): 596 predictions = array_ops.ones((10, 3)) 597 labels = array_ops.ones((10, 3)) 598 weights = array_ops.ones((9, 3)) 599 with self.assertRaises(ValueError): 600 metrics.accuracy(labels, predictions, weights) 601 602 @test_util.run_deprecated_v1 603 def testValueTensorIsIdempotent(self): 604 predictions = random_ops.random_uniform( 605 (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1) 606 labels = random_ops.random_uniform( 607 (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1) 608 accuracy, update_op = metrics.accuracy(labels, predictions) 609 610 with self.cached_session(): 611 self.evaluate(variables.local_variables_initializer()) 612 613 # Run several updates. 614 for _ in range(10): 615 self.evaluate(update_op) 616 617 # Then verify idempotency. 618 initial_accuracy = self.evaluate(accuracy) 619 for _ in range(10): 620 self.assertEqual(initial_accuracy, self.evaluate(accuracy)) 621 622 @test_util.run_deprecated_v1 623 def testMultipleUpdates(self): 624 with self.cached_session() as sess: 625 # Create the queue that populates the predictions. 626 preds_queue = data_flow_ops.FIFOQueue( 627 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 628 _enqueue_vector(sess, preds_queue, [0]) 629 _enqueue_vector(sess, preds_queue, [1]) 630 _enqueue_vector(sess, preds_queue, [2]) 631 _enqueue_vector(sess, preds_queue, [1]) 632 predictions = preds_queue.dequeue() 633 634 # Create the queue that populates the labels. 635 labels_queue = data_flow_ops.FIFOQueue( 636 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 637 _enqueue_vector(sess, labels_queue, [0]) 638 _enqueue_vector(sess, labels_queue, [1]) 639 _enqueue_vector(sess, labels_queue, [1]) 640 _enqueue_vector(sess, labels_queue, [2]) 641 labels = labels_queue.dequeue() 642 643 accuracy, update_op = metrics.accuracy(labels, predictions) 644 645 self.evaluate(variables.local_variables_initializer()) 646 for _ in xrange(3): 647 self.evaluate(update_op) 648 self.assertEqual(0.5, self.evaluate(update_op)) 649 self.assertEqual(0.5, self.evaluate(accuracy)) 650 651 @test_util.run_deprecated_v1 652 def testEffectivelyEquivalentSizes(self): 653 predictions = array_ops.ones((40, 1)) 654 labels = array_ops.ones((40,)) 655 with self.cached_session(): 656 accuracy, update_op = metrics.accuracy(labels, predictions) 657 658 self.evaluate(variables.local_variables_initializer()) 659 self.assertEqual(1.0, self.evaluate(update_op)) 660 self.assertEqual(1.0, self.evaluate(accuracy)) 661 662 @test_util.run_deprecated_v1 663 def testEffectivelyEquivalentSizesWithScalarWeight(self): 664 predictions = array_ops.ones((40, 1)) 665 labels = array_ops.ones((40,)) 666 with self.cached_session(): 667 accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0) 668 669 self.evaluate(variables.local_variables_initializer()) 670 self.assertEqual(1.0, self.evaluate(update_op)) 671 self.assertEqual(1.0, self.evaluate(accuracy)) 672 673 @test_util.run_deprecated_v1 674 def testEffectivelyEquivalentSizesWithStaticShapedWeight(self): 675 predictions = ops.convert_to_tensor([1, 1, 1]) # shape 3, 676 labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]), 677 1) # shape 3, 1 678 weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]), 679 1) # shape 3, 1 680 681 with self.cached_session(): 682 accuracy, update_op = metrics.accuracy(labels, predictions, weights) 683 684 self.evaluate(variables.local_variables_initializer()) 685 # if streaming_accuracy does not flatten the weight, accuracy would be 686 # 0.33333334 due to an intended broadcast of weight. Due to flattening, 687 # it will be higher than .95 688 self.assertGreater(self.evaluate(update_op), .95) 689 self.assertGreater(self.evaluate(accuracy), .95) 690 691 @test_util.run_deprecated_v1 692 def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self): 693 predictions = ops.convert_to_tensor([1, 1, 1]) # shape 3, 694 labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]), 695 1) # shape 3, 1 696 697 weights = [[100], [1], [1]] # shape 3, 1 698 weights_placeholder = array_ops.placeholder( 699 dtype=dtypes_lib.int32, name='weights') 700 feed_dict = {weights_placeholder: weights} 701 702 with self.cached_session(): 703 accuracy, update_op = metrics.accuracy(labels, predictions, 704 weights_placeholder) 705 706 self.evaluate(variables.local_variables_initializer()) 707 # if streaming_accuracy does not flatten the weight, accuracy would be 708 # 0.33333334 due to an intended broadcast of weight. Due to flattening, 709 # it will be higher than .95 710 self.assertGreater(update_op.eval(feed_dict=feed_dict), .95) 711 self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95) 712 713 @test_util.run_deprecated_v1 714 def testMultipleUpdatesWithWeightedValues(self): 715 with self.cached_session() as sess: 716 # Create the queue that populates the predictions. 717 preds_queue = data_flow_ops.FIFOQueue( 718 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 719 _enqueue_vector(sess, preds_queue, [0]) 720 _enqueue_vector(sess, preds_queue, [1]) 721 _enqueue_vector(sess, preds_queue, [2]) 722 _enqueue_vector(sess, preds_queue, [1]) 723 predictions = preds_queue.dequeue() 724 725 # Create the queue that populates the labels. 726 labels_queue = data_flow_ops.FIFOQueue( 727 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 728 _enqueue_vector(sess, labels_queue, [0]) 729 _enqueue_vector(sess, labels_queue, [1]) 730 _enqueue_vector(sess, labels_queue, [1]) 731 _enqueue_vector(sess, labels_queue, [2]) 732 labels = labels_queue.dequeue() 733 734 # Create the queue that populates the weights. 735 weights_queue = data_flow_ops.FIFOQueue( 736 4, dtypes=dtypes_lib.int64, shapes=(1, 1)) 737 _enqueue_vector(sess, weights_queue, [1]) 738 _enqueue_vector(sess, weights_queue, [1]) 739 _enqueue_vector(sess, weights_queue, [0]) 740 _enqueue_vector(sess, weights_queue, [0]) 741 weights = weights_queue.dequeue() 742 743 accuracy, update_op = metrics.accuracy(labels, predictions, weights) 744 745 self.evaluate(variables.local_variables_initializer()) 746 for _ in xrange(3): 747 self.evaluate(update_op) 748 self.assertEqual(1.0, self.evaluate(update_op)) 749 self.assertEqual(1.0, self.evaluate(accuracy)) 750 751 752class PrecisionTest(test.TestCase): 753 754 def setUp(self): 755 np.random.seed(1) 756 ops.reset_default_graph() 757 758 @test_util.run_deprecated_v1 759 def testVars(self): 760 metrics.precision( 761 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 762 _assert_metric_variables(self, ('precision/false_positives/count:0', 763 'precision/true_positives/count:0')) 764 765 @test_util.run_deprecated_v1 766 def testMetricsCollection(self): 767 my_collection_name = '__metrics__' 768 mean, _ = metrics.precision( 769 predictions=array_ops.ones((10, 1)), 770 labels=array_ops.ones((10, 1)), 771 metrics_collections=[my_collection_name]) 772 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 773 774 @test_util.run_deprecated_v1 775 def testUpdatesCollection(self): 776 my_collection_name = '__updates__' 777 _, update_op = metrics.precision( 778 predictions=array_ops.ones((10, 1)), 779 labels=array_ops.ones((10, 1)), 780 updates_collections=[my_collection_name]) 781 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 782 783 @test_util.run_deprecated_v1 784 def testValueTensorIsIdempotent(self): 785 predictions = random_ops.random_uniform( 786 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 787 labels = random_ops.random_uniform( 788 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 789 precision, update_op = metrics.precision(labels, predictions) 790 791 with self.cached_session(): 792 self.evaluate(variables.local_variables_initializer()) 793 794 # Run several updates. 795 for _ in range(10): 796 self.evaluate(update_op) 797 798 # Then verify idempotency. 799 initial_precision = self.evaluate(precision) 800 for _ in range(10): 801 self.assertEqual(initial_precision, self.evaluate(precision)) 802 803 @test_util.run_deprecated_v1 804 def testAllCorrect(self): 805 inputs = np.random.randint(0, 2, size=(100, 1)) 806 807 predictions = constant_op.constant(inputs) 808 labels = constant_op.constant(inputs) 809 precision, update_op = metrics.precision(labels, predictions) 810 811 with self.cached_session(): 812 self.evaluate(variables.local_variables_initializer()) 813 self.assertAlmostEqual(1.0, self.evaluate(update_op), 6) 814 self.assertAlmostEqual(1.0, self.evaluate(precision), 6) 815 816 @test_util.run_deprecated_v1 817 def testSomeCorrect_multipleInputDtypes(self): 818 for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 819 predictions = math_ops.cast( 820 constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype) 821 labels = math_ops.cast( 822 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype) 823 precision, update_op = metrics.precision(labels, predictions) 824 825 with self.cached_session(): 826 self.evaluate(variables.local_variables_initializer()) 827 self.assertAlmostEqual(0.5, self.evaluate(update_op)) 828 self.assertAlmostEqual(0.5, self.evaluate(precision)) 829 830 @test_util.run_deprecated_v1 831 def testWeighted1d(self): 832 predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]]) 833 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 834 precision, update_op = metrics.precision( 835 labels, predictions, weights=constant_op.constant([[2], [5]])) 836 837 with self.cached_session(): 838 variables.local_variables_initializer().run() 839 weighted_tp = 2.0 + 5.0 840 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) 841 expected_precision = weighted_tp / weighted_positives 842 self.assertAlmostEqual(expected_precision, self.evaluate(update_op)) 843 self.assertAlmostEqual(expected_precision, self.evaluate(precision)) 844 845 @test_util.run_deprecated_v1 846 def testWeightedScalar_placeholders(self): 847 predictions = array_ops.placeholder(dtype=dtypes_lib.float32) 848 labels = array_ops.placeholder(dtype=dtypes_lib.float32) 849 feed_dict = { 850 predictions: ((1, 0, 1, 0), (1, 0, 1, 0)), 851 labels: ((0, 1, 1, 0), (1, 0, 0, 1)) 852 } 853 precision, update_op = metrics.precision(labels, predictions, weights=2) 854 855 with self.cached_session(): 856 variables.local_variables_initializer().run() 857 weighted_tp = 2.0 + 2.0 858 weighted_positives = (2.0 + 2.0) + (2.0 + 2.0) 859 expected_precision = weighted_tp / weighted_positives 860 self.assertAlmostEqual( 861 expected_precision, update_op.eval(feed_dict=feed_dict)) 862 self.assertAlmostEqual( 863 expected_precision, precision.eval(feed_dict=feed_dict)) 864 865 @test_util.run_deprecated_v1 866 def testWeighted1d_placeholders(self): 867 predictions = array_ops.placeholder(dtype=dtypes_lib.float32) 868 labels = array_ops.placeholder(dtype=dtypes_lib.float32) 869 feed_dict = { 870 predictions: ((1, 0, 1, 0), (1, 0, 1, 0)), 871 labels: ((0, 1, 1, 0), (1, 0, 0, 1)) 872 } 873 precision, update_op = metrics.precision( 874 labels, predictions, weights=constant_op.constant([[2], [5]])) 875 876 with self.cached_session(): 877 variables.local_variables_initializer().run() 878 weighted_tp = 2.0 + 5.0 879 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) 880 expected_precision = weighted_tp / weighted_positives 881 self.assertAlmostEqual( 882 expected_precision, update_op.eval(feed_dict=feed_dict)) 883 self.assertAlmostEqual( 884 expected_precision, precision.eval(feed_dict=feed_dict)) 885 886 @test_util.run_deprecated_v1 887 def testWeighted2d(self): 888 predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]]) 889 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 890 precision, update_op = metrics.precision( 891 labels, 892 predictions, 893 weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) 894 895 with self.cached_session(): 896 variables.local_variables_initializer().run() 897 weighted_tp = 3.0 + 4.0 898 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) 899 expected_precision = weighted_tp / weighted_positives 900 self.assertAlmostEqual(expected_precision, self.evaluate(update_op)) 901 self.assertAlmostEqual(expected_precision, self.evaluate(precision)) 902 903 @test_util.run_deprecated_v1 904 def testWeighted2d_placeholders(self): 905 predictions = array_ops.placeholder(dtype=dtypes_lib.float32) 906 labels = array_ops.placeholder(dtype=dtypes_lib.float32) 907 feed_dict = { 908 predictions: ((1, 0, 1, 0), (1, 0, 1, 0)), 909 labels: ((0, 1, 1, 0), (1, 0, 0, 1)) 910 } 911 precision, update_op = metrics.precision( 912 labels, 913 predictions, 914 weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) 915 916 with self.cached_session(): 917 variables.local_variables_initializer().run() 918 weighted_tp = 3.0 + 4.0 919 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) 920 expected_precision = weighted_tp / weighted_positives 921 self.assertAlmostEqual( 922 expected_precision, update_op.eval(feed_dict=feed_dict)) 923 self.assertAlmostEqual( 924 expected_precision, precision.eval(feed_dict=feed_dict)) 925 926 @test_util.run_deprecated_v1 927 def testAllIncorrect(self): 928 inputs = np.random.randint(0, 2, size=(100, 1)) 929 930 predictions = constant_op.constant(inputs) 931 labels = constant_op.constant(1 - inputs) 932 precision, update_op = metrics.precision(labels, predictions) 933 934 with self.cached_session(): 935 self.evaluate(variables.local_variables_initializer()) 936 self.evaluate(update_op) 937 self.assertAlmostEqual(0, self.evaluate(precision)) 938 939 @test_util.run_deprecated_v1 940 def testZeroTrueAndFalsePositivesGivesZeroPrecision(self): 941 predictions = constant_op.constant([0, 0, 0, 0]) 942 labels = constant_op.constant([0, 0, 0, 0]) 943 precision, update_op = metrics.precision(labels, predictions) 944 945 with self.cached_session(): 946 self.evaluate(variables.local_variables_initializer()) 947 self.evaluate(update_op) 948 self.assertEqual(0.0, self.evaluate(precision)) 949 950 951class RecallTest(test.TestCase): 952 953 def setUp(self): 954 np.random.seed(1) 955 ops.reset_default_graph() 956 957 @test_util.run_deprecated_v1 958 def testVars(self): 959 metrics.recall( 960 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 961 _assert_metric_variables( 962 self, 963 ('recall/false_negatives/count:0', 'recall/true_positives/count:0')) 964 965 @test_util.run_deprecated_v1 966 def testMetricsCollection(self): 967 my_collection_name = '__metrics__' 968 mean, _ = metrics.recall( 969 predictions=array_ops.ones((10, 1)), 970 labels=array_ops.ones((10, 1)), 971 metrics_collections=[my_collection_name]) 972 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 973 974 @test_util.run_deprecated_v1 975 def testUpdatesCollection(self): 976 my_collection_name = '__updates__' 977 _, update_op = metrics.recall( 978 predictions=array_ops.ones((10, 1)), 979 labels=array_ops.ones((10, 1)), 980 updates_collections=[my_collection_name]) 981 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 982 983 @test_util.run_deprecated_v1 984 def testValueTensorIsIdempotent(self): 985 predictions = random_ops.random_uniform( 986 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 987 labels = random_ops.random_uniform( 988 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 989 recall, update_op = metrics.recall(labels, predictions) 990 991 with self.cached_session(): 992 self.evaluate(variables.local_variables_initializer()) 993 994 # Run several updates. 995 for _ in range(10): 996 self.evaluate(update_op) 997 998 # Then verify idempotency. 999 initial_recall = self.evaluate(recall) 1000 for _ in range(10): 1001 self.assertEqual(initial_recall, self.evaluate(recall)) 1002 1003 @test_util.run_deprecated_v1 1004 def testAllCorrect(self): 1005 np_inputs = np.random.randint(0, 2, size=(100, 1)) 1006 1007 predictions = constant_op.constant(np_inputs) 1008 labels = constant_op.constant(np_inputs) 1009 recall, update_op = metrics.recall(labels, predictions) 1010 1011 with self.cached_session(): 1012 self.evaluate(variables.local_variables_initializer()) 1013 self.evaluate(update_op) 1014 self.assertAlmostEqual(1.0, self.evaluate(recall), 6) 1015 1016 @test_util.run_deprecated_v1 1017 def testSomeCorrect_multipleInputDtypes(self): 1018 for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1019 predictions = math_ops.cast( 1020 constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype) 1021 labels = math_ops.cast( 1022 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype) 1023 recall, update_op = metrics.recall(labels, predictions) 1024 1025 with self.cached_session(): 1026 self.evaluate(variables.local_variables_initializer()) 1027 self.assertAlmostEqual(0.5, self.evaluate(update_op)) 1028 self.assertAlmostEqual(0.5, self.evaluate(recall)) 1029 1030 @test_util.run_deprecated_v1 1031 def testWeighted1d(self): 1032 predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) 1033 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 1034 weights = constant_op.constant([[2], [5]]) 1035 recall, update_op = metrics.recall(labels, predictions, weights=weights) 1036 1037 with self.cached_session(): 1038 self.evaluate(variables.local_variables_initializer()) 1039 weighted_tp = 2.0 + 5.0 1040 weighted_t = (2.0 + 2.0) + (5.0 + 5.0) 1041 expected_precision = weighted_tp / weighted_t 1042 self.assertAlmostEqual(expected_precision, self.evaluate(update_op)) 1043 self.assertAlmostEqual(expected_precision, self.evaluate(recall)) 1044 1045 @test_util.run_deprecated_v1 1046 def testWeighted2d(self): 1047 predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) 1048 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 1049 weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) 1050 recall, update_op = metrics.recall(labels, predictions, weights=weights) 1051 1052 with self.cached_session(): 1053 self.evaluate(variables.local_variables_initializer()) 1054 weighted_tp = 3.0 + 1.0 1055 weighted_t = (2.0 + 3.0) + (4.0 + 1.0) 1056 expected_precision = weighted_tp / weighted_t 1057 self.assertAlmostEqual(expected_precision, self.evaluate(update_op)) 1058 self.assertAlmostEqual(expected_precision, self.evaluate(recall)) 1059 1060 @test_util.run_deprecated_v1 1061 def testAllIncorrect(self): 1062 np_inputs = np.random.randint(0, 2, size=(100, 1)) 1063 1064 predictions = constant_op.constant(np_inputs) 1065 labels = constant_op.constant(1 - np_inputs) 1066 recall, update_op = metrics.recall(labels, predictions) 1067 1068 with self.cached_session(): 1069 self.evaluate(variables.local_variables_initializer()) 1070 self.evaluate(update_op) 1071 self.assertEqual(0, self.evaluate(recall)) 1072 1073 @test_util.run_deprecated_v1 1074 def testZeroTruePositivesAndFalseNegativesGivesZeroRecall(self): 1075 predictions = array_ops.zeros((1, 4)) 1076 labels = array_ops.zeros((1, 4)) 1077 recall, update_op = metrics.recall(labels, predictions) 1078 1079 with self.cached_session(): 1080 self.evaluate(variables.local_variables_initializer()) 1081 self.evaluate(update_op) 1082 self.assertEqual(0, self.evaluate(recall)) 1083 1084 1085class AUCTest(test.TestCase): 1086 1087 def setUp(self): 1088 np.random.seed(1) 1089 ops.reset_default_graph() 1090 1091 @test_util.run_deprecated_v1 1092 def testVars(self): 1093 metrics.auc(predictions=array_ops.ones((10, 1)), 1094 labels=array_ops.ones((10, 1))) 1095 _assert_metric_variables(self, 1096 ('auc/true_positives:0', 'auc/false_negatives:0', 1097 'auc/false_positives:0', 'auc/true_negatives:0')) 1098 1099 @test_util.run_deprecated_v1 1100 def testMetricsCollection(self): 1101 my_collection_name = '__metrics__' 1102 mean, _ = metrics.auc(predictions=array_ops.ones((10, 1)), 1103 labels=array_ops.ones((10, 1)), 1104 metrics_collections=[my_collection_name]) 1105 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 1106 1107 @test_util.run_deprecated_v1 1108 def testUpdatesCollection(self): 1109 my_collection_name = '__updates__' 1110 _, update_op = metrics.auc(predictions=array_ops.ones((10, 1)), 1111 labels=array_ops.ones((10, 1)), 1112 updates_collections=[my_collection_name]) 1113 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 1114 1115 @test_util.run_deprecated_v1 1116 def testValueTensorIsIdempotent(self): 1117 predictions = random_ops.random_uniform( 1118 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1119 labels = random_ops.random_uniform( 1120 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 1121 auc, update_op = metrics.auc(labels, predictions) 1122 1123 with self.cached_session(): 1124 self.evaluate(variables.local_variables_initializer()) 1125 1126 # Run several updates. 1127 for _ in range(10): 1128 self.evaluate(update_op) 1129 1130 # Then verify idempotency. 1131 initial_auc = self.evaluate(auc) 1132 for _ in range(10): 1133 self.assertAlmostEqual(initial_auc, self.evaluate(auc), 5) 1134 1135 @test_util.run_deprecated_v1 1136 def testAllCorrect(self): 1137 self.allCorrectAsExpected('ROC') 1138 1139 def allCorrectAsExpected(self, curve): 1140 inputs = np.random.randint(0, 2, size=(100, 1)) 1141 1142 with self.cached_session(): 1143 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1144 labels = constant_op.constant(inputs) 1145 auc, update_op = metrics.auc(labels, predictions, curve=curve) 1146 1147 self.evaluate(variables.local_variables_initializer()) 1148 self.assertEqual(1, self.evaluate(update_op)) 1149 1150 self.assertEqual(1, self.evaluate(auc)) 1151 1152 @test_util.run_deprecated_v1 1153 def testSomeCorrect_multipleLabelDtypes(self): 1154 with self.cached_session(): 1155 for label_dtype in ( 1156 dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1157 predictions = constant_op.constant( 1158 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1159 labels = math_ops.cast( 1160 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype) 1161 auc, update_op = metrics.auc(labels, predictions) 1162 1163 self.evaluate(variables.local_variables_initializer()) 1164 self.assertAlmostEqual(0.5, self.evaluate(update_op)) 1165 1166 self.assertAlmostEqual(0.5, self.evaluate(auc)) 1167 1168 @test_util.run_deprecated_v1 1169 def testWeighted1d(self): 1170 with self.cached_session(): 1171 predictions = constant_op.constant( 1172 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1173 labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) 1174 weights = constant_op.constant([2], shape=(1, 1)) 1175 auc, update_op = metrics.auc(labels, predictions, weights=weights) 1176 1177 self.evaluate(variables.local_variables_initializer()) 1178 self.assertAlmostEqual(0.5, self.evaluate(update_op), 5) 1179 1180 self.assertAlmostEqual(0.5, self.evaluate(auc), 5) 1181 1182 @test_util.run_deprecated_v1 1183 def testWeighted2d(self): 1184 with self.cached_session(): 1185 predictions = constant_op.constant( 1186 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1187 labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) 1188 weights = constant_op.constant([1, 2, 3, 4], shape=(1, 4)) 1189 auc, update_op = metrics.auc(labels, predictions, weights=weights) 1190 1191 self.evaluate(variables.local_variables_initializer()) 1192 self.assertAlmostEqual(0.7, self.evaluate(update_op), 5) 1193 1194 self.assertAlmostEqual(0.7, self.evaluate(auc), 5) 1195 1196 @test_util.run_deprecated_v1 1197 def testManualThresholds(self): 1198 with self.cached_session(): 1199 # Verifies that thresholds passed in to the `thresholds` parameter are 1200 # used correctly. 1201 # The default thresholds do not split the second and third predictions. 1202 # Thus, when we provide manual thresholds which correctly split it, we get 1203 # an accurate AUC value. 1204 predictions = constant_op.constant( 1205 [0.12, 0.3001, 0.3003, 0.72], shape=(1, 4), dtype=dtypes_lib.float32) 1206 labels = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 1207 weights = constant_op.constant([1, 1, 1, 1], shape=(1, 4)) 1208 thresholds = [0.0, 0.2, 0.3002, 0.6, 1.0] 1209 default_auc, default_update_op = metrics.auc(labels, 1210 predictions, 1211 weights=weights) 1212 manual_auc, manual_update_op = metrics.auc(labels, 1213 predictions, 1214 weights=weights, 1215 thresholds=thresholds) 1216 1217 self.evaluate(variables.local_variables_initializer()) 1218 self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3) 1219 self.assertAlmostEqual(0.875, self.evaluate(default_auc), 3) 1220 1221 self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3) 1222 self.assertAlmostEqual(0.75, self.evaluate(manual_auc), 3) 1223 1224 # Regarding the AUC-PR tests: note that the preferred method when 1225 # calculating AUC-PR is summation_method='careful_interpolation'. 1226 @test_util.run_deprecated_v1 1227 def testCorrectAUCPRSpecialCase(self): 1228 with self.cached_session(): 1229 predictions = constant_op.constant( 1230 [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) 1231 labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) 1232 auc, update_op = metrics.auc(labels, predictions, curve='PR', 1233 summation_method='careful_interpolation') 1234 1235 self.evaluate(variables.local_variables_initializer()) 1236 # expected ~= 0.79726744594 1237 expected = 1 - math.log(1.5) / 2 1238 self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3) 1239 self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3) 1240 1241 @test_util.run_deprecated_v1 1242 def testCorrectAnotherAUCPRSpecialCase(self): 1243 with self.cached_session(): 1244 predictions = constant_op.constant( 1245 [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], 1246 shape=(1, 7), 1247 dtype=dtypes_lib.float32) 1248 labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7)) 1249 auc, update_op = metrics.auc(labels, predictions, curve='PR', 1250 summation_method='careful_interpolation') 1251 1252 self.evaluate(variables.local_variables_initializer()) 1253 # expected ~= 0.61350593198 1254 expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3 1255 self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3) 1256 self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3) 1257 1258 @test_util.run_deprecated_v1 1259 def testThirdCorrectAUCPRSpecialCase(self): 1260 with self.cached_session(): 1261 predictions = constant_op.constant( 1262 [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], 1263 shape=(1, 7), 1264 dtype=dtypes_lib.float32) 1265 labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) 1266 auc, update_op = metrics.auc(labels, predictions, curve='PR', 1267 summation_method='careful_interpolation') 1268 1269 self.evaluate(variables.local_variables_initializer()) 1270 # expected ~= 0.90410597584 1271 expected = 1 - math.log(4./3) / 3 1272 self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3) 1273 self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3) 1274 1275 @test_util.run_deprecated_v1 1276 def testIncorrectAUCPRSpecialCase(self): 1277 with self.cached_session(): 1278 predictions = constant_op.constant( 1279 [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) 1280 labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) 1281 auc, update_op = metrics.auc(labels, predictions, curve='PR', 1282 summation_method='trapezoidal') 1283 1284 self.evaluate(variables.local_variables_initializer()) 1285 self.assertAlmostEqual(0.79166, self.evaluate(update_op), delta=1e-3) 1286 1287 self.assertAlmostEqual(0.79166, self.evaluate(auc), delta=1e-3) 1288 1289 @test_util.run_deprecated_v1 1290 def testAnotherIncorrectAUCPRSpecialCase(self): 1291 with self.cached_session(): 1292 predictions = constant_op.constant( 1293 [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], 1294 shape=(1, 7), 1295 dtype=dtypes_lib.float32) 1296 labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7)) 1297 auc, update_op = metrics.auc(labels, predictions, curve='PR', 1298 summation_method='trapezoidal') 1299 1300 self.evaluate(variables.local_variables_initializer()) 1301 self.assertAlmostEqual(0.610317, self.evaluate(update_op), delta=1e-3) 1302 1303 self.assertAlmostEqual(0.610317, self.evaluate(auc), delta=1e-3) 1304 1305 @test_util.run_deprecated_v1 1306 def testThirdIncorrectAUCPRSpecialCase(self): 1307 with self.cached_session(): 1308 predictions = constant_op.constant( 1309 [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], 1310 shape=(1, 7), 1311 dtype=dtypes_lib.float32) 1312 labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) 1313 auc, update_op = metrics.auc(labels, predictions, curve='PR', 1314 summation_method='trapezoidal') 1315 1316 self.evaluate(variables.local_variables_initializer()) 1317 self.assertAlmostEqual(0.90277, self.evaluate(update_op), delta=1e-3) 1318 1319 self.assertAlmostEqual(0.90277, self.evaluate(auc), delta=1e-3) 1320 1321 @test_util.run_deprecated_v1 1322 def testAllIncorrect(self): 1323 inputs = np.random.randint(0, 2, size=(100, 1)) 1324 1325 with self.cached_session(): 1326 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1327 labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) 1328 auc, update_op = metrics.auc(labels, predictions) 1329 1330 self.evaluate(variables.local_variables_initializer()) 1331 self.assertAlmostEqual(0, self.evaluate(update_op)) 1332 1333 self.assertAlmostEqual(0, self.evaluate(auc)) 1334 1335 @test_util.run_deprecated_v1 1336 def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self): 1337 with self.cached_session(): 1338 predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) 1339 labels = array_ops.zeros([4]) 1340 auc, update_op = metrics.auc(labels, predictions) 1341 1342 self.evaluate(variables.local_variables_initializer()) 1343 self.assertAlmostEqual(1, self.evaluate(update_op), 6) 1344 1345 self.assertAlmostEqual(1, self.evaluate(auc), 6) 1346 1347 @test_util.run_deprecated_v1 1348 def testRecallOneAndPrecisionOneGivesOnePRAUC(self): 1349 with self.cached_session(): 1350 predictions = array_ops.ones([4], dtype=dtypes_lib.float32) 1351 labels = array_ops.ones([4]) 1352 auc, update_op = metrics.auc(labels, predictions, curve='PR') 1353 1354 self.evaluate(variables.local_variables_initializer()) 1355 self.assertAlmostEqual(1, self.evaluate(update_op), 6) 1356 1357 self.assertAlmostEqual(1, self.evaluate(auc), 6) 1358 1359 def np_auc(self, predictions, labels, weights): 1360 """Computes the AUC explicitly using Numpy. 1361 1362 Args: 1363 predictions: an ndarray with shape [N]. 1364 labels: an ndarray with shape [N]. 1365 weights: an ndarray with shape [N]. 1366 1367 Returns: 1368 the area under the ROC curve. 1369 """ 1370 if weights is None: 1371 weights = np.ones(np.size(predictions)) 1372 is_positive = labels > 0 1373 num_positives = np.sum(weights[is_positive]) 1374 num_negatives = np.sum(weights[~is_positive]) 1375 1376 # Sort descending: 1377 inds = np.argsort(-predictions) 1378 1379 sorted_labels = labels[inds] 1380 sorted_weights = weights[inds] 1381 is_positive = sorted_labels > 0 1382 1383 tp = np.cumsum(sorted_weights * is_positive) / num_positives 1384 return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives 1385 1386 @test_util.run_deprecated_v1 1387 def testWithMultipleUpdates(self): 1388 num_samples = 1000 1389 batch_size = 10 1390 num_batches = int(num_samples / batch_size) 1391 1392 # Create the labels and data. 1393 labels = np.random.randint(0, 2, size=num_samples) 1394 noise = np.random.normal(0.0, scale=0.2, size=num_samples) 1395 predictions = 0.4 + 0.2 * labels + noise 1396 predictions[predictions > 1] = 1 1397 predictions[predictions < 0] = 0 1398 1399 def _enqueue_as_batches(x, enqueue_ops): 1400 x_batches = x.astype(np.float32).reshape((num_batches, batch_size)) 1401 x_queue = data_flow_ops.FIFOQueue( 1402 num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) 1403 for i in range(num_batches): 1404 enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :])) 1405 return x_queue.dequeue() 1406 1407 for weights in (None, np.ones(num_samples), np.random.exponential( 1408 scale=1.0, size=num_samples)): 1409 expected_auc = self.np_auc(predictions, labels, weights) 1410 1411 with self.cached_session() as sess: 1412 enqueue_ops = [[] for i in range(num_batches)] 1413 tf_predictions = _enqueue_as_batches(predictions, enqueue_ops) 1414 tf_labels = _enqueue_as_batches(labels, enqueue_ops) 1415 tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if 1416 weights is not None else None) 1417 1418 for i in range(num_batches): 1419 sess.run(enqueue_ops[i]) 1420 1421 auc, update_op = metrics.auc(tf_labels, 1422 tf_predictions, 1423 curve='ROC', 1424 num_thresholds=500, 1425 weights=tf_weights) 1426 1427 self.evaluate(variables.local_variables_initializer()) 1428 for i in range(num_batches): 1429 self.evaluate(update_op) 1430 1431 # Since this is only approximate, we can't expect a 6 digits match. 1432 # Although with higher number of samples/thresholds we should see the 1433 # accuracy improving 1434 self.assertAlmostEqual(expected_auc, self.evaluate(auc), 2) 1435 1436 1437class SpecificityAtSensitivityTest(test.TestCase): 1438 1439 def setUp(self): 1440 np.random.seed(1) 1441 ops.reset_default_graph() 1442 1443 @test_util.run_deprecated_v1 1444 def testVars(self): 1445 metrics.specificity_at_sensitivity( 1446 predictions=array_ops.ones((10, 1)), 1447 labels=array_ops.ones((10, 1)), 1448 sensitivity=0.7) 1449 _assert_metric_variables(self, 1450 ('specificity_at_sensitivity/true_positives:0', 1451 'specificity_at_sensitivity/false_negatives:0', 1452 'specificity_at_sensitivity/false_positives:0', 1453 'specificity_at_sensitivity/true_negatives:0')) 1454 1455 @test_util.run_deprecated_v1 1456 def testMetricsCollection(self): 1457 my_collection_name = '__metrics__' 1458 mean, _ = metrics.specificity_at_sensitivity( 1459 predictions=array_ops.ones((10, 1)), 1460 labels=array_ops.ones((10, 1)), 1461 sensitivity=0.7, 1462 metrics_collections=[my_collection_name]) 1463 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 1464 1465 @test_util.run_deprecated_v1 1466 def testUpdatesCollection(self): 1467 my_collection_name = '__updates__' 1468 _, update_op = metrics.specificity_at_sensitivity( 1469 predictions=array_ops.ones((10, 1)), 1470 labels=array_ops.ones((10, 1)), 1471 sensitivity=0.7, 1472 updates_collections=[my_collection_name]) 1473 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 1474 1475 @test_util.run_deprecated_v1 1476 def testValueTensorIsIdempotent(self): 1477 predictions = random_ops.random_uniform( 1478 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1479 labels = random_ops.random_uniform( 1480 (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1) 1481 specificity, update_op = metrics.specificity_at_sensitivity( 1482 labels, predictions, sensitivity=0.7) 1483 1484 with self.cached_session(): 1485 self.evaluate(variables.local_variables_initializer()) 1486 1487 # Run several updates. 1488 for _ in range(10): 1489 self.evaluate(update_op) 1490 1491 # Then verify idempotency. 1492 initial_specificity = self.evaluate(specificity) 1493 for _ in range(10): 1494 self.assertAlmostEqual(initial_specificity, self.evaluate(specificity), 1495 5) 1496 1497 @test_util.run_deprecated_v1 1498 def testAllCorrect(self): 1499 inputs = np.random.randint(0, 2, size=(100, 1)) 1500 1501 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1502 labels = constant_op.constant(inputs) 1503 specificity, update_op = metrics.specificity_at_sensitivity( 1504 labels, predictions, sensitivity=0.7) 1505 1506 with self.cached_session(): 1507 self.evaluate(variables.local_variables_initializer()) 1508 self.assertEqual(1, self.evaluate(update_op)) 1509 self.assertEqual(1, self.evaluate(specificity)) 1510 1511 @test_util.run_deprecated_v1 1512 def testSomeCorrectHighSensitivity(self): 1513 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9] 1514 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1515 1516 predictions = constant_op.constant( 1517 predictions_values, dtype=dtypes_lib.float32) 1518 labels = constant_op.constant(labels_values) 1519 specificity, update_op = metrics.specificity_at_sensitivity( 1520 labels, predictions, sensitivity=0.8) 1521 1522 with self.cached_session(): 1523 self.evaluate(variables.local_variables_initializer()) 1524 self.assertAlmostEqual(1.0, self.evaluate(update_op)) 1525 self.assertAlmostEqual(1.0, self.evaluate(specificity)) 1526 1527 @test_util.run_deprecated_v1 1528 def testSomeCorrectLowSensitivity(self): 1529 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26] 1530 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1531 1532 predictions = constant_op.constant( 1533 predictions_values, dtype=dtypes_lib.float32) 1534 labels = constant_op.constant(labels_values) 1535 specificity, update_op = metrics.specificity_at_sensitivity( 1536 labels, predictions, sensitivity=0.4) 1537 1538 with self.cached_session(): 1539 self.evaluate(variables.local_variables_initializer()) 1540 1541 self.assertAlmostEqual(0.6, self.evaluate(update_op)) 1542 self.assertAlmostEqual(0.6, self.evaluate(specificity)) 1543 1544 @test_util.run_deprecated_v1 1545 def testWeighted1d_multipleLabelDtypes(self): 1546 for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1547 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26] 1548 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1549 weights_values = [3] 1550 1551 predictions = constant_op.constant( 1552 predictions_values, dtype=dtypes_lib.float32) 1553 labels = math_ops.cast(labels_values, dtype=label_dtype) 1554 weights = constant_op.constant(weights_values) 1555 specificity, update_op = metrics.specificity_at_sensitivity( 1556 labels, predictions, weights=weights, sensitivity=0.4) 1557 1558 with self.cached_session(): 1559 self.evaluate(variables.local_variables_initializer()) 1560 1561 self.assertAlmostEqual(0.6, self.evaluate(update_op)) 1562 self.assertAlmostEqual(0.6, self.evaluate(specificity)) 1563 1564 @test_util.run_deprecated_v1 1565 def testWeighted2d(self): 1566 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26] 1567 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1568 weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1569 1570 predictions = constant_op.constant( 1571 predictions_values, dtype=dtypes_lib.float32) 1572 labels = constant_op.constant(labels_values) 1573 weights = constant_op.constant(weights_values) 1574 specificity, update_op = metrics.specificity_at_sensitivity( 1575 labels, predictions, weights=weights, sensitivity=0.4) 1576 1577 with self.cached_session(): 1578 self.evaluate(variables.local_variables_initializer()) 1579 1580 self.assertAlmostEqual(8.0 / 15.0, self.evaluate(update_op)) 1581 self.assertAlmostEqual(8.0 / 15.0, self.evaluate(specificity)) 1582 1583 1584class SensitivityAtSpecificityTest(test.TestCase): 1585 1586 def setUp(self): 1587 np.random.seed(1) 1588 ops.reset_default_graph() 1589 1590 @test_util.run_deprecated_v1 1591 def testVars(self): 1592 metrics.sensitivity_at_specificity( 1593 predictions=array_ops.ones((10, 1)), 1594 labels=array_ops.ones((10, 1)), 1595 specificity=0.7) 1596 _assert_metric_variables(self, 1597 ('sensitivity_at_specificity/true_positives:0', 1598 'sensitivity_at_specificity/false_negatives:0', 1599 'sensitivity_at_specificity/false_positives:0', 1600 'sensitivity_at_specificity/true_negatives:0')) 1601 1602 @test_util.run_deprecated_v1 1603 def testMetricsCollection(self): 1604 my_collection_name = '__metrics__' 1605 mean, _ = metrics.sensitivity_at_specificity( 1606 predictions=array_ops.ones((10, 1)), 1607 labels=array_ops.ones((10, 1)), 1608 specificity=0.7, 1609 metrics_collections=[my_collection_name]) 1610 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 1611 1612 @test_util.run_deprecated_v1 1613 def testUpdatesCollection(self): 1614 my_collection_name = '__updates__' 1615 _, update_op = metrics.sensitivity_at_specificity( 1616 predictions=array_ops.ones((10, 1)), 1617 labels=array_ops.ones((10, 1)), 1618 specificity=0.7, 1619 updates_collections=[my_collection_name]) 1620 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 1621 1622 @test_util.run_deprecated_v1 1623 def testValueTensorIsIdempotent(self): 1624 predictions = random_ops.random_uniform( 1625 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1626 labels = random_ops.random_uniform( 1627 (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1) 1628 sensitivity, update_op = metrics.sensitivity_at_specificity( 1629 labels, predictions, specificity=0.7) 1630 1631 with self.cached_session(): 1632 self.evaluate(variables.local_variables_initializer()) 1633 1634 # Run several updates. 1635 for _ in range(10): 1636 self.evaluate(update_op) 1637 1638 # Then verify idempotency. 1639 initial_sensitivity = self.evaluate(sensitivity) 1640 for _ in range(10): 1641 self.assertAlmostEqual(initial_sensitivity, self.evaluate(sensitivity), 1642 5) 1643 1644 @test_util.run_deprecated_v1 1645 def testAllCorrect(self): 1646 inputs = np.random.randint(0, 2, size=(100, 1)) 1647 1648 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1649 labels = constant_op.constant(inputs) 1650 specificity, update_op = metrics.sensitivity_at_specificity( 1651 labels, predictions, specificity=0.7) 1652 1653 with self.cached_session(): 1654 self.evaluate(variables.local_variables_initializer()) 1655 self.assertAlmostEqual(1.0, self.evaluate(update_op), 6) 1656 self.assertAlmostEqual(1.0, self.evaluate(specificity), 6) 1657 1658 @test_util.run_deprecated_v1 1659 def testSomeCorrectHighSpecificity(self): 1660 predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] 1661 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1662 1663 predictions = constant_op.constant( 1664 predictions_values, dtype=dtypes_lib.float32) 1665 labels = constant_op.constant(labels_values) 1666 specificity, update_op = metrics.sensitivity_at_specificity( 1667 labels, predictions, specificity=0.8) 1668 1669 with self.cached_session(): 1670 self.evaluate(variables.local_variables_initializer()) 1671 self.assertAlmostEqual(0.8, self.evaluate(update_op)) 1672 self.assertAlmostEqual(0.8, self.evaluate(specificity)) 1673 1674 @test_util.run_deprecated_v1 1675 def testSomeCorrectLowSpecificity(self): 1676 predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] 1677 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1678 1679 predictions = constant_op.constant( 1680 predictions_values, dtype=dtypes_lib.float32) 1681 labels = constant_op.constant(labels_values) 1682 specificity, update_op = metrics.sensitivity_at_specificity( 1683 labels, predictions, specificity=0.4) 1684 1685 with self.cached_session(): 1686 self.evaluate(variables.local_variables_initializer()) 1687 self.assertAlmostEqual(0.6, self.evaluate(update_op)) 1688 self.assertAlmostEqual(0.6, self.evaluate(specificity)) 1689 1690 @test_util.run_deprecated_v1 1691 def testWeighted_multipleLabelDtypes(self): 1692 for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1693 predictions_values = [ 1694 0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] 1695 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1696 weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1697 1698 predictions = constant_op.constant( 1699 predictions_values, dtype=dtypes_lib.float32) 1700 labels = math_ops.cast(labels_values, dtype=label_dtype) 1701 weights = constant_op.constant(weights_values) 1702 specificity, update_op = metrics.sensitivity_at_specificity( 1703 labels, predictions, weights=weights, specificity=0.4) 1704 1705 with self.cached_session(): 1706 self.evaluate(variables.local_variables_initializer()) 1707 self.assertAlmostEqual(0.675, self.evaluate(update_op)) 1708 self.assertAlmostEqual(0.675, self.evaluate(specificity)) 1709 1710 1711# TODO(nsilberman): Break this up into two sets of tests. 1712class PrecisionRecallThresholdsTest(test.TestCase): 1713 1714 def setUp(self): 1715 np.random.seed(1) 1716 ops.reset_default_graph() 1717 1718 @test_util.run_deprecated_v1 1719 def testVars(self): 1720 metrics.precision_at_thresholds( 1721 predictions=array_ops.ones((10, 1)), 1722 labels=array_ops.ones((10, 1)), 1723 thresholds=[0, 0.5, 1.0]) 1724 _assert_metric_variables(self, ( 1725 'precision_at_thresholds/true_positives:0', 1726 'precision_at_thresholds/false_positives:0', 1727 )) 1728 1729 @test_util.run_deprecated_v1 1730 def testMetricsCollection(self): 1731 my_collection_name = '__metrics__' 1732 prec, _ = metrics.precision_at_thresholds( 1733 predictions=array_ops.ones((10, 1)), 1734 labels=array_ops.ones((10, 1)), 1735 thresholds=[0, 0.5, 1.0], 1736 metrics_collections=[my_collection_name]) 1737 rec, _ = metrics.recall_at_thresholds( 1738 predictions=array_ops.ones((10, 1)), 1739 labels=array_ops.ones((10, 1)), 1740 thresholds=[0, 0.5, 1.0], 1741 metrics_collections=[my_collection_name]) 1742 self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) 1743 1744 @test_util.run_deprecated_v1 1745 def testUpdatesCollection(self): 1746 my_collection_name = '__updates__' 1747 _, precision_op = metrics.precision_at_thresholds( 1748 predictions=array_ops.ones((10, 1)), 1749 labels=array_ops.ones((10, 1)), 1750 thresholds=[0, 0.5, 1.0], 1751 updates_collections=[my_collection_name]) 1752 _, recall_op = metrics.recall_at_thresholds( 1753 predictions=array_ops.ones((10, 1)), 1754 labels=array_ops.ones((10, 1)), 1755 thresholds=[0, 0.5, 1.0], 1756 updates_collections=[my_collection_name]) 1757 self.assertListEqual( 1758 ops.get_collection(my_collection_name), [precision_op, recall_op]) 1759 1760 @test_util.run_deprecated_v1 1761 def testValueTensorIsIdempotent(self): 1762 predictions = random_ops.random_uniform( 1763 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1764 labels = random_ops.random_uniform( 1765 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 1766 thresholds = [0, 0.5, 1.0] 1767 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1768 thresholds) 1769 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds) 1770 1771 with self.cached_session(): 1772 self.evaluate(variables.local_variables_initializer()) 1773 1774 # Run several updates, then verify idempotency. 1775 self.evaluate([prec_op, rec_op]) 1776 initial_prec = self.evaluate(prec) 1777 initial_rec = self.evaluate(rec) 1778 for _ in range(10): 1779 self.evaluate([prec_op, rec_op]) 1780 self.assertAllClose(initial_prec, prec) 1781 self.assertAllClose(initial_rec, rec) 1782 1783 # TODO(nsilberman): fix tests (passing but incorrect). 1784 @test_util.run_deprecated_v1 1785 def testAllCorrect(self): 1786 inputs = np.random.randint(0, 2, size=(100, 1)) 1787 1788 with self.cached_session(): 1789 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1790 labels = constant_op.constant(inputs) 1791 thresholds = [0.5] 1792 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1793 thresholds) 1794 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1795 thresholds) 1796 1797 self.evaluate(variables.local_variables_initializer()) 1798 self.evaluate([prec_op, rec_op]) 1799 1800 self.assertEqual(1, self.evaluate(prec)) 1801 self.assertEqual(1, self.evaluate(rec)) 1802 1803 @test_util.run_deprecated_v1 1804 def testSomeCorrect_multipleLabelDtypes(self): 1805 with self.cached_session(): 1806 for label_dtype in ( 1807 dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1808 predictions = constant_op.constant( 1809 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1810 labels = math_ops.cast( 1811 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype) 1812 thresholds = [0.5] 1813 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1814 thresholds) 1815 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1816 thresholds) 1817 1818 self.evaluate(variables.local_variables_initializer()) 1819 self.evaluate([prec_op, rec_op]) 1820 1821 self.assertAlmostEqual(0.5, self.evaluate(prec)) 1822 self.assertAlmostEqual(0.5, self.evaluate(rec)) 1823 1824 @test_util.run_deprecated_v1 1825 def testAllIncorrect(self): 1826 inputs = np.random.randint(0, 2, size=(100, 1)) 1827 1828 with self.cached_session(): 1829 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1830 labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) 1831 thresholds = [0.5] 1832 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1833 thresholds) 1834 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1835 thresholds) 1836 1837 self.evaluate(variables.local_variables_initializer()) 1838 self.evaluate([prec_op, rec_op]) 1839 1840 self.assertAlmostEqual(0, self.evaluate(prec)) 1841 self.assertAlmostEqual(0, self.evaluate(rec)) 1842 1843 @test_util.run_deprecated_v1 1844 def testWeights1d(self): 1845 with self.cached_session(): 1846 predictions = constant_op.constant( 1847 [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) 1848 labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) 1849 weights = constant_op.constant( 1850 [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) 1851 thresholds = [0.5, 1.1] 1852 prec, prec_op = metrics.precision_at_thresholds( 1853 labels, predictions, thresholds, weights=weights) 1854 rec, rec_op = metrics.recall_at_thresholds( 1855 labels, predictions, thresholds, weights=weights) 1856 1857 [prec_low, prec_high] = array_ops.split( 1858 value=prec, num_or_size_splits=2, axis=0) 1859 prec_low = array_ops.reshape(prec_low, shape=()) 1860 prec_high = array_ops.reshape(prec_high, shape=()) 1861 [rec_low, rec_high] = array_ops.split( 1862 value=rec, num_or_size_splits=2, axis=0) 1863 rec_low = array_ops.reshape(rec_low, shape=()) 1864 rec_high = array_ops.reshape(rec_high, shape=()) 1865 1866 self.evaluate(variables.local_variables_initializer()) 1867 self.evaluate([prec_op, rec_op]) 1868 1869 self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5) 1870 self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5) 1871 self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5) 1872 self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5) 1873 1874 @test_util.run_deprecated_v1 1875 def testWeights2d(self): 1876 with self.cached_session(): 1877 predictions = constant_op.constant( 1878 [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) 1879 labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) 1880 weights = constant_op.constant( 1881 [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) 1882 thresholds = [0.5, 1.1] 1883 prec, prec_op = metrics.precision_at_thresholds( 1884 labels, predictions, thresholds, weights=weights) 1885 rec, rec_op = metrics.recall_at_thresholds( 1886 labels, predictions, thresholds, weights=weights) 1887 1888 [prec_low, prec_high] = array_ops.split( 1889 value=prec, num_or_size_splits=2, axis=0) 1890 prec_low = array_ops.reshape(prec_low, shape=()) 1891 prec_high = array_ops.reshape(prec_high, shape=()) 1892 [rec_low, rec_high] = array_ops.split( 1893 value=rec, num_or_size_splits=2, axis=0) 1894 rec_low = array_ops.reshape(rec_low, shape=()) 1895 rec_high = array_ops.reshape(rec_high, shape=()) 1896 1897 self.evaluate(variables.local_variables_initializer()) 1898 self.evaluate([prec_op, rec_op]) 1899 1900 self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5) 1901 self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5) 1902 self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5) 1903 self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5) 1904 1905 @test_util.run_deprecated_v1 1906 def testExtremeThresholds(self): 1907 with self.cached_session(): 1908 predictions = constant_op.constant( 1909 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1910 labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) 1911 thresholds = [-1.0, 2.0] # lower/higher than any values 1912 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1913 thresholds) 1914 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1915 thresholds) 1916 1917 [prec_low, prec_high] = array_ops.split( 1918 value=prec, num_or_size_splits=2, axis=0) 1919 [rec_low, rec_high] = array_ops.split( 1920 value=rec, num_or_size_splits=2, axis=0) 1921 1922 self.evaluate(variables.local_variables_initializer()) 1923 self.evaluate([prec_op, rec_op]) 1924 1925 self.assertAlmostEqual(0.75, self.evaluate(prec_low)) 1926 self.assertAlmostEqual(0.0, self.evaluate(prec_high)) 1927 self.assertAlmostEqual(1.0, self.evaluate(rec_low)) 1928 self.assertAlmostEqual(0.0, self.evaluate(rec_high)) 1929 1930 @test_util.run_deprecated_v1 1931 def testZeroLabelsPredictions(self): 1932 with self.cached_session(): 1933 predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) 1934 labels = array_ops.zeros([4]) 1935 thresholds = [0.5] 1936 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1937 thresholds) 1938 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1939 thresholds) 1940 1941 self.evaluate(variables.local_variables_initializer()) 1942 self.evaluate([prec_op, rec_op]) 1943 1944 self.assertAlmostEqual(0, self.evaluate(prec), 6) 1945 self.assertAlmostEqual(0, self.evaluate(rec), 6) 1946 1947 @test_util.run_deprecated_v1 1948 def testWithMultipleUpdates(self): 1949 num_samples = 1000 1950 batch_size = 10 1951 num_batches = int(num_samples / batch_size) 1952 1953 # Create the labels and data. 1954 labels = np.random.randint(0, 2, size=(num_samples, 1)) 1955 noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) 1956 predictions = 0.4 + 0.2 * labels + noise 1957 predictions[predictions > 1] = 1 1958 predictions[predictions < 0] = 0 1959 thresholds = [0.3] 1960 1961 tp = 0 1962 fp = 0 1963 fn = 0 1964 tn = 0 1965 for i in range(num_samples): 1966 if predictions[i] > thresholds[0]: 1967 if labels[i] == 1: 1968 tp += 1 1969 else: 1970 fp += 1 1971 else: 1972 if labels[i] == 1: 1973 fn += 1 1974 else: 1975 tn += 1 1976 epsilon = 1e-7 1977 expected_prec = tp / (epsilon + tp + fp) 1978 expected_rec = tp / (epsilon + tp + fn) 1979 1980 labels = labels.astype(np.float32) 1981 predictions = predictions.astype(np.float32) 1982 1983 with self.cached_session() as sess: 1984 # Reshape the data so its easy to queue up: 1985 predictions_batches = predictions.reshape((batch_size, num_batches)) 1986 labels_batches = labels.reshape((batch_size, num_batches)) 1987 1988 # Enqueue the data: 1989 predictions_queue = data_flow_ops.FIFOQueue( 1990 num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) 1991 labels_queue = data_flow_ops.FIFOQueue( 1992 num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) 1993 1994 for i in range(int(num_batches)): 1995 tf_prediction = constant_op.constant(predictions_batches[:, i]) 1996 tf_label = constant_op.constant(labels_batches[:, i]) 1997 sess.run([ 1998 predictions_queue.enqueue(tf_prediction), 1999 labels_queue.enqueue(tf_label) 2000 ]) 2001 2002 tf_predictions = predictions_queue.dequeue() 2003 tf_labels = labels_queue.dequeue() 2004 2005 prec, prec_op = metrics.precision_at_thresholds(tf_labels, tf_predictions, 2006 thresholds) 2007 rec, rec_op = metrics.recall_at_thresholds(tf_labels, tf_predictions, 2008 thresholds) 2009 2010 self.evaluate(variables.local_variables_initializer()) 2011 for _ in range(int(num_samples / batch_size)): 2012 self.evaluate([prec_op, rec_op]) 2013 # Since this is only approximate, we can't expect a 6 digits match. 2014 # Although with higher number of samples/thresholds we should see the 2015 # accuracy improving 2016 self.assertAlmostEqual(expected_prec, self.evaluate(prec), 2) 2017 self.assertAlmostEqual(expected_rec, self.evaluate(rec), 2) 2018 2019 2020def _test_precision_at_k(predictions, 2021 labels, 2022 k, 2023 expected, 2024 class_id=None, 2025 weights=None, 2026 test_case=None): 2027 with ops.Graph().as_default() as g, test_case.test_session(g): 2028 if weights is not None: 2029 weights = constant_op.constant(weights, dtypes_lib.float32) 2030 metric, update = metrics.precision_at_k( 2031 predictions=constant_op.constant(predictions, dtypes_lib.float32), 2032 labels=labels, 2033 k=k, 2034 class_id=class_id, 2035 weights=weights) 2036 2037 # Fails without initialized vars. 2038 test_case.assertRaises(errors_impl.OpError, metric.eval) 2039 test_case.assertRaises(errors_impl.OpError, update.eval) 2040 variables.variables_initializer(variables.local_variables()).run() 2041 2042 # Run per-step op and assert expected values. 2043 if math.isnan(expected): 2044 _assert_nan(test_case, update.eval()) 2045 _assert_nan(test_case, metric.eval()) 2046 else: 2047 test_case.assertEqual(expected, update.eval()) 2048 test_case.assertEqual(expected, metric.eval()) 2049 2050 2051def _test_precision_at_top_k( 2052 predictions_idx, 2053 labels, 2054 expected, 2055 k=None, 2056 class_id=None, 2057 weights=None, 2058 test_case=None): 2059 with ops.Graph().as_default() as g, test_case.test_session(g): 2060 if weights is not None: 2061 weights = constant_op.constant(weights, dtypes_lib.float32) 2062 metric, update = metrics.precision_at_top_k( 2063 predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32), 2064 labels=labels, 2065 k=k, 2066 class_id=class_id, 2067 weights=weights) 2068 2069 # Fails without initialized vars. 2070 test_case.assertRaises(errors_impl.OpError, metric.eval) 2071 test_case.assertRaises(errors_impl.OpError, update.eval) 2072 variables.variables_initializer(variables.local_variables()).run() 2073 2074 # Run per-step op and assert expected values. 2075 if math.isnan(expected): 2076 test_case.assertTrue(math.isnan(update.eval())) 2077 test_case.assertTrue(math.isnan(metric.eval())) 2078 else: 2079 test_case.assertEqual(expected, update.eval()) 2080 test_case.assertEqual(expected, metric.eval()) 2081 2082 2083def _test_average_precision_at_k(predictions, 2084 labels, 2085 k, 2086 expected, 2087 weights=None, 2088 test_case=None): 2089 with ops.Graph().as_default() as g, test_case.test_session(g): 2090 if weights is not None: 2091 weights = constant_op.constant(weights, dtypes_lib.float32) 2092 predictions = constant_op.constant(predictions, dtypes_lib.float32) 2093 metric, update = metrics.average_precision_at_k( 2094 labels, predictions, k, weights=weights) 2095 2096 # Fails without initialized vars. 2097 test_case.assertRaises(errors_impl.OpError, metric.eval) 2098 test_case.assertRaises(errors_impl.OpError, update.eval) 2099 variables.variables_initializer(variables.local_variables()).run() 2100 2101 # Run per-step op and assert expected values. 2102 if math.isnan(expected): 2103 _assert_nan(test_case, update.eval()) 2104 _assert_nan(test_case, metric.eval()) 2105 else: 2106 test_case.assertAlmostEqual(expected, update.eval()) 2107 test_case.assertAlmostEqual(expected, metric.eval()) 2108 2109 2110class SingleLabelPrecisionAtKTest(test.TestCase): 2111 2112 def setUp(self): 2113 self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4)) 2114 self._predictions_idx = [[3], [3]] 2115 indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0)) 2116 class_labels = (3, 2) 2117 # Sparse vs dense, and 1d vs 2d labels should all be handled the same. 2118 self._labels = ( 2119 _binary_2d_label_to_1d_sparse_value(indicator_labels), 2120 _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array( 2121 class_labels, dtype=np.int64), np.array( 2122 [[class_id] for class_id in class_labels], dtype=np.int64)) 2123 self._test_precision_at_k = functools.partial( 2124 _test_precision_at_k, test_case=self) 2125 self._test_precision_at_top_k = functools.partial( 2126 _test_precision_at_top_k, test_case=self) 2127 self._test_average_precision_at_k = functools.partial( 2128 _test_average_precision_at_k, test_case=self) 2129 2130 @test_util.run_deprecated_v1 2131 def test_at_k1_nan(self): 2132 for labels in self._labels: 2133 # Classes 0,1,2 have 0 predictions, classes -1 and 4 are out of range. 2134 for class_id in (-1, 0, 1, 2, 4): 2135 self._test_precision_at_k( 2136 self._predictions, labels, k=1, expected=NAN, class_id=class_id) 2137 self._test_precision_at_top_k( 2138 self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id) 2139 2140 @test_util.run_deprecated_v1 2141 def test_at_k1(self): 2142 for labels in self._labels: 2143 # Class 3: 1 label, 2 predictions, 1 correct. 2144 self._test_precision_at_k( 2145 self._predictions, labels, k=1, expected=1.0 / 2, class_id=3) 2146 self._test_precision_at_top_k( 2147 self._predictions_idx, labels, k=1, expected=1.0 / 2, class_id=3) 2148 2149 # All classes: 2 labels, 2 predictions, 1 correct. 2150 self._test_precision_at_k( 2151 self._predictions, labels, k=1, expected=1.0 / 2) 2152 self._test_precision_at_top_k( 2153 self._predictions_idx, labels, k=1, expected=1.0 / 2) 2154 self._test_average_precision_at_k( 2155 self._predictions, labels, k=1, expected=1.0 / 2) 2156 2157 2158class MultiLabelPrecisionAtKTest(test.TestCase): 2159 2160 def setUp(self): 2161 self._test_precision_at_k = functools.partial( 2162 _test_precision_at_k, test_case=self) 2163 self._test_precision_at_top_k = functools.partial( 2164 _test_precision_at_top_k, test_case=self) 2165 self._test_average_precision_at_k = functools.partial( 2166 _test_average_precision_at_k, test_case=self) 2167 2168 @test_util.run_deprecated_v1 2169 def test_average_precision(self): 2170 # Example 1. 2171 # Matches example here: 2172 # fastml.com/what-you-wanted-to-know-about-mean-average-precision 2173 labels_ex1 = (0, 1, 2, 3, 4) 2174 labels = np.array([labels_ex1], dtype=np.int64) 2175 predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3) 2176 predictions = (predictions_ex1,) 2177 predictions_idx_ex1 = (5, 3, 6, 0, 1) 2178 precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4) 2179 avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3, 2180 (precision_ex1[1] + precision_ex1[3]) / 4) 2181 for i in xrange(4): 2182 k = i + 1 2183 self._test_precision_at_k( 2184 predictions, labels, k, expected=precision_ex1[i]) 2185 self._test_precision_at_top_k( 2186 (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i]) 2187 self._test_average_precision_at_k( 2188 predictions, labels, k, expected=avg_precision_ex1[i]) 2189 2190 # Example 2. 2191 labels_ex2 = (0, 2, 4, 5, 6) 2192 labels = np.array([labels_ex2], dtype=np.int64) 2193 predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2) 2194 predictions = (predictions_ex2,) 2195 predictions_idx_ex2 = (1, 3, 0, 6, 5) 2196 precision_ex2 = (0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4) 2197 avg_precision_ex2 = (0.0 / 1, 0.0 / 2, precision_ex2[2] / 3, 2198 (precision_ex2[2] + precision_ex2[3]) / 4) 2199 for i in xrange(4): 2200 k = i + 1 2201 self._test_precision_at_k( 2202 predictions, labels, k, expected=precision_ex2[i]) 2203 self._test_precision_at_top_k( 2204 (predictions_idx_ex2[:k],), labels, k=k, expected=precision_ex2[i]) 2205 self._test_average_precision_at_k( 2206 predictions, labels, k, expected=avg_precision_ex2[i]) 2207 2208 # Both examples, we expect both precision and average precision to be the 2209 # average of the 2 examples. 2210 labels = np.array([labels_ex1, labels_ex2], dtype=np.int64) 2211 predictions = (predictions_ex1, predictions_ex2) 2212 streaming_precision = [(ex1 + ex2) / 2 2213 for ex1, ex2 in zip(precision_ex1, precision_ex2)] 2214 streaming_average_precision = [ 2215 (ex1 + ex2) / 2 2216 for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2) 2217 ] 2218 for i in xrange(4): 2219 k = i + 1 2220 predictions_idx = (predictions_idx_ex1[:k], predictions_idx_ex2[:k]) 2221 self._test_precision_at_k( 2222 predictions, labels, k, expected=streaming_precision[i]) 2223 self._test_precision_at_top_k( 2224 predictions_idx, labels, k=k, expected=streaming_precision[i]) 2225 self._test_average_precision_at_k( 2226 predictions, labels, k, expected=streaming_average_precision[i]) 2227 2228 # Weighted examples, we expect streaming average precision to be the 2229 # weighted average of the 2 examples. 2230 weights = (0.3, 0.6) 2231 streaming_average_precision = [ 2232 (weights[0] * ex1 + weights[1] * ex2) / (weights[0] + weights[1]) 2233 for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2) 2234 ] 2235 for i in xrange(4): 2236 k = i + 1 2237 self._test_average_precision_at_k( 2238 predictions, 2239 labels, 2240 k, 2241 expected=streaming_average_precision[i], 2242 weights=weights) 2243 2244 @test_util.run_deprecated_v1 2245 def test_average_precision_some_labels_out_of_range(self): 2246 """Tests that labels outside the [0, n_classes) range are ignored.""" 2247 labels_ex1 = (-1, 0, 1, 2, 3, 4, 7) 2248 labels = np.array([labels_ex1], dtype=np.int64) 2249 predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3) 2250 predictions = (predictions_ex1,) 2251 predictions_idx_ex1 = (5, 3, 6, 0, 1) 2252 precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4) 2253 avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3, 2254 (precision_ex1[1] + precision_ex1[3]) / 4) 2255 for i in xrange(4): 2256 k = i + 1 2257 self._test_precision_at_k( 2258 predictions, labels, k, expected=precision_ex1[i]) 2259 self._test_precision_at_top_k( 2260 (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i]) 2261 self._test_average_precision_at_k( 2262 predictions, labels, k, expected=avg_precision_ex1[i]) 2263 2264 @test_util.run_deprecated_v1 2265 def test_average_precision_different_num_labels(self): 2266 """Tests the case where the numbers of labels differ across examples.""" 2267 predictions = [[0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4]] 2268 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2269 [[0, 0, 1, 1], [0, 0, 0, 1]]) 2270 dense_labels = np.array([[2, 3], [3, -1]], dtype=np.int64) 2271 predictions_idx_ex1 = np.array(((0, 1, 2, 3), (3, 2, 1, 0))) 2272 precision_ex1 = ((0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4), 2273 (1.0 / 1, 1.0 / 2, 1.0 / 3, 1.0 / 4)) 2274 mean_precision_ex1 = np.mean(precision_ex1, axis=0) 2275 avg_precision_ex1 = ( 2276 (0.0 / 1, 0.0 / 2, 1.0 / 3 / 2, (1.0 / 3 + 2.0 / 4) / 2), 2277 (1.0 / 1, 1.0 / 1, 1.0 / 1, 1.0 / 1)) 2278 mean_avg_precision_ex1 = np.mean(avg_precision_ex1, axis=0) 2279 for labels in (sparse_labels, dense_labels): 2280 for i in xrange(4): 2281 k = i + 1 2282 self._test_precision_at_k( 2283 predictions, labels, k, expected=mean_precision_ex1[i]) 2284 self._test_precision_at_top_k( 2285 predictions_idx_ex1[:, :k], labels, k=k, 2286 expected=mean_precision_ex1[i]) 2287 self._test_average_precision_at_k( 2288 predictions, labels, k, expected=mean_avg_precision_ex1[i]) 2289 2290 @test_util.run_deprecated_v1 2291 def test_three_labels_at_k5_no_predictions(self): 2292 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2293 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2294 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2295 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2296 [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) 2297 dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) 2298 2299 for labels in (sparse_labels, dense_labels): 2300 # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range. 2301 for class_id in (-1, 1, 3, 8, 10): 2302 self._test_precision_at_k( 2303 predictions, labels, k=5, expected=NAN, class_id=class_id) 2304 self._test_precision_at_top_k( 2305 predictions_idx, labels, k=5, expected=NAN, class_id=class_id) 2306 2307 @test_util.run_deprecated_v1 2308 def test_three_labels_at_k5_no_labels(self): 2309 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2310 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2311 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2312 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2313 [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) 2314 dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) 2315 2316 for labels in (sparse_labels, dense_labels): 2317 # Classes 0,4,6,9: 0 labels, >=1 prediction. 2318 for class_id in (0, 4, 6, 9): 2319 self._test_precision_at_k( 2320 predictions, labels, k=5, expected=0.0, class_id=class_id) 2321 self._test_precision_at_top_k( 2322 predictions_idx, labels, k=5, expected=0.0, class_id=class_id) 2323 2324 @test_util.run_deprecated_v1 2325 def test_three_labels_at_k5(self): 2326 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2327 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2328 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2329 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2330 [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) 2331 dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) 2332 2333 for labels in (sparse_labels, dense_labels): 2334 # Class 2: 2 labels, 2 correct predictions. 2335 self._test_precision_at_k( 2336 predictions, labels, k=5, expected=2.0 / 2, class_id=2) 2337 self._test_precision_at_top_k( 2338 predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2) 2339 2340 # Class 5: 1 label, 1 correct prediction. 2341 self._test_precision_at_k( 2342 predictions, labels, k=5, expected=1.0 / 1, class_id=5) 2343 self._test_precision_at_top_k( 2344 predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5) 2345 2346 # Class 7: 1 label, 1 incorrect prediction. 2347 self._test_precision_at_k( 2348 predictions, labels, k=5, expected=0.0 / 1, class_id=7) 2349 self._test_precision_at_top_k( 2350 predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7) 2351 2352 # All classes: 10 predictions, 3 correct. 2353 self._test_precision_at_k( 2354 predictions, labels, k=5, expected=3.0 / 10) 2355 self._test_precision_at_top_k( 2356 predictions_idx, labels, k=5, expected=3.0 / 10) 2357 2358 @test_util.run_deprecated_v1 2359 def test_three_labels_at_k5_some_out_of_range(self): 2360 """Tests that labels outside the [0, n_classes) range are ignored.""" 2361 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2362 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2363 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2364 sp_labels = sparse_tensor.SparseTensorValue( 2365 indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], 2366 [1, 3]], 2367 # values -1 and 10 are outside the [0, n_classes) range and are ignored. 2368 values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64), 2369 dense_shape=[2, 4]) 2370 2371 # Class 2: 2 labels, 2 correct predictions. 2372 self._test_precision_at_k( 2373 predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2) 2374 self._test_precision_at_top_k( 2375 predictions_idx, sp_labels, k=5, expected=2.0 / 2, class_id=2) 2376 2377 # Class 5: 1 label, 1 correct prediction. 2378 self._test_precision_at_k( 2379 predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5) 2380 self._test_precision_at_top_k( 2381 predictions_idx, sp_labels, k=5, expected=1.0 / 1, class_id=5) 2382 2383 # Class 7: 1 label, 1 incorrect prediction. 2384 self._test_precision_at_k( 2385 predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7) 2386 self._test_precision_at_top_k( 2387 predictions_idx, sp_labels, k=5, expected=0.0 / 1, class_id=7) 2388 2389 # All classes: 10 predictions, 3 correct. 2390 self._test_precision_at_k( 2391 predictions, sp_labels, k=5, expected=3.0 / 10) 2392 self._test_precision_at_top_k( 2393 predictions_idx, sp_labels, k=5, expected=3.0 / 10) 2394 2395 @test_util.run_deprecated_v1 2396 def test_3d_nan(self): 2397 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2398 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2399 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2400 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2401 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2402 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2403 labels = _binary_3d_label_to_sparse_value( 2404 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2405 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2406 2407 # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range. 2408 for class_id in (-1, 1, 3, 8, 10): 2409 self._test_precision_at_k( 2410 predictions, labels, k=5, expected=NAN, class_id=class_id) 2411 self._test_precision_at_top_k( 2412 predictions_idx, labels, k=5, expected=NAN, class_id=class_id) 2413 2414 @test_util.run_deprecated_v1 2415 def test_3d_no_labels(self): 2416 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2417 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2418 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2419 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2420 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2421 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2422 labels = _binary_3d_label_to_sparse_value( 2423 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2424 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2425 2426 # Classes 0,4,6,9: 0 labels, >=1 prediction. 2427 for class_id in (0, 4, 6, 9): 2428 self._test_precision_at_k( 2429 predictions, labels, k=5, expected=0.0, class_id=class_id) 2430 self._test_precision_at_top_k( 2431 predictions_idx, labels, k=5, expected=0.0, class_id=class_id) 2432 2433 @test_util.run_deprecated_v1 2434 def test_3d(self): 2435 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2436 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2437 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2438 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2439 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2440 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2441 labels = _binary_3d_label_to_sparse_value( 2442 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2443 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2444 2445 # Class 2: 4 predictions, all correct. 2446 self._test_precision_at_k( 2447 predictions, labels, k=5, expected=4.0 / 4, class_id=2) 2448 self._test_precision_at_top_k( 2449 predictions_idx, labels, k=5, expected=4.0 / 4, class_id=2) 2450 2451 # Class 5: 2 predictions, both correct. 2452 self._test_precision_at_k( 2453 predictions, labels, k=5, expected=2.0 / 2, class_id=5) 2454 self._test_precision_at_top_k( 2455 predictions_idx, labels, k=5, expected=2.0 / 2, class_id=5) 2456 2457 # Class 7: 2 predictions, 1 correct. 2458 self._test_precision_at_k( 2459 predictions, labels, k=5, expected=1.0 / 2, class_id=7) 2460 self._test_precision_at_top_k( 2461 predictions_idx, labels, k=5, expected=1.0 / 2, class_id=7) 2462 2463 # All classes: 20 predictions, 7 correct. 2464 self._test_precision_at_k( 2465 predictions, labels, k=5, expected=7.0 / 20) 2466 self._test_precision_at_top_k( 2467 predictions_idx, labels, k=5, expected=7.0 / 20) 2468 2469 @test_util.run_deprecated_v1 2470 def test_3d_ignore_some(self): 2471 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2472 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2473 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2474 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2475 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2476 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2477 labels = _binary_3d_label_to_sparse_value( 2478 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2479 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2480 2481 # Class 2: 2 predictions, both correct. 2482 self._test_precision_at_k( 2483 predictions, labels, k=5, expected=2.0 / 2.0, class_id=2, 2484 weights=[[1], [0]]) 2485 self._test_precision_at_top_k( 2486 predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2, 2487 weights=[[1], [0]]) 2488 2489 # Class 2: 2 predictions, both correct. 2490 self._test_precision_at_k( 2491 predictions, labels, k=5, expected=2.0 / 2.0, class_id=2, 2492 weights=[[0], [1]]) 2493 self._test_precision_at_top_k( 2494 predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2, 2495 weights=[[0], [1]]) 2496 2497 # Class 7: 1 incorrect prediction. 2498 self._test_precision_at_k( 2499 predictions, labels, k=5, expected=0.0 / 1.0, class_id=7, 2500 weights=[[1], [0]]) 2501 self._test_precision_at_top_k( 2502 predictions_idx, labels, k=5, expected=0.0 / 1.0, class_id=7, 2503 weights=[[1], [0]]) 2504 2505 # Class 7: 1 correct prediction. 2506 self._test_precision_at_k( 2507 predictions, labels, k=5, expected=1.0 / 1.0, class_id=7, 2508 weights=[[0], [1]]) 2509 self._test_precision_at_top_k( 2510 predictions_idx, labels, k=5, expected=1.0 / 1.0, class_id=7, 2511 weights=[[0], [1]]) 2512 2513 # Class 7: no predictions. 2514 self._test_precision_at_k( 2515 predictions, labels, k=5, expected=NAN, class_id=7, 2516 weights=[[1, 0], [0, 1]]) 2517 self._test_precision_at_top_k( 2518 predictions_idx, labels, k=5, expected=NAN, class_id=7, 2519 weights=[[1, 0], [0, 1]]) 2520 2521 # Class 7: 2 predictions, 1 correct. 2522 self._test_precision_at_k( 2523 predictions, labels, k=5, expected=1.0 / 2.0, class_id=7, 2524 weights=[[0, 1], [1, 0]]) 2525 self._test_precision_at_top_k( 2526 predictions_idx, labels, k=5, expected=1.0 / 2.0, class_id=7, 2527 weights=[[0, 1], [1, 0]]) 2528 2529 2530def _test_recall_at_k(predictions, 2531 labels, 2532 k, 2533 expected, 2534 class_id=None, 2535 weights=None, 2536 test_case=None): 2537 with ops.Graph().as_default() as g, test_case.test_session(g): 2538 if weights is not None: 2539 weights = constant_op.constant(weights, dtypes_lib.float32) 2540 metric, update = metrics.recall_at_k( 2541 predictions=constant_op.constant(predictions, dtypes_lib.float32), 2542 labels=labels, 2543 k=k, 2544 class_id=class_id, 2545 weights=weights) 2546 2547 # Fails without initialized vars. 2548 test_case.assertRaises(errors_impl.OpError, metric.eval) 2549 test_case.assertRaises(errors_impl.OpError, update.eval) 2550 variables.variables_initializer(variables.local_variables()).run() 2551 2552 # Run per-step op and assert expected values. 2553 if math.isnan(expected): 2554 _assert_nan(test_case, update.eval()) 2555 _assert_nan(test_case, metric.eval()) 2556 else: 2557 test_case.assertEqual(expected, update.eval()) 2558 test_case.assertEqual(expected, metric.eval()) 2559 2560 2561def _test_recall_at_top_k( 2562 predictions_idx, 2563 labels, 2564 expected, 2565 k=None, 2566 class_id=None, 2567 weights=None, 2568 test_case=None): 2569 with ops.Graph().as_default() as g, test_case.test_session(g): 2570 if weights is not None: 2571 weights = constant_op.constant(weights, dtypes_lib.float32) 2572 metric, update = metrics.recall_at_top_k( 2573 predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32), 2574 labels=labels, 2575 k=k, 2576 class_id=class_id, 2577 weights=weights) 2578 2579 # Fails without initialized vars. 2580 test_case.assertRaises(errors_impl.OpError, metric.eval) 2581 test_case.assertRaises(errors_impl.OpError, update.eval) 2582 variables.variables_initializer(variables.local_variables()).run() 2583 2584 # Run per-step op and assert expected values. 2585 if math.isnan(expected): 2586 _assert_nan(test_case, update.eval()) 2587 _assert_nan(test_case, metric.eval()) 2588 else: 2589 test_case.assertEqual(expected, update.eval()) 2590 test_case.assertEqual(expected, metric.eval()) 2591 2592 2593class SingleLabelRecallAtKTest(test.TestCase): 2594 2595 def setUp(self): 2596 self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4)) 2597 self._predictions_idx = [[3], [3]] 2598 indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0)) 2599 class_labels = (3, 2) 2600 # Sparse vs dense, and 1d vs 2d labels should all be handled the same. 2601 self._labels = ( 2602 _binary_2d_label_to_1d_sparse_value(indicator_labels), 2603 _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array( 2604 class_labels, dtype=np.int64), np.array( 2605 [[class_id] for class_id in class_labels], dtype=np.int64)) 2606 self._test_recall_at_k = functools.partial( 2607 _test_recall_at_k, test_case=self) 2608 self._test_recall_at_top_k = functools.partial( 2609 _test_recall_at_top_k, test_case=self) 2610 2611 @test_util.run_deprecated_v1 2612 def test_at_k1_nan(self): 2613 # Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of 2614 # range. 2615 for labels in self._labels: 2616 for class_id in (-1, 0, 1, 4): 2617 self._test_recall_at_k( 2618 self._predictions, labels, k=1, expected=NAN, class_id=class_id) 2619 self._test_recall_at_top_k( 2620 self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id) 2621 2622 @test_util.run_deprecated_v1 2623 def test_at_k1_no_predictions(self): 2624 for labels in self._labels: 2625 # Class 2: 0 predictions. 2626 self._test_recall_at_k( 2627 self._predictions, labels, k=1, expected=0.0, class_id=2) 2628 self._test_recall_at_top_k( 2629 self._predictions_idx, labels, k=1, expected=0.0, class_id=2) 2630 2631 @test_util.run_deprecated_v1 2632 def test_one_label_at_k1(self): 2633 for labels in self._labels: 2634 # Class 3: 1 label, 2 predictions, 1 correct. 2635 self._test_recall_at_k( 2636 self._predictions, labels, k=1, expected=1.0 / 1, class_id=3) 2637 self._test_recall_at_top_k( 2638 self._predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3) 2639 2640 # All classes: 2 labels, 2 predictions, 1 correct. 2641 self._test_recall_at_k(self._predictions, labels, k=1, expected=1.0 / 2) 2642 self._test_recall_at_top_k( 2643 self._predictions_idx, labels, k=1, expected=1.0 / 2) 2644 2645 @test_util.run_deprecated_v1 2646 def test_one_label_at_k1_weighted_class_id3(self): 2647 predictions = self._predictions 2648 predictions_idx = self._predictions_idx 2649 for labels in self._labels: 2650 # Class 3: 1 label, 2 predictions, 1 correct. 2651 self._test_recall_at_k( 2652 predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) 2653 self._test_recall_at_top_k( 2654 predictions_idx, labels, k=1, expected=NAN, class_id=3, 2655 weights=(0.0,)) 2656 self._test_recall_at_k( 2657 predictions, labels, k=1, expected=1.0 / 1, class_id=3, 2658 weights=(1.0,)) 2659 self._test_recall_at_top_k( 2660 predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3, 2661 weights=(1.0,)) 2662 self._test_recall_at_k( 2663 predictions, labels, k=1, expected=1.0 / 1, class_id=3, 2664 weights=(2.0,)) 2665 self._test_recall_at_top_k( 2666 predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3, 2667 weights=(2.0,)) 2668 self._test_recall_at_k( 2669 predictions, labels, k=1, expected=NAN, class_id=3, 2670 weights=(0.0, 1.0)) 2671 self._test_recall_at_top_k( 2672 predictions_idx, labels, k=1, expected=NAN, class_id=3, 2673 weights=(0.0, 1.0)) 2674 self._test_recall_at_k( 2675 predictions, labels, k=1, expected=1.0 / 1, class_id=3, 2676 weights=(1.0, 0.0)) 2677 self._test_recall_at_top_k( 2678 predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3, 2679 weights=(1.0, 0.0)) 2680 self._test_recall_at_k( 2681 predictions, labels, k=1, expected=2.0 / 2, class_id=3, 2682 weights=(2.0, 3.0)) 2683 self._test_recall_at_top_k( 2684 predictions_idx, labels, k=1, expected=2.0 / 2, class_id=3, 2685 weights=(2.0, 3.0)) 2686 2687 @test_util.run_deprecated_v1 2688 def test_one_label_at_k1_weighted(self): 2689 predictions = self._predictions 2690 predictions_idx = self._predictions_idx 2691 for labels in self._labels: 2692 # All classes: 2 labels, 2 predictions, 1 correct. 2693 self._test_recall_at_k( 2694 predictions, labels, k=1, expected=NAN, weights=(0.0,)) 2695 self._test_recall_at_top_k( 2696 predictions_idx, labels, k=1, expected=NAN, weights=(0.0,)) 2697 self._test_recall_at_k( 2698 predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) 2699 self._test_recall_at_top_k( 2700 predictions_idx, labels, k=1, expected=1.0 / 2, weights=(1.0,)) 2701 self._test_recall_at_k( 2702 predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) 2703 self._test_recall_at_top_k( 2704 predictions_idx, labels, k=1, expected=1.0 / 2, weights=(2.0,)) 2705 self._test_recall_at_k( 2706 predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) 2707 self._test_recall_at_top_k( 2708 predictions_idx, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) 2709 self._test_recall_at_k( 2710 predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) 2711 self._test_recall_at_top_k( 2712 predictions_idx, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) 2713 self._test_recall_at_k( 2714 predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) 2715 self._test_recall_at_top_k( 2716 predictions_idx, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) 2717 2718 2719class MultiLabel2dRecallAtKTest(test.TestCase): 2720 2721 def setUp(self): 2722 self._predictions = ((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9), 2723 (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)) 2724 self._predictions_idx = ((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)) 2725 indicator_labels = ((0, 0, 1, 0, 0, 0, 0, 1, 1, 0), 2726 (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)) 2727 class_labels = ((2, 7, 8), (1, 2, 5)) 2728 # Sparse vs dense labels should be handled the same. 2729 self._labels = (_binary_2d_label_to_2d_sparse_value(indicator_labels), 2730 np.array( 2731 class_labels, dtype=np.int64)) 2732 self._test_recall_at_k = functools.partial( 2733 _test_recall_at_k, test_case=self) 2734 self._test_recall_at_top_k = functools.partial( 2735 _test_recall_at_top_k, test_case=self) 2736 2737 @test_util.run_deprecated_v1 2738 def test_at_k5_nan(self): 2739 for labels in self._labels: 2740 # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range. 2741 for class_id in (0, 3, 4, 6, 9, 10): 2742 self._test_recall_at_k( 2743 self._predictions, labels, k=5, expected=NAN, class_id=class_id) 2744 self._test_recall_at_top_k( 2745 self._predictions_idx, labels, k=5, expected=NAN, class_id=class_id) 2746 2747 @test_util.run_deprecated_v1 2748 def test_at_k5_no_predictions(self): 2749 for labels in self._labels: 2750 # Class 8: 1 label, no predictions. 2751 self._test_recall_at_k( 2752 self._predictions, labels, k=5, expected=0.0 / 1, class_id=8) 2753 self._test_recall_at_top_k( 2754 self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=8) 2755 2756 @test_util.run_deprecated_v1 2757 def test_at_k5(self): 2758 for labels in self._labels: 2759 # Class 2: 2 labels, both correct. 2760 self._test_recall_at_k( 2761 self._predictions, labels, k=5, expected=2.0 / 2, class_id=2) 2762 self._test_recall_at_top_k( 2763 self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2) 2764 2765 # Class 5: 1 label, incorrect. 2766 self._test_recall_at_k( 2767 self._predictions, labels, k=5, expected=1.0 / 1, class_id=5) 2768 self._test_recall_at_top_k( 2769 self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5) 2770 2771 # Class 7: 1 label, incorrect. 2772 self._test_recall_at_k( 2773 self._predictions, labels, k=5, expected=0.0 / 1, class_id=7) 2774 self._test_recall_at_top_k( 2775 self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7) 2776 2777 # All classes: 6 labels, 3 correct. 2778 self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 6) 2779 self._test_recall_at_top_k( 2780 self._predictions_idx, labels, k=5, expected=3.0 / 6) 2781 2782 @test_util.run_deprecated_v1 2783 def test_at_k5_some_out_of_range(self): 2784 """Tests that labels outside the [0, n_classes) count in denominator.""" 2785 labels = sparse_tensor.SparseTensorValue( 2786 indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], 2787 [1, 3]], 2788 # values -1 and 10 are outside the [0, n_classes) range. 2789 values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64), 2790 dense_shape=[2, 4]) 2791 2792 # Class 2: 2 labels, both correct. 2793 self._test_recall_at_k( 2794 self._predictions, labels, k=5, expected=2.0 / 2, class_id=2) 2795 self._test_recall_at_top_k( 2796 self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2) 2797 2798 # Class 5: 1 label, incorrect. 2799 self._test_recall_at_k( 2800 self._predictions, labels, k=5, expected=1.0 / 1, class_id=5) 2801 self._test_recall_at_top_k( 2802 self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5) 2803 2804 # Class 7: 1 label, incorrect. 2805 self._test_recall_at_k( 2806 self._predictions, labels, k=5, expected=0.0 / 1, class_id=7) 2807 self._test_recall_at_top_k( 2808 self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7) 2809 2810 # All classes: 8 labels, 3 correct. 2811 self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 8) 2812 self._test_recall_at_top_k( 2813 self._predictions_idx, labels, k=5, expected=3.0 / 8) 2814 2815 2816class MultiLabel3dRecallAtKTest(test.TestCase): 2817 2818 def setUp(self): 2819 self._predictions = (((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9), 2820 (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)), 2821 ((0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6), 2822 (0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9))) 2823 self._predictions_idx = (((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)), 2824 ((5, 7, 2, 9, 6), (9, 4, 6, 2, 0))) 2825 # Note: We don't test dense labels here, since examples have different 2826 # numbers of labels. 2827 self._labels = _binary_3d_label_to_sparse_value((( 2828 (0, 0, 1, 0, 0, 0, 0, 1, 1, 0), (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)), ( 2829 (0, 1, 1, 0, 0, 1, 0, 1, 0, 0), (0, 0, 1, 0, 0, 0, 0, 0, 1, 0)))) 2830 self._test_recall_at_k = functools.partial( 2831 _test_recall_at_k, test_case=self) 2832 self._test_recall_at_top_k = functools.partial( 2833 _test_recall_at_top_k, test_case=self) 2834 2835 @test_util.run_deprecated_v1 2836 def test_3d_nan(self): 2837 # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range. 2838 for class_id in (0, 3, 4, 6, 9, 10): 2839 self._test_recall_at_k( 2840 self._predictions, self._labels, k=5, expected=NAN, class_id=class_id) 2841 self._test_recall_at_top_k( 2842 self._predictions_idx, self._labels, k=5, expected=NAN, 2843 class_id=class_id) 2844 2845 @test_util.run_deprecated_v1 2846 def test_3d_no_predictions(self): 2847 # Classes 1,8 have 0 predictions, >=1 label. 2848 for class_id in (1, 8): 2849 self._test_recall_at_k( 2850 self._predictions, self._labels, k=5, expected=0.0, class_id=class_id) 2851 self._test_recall_at_top_k( 2852 self._predictions_idx, self._labels, k=5, expected=0.0, 2853 class_id=class_id) 2854 2855 @test_util.run_deprecated_v1 2856 def test_3d(self): 2857 # Class 2: 4 labels, all correct. 2858 self._test_recall_at_k( 2859 self._predictions, self._labels, k=5, expected=4.0 / 4, class_id=2) 2860 self._test_recall_at_top_k( 2861 self._predictions_idx, self._labels, k=5, expected=4.0 / 4, 2862 class_id=2) 2863 2864 # Class 5: 2 labels, both correct. 2865 self._test_recall_at_k( 2866 self._predictions, self._labels, k=5, expected=2.0 / 2, class_id=5) 2867 self._test_recall_at_top_k( 2868 self._predictions_idx, self._labels, k=5, expected=2.0 / 2, 2869 class_id=5) 2870 2871 # Class 7: 2 labels, 1 incorrect. 2872 self._test_recall_at_k( 2873 self._predictions, self._labels, k=5, expected=1.0 / 2, class_id=7) 2874 self._test_recall_at_top_k( 2875 self._predictions_idx, self._labels, k=5, expected=1.0 / 2, 2876 class_id=7) 2877 2878 # All classes: 12 labels, 7 correct. 2879 self._test_recall_at_k( 2880 self._predictions, self._labels, k=5, expected=7.0 / 12) 2881 self._test_recall_at_top_k( 2882 self._predictions_idx, self._labels, k=5, expected=7.0 / 12) 2883 2884 @test_util.run_deprecated_v1 2885 def test_3d_ignore_all(self): 2886 for class_id in xrange(10): 2887 self._test_recall_at_k( 2888 self._predictions, self._labels, k=5, expected=NAN, class_id=class_id, 2889 weights=[[0], [0]]) 2890 self._test_recall_at_top_k( 2891 self._predictions_idx, self._labels, k=5, expected=NAN, 2892 class_id=class_id, weights=[[0], [0]]) 2893 self._test_recall_at_k( 2894 self._predictions, self._labels, k=5, expected=NAN, class_id=class_id, 2895 weights=[[0, 0], [0, 0]]) 2896 self._test_recall_at_top_k( 2897 self._predictions_idx, self._labels, k=5, expected=NAN, 2898 class_id=class_id, weights=[[0, 0], [0, 0]]) 2899 self._test_recall_at_k( 2900 self._predictions, self._labels, k=5, expected=NAN, weights=[[0], [0]]) 2901 self._test_recall_at_top_k( 2902 self._predictions_idx, self._labels, k=5, expected=NAN, 2903 weights=[[0], [0]]) 2904 self._test_recall_at_k( 2905 self._predictions, self._labels, k=5, expected=NAN, 2906 weights=[[0, 0], [0, 0]]) 2907 self._test_recall_at_top_k( 2908 self._predictions_idx, self._labels, k=5, expected=NAN, 2909 weights=[[0, 0], [0, 0]]) 2910 2911 @test_util.run_deprecated_v1 2912 def test_3d_ignore_some(self): 2913 # Class 2: 2 labels, both correct. 2914 self._test_recall_at_k( 2915 self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2, 2916 weights=[[1], [0]]) 2917 self._test_recall_at_top_k( 2918 self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0, 2919 class_id=2, weights=[[1], [0]]) 2920 2921 # Class 2: 2 labels, both correct. 2922 self._test_recall_at_k( 2923 self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2, 2924 weights=[[0], [1]]) 2925 self._test_recall_at_top_k( 2926 self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0, 2927 class_id=2, weights=[[0], [1]]) 2928 2929 # Class 7: 1 label, correct. 2930 self._test_recall_at_k( 2931 self._predictions, self._labels, k=5, expected=1.0 / 1.0, class_id=7, 2932 weights=[[0], [1]]) 2933 self._test_recall_at_top_k( 2934 self._predictions_idx, self._labels, k=5, expected=1.0 / 1.0, 2935 class_id=7, weights=[[0], [1]]) 2936 2937 # Class 7: 1 label, incorrect. 2938 self._test_recall_at_k( 2939 self._predictions, self._labels, k=5, expected=0.0 / 1.0, class_id=7, 2940 weights=[[1], [0]]) 2941 self._test_recall_at_top_k( 2942 self._predictions_idx, self._labels, k=5, expected=0.0 / 1.0, 2943 class_id=7, weights=[[1], [0]]) 2944 2945 # Class 7: 2 labels, 1 correct. 2946 self._test_recall_at_k( 2947 self._predictions, self._labels, k=5, expected=1.0 / 2.0, class_id=7, 2948 weights=[[1, 0], [1, 0]]) 2949 self._test_recall_at_top_k( 2950 self._predictions_idx, self._labels, k=5, expected=1.0 / 2.0, 2951 class_id=7, weights=[[1, 0], [1, 0]]) 2952 2953 # Class 7: No labels. 2954 self._test_recall_at_k( 2955 self._predictions, self._labels, k=5, expected=NAN, class_id=7, 2956 weights=[[0, 1], [0, 1]]) 2957 self._test_recall_at_top_k( 2958 self._predictions_idx, self._labels, k=5, expected=NAN, class_id=7, 2959 weights=[[0, 1], [0, 1]]) 2960 2961 2962class MeanAbsoluteErrorTest(test.TestCase): 2963 2964 def setUp(self): 2965 ops.reset_default_graph() 2966 2967 @test_util.run_deprecated_v1 2968 def testVars(self): 2969 metrics.mean_absolute_error( 2970 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 2971 _assert_metric_variables( 2972 self, ('mean_absolute_error/count:0', 'mean_absolute_error/total:0')) 2973 2974 @test_util.run_deprecated_v1 2975 def testMetricsCollection(self): 2976 my_collection_name = '__metrics__' 2977 mean, _ = metrics.mean_absolute_error( 2978 predictions=array_ops.ones((10, 1)), 2979 labels=array_ops.ones((10, 1)), 2980 metrics_collections=[my_collection_name]) 2981 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 2982 2983 @test_util.run_deprecated_v1 2984 def testUpdatesCollection(self): 2985 my_collection_name = '__updates__' 2986 _, update_op = metrics.mean_absolute_error( 2987 predictions=array_ops.ones((10, 1)), 2988 labels=array_ops.ones((10, 1)), 2989 updates_collections=[my_collection_name]) 2990 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 2991 2992 @test_util.run_deprecated_v1 2993 def testValueTensorIsIdempotent(self): 2994 predictions = random_ops.random_normal((10, 3), seed=1) 2995 labels = random_ops.random_normal((10, 3), seed=2) 2996 error, update_op = metrics.mean_absolute_error(labels, predictions) 2997 2998 with self.cached_session(): 2999 self.evaluate(variables.local_variables_initializer()) 3000 3001 # Run several updates. 3002 for _ in range(10): 3003 self.evaluate(update_op) 3004 3005 # Then verify idempotency. 3006 initial_error = self.evaluate(error) 3007 for _ in range(10): 3008 self.assertEqual(initial_error, self.evaluate(error)) 3009 3010 @test_util.run_deprecated_v1 3011 def testSingleUpdateWithErrorAndWeights(self): 3012 predictions = constant_op.constant( 3013 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3014 labels = constant_op.constant( 3015 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 3016 weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 3017 3018 error, update_op = metrics.mean_absolute_error(labels, predictions, weights) 3019 3020 with self.cached_session(): 3021 self.evaluate(variables.local_variables_initializer()) 3022 self.assertEqual(3, self.evaluate(update_op)) 3023 self.assertEqual(3, self.evaluate(error)) 3024 3025 3026class MeanRelativeErrorTest(test.TestCase): 3027 3028 def setUp(self): 3029 ops.reset_default_graph() 3030 3031 @test_util.run_deprecated_v1 3032 def testVars(self): 3033 metrics.mean_relative_error( 3034 predictions=array_ops.ones((10, 1)), 3035 labels=array_ops.ones((10, 1)), 3036 normalizer=array_ops.ones((10, 1))) 3037 _assert_metric_variables( 3038 self, ('mean_relative_error/count:0', 'mean_relative_error/total:0')) 3039 3040 @test_util.run_deprecated_v1 3041 def testMetricsCollection(self): 3042 my_collection_name = '__metrics__' 3043 mean, _ = metrics.mean_relative_error( 3044 predictions=array_ops.ones((10, 1)), 3045 labels=array_ops.ones((10, 1)), 3046 normalizer=array_ops.ones((10, 1)), 3047 metrics_collections=[my_collection_name]) 3048 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3049 3050 @test_util.run_deprecated_v1 3051 def testUpdatesCollection(self): 3052 my_collection_name = '__updates__' 3053 _, update_op = metrics.mean_relative_error( 3054 predictions=array_ops.ones((10, 1)), 3055 labels=array_ops.ones((10, 1)), 3056 normalizer=array_ops.ones((10, 1)), 3057 updates_collections=[my_collection_name]) 3058 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3059 3060 @test_util.run_deprecated_v1 3061 def testValueTensorIsIdempotent(self): 3062 predictions = random_ops.random_normal((10, 3), seed=1) 3063 labels = random_ops.random_normal((10, 3), seed=2) 3064 normalizer = random_ops.random_normal((10, 3), seed=3) 3065 error, update_op = metrics.mean_relative_error(labels, predictions, 3066 normalizer) 3067 3068 with self.cached_session(): 3069 self.evaluate(variables.local_variables_initializer()) 3070 3071 # Run several updates. 3072 for _ in range(10): 3073 self.evaluate(update_op) 3074 3075 # Then verify idempotency. 3076 initial_error = self.evaluate(error) 3077 for _ in range(10): 3078 self.assertEqual(initial_error, self.evaluate(error)) 3079 3080 @test_util.run_deprecated_v1 3081 def testSingleUpdateNormalizedByLabels(self): 3082 np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32) 3083 np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32) 3084 expected_error = np.mean( 3085 np.divide(np.absolute(np_predictions - np_labels), np_labels)) 3086 3087 predictions = constant_op.constant( 3088 np_predictions, shape=(1, 4), dtype=dtypes_lib.float32) 3089 labels = constant_op.constant(np_labels, shape=(1, 4)) 3090 3091 error, update_op = metrics.mean_relative_error( 3092 labels, predictions, normalizer=labels) 3093 3094 with self.cached_session(): 3095 self.evaluate(variables.local_variables_initializer()) 3096 self.assertEqual(expected_error, self.evaluate(update_op)) 3097 self.assertEqual(expected_error, self.evaluate(error)) 3098 3099 @test_util.run_deprecated_v1 3100 def testSingleUpdateNormalizedByZeros(self): 3101 np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32) 3102 3103 predictions = constant_op.constant( 3104 np_predictions, shape=(1, 4), dtype=dtypes_lib.float32) 3105 labels = constant_op.constant( 3106 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 3107 3108 error, update_op = metrics.mean_relative_error( 3109 labels, predictions, normalizer=array_ops.zeros_like(labels)) 3110 3111 with self.cached_session(): 3112 self.evaluate(variables.local_variables_initializer()) 3113 self.assertEqual(0.0, self.evaluate(update_op)) 3114 self.assertEqual(0.0, self.evaluate(error)) 3115 3116 3117class MeanSquaredErrorTest(test.TestCase): 3118 3119 def setUp(self): 3120 ops.reset_default_graph() 3121 3122 @test_util.run_deprecated_v1 3123 def testVars(self): 3124 metrics.mean_squared_error( 3125 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 3126 _assert_metric_variables( 3127 self, ('mean_squared_error/count:0', 'mean_squared_error/total:0')) 3128 3129 @test_util.run_deprecated_v1 3130 def testMetricsCollection(self): 3131 my_collection_name = '__metrics__' 3132 mean, _ = metrics.mean_squared_error( 3133 predictions=array_ops.ones((10, 1)), 3134 labels=array_ops.ones((10, 1)), 3135 metrics_collections=[my_collection_name]) 3136 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3137 3138 @test_util.run_deprecated_v1 3139 def testUpdatesCollection(self): 3140 my_collection_name = '__updates__' 3141 _, update_op = metrics.mean_squared_error( 3142 predictions=array_ops.ones((10, 1)), 3143 labels=array_ops.ones((10, 1)), 3144 updates_collections=[my_collection_name]) 3145 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3146 3147 @test_util.run_deprecated_v1 3148 def testValueTensorIsIdempotent(self): 3149 predictions = random_ops.random_normal((10, 3), seed=1) 3150 labels = random_ops.random_normal((10, 3), seed=2) 3151 error, update_op = metrics.mean_squared_error(labels, predictions) 3152 3153 with self.cached_session(): 3154 self.evaluate(variables.local_variables_initializer()) 3155 3156 # Run several updates. 3157 for _ in range(10): 3158 self.evaluate(update_op) 3159 3160 # Then verify idempotency. 3161 initial_error = self.evaluate(error) 3162 for _ in range(10): 3163 self.assertEqual(initial_error, self.evaluate(error)) 3164 3165 @test_util.run_deprecated_v1 3166 def testSingleUpdateZeroError(self): 3167 predictions = array_ops.zeros((1, 3), dtype=dtypes_lib.float32) 3168 labels = array_ops.zeros((1, 3), dtype=dtypes_lib.float32) 3169 3170 error, update_op = metrics.mean_squared_error(labels, predictions) 3171 3172 with self.cached_session(): 3173 self.evaluate(variables.local_variables_initializer()) 3174 self.assertEqual(0, self.evaluate(update_op)) 3175 self.assertEqual(0, self.evaluate(error)) 3176 3177 @test_util.run_deprecated_v1 3178 def testSingleUpdateWithError(self): 3179 predictions = constant_op.constant( 3180 [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) 3181 labels = constant_op.constant( 3182 [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32) 3183 3184 error, update_op = metrics.mean_squared_error(labels, predictions) 3185 3186 with self.cached_session(): 3187 self.evaluate(variables.local_variables_initializer()) 3188 self.assertEqual(6, self.evaluate(update_op)) 3189 self.assertEqual(6, self.evaluate(error)) 3190 3191 @test_util.run_deprecated_v1 3192 def testSingleUpdateWithErrorAndWeights(self): 3193 predictions = constant_op.constant( 3194 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3195 labels = constant_op.constant( 3196 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 3197 weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 3198 3199 error, update_op = metrics.mean_squared_error(labels, predictions, weights) 3200 3201 with self.cached_session(): 3202 self.evaluate(variables.local_variables_initializer()) 3203 self.assertEqual(13, self.evaluate(update_op)) 3204 self.assertEqual(13, self.evaluate(error)) 3205 3206 @test_util.run_deprecated_v1 3207 def testMultipleBatchesOfSizeOne(self): 3208 with self.cached_session() as sess: 3209 # Create the queue that populates the predictions. 3210 preds_queue = data_flow_ops.FIFOQueue( 3211 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3212 _enqueue_vector(sess, preds_queue, [10, 8, 6]) 3213 _enqueue_vector(sess, preds_queue, [-4, 3, -1]) 3214 predictions = preds_queue.dequeue() 3215 3216 # Create the queue that populates the labels. 3217 labels_queue = data_flow_ops.FIFOQueue( 3218 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3219 _enqueue_vector(sess, labels_queue, [1, 3, 2]) 3220 _enqueue_vector(sess, labels_queue, [2, 4, 6]) 3221 labels = labels_queue.dequeue() 3222 3223 error, update_op = metrics.mean_squared_error(labels, predictions) 3224 3225 self.evaluate(variables.local_variables_initializer()) 3226 self.evaluate(update_op) 3227 self.assertAlmostEqual(208.0 / 6, self.evaluate(update_op), 5) 3228 3229 self.assertAlmostEqual(208.0 / 6, self.evaluate(error), 5) 3230 3231 @test_util.run_deprecated_v1 3232 def testMetricsComputedConcurrently(self): 3233 with self.cached_session() as sess: 3234 # Create the queue that populates one set of predictions. 3235 preds_queue0 = data_flow_ops.FIFOQueue( 3236 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3237 _enqueue_vector(sess, preds_queue0, [10, 8, 6]) 3238 _enqueue_vector(sess, preds_queue0, [-4, 3, -1]) 3239 predictions0 = preds_queue0.dequeue() 3240 3241 # Create the queue that populates one set of predictions. 3242 preds_queue1 = data_flow_ops.FIFOQueue( 3243 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3244 _enqueue_vector(sess, preds_queue1, [0, 1, 1]) 3245 _enqueue_vector(sess, preds_queue1, [1, 1, 0]) 3246 predictions1 = preds_queue1.dequeue() 3247 3248 # Create the queue that populates one set of labels. 3249 labels_queue0 = data_flow_ops.FIFOQueue( 3250 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3251 _enqueue_vector(sess, labels_queue0, [1, 3, 2]) 3252 _enqueue_vector(sess, labels_queue0, [2, 4, 6]) 3253 labels0 = labels_queue0.dequeue() 3254 3255 # Create the queue that populates another set of labels. 3256 labels_queue1 = data_flow_ops.FIFOQueue( 3257 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3258 _enqueue_vector(sess, labels_queue1, [-5, -3, -1]) 3259 _enqueue_vector(sess, labels_queue1, [5, 4, 3]) 3260 labels1 = labels_queue1.dequeue() 3261 3262 mse0, update_op0 = metrics.mean_squared_error( 3263 labels0, predictions0, name='msd0') 3264 mse1, update_op1 = metrics.mean_squared_error( 3265 labels1, predictions1, name='msd1') 3266 3267 self.evaluate(variables.local_variables_initializer()) 3268 self.evaluate([update_op0, update_op1]) 3269 self.evaluate([update_op0, update_op1]) 3270 3271 mse0, mse1 = self.evaluate([mse0, mse1]) 3272 self.assertAlmostEqual(208.0 / 6, mse0, 5) 3273 self.assertAlmostEqual(79.0 / 6, mse1, 5) 3274 3275 @test_util.run_deprecated_v1 3276 def testMultipleMetricsOnMultipleBatchesOfSizeOne(self): 3277 with self.cached_session() as sess: 3278 # Create the queue that populates the predictions. 3279 preds_queue = data_flow_ops.FIFOQueue( 3280 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3281 _enqueue_vector(sess, preds_queue, [10, 8, 6]) 3282 _enqueue_vector(sess, preds_queue, [-4, 3, -1]) 3283 predictions = preds_queue.dequeue() 3284 3285 # Create the queue that populates the labels. 3286 labels_queue = data_flow_ops.FIFOQueue( 3287 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3288 _enqueue_vector(sess, labels_queue, [1, 3, 2]) 3289 _enqueue_vector(sess, labels_queue, [2, 4, 6]) 3290 labels = labels_queue.dequeue() 3291 3292 mae, ma_update_op = metrics.mean_absolute_error(labels, predictions) 3293 mse, ms_update_op = metrics.mean_squared_error(labels, predictions) 3294 3295 self.evaluate(variables.local_variables_initializer()) 3296 self.evaluate([ma_update_op, ms_update_op]) 3297 self.evaluate([ma_update_op, ms_update_op]) 3298 3299 self.assertAlmostEqual(32.0 / 6, self.evaluate(mae), 5) 3300 self.assertAlmostEqual(208.0 / 6, self.evaluate(mse), 5) 3301 3302 3303class RootMeanSquaredErrorTest(test.TestCase): 3304 3305 def setUp(self): 3306 ops.reset_default_graph() 3307 3308 @test_util.run_deprecated_v1 3309 def testVars(self): 3310 metrics.root_mean_squared_error( 3311 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 3312 _assert_metric_variables( 3313 self, 3314 ('root_mean_squared_error/count:0', 'root_mean_squared_error/total:0')) 3315 3316 @test_util.run_deprecated_v1 3317 def testMetricsCollection(self): 3318 my_collection_name = '__metrics__' 3319 mean, _ = metrics.root_mean_squared_error( 3320 predictions=array_ops.ones((10, 1)), 3321 labels=array_ops.ones((10, 1)), 3322 metrics_collections=[my_collection_name]) 3323 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3324 3325 @test_util.run_deprecated_v1 3326 def testUpdatesCollection(self): 3327 my_collection_name = '__updates__' 3328 _, update_op = metrics.root_mean_squared_error( 3329 predictions=array_ops.ones((10, 1)), 3330 labels=array_ops.ones((10, 1)), 3331 updates_collections=[my_collection_name]) 3332 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3333 3334 @test_util.run_deprecated_v1 3335 def testValueTensorIsIdempotent(self): 3336 predictions = random_ops.random_normal((10, 3), seed=1) 3337 labels = random_ops.random_normal((10, 3), seed=2) 3338 error, update_op = metrics.root_mean_squared_error(labels, predictions) 3339 3340 with self.cached_session(): 3341 self.evaluate(variables.local_variables_initializer()) 3342 3343 # Run several updates. 3344 for _ in range(10): 3345 self.evaluate(update_op) 3346 3347 # Then verify idempotency. 3348 initial_error = self.evaluate(error) 3349 for _ in range(10): 3350 self.assertEqual(initial_error, self.evaluate(error)) 3351 3352 @test_util.run_deprecated_v1 3353 def testSingleUpdateZeroError(self): 3354 with self.cached_session(): 3355 predictions = constant_op.constant( 3356 0.0, shape=(1, 3), dtype=dtypes_lib.float32) 3357 labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32) 3358 3359 rmse, update_op = metrics.root_mean_squared_error(labels, predictions) 3360 3361 self.evaluate(variables.local_variables_initializer()) 3362 self.assertEqual(0, self.evaluate(update_op)) 3363 3364 self.assertEqual(0, self.evaluate(rmse)) 3365 3366 @test_util.run_deprecated_v1 3367 def testSingleUpdateWithError(self): 3368 with self.cached_session(): 3369 predictions = constant_op.constant( 3370 [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) 3371 labels = constant_op.constant( 3372 [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32) 3373 3374 rmse, update_op = metrics.root_mean_squared_error(labels, predictions) 3375 3376 self.evaluate(variables.local_variables_initializer()) 3377 self.assertAlmostEqual(math.sqrt(6), self.evaluate(update_op), 5) 3378 self.assertAlmostEqual(math.sqrt(6), self.evaluate(rmse), 5) 3379 3380 @test_util.run_deprecated_v1 3381 def testSingleUpdateWithErrorAndWeights(self): 3382 with self.cached_session(): 3383 predictions = constant_op.constant( 3384 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3385 labels = constant_op.constant( 3386 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 3387 weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 3388 3389 rmse, update_op = metrics.root_mean_squared_error(labels, predictions, 3390 weights) 3391 3392 self.evaluate(variables.local_variables_initializer()) 3393 self.assertAlmostEqual(math.sqrt(13), self.evaluate(update_op)) 3394 3395 self.assertAlmostEqual(math.sqrt(13), self.evaluate(rmse), 5) 3396 3397 3398def _reweight(predictions, labels, weights): 3399 return (np.concatenate([[p] * int(w) for p, w in zip(predictions, weights)]), 3400 np.concatenate([[l] * int(w) for l, w in zip(labels, weights)])) 3401 3402 3403class MeanCosineDistanceTest(test.TestCase): 3404 3405 def setUp(self): 3406 ops.reset_default_graph() 3407 3408 @test_util.run_deprecated_v1 3409 def testVars(self): 3410 metrics.mean_cosine_distance( 3411 predictions=array_ops.ones((10, 3)), 3412 labels=array_ops.ones((10, 3)), 3413 dim=1) 3414 _assert_metric_variables(self, ( 3415 'mean_cosine_distance/count:0', 3416 'mean_cosine_distance/total:0', 3417 )) 3418 3419 @test_util.run_deprecated_v1 3420 def testMetricsCollection(self): 3421 my_collection_name = '__metrics__' 3422 mean, _ = metrics.mean_cosine_distance( 3423 predictions=array_ops.ones((10, 3)), 3424 labels=array_ops.ones((10, 3)), 3425 dim=1, 3426 metrics_collections=[my_collection_name]) 3427 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3428 3429 @test_util.run_deprecated_v1 3430 def testUpdatesCollection(self): 3431 my_collection_name = '__updates__' 3432 _, update_op = metrics.mean_cosine_distance( 3433 predictions=array_ops.ones((10, 3)), 3434 labels=array_ops.ones((10, 3)), 3435 dim=1, 3436 updates_collections=[my_collection_name]) 3437 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3438 3439 @test_util.run_deprecated_v1 3440 def testValueTensorIsIdempotent(self): 3441 predictions = random_ops.random_normal((10, 3), seed=1) 3442 labels = random_ops.random_normal((10, 3), seed=2) 3443 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1) 3444 3445 with self.cached_session(): 3446 self.evaluate(variables.local_variables_initializer()) 3447 3448 # Run several updates. 3449 for _ in range(10): 3450 self.evaluate(update_op) 3451 3452 # Then verify idempotency. 3453 initial_error = self.evaluate(error) 3454 for _ in range(10): 3455 self.assertEqual(initial_error, self.evaluate(error)) 3456 3457 @test_util.run_deprecated_v1 3458 def testSingleUpdateZeroError(self): 3459 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3460 3461 predictions = constant_op.constant( 3462 np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32) 3463 labels = constant_op.constant( 3464 np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32) 3465 3466 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2) 3467 3468 with self.cached_session(): 3469 self.evaluate(variables.local_variables_initializer()) 3470 self.assertEqual(0, self.evaluate(update_op)) 3471 self.assertEqual(0, self.evaluate(error)) 3472 3473 @test_util.run_deprecated_v1 3474 def testSingleUpdateWithError1(self): 3475 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3476 np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0')) 3477 3478 predictions = constant_op.constant( 3479 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3480 labels = constant_op.constant( 3481 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3482 3483 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2) 3484 3485 with self.cached_session(): 3486 self.evaluate(variables.local_variables_initializer()) 3487 self.assertAlmostEqual(1, self.evaluate(update_op), 5) 3488 self.assertAlmostEqual(1, self.evaluate(error), 5) 3489 3490 @test_util.run_deprecated_v1 3491 def testSingleUpdateWithError2(self): 3492 np_predictions = np.matrix( 3493 ('0.819031913261206 0.567041924552012 0.087465312324590;' 3494 '-0.665139432070255 -0.739487441769973 -0.103671883216994;' 3495 '0.707106781186548 -0.707106781186548 0')) 3496 np_labels = np.matrix( 3497 ('0.819031913261206 0.567041924552012 0.087465312324590;' 3498 '0.665139432070255 0.739487441769973 0.103671883216994;' 3499 '0.707106781186548 0.707106781186548 0')) 3500 3501 predictions = constant_op.constant( 3502 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3503 labels = constant_op.constant( 3504 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3505 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2) 3506 3507 with self.cached_session(): 3508 self.evaluate(variables.local_variables_initializer()) 3509 self.assertAlmostEqual(1.0, self.evaluate(update_op), 5) 3510 self.assertAlmostEqual(1.0, self.evaluate(error), 5) 3511 3512 @test_util.run_deprecated_v1 3513 def testSingleUpdateWithErrorAndWeights1(self): 3514 np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0')) 3515 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3516 3517 predictions = constant_op.constant( 3518 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3519 labels = constant_op.constant( 3520 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3521 weights = constant_op.constant( 3522 [1, 0, 0], shape=(3, 1, 1), dtype=dtypes_lib.float32) 3523 3524 error, update_op = metrics.mean_cosine_distance( 3525 labels, predictions, dim=2, weights=weights) 3526 3527 with self.cached_session(): 3528 self.evaluate(variables.local_variables_initializer()) 3529 self.assertEqual(0, self.evaluate(update_op)) 3530 self.assertEqual(0, self.evaluate(error)) 3531 3532 @test_util.run_deprecated_v1 3533 def testSingleUpdateWithErrorAndWeights2(self): 3534 np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0')) 3535 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3536 3537 predictions = constant_op.constant( 3538 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3539 labels = constant_op.constant( 3540 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3541 weights = constant_op.constant( 3542 [0, 1, 1], shape=(3, 1, 1), dtype=dtypes_lib.float32) 3543 3544 error, update_op = metrics.mean_cosine_distance( 3545 labels, predictions, dim=2, weights=weights) 3546 3547 with self.cached_session(): 3548 self.evaluate(variables.local_variables_initializer()) 3549 self.assertEqual(1.5, self.evaluate(update_op)) 3550 self.assertEqual(1.5, self.evaluate(error)) 3551 3552 3553class PcntBelowThreshTest(test.TestCase): 3554 3555 def setUp(self): 3556 ops.reset_default_graph() 3557 3558 @test_util.run_deprecated_v1 3559 def testVars(self): 3560 metrics.percentage_below(values=array_ops.ones((10,)), threshold=2) 3561 _assert_metric_variables(self, ( 3562 'percentage_below_threshold/count:0', 3563 'percentage_below_threshold/total:0', 3564 )) 3565 3566 @test_util.run_deprecated_v1 3567 def testMetricsCollection(self): 3568 my_collection_name = '__metrics__' 3569 mean, _ = metrics.percentage_below( 3570 values=array_ops.ones((10,)), 3571 threshold=2, 3572 metrics_collections=[my_collection_name]) 3573 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3574 3575 @test_util.run_deprecated_v1 3576 def testUpdatesCollection(self): 3577 my_collection_name = '__updates__' 3578 _, update_op = metrics.percentage_below( 3579 values=array_ops.ones((10,)), 3580 threshold=2, 3581 updates_collections=[my_collection_name]) 3582 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3583 3584 @test_util.run_deprecated_v1 3585 def testOneUpdate(self): 3586 with self.cached_session(): 3587 values = constant_op.constant( 3588 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3589 3590 pcnt0, update_op0 = metrics.percentage_below(values, 100, name='high') 3591 pcnt1, update_op1 = metrics.percentage_below(values, 7, name='medium') 3592 pcnt2, update_op2 = metrics.percentage_below(values, 1, name='low') 3593 3594 self.evaluate(variables.local_variables_initializer()) 3595 self.evaluate([update_op0, update_op1, update_op2]) 3596 3597 pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2]) 3598 self.assertAlmostEqual(1.0, pcnt0, 5) 3599 self.assertAlmostEqual(0.75, pcnt1, 5) 3600 self.assertAlmostEqual(0.0, pcnt2, 5) 3601 3602 @test_util.run_deprecated_v1 3603 def testSomePresentOneUpdate(self): 3604 with self.cached_session(): 3605 values = constant_op.constant( 3606 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3607 weights = constant_op.constant( 3608 [1, 0, 0, 1], shape=(1, 4), dtype=dtypes_lib.float32) 3609 3610 pcnt0, update_op0 = metrics.percentage_below( 3611 values, 100, weights=weights, name='high') 3612 pcnt1, update_op1 = metrics.percentage_below( 3613 values, 7, weights=weights, name='medium') 3614 pcnt2, update_op2 = metrics.percentage_below( 3615 values, 1, weights=weights, name='low') 3616 3617 self.evaluate(variables.local_variables_initializer()) 3618 self.assertListEqual([1.0, 0.5, 0.0], 3619 self.evaluate([update_op0, update_op1, update_op2])) 3620 3621 pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2]) 3622 self.assertAlmostEqual(1.0, pcnt0, 5) 3623 self.assertAlmostEqual(0.5, pcnt1, 5) 3624 self.assertAlmostEqual(0.0, pcnt2, 5) 3625 3626 3627class MeanIOUTest(test.TestCase): 3628 3629 def setUp(self): 3630 np.random.seed(1) 3631 ops.reset_default_graph() 3632 3633 @test_util.run_deprecated_v1 3634 def testVars(self): 3635 metrics.mean_iou( 3636 predictions=array_ops.ones([10, 1]), 3637 labels=array_ops.ones([10, 1]), 3638 num_classes=2) 3639 _assert_metric_variables(self, ('mean_iou/total_confusion_matrix:0',)) 3640 3641 @test_util.run_deprecated_v1 3642 def testMetricsCollections(self): 3643 my_collection_name = '__metrics__' 3644 mean_iou, _ = metrics.mean_iou( 3645 predictions=array_ops.ones([10, 1]), 3646 labels=array_ops.ones([10, 1]), 3647 num_classes=2, 3648 metrics_collections=[my_collection_name]) 3649 self.assertListEqual(ops.get_collection(my_collection_name), [mean_iou]) 3650 3651 @test_util.run_deprecated_v1 3652 def testUpdatesCollection(self): 3653 my_collection_name = '__updates__' 3654 _, update_op = metrics.mean_iou( 3655 predictions=array_ops.ones([10, 1]), 3656 labels=array_ops.ones([10, 1]), 3657 num_classes=2, 3658 updates_collections=[my_collection_name]) 3659 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3660 3661 @test_util.run_deprecated_v1 3662 def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self): 3663 predictions = array_ops.ones([10, 3]) 3664 labels = array_ops.ones([10, 4]) 3665 with self.assertRaises(ValueError): 3666 metrics.mean_iou(labels, predictions, num_classes=2) 3667 3668 @test_util.run_deprecated_v1 3669 def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self): 3670 predictions = array_ops.ones([10]) 3671 labels = array_ops.ones([10]) 3672 weights = array_ops.zeros([9]) 3673 with self.assertRaises(ValueError): 3674 metrics.mean_iou(labels, predictions, num_classes=2, weights=weights) 3675 3676 @test_util.run_deprecated_v1 3677 def testValueTensorIsIdempotent(self): 3678 num_classes = 3 3679 predictions = random_ops.random_uniform( 3680 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 3681 labels = random_ops.random_uniform( 3682 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 3683 mean_iou, update_op = metrics.mean_iou( 3684 labels, predictions, num_classes=num_classes) 3685 3686 with self.cached_session(): 3687 self.evaluate(variables.local_variables_initializer()) 3688 3689 # Run several updates. 3690 for _ in range(10): 3691 self.evaluate(update_op) 3692 3693 # Then verify idempotency. 3694 initial_mean_iou = self.evaluate(mean_iou) 3695 for _ in range(10): 3696 self.assertEqual(initial_mean_iou, self.evaluate(mean_iou)) 3697 3698 @test_util.run_deprecated_v1 3699 def testMultipleUpdates(self): 3700 num_classes = 3 3701 with self.cached_session() as sess: 3702 # Create the queue that populates the predictions. 3703 preds_queue = data_flow_ops.FIFOQueue( 3704 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3705 _enqueue_vector(sess, preds_queue, [0]) 3706 _enqueue_vector(sess, preds_queue, [1]) 3707 _enqueue_vector(sess, preds_queue, [2]) 3708 _enqueue_vector(sess, preds_queue, [1]) 3709 _enqueue_vector(sess, preds_queue, [0]) 3710 predictions = preds_queue.dequeue() 3711 3712 # Create the queue that populates the labels. 3713 labels_queue = data_flow_ops.FIFOQueue( 3714 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3715 _enqueue_vector(sess, labels_queue, [0]) 3716 _enqueue_vector(sess, labels_queue, [1]) 3717 _enqueue_vector(sess, labels_queue, [1]) 3718 _enqueue_vector(sess, labels_queue, [2]) 3719 _enqueue_vector(sess, labels_queue, [1]) 3720 labels = labels_queue.dequeue() 3721 3722 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3723 3724 self.evaluate(variables.local_variables_initializer()) 3725 for _ in range(5): 3726 self.evaluate(update_op) 3727 desired_output = np.mean([1.0 / 2.0, 1.0 / 4.0, 0.]) 3728 self.assertEqual(desired_output, self.evaluate(miou)) 3729 3730 @test_util.run_deprecated_v1 3731 def testMultipleUpdatesWithWeights(self): 3732 num_classes = 2 3733 with self.cached_session() as sess: 3734 # Create the queue that populates the predictions. 3735 preds_queue = data_flow_ops.FIFOQueue( 3736 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3737 _enqueue_vector(sess, preds_queue, [0]) 3738 _enqueue_vector(sess, preds_queue, [1]) 3739 _enqueue_vector(sess, preds_queue, [0]) 3740 _enqueue_vector(sess, preds_queue, [1]) 3741 _enqueue_vector(sess, preds_queue, [0]) 3742 _enqueue_vector(sess, preds_queue, [1]) 3743 predictions = preds_queue.dequeue() 3744 3745 # Create the queue that populates the labels. 3746 labels_queue = data_flow_ops.FIFOQueue( 3747 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3748 _enqueue_vector(sess, labels_queue, [0]) 3749 _enqueue_vector(sess, labels_queue, [1]) 3750 _enqueue_vector(sess, labels_queue, [1]) 3751 _enqueue_vector(sess, labels_queue, [0]) 3752 _enqueue_vector(sess, labels_queue, [0]) 3753 _enqueue_vector(sess, labels_queue, [1]) 3754 labels = labels_queue.dequeue() 3755 3756 # Create the queue that populates the weights. 3757 weights_queue = data_flow_ops.FIFOQueue( 3758 6, dtypes=dtypes_lib.float32, shapes=(1, 1)) 3759 _enqueue_vector(sess, weights_queue, [1.0]) 3760 _enqueue_vector(sess, weights_queue, [1.0]) 3761 _enqueue_vector(sess, weights_queue, [1.0]) 3762 _enqueue_vector(sess, weights_queue, [0.0]) 3763 _enqueue_vector(sess, weights_queue, [1.0]) 3764 _enqueue_vector(sess, weights_queue, [0.0]) 3765 weights = weights_queue.dequeue() 3766 3767 mean_iou, update_op = metrics.mean_iou( 3768 labels, predictions, num_classes, weights=weights) 3769 3770 variables.local_variables_initializer().run() 3771 for _ in range(6): 3772 self.evaluate(update_op) 3773 desired_output = np.mean([2.0 / 3.0, 1.0 / 2.0]) 3774 self.assertAlmostEqual(desired_output, self.evaluate(mean_iou)) 3775 3776 @test_util.run_deprecated_v1 3777 def testMultipleUpdatesWithMissingClass(self): 3778 # Test the case where there are no predictions and labels for 3779 # one class, and thus there is one row and one column with 3780 # zero entries in the confusion matrix. 3781 num_classes = 3 3782 with self.cached_session() as sess: 3783 # Create the queue that populates the predictions. 3784 # There is no prediction for class 2. 3785 preds_queue = data_flow_ops.FIFOQueue( 3786 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3787 _enqueue_vector(sess, preds_queue, [0]) 3788 _enqueue_vector(sess, preds_queue, [1]) 3789 _enqueue_vector(sess, preds_queue, [1]) 3790 _enqueue_vector(sess, preds_queue, [1]) 3791 _enqueue_vector(sess, preds_queue, [0]) 3792 predictions = preds_queue.dequeue() 3793 3794 # Create the queue that populates the labels. 3795 # There is label for class 2. 3796 labels_queue = data_flow_ops.FIFOQueue( 3797 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3798 _enqueue_vector(sess, labels_queue, [0]) 3799 _enqueue_vector(sess, labels_queue, [1]) 3800 _enqueue_vector(sess, labels_queue, [1]) 3801 _enqueue_vector(sess, labels_queue, [0]) 3802 _enqueue_vector(sess, labels_queue, [1]) 3803 labels = labels_queue.dequeue() 3804 3805 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3806 3807 self.evaluate(variables.local_variables_initializer()) 3808 for _ in range(5): 3809 self.evaluate(update_op) 3810 desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0]) 3811 self.assertAlmostEqual(desired_output, self.evaluate(miou)) 3812 3813 @test_util.run_deprecated_v1 3814 def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): 3815 predictions = array_ops.concat( 3816 [ 3817 constant_op.constant( 3818 0, shape=[5]), constant_op.constant( 3819 1, shape=[5]) 3820 ], 3821 0) 3822 labels = array_ops.concat( 3823 [ 3824 constant_op.constant( 3825 0, shape=[3]), constant_op.constant( 3826 1, shape=[7]) 3827 ], 3828 0) 3829 num_classes = 2 3830 with self.cached_session(): 3831 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3832 self.evaluate(variables.local_variables_initializer()) 3833 confusion_matrix = self.evaluate(update_op) 3834 self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix) 3835 desired_miou = np.mean([3. / 5., 5. / 7.]) 3836 self.assertAlmostEqual(desired_miou, self.evaluate(miou)) 3837 3838 @test_util.run_deprecated_v1 3839 def testAllCorrect(self): 3840 predictions = array_ops.zeros([40]) 3841 labels = array_ops.zeros([40]) 3842 num_classes = 1 3843 with self.cached_session(): 3844 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3845 self.evaluate(variables.local_variables_initializer()) 3846 self.assertEqual(40, self.evaluate(update_op)[0]) 3847 self.assertEqual(1.0, self.evaluate(miou)) 3848 3849 @test_util.run_deprecated_v1 3850 def testAllWrong(self): 3851 predictions = array_ops.zeros([40]) 3852 labels = array_ops.ones([40]) 3853 num_classes = 2 3854 with self.cached_session(): 3855 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3856 self.evaluate(variables.local_variables_initializer()) 3857 self.assertAllEqual([[0, 0], [40, 0]], update_op) 3858 self.assertEqual(0., self.evaluate(miou)) 3859 3860 @test_util.run_deprecated_v1 3861 def testResultsWithSomeMissing(self): 3862 predictions = array_ops.concat( 3863 [ 3864 constant_op.constant( 3865 0, shape=[5]), constant_op.constant( 3866 1, shape=[5]) 3867 ], 3868 0) 3869 labels = array_ops.concat( 3870 [ 3871 constant_op.constant( 3872 0, shape=[3]), constant_op.constant( 3873 1, shape=[7]) 3874 ], 3875 0) 3876 num_classes = 2 3877 weights = array_ops.concat( 3878 [ 3879 constant_op.constant( 3880 0, shape=[1]), constant_op.constant( 3881 1, shape=[8]), constant_op.constant( 3882 0, shape=[1]) 3883 ], 3884 0) 3885 with self.cached_session(): 3886 miou, update_op = metrics.mean_iou( 3887 labels, predictions, num_classes, weights=weights) 3888 self.evaluate(variables.local_variables_initializer()) 3889 self.assertAllEqual([[2, 0], [2, 4]], update_op) 3890 desired_miou = np.mean([2. / 4., 4. / 6.]) 3891 self.assertAlmostEqual(desired_miou, self.evaluate(miou)) 3892 3893 @test_util.run_deprecated_v1 3894 def testMissingClassInLabels(self): 3895 labels = constant_op.constant([ 3896 [[0, 0, 1, 1, 0, 0], 3897 [1, 0, 0, 0, 0, 1]], 3898 [[1, 1, 1, 1, 1, 1], 3899 [0, 0, 0, 0, 0, 0]]]) 3900 predictions = constant_op.constant([ 3901 [[0, 0, 2, 1, 1, 0], 3902 [0, 1, 2, 2, 0, 1]], 3903 [[0, 0, 2, 1, 1, 1], 3904 [1, 1, 2, 0, 0, 0]]]) 3905 num_classes = 3 3906 with self.cached_session(): 3907 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3908 self.evaluate(variables.local_variables_initializer()) 3909 self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op) 3910 self.assertAlmostEqual( 3911 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)), 3912 self.evaluate(miou)) 3913 3914 @test_util.run_deprecated_v1 3915 def testMissingClassOverallSmall(self): 3916 labels = constant_op.constant([0]) 3917 predictions = constant_op.constant([0]) 3918 num_classes = 2 3919 with self.cached_session(): 3920 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3921 self.evaluate(variables.local_variables_initializer()) 3922 self.assertAllEqual([[1, 0], [0, 0]], update_op) 3923 self.assertAlmostEqual(1, self.evaluate(miou)) 3924 3925 @test_util.run_deprecated_v1 3926 def testMissingClassOverallLarge(self): 3927 labels = constant_op.constant([ 3928 [[0, 0, 1, 1, 0, 0], 3929 [1, 0, 0, 0, 0, 1]], 3930 [[1, 1, 1, 1, 1, 1], 3931 [0, 0, 0, 0, 0, 0]]]) 3932 predictions = constant_op.constant([ 3933 [[0, 0, 1, 1, 0, 0], 3934 [1, 1, 0, 0, 1, 1]], 3935 [[0, 0, 0, 1, 1, 1], 3936 [1, 1, 1, 0, 0, 0]]]) 3937 num_classes = 3 3938 with self.cached_session(): 3939 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3940 self.evaluate(variables.local_variables_initializer()) 3941 self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op) 3942 self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), 3943 self.evaluate(miou)) 3944 3945 3946class MeanPerClassAccuracyTest(test.TestCase): 3947 3948 def setUp(self): 3949 np.random.seed(1) 3950 ops.reset_default_graph() 3951 3952 @test_util.run_deprecated_v1 3953 def testVars(self): 3954 metrics.mean_per_class_accuracy( 3955 predictions=array_ops.ones([10, 1]), 3956 labels=array_ops.ones([10, 1]), 3957 num_classes=2) 3958 _assert_metric_variables(self, ('mean_accuracy/count:0', 3959 'mean_accuracy/total:0')) 3960 3961 @test_util.run_deprecated_v1 3962 def testMetricsCollections(self): 3963 my_collection_name = '__metrics__' 3964 mean_accuracy, _ = metrics.mean_per_class_accuracy( 3965 predictions=array_ops.ones([10, 1]), 3966 labels=array_ops.ones([10, 1]), 3967 num_classes=2, 3968 metrics_collections=[my_collection_name]) 3969 self.assertListEqual( 3970 ops.get_collection(my_collection_name), [mean_accuracy]) 3971 3972 @test_util.run_deprecated_v1 3973 def testUpdatesCollection(self): 3974 my_collection_name = '__updates__' 3975 _, update_op = metrics.mean_per_class_accuracy( 3976 predictions=array_ops.ones([10, 1]), 3977 labels=array_ops.ones([10, 1]), 3978 num_classes=2, 3979 updates_collections=[my_collection_name]) 3980 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3981 3982 @test_util.run_deprecated_v1 3983 def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self): 3984 predictions = array_ops.ones([10, 3]) 3985 labels = array_ops.ones([10, 4]) 3986 with self.assertRaises(ValueError): 3987 metrics.mean_per_class_accuracy(labels, predictions, num_classes=2) 3988 3989 @test_util.run_deprecated_v1 3990 def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self): 3991 predictions = array_ops.ones([10]) 3992 labels = array_ops.ones([10]) 3993 weights = array_ops.zeros([9]) 3994 with self.assertRaises(ValueError): 3995 metrics.mean_per_class_accuracy( 3996 labels, predictions, num_classes=2, weights=weights) 3997 3998 @test_util.run_deprecated_v1 3999 def testValueTensorIsIdempotent(self): 4000 num_classes = 3 4001 predictions = random_ops.random_uniform( 4002 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 4003 labels = random_ops.random_uniform( 4004 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 4005 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4006 labels, predictions, num_classes=num_classes) 4007 4008 with self.cached_session(): 4009 self.evaluate(variables.local_variables_initializer()) 4010 4011 # Run several updates. 4012 for _ in range(10): 4013 self.evaluate(update_op) 4014 4015 # Then verify idempotency. 4016 initial_mean_accuracy = self.evaluate(mean_accuracy) 4017 for _ in range(10): 4018 self.assertEqual(initial_mean_accuracy, self.evaluate(mean_accuracy)) 4019 4020 num_classes = 3 4021 with self.cached_session() as sess: 4022 # Create the queue that populates the predictions. 4023 preds_queue = data_flow_ops.FIFOQueue( 4024 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 4025 _enqueue_vector(sess, preds_queue, [0]) 4026 _enqueue_vector(sess, preds_queue, [1]) 4027 _enqueue_vector(sess, preds_queue, [2]) 4028 _enqueue_vector(sess, preds_queue, [1]) 4029 _enqueue_vector(sess, preds_queue, [0]) 4030 predictions = preds_queue.dequeue() 4031 4032 # Create the queue that populates the labels. 4033 labels_queue = data_flow_ops.FIFOQueue( 4034 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 4035 _enqueue_vector(sess, labels_queue, [0]) 4036 _enqueue_vector(sess, labels_queue, [1]) 4037 _enqueue_vector(sess, labels_queue, [1]) 4038 _enqueue_vector(sess, labels_queue, [2]) 4039 _enqueue_vector(sess, labels_queue, [1]) 4040 labels = labels_queue.dequeue() 4041 4042 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4043 labels, predictions, num_classes) 4044 4045 self.evaluate(variables.local_variables_initializer()) 4046 for _ in range(5): 4047 self.evaluate(update_op) 4048 desired_output = np.mean([1.0, 1.0 / 3.0, 0.0]) 4049 self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy)) 4050 4051 @test_util.run_deprecated_v1 4052 def testMultipleUpdatesWithWeights(self): 4053 num_classes = 2 4054 with self.cached_session() as sess: 4055 # Create the queue that populates the predictions. 4056 preds_queue = data_flow_ops.FIFOQueue( 4057 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 4058 _enqueue_vector(sess, preds_queue, [0]) 4059 _enqueue_vector(sess, preds_queue, [1]) 4060 _enqueue_vector(sess, preds_queue, [0]) 4061 _enqueue_vector(sess, preds_queue, [1]) 4062 _enqueue_vector(sess, preds_queue, [0]) 4063 _enqueue_vector(sess, preds_queue, [1]) 4064 predictions = preds_queue.dequeue() 4065 4066 # Create the queue that populates the labels. 4067 labels_queue = data_flow_ops.FIFOQueue( 4068 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 4069 _enqueue_vector(sess, labels_queue, [0]) 4070 _enqueue_vector(sess, labels_queue, [1]) 4071 _enqueue_vector(sess, labels_queue, [1]) 4072 _enqueue_vector(sess, labels_queue, [0]) 4073 _enqueue_vector(sess, labels_queue, [0]) 4074 _enqueue_vector(sess, labels_queue, [1]) 4075 labels = labels_queue.dequeue() 4076 4077 # Create the queue that populates the weights. 4078 weights_queue = data_flow_ops.FIFOQueue( 4079 6, dtypes=dtypes_lib.float32, shapes=(1, 1)) 4080 _enqueue_vector(sess, weights_queue, [1.0]) 4081 _enqueue_vector(sess, weights_queue, [0.5]) 4082 _enqueue_vector(sess, weights_queue, [1.0]) 4083 _enqueue_vector(sess, weights_queue, [0.0]) 4084 _enqueue_vector(sess, weights_queue, [1.0]) 4085 _enqueue_vector(sess, weights_queue, [0.0]) 4086 weights = weights_queue.dequeue() 4087 4088 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4089 labels, predictions, num_classes, weights=weights) 4090 4091 variables.local_variables_initializer().run() 4092 for _ in range(6): 4093 self.evaluate(update_op) 4094 desired_output = np.mean([2.0 / 2.0, 0.5 / 1.5]) 4095 self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy)) 4096 4097 @test_util.run_deprecated_v1 4098 def testMultipleUpdatesWithMissingClass(self): 4099 # Test the case where there are no predictions and labels for 4100 # one class, and thus there is one row and one column with 4101 # zero entries in the confusion matrix. 4102 num_classes = 3 4103 with self.cached_session() as sess: 4104 # Create the queue that populates the predictions. 4105 # There is no prediction for class 2. 4106 preds_queue = data_flow_ops.FIFOQueue( 4107 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 4108 _enqueue_vector(sess, preds_queue, [0]) 4109 _enqueue_vector(sess, preds_queue, [1]) 4110 _enqueue_vector(sess, preds_queue, [1]) 4111 _enqueue_vector(sess, preds_queue, [1]) 4112 _enqueue_vector(sess, preds_queue, [0]) 4113 predictions = preds_queue.dequeue() 4114 4115 # Create the queue that populates the labels. 4116 # There is label for class 2. 4117 labels_queue = data_flow_ops.FIFOQueue( 4118 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 4119 _enqueue_vector(sess, labels_queue, [0]) 4120 _enqueue_vector(sess, labels_queue, [1]) 4121 _enqueue_vector(sess, labels_queue, [1]) 4122 _enqueue_vector(sess, labels_queue, [0]) 4123 _enqueue_vector(sess, labels_queue, [1]) 4124 labels = labels_queue.dequeue() 4125 4126 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4127 labels, predictions, num_classes) 4128 4129 self.evaluate(variables.local_variables_initializer()) 4130 for _ in range(5): 4131 self.evaluate(update_op) 4132 desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.]) 4133 self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy)) 4134 4135 @test_util.run_deprecated_v1 4136 def testAllCorrect(self): 4137 predictions = array_ops.zeros([40]) 4138 labels = array_ops.zeros([40]) 4139 num_classes = 1 4140 with self.cached_session(): 4141 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4142 labels, predictions, num_classes) 4143 self.evaluate(variables.local_variables_initializer()) 4144 self.assertEqual(1.0, self.evaluate(update_op)[0]) 4145 self.assertEqual(1.0, self.evaluate(mean_accuracy)) 4146 4147 @test_util.run_deprecated_v1 4148 def testAllWrong(self): 4149 predictions = array_ops.zeros([40]) 4150 labels = array_ops.ones([40]) 4151 num_classes = 2 4152 with self.cached_session(): 4153 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4154 labels, predictions, num_classes) 4155 self.evaluate(variables.local_variables_initializer()) 4156 self.assertAllEqual([0.0, 0.0], update_op) 4157 self.assertEqual(0., self.evaluate(mean_accuracy)) 4158 4159 @test_util.run_deprecated_v1 4160 def testResultsWithSomeMissing(self): 4161 predictions = array_ops.concat([ 4162 constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5]) 4163 ], 0) 4164 labels = array_ops.concat([ 4165 constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7]) 4166 ], 0) 4167 num_classes = 2 4168 weights = array_ops.concat([ 4169 constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]), 4170 constant_op.constant(0, shape=[1]) 4171 ], 0) 4172 with self.cached_session(): 4173 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 4174 labels, predictions, num_classes, weights=weights) 4175 self.evaluate(variables.local_variables_initializer()) 4176 desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32) 4177 self.assertAllEqual(desired_accuracy, update_op) 4178 desired_mean_accuracy = np.mean(desired_accuracy) 4179 self.assertAlmostEqual(desired_mean_accuracy, 4180 self.evaluate(mean_accuracy)) 4181 4182 4183class FalseNegativesTest(test.TestCase): 4184 4185 def setUp(self): 4186 np.random.seed(1) 4187 ops.reset_default_graph() 4188 4189 @test_util.run_deprecated_v1 4190 def testVars(self): 4191 metrics.false_negatives( 4192 labels=(0, 1, 0, 1), 4193 predictions=(0, 0, 1, 1)) 4194 _assert_metric_variables(self, ('false_negatives/count:0',)) 4195 4196 @test_util.run_deprecated_v1 4197 def testUnweighted(self): 4198 labels = constant_op.constant(((0, 1, 0, 1, 0), 4199 (0, 0, 1, 1, 1), 4200 (1, 1, 1, 1, 0), 4201 (0, 0, 0, 0, 1))) 4202 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4203 (1, 1, 1, 1, 1), 4204 (0, 1, 0, 1, 0), 4205 (1, 1, 1, 1, 1))) 4206 tn, tn_update_op = metrics.false_negatives( 4207 labels=labels, predictions=predictions) 4208 4209 with self.cached_session(): 4210 self.evaluate(variables.local_variables_initializer()) 4211 self.assertAllClose(0., tn) 4212 self.assertAllClose(3., tn_update_op) 4213 self.assertAllClose(3., tn) 4214 4215 @test_util.run_deprecated_v1 4216 def testWeighted(self): 4217 labels = constant_op.constant(((0, 1, 0, 1, 0), 4218 (0, 0, 1, 1, 1), 4219 (1, 1, 1, 1, 0), 4220 (0, 0, 0, 0, 1))) 4221 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4222 (1, 1, 1, 1, 1), 4223 (0, 1, 0, 1, 0), 4224 (1, 1, 1, 1, 1))) 4225 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4226 tn, tn_update_op = metrics.false_negatives( 4227 labels=labels, predictions=predictions, weights=weights) 4228 4229 with self.cached_session(): 4230 self.evaluate(variables.local_variables_initializer()) 4231 self.assertAllClose(0., tn) 4232 self.assertAllClose(5., tn_update_op) 4233 self.assertAllClose(5., tn) 4234 4235 4236class FalseNegativesAtThresholdsTest(test.TestCase): 4237 4238 def setUp(self): 4239 np.random.seed(1) 4240 ops.reset_default_graph() 4241 4242 @test_util.run_deprecated_v1 4243 def testVars(self): 4244 metrics.false_negatives_at_thresholds( 4245 predictions=array_ops.ones((10, 1)), 4246 labels=array_ops.ones((10, 1)), 4247 thresholds=[0.15, 0.5, 0.85]) 4248 _assert_metric_variables(self, ('false_negatives/false_negatives:0',)) 4249 4250 @test_util.run_deprecated_v1 4251 def testUnweighted(self): 4252 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4253 (0.2, 0.9, 0.7, 0.6), 4254 (0.1, 0.2, 0.4, 0.3))) 4255 labels = constant_op.constant(((0, 1, 1, 0), 4256 (1, 0, 0, 0), 4257 (0, 0, 0, 0))) 4258 fn, fn_update_op = metrics.false_negatives_at_thresholds( 4259 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4260 4261 with self.cached_session(): 4262 self.evaluate(variables.local_variables_initializer()) 4263 self.assertAllEqual((0, 0, 0), fn) 4264 self.assertAllEqual((0, 2, 3), fn_update_op) 4265 self.assertAllEqual((0, 2, 3), fn) 4266 4267 @test_util.run_deprecated_v1 4268 def testWeighted(self): 4269 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4270 (0.2, 0.9, 0.7, 0.6), 4271 (0.1, 0.2, 0.4, 0.3))) 4272 labels = constant_op.constant(((0, 1, 1, 0), 4273 (1, 0, 0, 0), 4274 (0, 0, 0, 0))) 4275 fn, fn_update_op = metrics.false_negatives_at_thresholds( 4276 predictions=predictions, 4277 labels=labels, 4278 weights=((3.0,), (5.0,), (7.0,)), 4279 thresholds=[0.15, 0.5, 0.85]) 4280 4281 with self.cached_session(): 4282 self.evaluate(variables.local_variables_initializer()) 4283 self.assertAllEqual((0.0, 0.0, 0.0), fn) 4284 self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op) 4285 self.assertAllEqual((0.0, 8.0, 11.0), fn) 4286 4287 4288class FalsePositivesTest(test.TestCase): 4289 4290 def setUp(self): 4291 np.random.seed(1) 4292 ops.reset_default_graph() 4293 4294 @test_util.run_deprecated_v1 4295 def testVars(self): 4296 metrics.false_positives( 4297 labels=(0, 1, 0, 1), 4298 predictions=(0, 0, 1, 1)) 4299 _assert_metric_variables(self, ('false_positives/count:0',)) 4300 4301 @test_util.run_deprecated_v1 4302 def testUnweighted(self): 4303 labels = constant_op.constant(((0, 1, 0, 1, 0), 4304 (0, 0, 1, 1, 1), 4305 (1, 1, 1, 1, 0), 4306 (0, 0, 0, 0, 1))) 4307 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4308 (1, 1, 1, 1, 1), 4309 (0, 1, 0, 1, 0), 4310 (1, 1, 1, 1, 1))) 4311 tn, tn_update_op = metrics.false_positives( 4312 labels=labels, predictions=predictions) 4313 4314 with self.cached_session(): 4315 self.evaluate(variables.local_variables_initializer()) 4316 self.assertAllClose(0., tn) 4317 self.assertAllClose(7., tn_update_op) 4318 self.assertAllClose(7., tn) 4319 4320 @test_util.run_deprecated_v1 4321 def testWeighted(self): 4322 labels = constant_op.constant(((0, 1, 0, 1, 0), 4323 (0, 0, 1, 1, 1), 4324 (1, 1, 1, 1, 0), 4325 (0, 0, 0, 0, 1))) 4326 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4327 (1, 1, 1, 1, 1), 4328 (0, 1, 0, 1, 0), 4329 (1, 1, 1, 1, 1))) 4330 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4331 tn, tn_update_op = metrics.false_positives( 4332 labels=labels, predictions=predictions, weights=weights) 4333 4334 with self.cached_session(): 4335 self.evaluate(variables.local_variables_initializer()) 4336 self.assertAllClose(0., tn) 4337 self.assertAllClose(14., tn_update_op) 4338 self.assertAllClose(14., tn) 4339 4340 4341class FalsePositivesAtThresholdsTest(test.TestCase): 4342 4343 def setUp(self): 4344 np.random.seed(1) 4345 ops.reset_default_graph() 4346 4347 @test_util.run_deprecated_v1 4348 def testVars(self): 4349 metrics.false_positives_at_thresholds( 4350 predictions=array_ops.ones((10, 1)), 4351 labels=array_ops.ones((10, 1)), 4352 thresholds=[0.15, 0.5, 0.85]) 4353 _assert_metric_variables(self, ('false_positives/false_positives:0',)) 4354 4355 @test_util.run_deprecated_v1 4356 def testUnweighted(self): 4357 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4358 (0.2, 0.9, 0.7, 0.6), 4359 (0.1, 0.2, 0.4, 0.3))) 4360 labels = constant_op.constant(((0, 1, 1, 0), 4361 (1, 0, 0, 0), 4362 (0, 0, 0, 0))) 4363 fp, fp_update_op = metrics.false_positives_at_thresholds( 4364 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4365 4366 with self.cached_session(): 4367 self.evaluate(variables.local_variables_initializer()) 4368 self.assertAllEqual((0, 0, 0), fp) 4369 self.assertAllEqual((7, 4, 2), fp_update_op) 4370 self.assertAllEqual((7, 4, 2), fp) 4371 4372 @test_util.run_deprecated_v1 4373 def testWeighted(self): 4374 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4375 (0.2, 0.9, 0.7, 0.6), 4376 (0.1, 0.2, 0.4, 0.3))) 4377 labels = constant_op.constant(((0, 1, 1, 0), 4378 (1, 0, 0, 0), 4379 (0, 0, 0, 0))) 4380 fp, fp_update_op = metrics.false_positives_at_thresholds( 4381 predictions=predictions, 4382 labels=labels, 4383 weights=((1.0, 2.0, 3.0, 5.0), 4384 (7.0, 11.0, 13.0, 17.0), 4385 (19.0, 23.0, 29.0, 31.0)), 4386 thresholds=[0.15, 0.5, 0.85]) 4387 4388 with self.cached_session(): 4389 self.evaluate(variables.local_variables_initializer()) 4390 self.assertAllEqual((0.0, 0.0, 0.0), fp) 4391 self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op) 4392 self.assertAllEqual((125.0, 42.0, 12.0), fp) 4393 4394 4395class TrueNegativesTest(test.TestCase): 4396 4397 def setUp(self): 4398 np.random.seed(1) 4399 ops.reset_default_graph() 4400 4401 @test_util.run_deprecated_v1 4402 def testVars(self): 4403 metrics.true_negatives( 4404 labels=(0, 1, 0, 1), 4405 predictions=(0, 0, 1, 1)) 4406 _assert_metric_variables(self, ('true_negatives/count:0',)) 4407 4408 @test_util.run_deprecated_v1 4409 def testUnweighted(self): 4410 labels = constant_op.constant(((0, 1, 0, 1, 0), 4411 (0, 0, 1, 1, 1), 4412 (1, 1, 1, 1, 0), 4413 (0, 0, 0, 0, 1))) 4414 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4415 (1, 1, 1, 1, 1), 4416 (0, 1, 0, 1, 0), 4417 (1, 1, 1, 1, 1))) 4418 tn, tn_update_op = metrics.true_negatives( 4419 labels=labels, predictions=predictions) 4420 4421 with self.cached_session(): 4422 self.evaluate(variables.local_variables_initializer()) 4423 self.assertAllClose(0., tn) 4424 self.assertAllClose(3., tn_update_op) 4425 self.assertAllClose(3., tn) 4426 4427 @test_util.run_deprecated_v1 4428 def testWeighted(self): 4429 labels = constant_op.constant(((0, 1, 0, 1, 0), 4430 (0, 0, 1, 1, 1), 4431 (1, 1, 1, 1, 0), 4432 (0, 0, 0, 0, 1))) 4433 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4434 (1, 1, 1, 1, 1), 4435 (0, 1, 0, 1, 0), 4436 (1, 1, 1, 1, 1))) 4437 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4438 tn, tn_update_op = metrics.true_negatives( 4439 labels=labels, predictions=predictions, weights=weights) 4440 4441 with self.cached_session(): 4442 self.evaluate(variables.local_variables_initializer()) 4443 self.assertAllClose(0., tn) 4444 self.assertAllClose(4., tn_update_op) 4445 self.assertAllClose(4., tn) 4446 4447 4448class TrueNegativesAtThresholdsTest(test.TestCase): 4449 4450 def setUp(self): 4451 np.random.seed(1) 4452 ops.reset_default_graph() 4453 4454 @test_util.run_deprecated_v1 4455 def testVars(self): 4456 metrics.true_negatives_at_thresholds( 4457 predictions=array_ops.ones((10, 1)), 4458 labels=array_ops.ones((10, 1)), 4459 thresholds=[0.15, 0.5, 0.85]) 4460 _assert_metric_variables(self, ('true_negatives/true_negatives:0',)) 4461 4462 @test_util.run_deprecated_v1 4463 def testUnweighted(self): 4464 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4465 (0.2, 0.9, 0.7, 0.6), 4466 (0.1, 0.2, 0.4, 0.3))) 4467 labels = constant_op.constant(((0, 1, 1, 0), 4468 (1, 0, 0, 0), 4469 (0, 0, 0, 0))) 4470 tn, tn_update_op = metrics.true_negatives_at_thresholds( 4471 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4472 4473 with self.cached_session(): 4474 self.evaluate(variables.local_variables_initializer()) 4475 self.assertAllEqual((0, 0, 0), tn) 4476 self.assertAllEqual((2, 5, 7), tn_update_op) 4477 self.assertAllEqual((2, 5, 7), tn) 4478 4479 @test_util.run_deprecated_v1 4480 def testWeighted(self): 4481 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4482 (0.2, 0.9, 0.7, 0.6), 4483 (0.1, 0.2, 0.4, 0.3))) 4484 labels = constant_op.constant(((0, 1, 1, 0), 4485 (1, 0, 0, 0), 4486 (0, 0, 0, 0))) 4487 tn, tn_update_op = metrics.true_negatives_at_thresholds( 4488 predictions=predictions, 4489 labels=labels, 4490 weights=((0.0, 2.0, 3.0, 5.0),), 4491 thresholds=[0.15, 0.5, 0.85]) 4492 4493 with self.cached_session(): 4494 self.evaluate(variables.local_variables_initializer()) 4495 self.assertAllEqual((0.0, 0.0, 0.0), tn) 4496 self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op) 4497 self.assertAllEqual((5.0, 15.0, 23.0), tn) 4498 4499 4500class TruePositivesTest(test.TestCase): 4501 4502 def setUp(self): 4503 np.random.seed(1) 4504 ops.reset_default_graph() 4505 4506 @test_util.run_deprecated_v1 4507 def testVars(self): 4508 metrics.true_positives( 4509 labels=(0, 1, 0, 1), 4510 predictions=(0, 0, 1, 1)) 4511 _assert_metric_variables(self, ('true_positives/count:0',)) 4512 4513 @test_util.run_deprecated_v1 4514 def testUnweighted(self): 4515 labels = constant_op.constant(((0, 1, 0, 1, 0), 4516 (0, 0, 1, 1, 1), 4517 (1, 1, 1, 1, 0), 4518 (0, 0, 0, 0, 1))) 4519 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4520 (1, 1, 1, 1, 1), 4521 (0, 1, 0, 1, 0), 4522 (1, 1, 1, 1, 1))) 4523 tn, tn_update_op = metrics.true_positives( 4524 labels=labels, predictions=predictions) 4525 4526 with self.cached_session(): 4527 self.evaluate(variables.local_variables_initializer()) 4528 self.assertAllClose(0., tn) 4529 self.assertAllClose(7., tn_update_op) 4530 self.assertAllClose(7., tn) 4531 4532 @test_util.run_deprecated_v1 4533 def testWeighted(self): 4534 labels = constant_op.constant(((0, 1, 0, 1, 0), 4535 (0, 0, 1, 1, 1), 4536 (1, 1, 1, 1, 0), 4537 (0, 0, 0, 0, 1))) 4538 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4539 (1, 1, 1, 1, 1), 4540 (0, 1, 0, 1, 0), 4541 (1, 1, 1, 1, 1))) 4542 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4543 tn, tn_update_op = metrics.true_positives( 4544 labels=labels, predictions=predictions, weights=weights) 4545 4546 with self.cached_session(): 4547 self.evaluate(variables.local_variables_initializer()) 4548 self.assertAllClose(0., tn) 4549 self.assertAllClose(12., tn_update_op) 4550 self.assertAllClose(12., tn) 4551 4552 4553class TruePositivesAtThresholdsTest(test.TestCase): 4554 4555 def setUp(self): 4556 np.random.seed(1) 4557 ops.reset_default_graph() 4558 4559 @test_util.run_deprecated_v1 4560 def testVars(self): 4561 metrics.true_positives_at_thresholds( 4562 predictions=array_ops.ones((10, 1)), 4563 labels=array_ops.ones((10, 1)), 4564 thresholds=[0.15, 0.5, 0.85]) 4565 _assert_metric_variables(self, ('true_positives/true_positives:0',)) 4566 4567 @test_util.run_deprecated_v1 4568 def testUnweighted(self): 4569 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4570 (0.2, 0.9, 0.7, 0.6), 4571 (0.1, 0.2, 0.4, 0.3))) 4572 labels = constant_op.constant(((0, 1, 1, 0), 4573 (1, 0, 0, 0), 4574 (0, 0, 0, 0))) 4575 tp, tp_update_op = metrics.true_positives_at_thresholds( 4576 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4577 4578 with self.cached_session(): 4579 self.evaluate(variables.local_variables_initializer()) 4580 self.assertAllEqual((0, 0, 0), tp) 4581 self.assertAllEqual((3, 1, 0), tp_update_op) 4582 self.assertAllEqual((3, 1, 0), tp) 4583 4584 @test_util.run_deprecated_v1 4585 def testWeighted(self): 4586 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4587 (0.2, 0.9, 0.7, 0.6), 4588 (0.1, 0.2, 0.4, 0.3))) 4589 labels = constant_op.constant(((0, 1, 1, 0), 4590 (1, 0, 0, 0), 4591 (0, 0, 0, 0))) 4592 tp, tp_update_op = metrics.true_positives_at_thresholds( 4593 predictions=predictions, labels=labels, weights=37.0, 4594 thresholds=[0.15, 0.5, 0.85]) 4595 4596 with self.cached_session(): 4597 self.evaluate(variables.local_variables_initializer()) 4598 self.assertAllEqual((0.0, 0.0, 0.0), tp) 4599 self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op) 4600 self.assertAllEqual((111.0, 37.0, 0.0), tp) 4601 4602 4603if __name__ == '__main__': 4604 test.main() 4605