• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for metrics."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import math
23
24import numpy as np
25from six.moves import xrange  # pylint: disable=redefined-builtin
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes as dtypes_lib
29from tensorflow.python.framework import errors_impl
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import data_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import metrics
37from tensorflow.python.ops import random_ops
38from tensorflow.python.ops import variables
39import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
40import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
41from tensorflow.python.platform import test
42
43NAN = float('nan')
44
45
46def _enqueue_vector(sess, queue, values, shape=None):
47  if not shape:
48    shape = (1, len(values))
49  dtype = queue.dtypes[0]
50  sess.run(
51      queue.enqueue(constant_op.constant(
52          values, dtype=dtype, shape=shape)))
53
54
55def _binary_2d_label_to_2d_sparse_value(labels):
56  """Convert dense 2D binary indicator to sparse ID.
57
58  Only 1 values in `labels` are included in result.
59
60  Args:
61    labels: Dense 2D binary indicator, shape [batch_size, num_classes].
62
63  Returns:
64    `SparseTensorValue` of shape [batch_size, num_classes], where num_classes
65    is the number of `1` values in each row of `labels`. Values are indices
66    of `1` values along the last dimension of `labels`.
67  """
68  indices = []
69  values = []
70  batch = 0
71  for row in labels:
72    label = 0
73    xi = 0
74    for x in row:
75      if x == 1:
76        indices.append([batch, xi])
77        values.append(label)
78        xi += 1
79      else:
80        assert x == 0
81      label += 1
82    batch += 1
83  shape = [len(labels), len(labels[0])]
84  return sparse_tensor.SparseTensorValue(
85      np.array(indices, np.int64),
86      np.array(values, np.int64), np.array(shape, np.int64))
87
88
89def _binary_2d_label_to_1d_sparse_value(labels):
90  """Convert dense 2D binary indicator to sparse ID.
91
92  Only 1 values in `labels` are included in result.
93
94  Args:
95    labels: Dense 2D binary indicator, shape [batch_size, num_classes]. Each
96    row must contain exactly 1 `1` value.
97
98  Returns:
99    `SparseTensorValue` of shape [batch_size]. Values are indices of `1` values
100    along the last dimension of `labels`.
101
102  Raises:
103    ValueError: if there is not exactly 1 `1` value per row of `labels`.
104  """
105  indices = []
106  values = []
107  batch = 0
108  for row in labels:
109    label = 0
110    xi = 0
111    for x in row:
112      if x == 1:
113        indices.append([batch])
114        values.append(label)
115        xi += 1
116      else:
117        assert x == 0
118      label += 1
119    batch += 1
120  if indices != [[i] for i in range(len(labels))]:
121    raise ValueError('Expected 1 label/example, got %s.' % indices)
122  shape = [len(labels)]
123  return sparse_tensor.SparseTensorValue(
124      np.array(indices, np.int64),
125      np.array(values, np.int64), np.array(shape, np.int64))
126
127
128def _binary_3d_label_to_sparse_value(labels):
129  """Convert dense 3D binary indicator tensor to sparse tensor.
130
131  Only 1 values in `labels` are included in result.
132
133  Args:
134    labels: Dense 2D binary indicator tensor.
135
136  Returns:
137    `SparseTensorValue` whose values are indices along the last dimension of
138    `labels`.
139  """
140  indices = []
141  values = []
142  for d0, labels_d0 in enumerate(labels):
143    for d1, labels_d1 in enumerate(labels_d0):
144      d2 = 0
145      for class_id, label in enumerate(labels_d1):
146        if label == 1:
147          values.append(class_id)
148          indices.append([d0, d1, d2])
149          d2 += 1
150        else:
151          assert label == 0
152  shape = [len(labels), len(labels[0]), len(labels[0][0])]
153  return sparse_tensor.SparseTensorValue(
154      np.array(indices, np.int64),
155      np.array(values, np.int64), np.array(shape, np.int64))
156
157
158def _assert_nan(test_case, actual):
159  test_case.assertTrue(math.isnan(actual), 'Expected NAN, got %s.' % actual)
160
161
162def _assert_metric_variables(test_case, expected):
163  test_case.assertEqual(
164      set(expected), set(v.name for v in variables.local_variables()))
165  test_case.assertEqual(
166      set(expected),
167      set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
168
169
170def _test_values(shape):
171  return np.reshape(np.cumsum(np.ones(shape)), newshape=shape)
172
173
174class MeanTest(test.TestCase):
175
176  def setUp(self):
177    ops.reset_default_graph()
178
179  @test_util.run_deprecated_v1
180  def testVars(self):
181    metrics.mean(array_ops.ones([4, 3]))
182    _assert_metric_variables(self, ('mean/count:0', 'mean/total:0'))
183
184  @test_util.run_deprecated_v1
185  def testMetricsCollection(self):
186    my_collection_name = '__metrics__'
187    mean, _ = metrics.mean(
188        array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
189    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
190
191  @test_util.run_deprecated_v1
192  def testUpdatesCollection(self):
193    my_collection_name = '__updates__'
194    _, update_op = metrics.mean(
195        array_ops.ones([4, 3]), updates_collections=[my_collection_name])
196    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
197
198  @test_util.run_deprecated_v1
199  def testBasic(self):
200    with self.cached_session() as sess:
201      values_queue = data_flow_ops.FIFOQueue(
202          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
203      _enqueue_vector(sess, values_queue, [0, 1])
204      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
205      _enqueue_vector(sess, values_queue, [6.5, 0])
206      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
207      values = values_queue.dequeue()
208
209      mean, update_op = metrics.mean(values)
210
211      self.evaluate(variables.local_variables_initializer())
212      for _ in range(4):
213        self.evaluate(update_op)
214      self.assertAlmostEqual(1.65, self.evaluate(mean), 5)
215
216  @test_util.run_deprecated_v1
217  def testUpdateOpsReturnsCurrentValue(self):
218    with self.cached_session() as sess:
219      values_queue = data_flow_ops.FIFOQueue(
220          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
221      _enqueue_vector(sess, values_queue, [0, 1])
222      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
223      _enqueue_vector(sess, values_queue, [6.5, 0])
224      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
225      values = values_queue.dequeue()
226
227      mean, update_op = metrics.mean(values)
228
229      self.evaluate(variables.local_variables_initializer())
230
231      self.assertAlmostEqual(0.5, self.evaluate(update_op), 5)
232      self.assertAlmostEqual(1.475, self.evaluate(update_op), 5)
233      self.assertAlmostEqual(12.4 / 6.0, self.evaluate(update_op), 5)
234      self.assertAlmostEqual(1.65, self.evaluate(update_op), 5)
235
236      self.assertAlmostEqual(1.65, self.evaluate(mean), 5)
237
238  @test_util.run_deprecated_v1
239  def testUnweighted(self):
240    values = _test_values((3, 2, 4, 1))
241    mean_results = (
242        metrics.mean(values),
243        metrics.mean(values, weights=1.0),
244        metrics.mean(values, weights=np.ones((1, 1, 1))),
245        metrics.mean(values, weights=np.ones((1, 1, 1, 1))),
246        metrics.mean(values, weights=np.ones((1, 1, 1, 1, 1))),
247        metrics.mean(values, weights=np.ones((1, 1, 4))),
248        metrics.mean(values, weights=np.ones((1, 1, 4, 1))),
249        metrics.mean(values, weights=np.ones((1, 2, 1))),
250        metrics.mean(values, weights=np.ones((1, 2, 1, 1))),
251        metrics.mean(values, weights=np.ones((1, 2, 4))),
252        metrics.mean(values, weights=np.ones((1, 2, 4, 1))),
253        metrics.mean(values, weights=np.ones((3, 1, 1))),
254        metrics.mean(values, weights=np.ones((3, 1, 1, 1))),
255        metrics.mean(values, weights=np.ones((3, 1, 4))),
256        metrics.mean(values, weights=np.ones((3, 1, 4, 1))),
257        metrics.mean(values, weights=np.ones((3, 2, 1))),
258        metrics.mean(values, weights=np.ones((3, 2, 1, 1))),
259        metrics.mean(values, weights=np.ones((3, 2, 4))),
260        metrics.mean(values, weights=np.ones((3, 2, 4, 1))),
261        metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),)
262    expected = np.mean(values)
263    with self.cached_session():
264      variables.local_variables_initializer().run()
265      for mean_result in mean_results:
266        mean, update_op = mean_result
267        self.assertAlmostEqual(expected, self.evaluate(update_op))
268        self.assertAlmostEqual(expected, self.evaluate(mean))
269
270  def _test_3d_weighted(self, values, weights):
271    expected = (
272        np.sum(np.multiply(weights, values)) /
273        np.sum(np.multiply(weights, np.ones_like(values)))
274    )
275    mean, update_op = metrics.mean(values, weights=weights)
276    with self.cached_session():
277      variables.local_variables_initializer().run()
278      self.assertAlmostEqual(expected, self.evaluate(update_op), places=5)
279      self.assertAlmostEqual(expected, self.evaluate(mean), places=5)
280
281  @test_util.run_deprecated_v1
282  def test1x1x1Weighted(self):
283    self._test_3d_weighted(
284        _test_values((3, 2, 4)),
285        weights=np.asarray((5,)).reshape((1, 1, 1)))
286
287  @test_util.run_deprecated_v1
288  def test1x1xNWeighted(self):
289    self._test_3d_weighted(
290        _test_values((3, 2, 4)),
291        weights=np.asarray((5, 7, 11, 3)).reshape((1, 1, 4)))
292
293  @test_util.run_deprecated_v1
294  def test1xNx1Weighted(self):
295    self._test_3d_weighted(
296        _test_values((3, 2, 4)),
297        weights=np.asarray((5, 11)).reshape((1, 2, 1)))
298
299  @test_util.run_deprecated_v1
300  def test1xNxNWeighted(self):
301    self._test_3d_weighted(
302        _test_values((3, 2, 4)),
303        weights=np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4)))
304
305  @test_util.run_deprecated_v1
306  def testNx1x1Weighted(self):
307    self._test_3d_weighted(
308        _test_values((3, 2, 4)),
309        weights=np.asarray((5, 7, 11)).reshape((3, 1, 1)))
310
311  @test_util.run_deprecated_v1
312  def testNx1xNWeighted(self):
313    self._test_3d_weighted(
314        _test_values((3, 2, 4)),
315        weights=np.asarray((
316            5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4)))
317
318  @test_util.run_deprecated_v1
319  def testNxNxNWeighted(self):
320    self._test_3d_weighted(
321        _test_values((3, 2, 4)),
322        weights=np.asarray((
323            5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3,
324            2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4)))
325
326  @test_util.run_deprecated_v1
327  def testInvalidWeights(self):
328    values_placeholder = array_ops.placeholder(dtype=dtypes_lib.float32)
329    values = _test_values((3, 2, 4, 1))
330    invalid_weights = (
331        (1,),
332        (1, 1),
333        (3, 2),
334        (2, 4, 1),
335        (4, 2, 4, 1),
336        (3, 3, 4, 1),
337        (3, 2, 5, 1),
338        (3, 2, 4, 2),
339        (1, 1, 1, 1, 1))
340    expected_error_msg = 'weights can not be broadcast to values'
341    for invalid_weight in invalid_weights:
342      # Static shapes.
343      with self.assertRaisesRegex(ValueError, expected_error_msg):
344        metrics.mean(values, invalid_weight)
345
346      # Dynamic shapes.
347      with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg):
348        with self.cached_session():
349          _, update_op = metrics.mean(values_placeholder, invalid_weight)
350          variables.local_variables_initializer().run()
351          update_op.eval(feed_dict={values_placeholder: values})
352
353
354class MeanTensorTest(test.TestCase):
355
356  def setUp(self):
357    ops.reset_default_graph()
358
359  @test_util.run_deprecated_v1
360  def testVars(self):
361    metrics.mean_tensor(array_ops.ones([4, 3]))
362    _assert_metric_variables(self,
363                             ('mean/total_tensor:0', 'mean/count_tensor:0'))
364
365  @test_util.run_deprecated_v1
366  def testMetricsCollection(self):
367    my_collection_name = '__metrics__'
368    mean, _ = metrics.mean_tensor(
369        array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
370    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
371
372  @test_util.run_deprecated_v1
373  def testUpdatesCollection(self):
374    my_collection_name = '__updates__'
375    _, update_op = metrics.mean_tensor(
376        array_ops.ones([4, 3]), updates_collections=[my_collection_name])
377    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
378
379  @test_util.run_deprecated_v1
380  def testBasic(self):
381    with self.cached_session() as sess:
382      values_queue = data_flow_ops.FIFOQueue(
383          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
384      _enqueue_vector(sess, values_queue, [0, 1])
385      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
386      _enqueue_vector(sess, values_queue, [6.5, 0])
387      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
388      values = values_queue.dequeue()
389
390      mean, update_op = metrics.mean_tensor(values)
391
392      self.evaluate(variables.local_variables_initializer())
393      for _ in range(4):
394        self.evaluate(update_op)
395      self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(mean))
396
397  @test_util.run_deprecated_v1
398  def testMultiDimensional(self):
399    with self.cached_session() as sess:
400      values_queue = data_flow_ops.FIFOQueue(
401          2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
402      _enqueue_vector(
403          sess,
404          values_queue, [[[1, 2], [1, 2]], [[1, 2], [1, 2]]],
405          shape=(2, 2, 2))
406      _enqueue_vector(
407          sess,
408          values_queue, [[[1, 2], [1, 2]], [[3, 4], [9, 10]]],
409          shape=(2, 2, 2))
410      values = values_queue.dequeue()
411
412      mean, update_op = metrics.mean_tensor(values)
413
414      self.evaluate(variables.local_variables_initializer())
415      for _ in range(2):
416        self.evaluate(update_op)
417      self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]],
418                          self.evaluate(mean))
419
420  @test_util.run_deprecated_v1
421  def testUpdateOpsReturnsCurrentValue(self):
422    with self.cached_session() as sess:
423      values_queue = data_flow_ops.FIFOQueue(
424          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
425      _enqueue_vector(sess, values_queue, [0, 1])
426      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
427      _enqueue_vector(sess, values_queue, [6.5, 0])
428      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
429      values = values_queue.dequeue()
430
431      mean, update_op = metrics.mean_tensor(values)
432
433      self.evaluate(variables.local_variables_initializer())
434
435      self.assertAllClose([[0, 1]], self.evaluate(update_op), 5)
436      self.assertAllClose([[-2.1, 5.05]], self.evaluate(update_op), 5)
437      self.assertAllClose([[2.3 / 3., 10.1 / 3.]], self.evaluate(update_op), 5)
438      self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(update_op), 5)
439
440      self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(mean), 5)
441
442  @test_util.run_deprecated_v1
443  def testBinaryWeighted1d(self):
444    with self.cached_session() as sess:
445      # Create the queue that populates the values.
446      values_queue = data_flow_ops.FIFOQueue(
447          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
448      _enqueue_vector(sess, values_queue, [0, 1])
449      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
450      _enqueue_vector(sess, values_queue, [6.5, 0])
451      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
452      values = values_queue.dequeue()
453
454      # Create the queue that populates the weights.
455      weights_queue = data_flow_ops.FIFOQueue(
456          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
457      _enqueue_vector(sess, weights_queue, [[1]])
458      _enqueue_vector(sess, weights_queue, [[0]])
459      _enqueue_vector(sess, weights_queue, [[1]])
460      _enqueue_vector(sess, weights_queue, [[0]])
461      weights = weights_queue.dequeue()
462
463      mean, update_op = metrics.mean_tensor(values, weights)
464
465      self.evaluate(variables.local_variables_initializer())
466      for _ in range(4):
467        self.evaluate(update_op)
468      self.assertAllClose([[3.25, 0.5]], self.evaluate(mean), 5)
469
470  @test_util.run_deprecated_v1
471  def testWeighted1d(self):
472    with self.cached_session() as sess:
473      # Create the queue that populates the values.
474      values_queue = data_flow_ops.FIFOQueue(
475          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
476      _enqueue_vector(sess, values_queue, [0, 1])
477      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
478      _enqueue_vector(sess, values_queue, [6.5, 0])
479      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
480      values = values_queue.dequeue()
481
482      # Create the queue that populates the weights.
483      weights_queue = data_flow_ops.FIFOQueue(
484          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
485      _enqueue_vector(sess, weights_queue, [[0.0025]])
486      _enqueue_vector(sess, weights_queue, [[0.005]])
487      _enqueue_vector(sess, weights_queue, [[0.01]])
488      _enqueue_vector(sess, weights_queue, [[0.0075]])
489      weights = weights_queue.dequeue()
490
491      mean, update_op = metrics.mean_tensor(values, weights)
492
493      self.evaluate(variables.local_variables_initializer())
494      for _ in range(4):
495        self.evaluate(update_op)
496      self.assertAllClose([[0.8, 3.52]], self.evaluate(mean), 5)
497
498  @test_util.run_deprecated_v1
499  def testWeighted2d_1(self):
500    with self.cached_session() as sess:
501      # Create the queue that populates the values.
502      values_queue = data_flow_ops.FIFOQueue(
503          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
504      _enqueue_vector(sess, values_queue, [0, 1])
505      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
506      _enqueue_vector(sess, values_queue, [6.5, 0])
507      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
508      values = values_queue.dequeue()
509
510      # Create the queue that populates the weights.
511      weights_queue = data_flow_ops.FIFOQueue(
512          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
513      _enqueue_vector(sess, weights_queue, [1, 1])
514      _enqueue_vector(sess, weights_queue, [1, 0])
515      _enqueue_vector(sess, weights_queue, [0, 1])
516      _enqueue_vector(sess, weights_queue, [0, 0])
517      weights = weights_queue.dequeue()
518
519      mean, update_op = metrics.mean_tensor(values, weights)
520
521      self.evaluate(variables.local_variables_initializer())
522      for _ in range(4):
523        self.evaluate(update_op)
524      self.assertAllClose([[-2.1, 0.5]], self.evaluate(mean), 5)
525
526  @test_util.run_deprecated_v1
527  def testWeighted2d_2(self):
528    with self.cached_session() as sess:
529      # Create the queue that populates the values.
530      values_queue = data_flow_ops.FIFOQueue(
531          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
532      _enqueue_vector(sess, values_queue, [0, 1])
533      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
534      _enqueue_vector(sess, values_queue, [6.5, 0])
535      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
536      values = values_queue.dequeue()
537
538      # Create the queue that populates the weights.
539      weights_queue = data_flow_ops.FIFOQueue(
540          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
541      _enqueue_vector(sess, weights_queue, [0, 1])
542      _enqueue_vector(sess, weights_queue, [0, 0])
543      _enqueue_vector(sess, weights_queue, [0, 1])
544      _enqueue_vector(sess, weights_queue, [0, 0])
545      weights = weights_queue.dequeue()
546
547      mean, update_op = metrics.mean_tensor(values, weights)
548
549      self.evaluate(variables.local_variables_initializer())
550      for _ in range(4):
551        self.evaluate(update_op)
552      self.assertAllClose([[0, 0.5]], self.evaluate(mean), 5)
553
554
555class AccuracyTest(test.TestCase):
556
557  def setUp(self):
558    ops.reset_default_graph()
559
560  @test_util.run_deprecated_v1
561  def testVars(self):
562    metrics.accuracy(
563        predictions=array_ops.ones((10, 1)),
564        labels=array_ops.ones((10, 1)),
565        name='my_accuracy')
566    _assert_metric_variables(self,
567                             ('my_accuracy/count:0', 'my_accuracy/total:0'))
568
569  @test_util.run_deprecated_v1
570  def testMetricsCollection(self):
571    my_collection_name = '__metrics__'
572    mean, _ = metrics.accuracy(
573        predictions=array_ops.ones((10, 1)),
574        labels=array_ops.ones((10, 1)),
575        metrics_collections=[my_collection_name])
576    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
577
578  @test_util.run_deprecated_v1
579  def testUpdatesCollection(self):
580    my_collection_name = '__updates__'
581    _, update_op = metrics.accuracy(
582        predictions=array_ops.ones((10, 1)),
583        labels=array_ops.ones((10, 1)),
584        updates_collections=[my_collection_name])
585    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
586
587  @test_util.run_deprecated_v1
588  def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
589    predictions = array_ops.ones((10, 3))
590    labels = array_ops.ones((10, 4))
591    with self.assertRaises(ValueError):
592      metrics.accuracy(labels, predictions)
593
594  @test_util.run_deprecated_v1
595  def testPredictionsAndWeightsOfDifferentSizeRaisesValueError(self):
596    predictions = array_ops.ones((10, 3))
597    labels = array_ops.ones((10, 3))
598    weights = array_ops.ones((9, 3))
599    with self.assertRaises(ValueError):
600      metrics.accuracy(labels, predictions, weights)
601
602  @test_util.run_deprecated_v1
603  def testValueTensorIsIdempotent(self):
604    predictions = random_ops.random_uniform(
605        (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
606    labels = random_ops.random_uniform(
607        (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
608    accuracy, update_op = metrics.accuracy(labels, predictions)
609
610    with self.cached_session():
611      self.evaluate(variables.local_variables_initializer())
612
613      # Run several updates.
614      for _ in range(10):
615        self.evaluate(update_op)
616
617      # Then verify idempotency.
618      initial_accuracy = self.evaluate(accuracy)
619      for _ in range(10):
620        self.assertEqual(initial_accuracy, self.evaluate(accuracy))
621
622  @test_util.run_deprecated_v1
623  def testMultipleUpdates(self):
624    with self.cached_session() as sess:
625      # Create the queue that populates the predictions.
626      preds_queue = data_flow_ops.FIFOQueue(
627          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
628      _enqueue_vector(sess, preds_queue, [0])
629      _enqueue_vector(sess, preds_queue, [1])
630      _enqueue_vector(sess, preds_queue, [2])
631      _enqueue_vector(sess, preds_queue, [1])
632      predictions = preds_queue.dequeue()
633
634      # Create the queue that populates the labels.
635      labels_queue = data_flow_ops.FIFOQueue(
636          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
637      _enqueue_vector(sess, labels_queue, [0])
638      _enqueue_vector(sess, labels_queue, [1])
639      _enqueue_vector(sess, labels_queue, [1])
640      _enqueue_vector(sess, labels_queue, [2])
641      labels = labels_queue.dequeue()
642
643      accuracy, update_op = metrics.accuracy(labels, predictions)
644
645      self.evaluate(variables.local_variables_initializer())
646      for _ in xrange(3):
647        self.evaluate(update_op)
648      self.assertEqual(0.5, self.evaluate(update_op))
649      self.assertEqual(0.5, self.evaluate(accuracy))
650
651  @test_util.run_deprecated_v1
652  def testEffectivelyEquivalentSizes(self):
653    predictions = array_ops.ones((40, 1))
654    labels = array_ops.ones((40,))
655    with self.cached_session():
656      accuracy, update_op = metrics.accuracy(labels, predictions)
657
658      self.evaluate(variables.local_variables_initializer())
659      self.assertEqual(1.0, self.evaluate(update_op))
660      self.assertEqual(1.0, self.evaluate(accuracy))
661
662  @test_util.run_deprecated_v1
663  def testEffectivelyEquivalentSizesWithScalarWeight(self):
664    predictions = array_ops.ones((40, 1))
665    labels = array_ops.ones((40,))
666    with self.cached_session():
667      accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)
668
669      self.evaluate(variables.local_variables_initializer())
670      self.assertEqual(1.0, self.evaluate(update_op))
671      self.assertEqual(1.0, self.evaluate(accuracy))
672
673  @test_util.run_deprecated_v1
674  def testEffectivelyEquivalentSizesWithStaticShapedWeight(self):
675    predictions = ops.convert_to_tensor([1, 1, 1])  # shape 3,
676    labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]),
677                                   1)  # shape 3, 1
678    weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
679                                    1)  # shape 3, 1
680
681    with self.cached_session():
682      accuracy, update_op = metrics.accuracy(labels, predictions, weights)
683
684      self.evaluate(variables.local_variables_initializer())
685      # if streaming_accuracy does not flatten the weight, accuracy would be
686      # 0.33333334 due to an intended broadcast of weight. Due to flattening,
687      # it will be higher than .95
688      self.assertGreater(self.evaluate(update_op), .95)
689      self.assertGreater(self.evaluate(accuracy), .95)
690
691  @test_util.run_deprecated_v1
692  def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self):
693    predictions = ops.convert_to_tensor([1, 1, 1])  # shape 3,
694    labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]),
695                                   1)  # shape 3, 1
696
697    weights = [[100], [1], [1]]  # shape 3, 1
698    weights_placeholder = array_ops.placeholder(
699        dtype=dtypes_lib.int32, name='weights')
700    feed_dict = {weights_placeholder: weights}
701
702    with self.cached_session():
703      accuracy, update_op = metrics.accuracy(labels, predictions,
704                                             weights_placeholder)
705
706      self.evaluate(variables.local_variables_initializer())
707      # if streaming_accuracy does not flatten the weight, accuracy would be
708      # 0.33333334 due to an intended broadcast of weight. Due to flattening,
709      # it will be higher than .95
710      self.assertGreater(update_op.eval(feed_dict=feed_dict), .95)
711      self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
712
713  @test_util.run_deprecated_v1
714  def testMultipleUpdatesWithWeightedValues(self):
715    with self.cached_session() as sess:
716      # Create the queue that populates the predictions.
717      preds_queue = data_flow_ops.FIFOQueue(
718          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
719      _enqueue_vector(sess, preds_queue, [0])
720      _enqueue_vector(sess, preds_queue, [1])
721      _enqueue_vector(sess, preds_queue, [2])
722      _enqueue_vector(sess, preds_queue, [1])
723      predictions = preds_queue.dequeue()
724
725      # Create the queue that populates the labels.
726      labels_queue = data_flow_ops.FIFOQueue(
727          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
728      _enqueue_vector(sess, labels_queue, [0])
729      _enqueue_vector(sess, labels_queue, [1])
730      _enqueue_vector(sess, labels_queue, [1])
731      _enqueue_vector(sess, labels_queue, [2])
732      labels = labels_queue.dequeue()
733
734      # Create the queue that populates the weights.
735      weights_queue = data_flow_ops.FIFOQueue(
736          4, dtypes=dtypes_lib.int64, shapes=(1, 1))
737      _enqueue_vector(sess, weights_queue, [1])
738      _enqueue_vector(sess, weights_queue, [1])
739      _enqueue_vector(sess, weights_queue, [0])
740      _enqueue_vector(sess, weights_queue, [0])
741      weights = weights_queue.dequeue()
742
743      accuracy, update_op = metrics.accuracy(labels, predictions, weights)
744
745      self.evaluate(variables.local_variables_initializer())
746      for _ in xrange(3):
747        self.evaluate(update_op)
748      self.assertEqual(1.0, self.evaluate(update_op))
749      self.assertEqual(1.0, self.evaluate(accuracy))
750
751
752class PrecisionTest(test.TestCase):
753
754  def setUp(self):
755    np.random.seed(1)
756    ops.reset_default_graph()
757
758  @test_util.run_deprecated_v1
759  def testVars(self):
760    metrics.precision(
761        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
762    _assert_metric_variables(self, ('precision/false_positives/count:0',
763                                    'precision/true_positives/count:0'))
764
765  @test_util.run_deprecated_v1
766  def testMetricsCollection(self):
767    my_collection_name = '__metrics__'
768    mean, _ = metrics.precision(
769        predictions=array_ops.ones((10, 1)),
770        labels=array_ops.ones((10, 1)),
771        metrics_collections=[my_collection_name])
772    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
773
774  @test_util.run_deprecated_v1
775  def testUpdatesCollection(self):
776    my_collection_name = '__updates__'
777    _, update_op = metrics.precision(
778        predictions=array_ops.ones((10, 1)),
779        labels=array_ops.ones((10, 1)),
780        updates_collections=[my_collection_name])
781    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
782
783  @test_util.run_deprecated_v1
784  def testValueTensorIsIdempotent(self):
785    predictions = random_ops.random_uniform(
786        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
787    labels = random_ops.random_uniform(
788        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
789    precision, update_op = metrics.precision(labels, predictions)
790
791    with self.cached_session():
792      self.evaluate(variables.local_variables_initializer())
793
794      # Run several updates.
795      for _ in range(10):
796        self.evaluate(update_op)
797
798      # Then verify idempotency.
799      initial_precision = self.evaluate(precision)
800      for _ in range(10):
801        self.assertEqual(initial_precision, self.evaluate(precision))
802
803  @test_util.run_deprecated_v1
804  def testAllCorrect(self):
805    inputs = np.random.randint(0, 2, size=(100, 1))
806
807    predictions = constant_op.constant(inputs)
808    labels = constant_op.constant(inputs)
809    precision, update_op = metrics.precision(labels, predictions)
810
811    with self.cached_session():
812      self.evaluate(variables.local_variables_initializer())
813      self.assertAlmostEqual(1.0, self.evaluate(update_op), 6)
814      self.assertAlmostEqual(1.0, self.evaluate(precision), 6)
815
816  @test_util.run_deprecated_v1
817  def testSomeCorrect_multipleInputDtypes(self):
818    for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
819      predictions = math_ops.cast(
820          constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype)
821      labels = math_ops.cast(
822          constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
823      precision, update_op = metrics.precision(labels, predictions)
824
825      with self.cached_session():
826        self.evaluate(variables.local_variables_initializer())
827        self.assertAlmostEqual(0.5, self.evaluate(update_op))
828        self.assertAlmostEqual(0.5, self.evaluate(precision))
829
830  @test_util.run_deprecated_v1
831  def testWeighted1d(self):
832    predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
833    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
834    precision, update_op = metrics.precision(
835        labels, predictions, weights=constant_op.constant([[2], [5]]))
836
837    with self.cached_session():
838      variables.local_variables_initializer().run()
839      weighted_tp = 2.0 + 5.0
840      weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
841      expected_precision = weighted_tp / weighted_positives
842      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
843      self.assertAlmostEqual(expected_precision, self.evaluate(precision))
844
845  @test_util.run_deprecated_v1
846  def testWeightedScalar_placeholders(self):
847    predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
848    labels = array_ops.placeholder(dtype=dtypes_lib.float32)
849    feed_dict = {
850        predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
851        labels: ((0, 1, 1, 0), (1, 0, 0, 1))
852    }
853    precision, update_op = metrics.precision(labels, predictions, weights=2)
854
855    with self.cached_session():
856      variables.local_variables_initializer().run()
857      weighted_tp = 2.0 + 2.0
858      weighted_positives = (2.0 + 2.0) + (2.0 + 2.0)
859      expected_precision = weighted_tp / weighted_positives
860      self.assertAlmostEqual(
861          expected_precision, update_op.eval(feed_dict=feed_dict))
862      self.assertAlmostEqual(
863          expected_precision, precision.eval(feed_dict=feed_dict))
864
865  @test_util.run_deprecated_v1
866  def testWeighted1d_placeholders(self):
867    predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
868    labels = array_ops.placeholder(dtype=dtypes_lib.float32)
869    feed_dict = {
870        predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
871        labels: ((0, 1, 1, 0), (1, 0, 0, 1))
872    }
873    precision, update_op = metrics.precision(
874        labels, predictions, weights=constant_op.constant([[2], [5]]))
875
876    with self.cached_session():
877      variables.local_variables_initializer().run()
878      weighted_tp = 2.0 + 5.0
879      weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
880      expected_precision = weighted_tp / weighted_positives
881      self.assertAlmostEqual(
882          expected_precision, update_op.eval(feed_dict=feed_dict))
883      self.assertAlmostEqual(
884          expected_precision, precision.eval(feed_dict=feed_dict))
885
886  @test_util.run_deprecated_v1
887  def testWeighted2d(self):
888    predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
889    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
890    precision, update_op = metrics.precision(
891        labels,
892        predictions,
893        weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
894
895    with self.cached_session():
896      variables.local_variables_initializer().run()
897      weighted_tp = 3.0 + 4.0
898      weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
899      expected_precision = weighted_tp / weighted_positives
900      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
901      self.assertAlmostEqual(expected_precision, self.evaluate(precision))
902
903  @test_util.run_deprecated_v1
904  def testWeighted2d_placeholders(self):
905    predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
906    labels = array_ops.placeholder(dtype=dtypes_lib.float32)
907    feed_dict = {
908        predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
909        labels: ((0, 1, 1, 0), (1, 0, 0, 1))
910    }
911    precision, update_op = metrics.precision(
912        labels,
913        predictions,
914        weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
915
916    with self.cached_session():
917      variables.local_variables_initializer().run()
918      weighted_tp = 3.0 + 4.0
919      weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
920      expected_precision = weighted_tp / weighted_positives
921      self.assertAlmostEqual(
922          expected_precision, update_op.eval(feed_dict=feed_dict))
923      self.assertAlmostEqual(
924          expected_precision, precision.eval(feed_dict=feed_dict))
925
926  @test_util.run_deprecated_v1
927  def testAllIncorrect(self):
928    inputs = np.random.randint(0, 2, size=(100, 1))
929
930    predictions = constant_op.constant(inputs)
931    labels = constant_op.constant(1 - inputs)
932    precision, update_op = metrics.precision(labels, predictions)
933
934    with self.cached_session():
935      self.evaluate(variables.local_variables_initializer())
936      self.evaluate(update_op)
937      self.assertAlmostEqual(0, self.evaluate(precision))
938
939  @test_util.run_deprecated_v1
940  def testZeroTrueAndFalsePositivesGivesZeroPrecision(self):
941    predictions = constant_op.constant([0, 0, 0, 0])
942    labels = constant_op.constant([0, 0, 0, 0])
943    precision, update_op = metrics.precision(labels, predictions)
944
945    with self.cached_session():
946      self.evaluate(variables.local_variables_initializer())
947      self.evaluate(update_op)
948      self.assertEqual(0.0, self.evaluate(precision))
949
950
951class RecallTest(test.TestCase):
952
953  def setUp(self):
954    np.random.seed(1)
955    ops.reset_default_graph()
956
957  @test_util.run_deprecated_v1
958  def testVars(self):
959    metrics.recall(
960        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
961    _assert_metric_variables(
962        self,
963        ('recall/false_negatives/count:0', 'recall/true_positives/count:0'))
964
965  @test_util.run_deprecated_v1
966  def testMetricsCollection(self):
967    my_collection_name = '__metrics__'
968    mean, _ = metrics.recall(
969        predictions=array_ops.ones((10, 1)),
970        labels=array_ops.ones((10, 1)),
971        metrics_collections=[my_collection_name])
972    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
973
974  @test_util.run_deprecated_v1
975  def testUpdatesCollection(self):
976    my_collection_name = '__updates__'
977    _, update_op = metrics.recall(
978        predictions=array_ops.ones((10, 1)),
979        labels=array_ops.ones((10, 1)),
980        updates_collections=[my_collection_name])
981    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
982
983  @test_util.run_deprecated_v1
984  def testValueTensorIsIdempotent(self):
985    predictions = random_ops.random_uniform(
986        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
987    labels = random_ops.random_uniform(
988        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
989    recall, update_op = metrics.recall(labels, predictions)
990
991    with self.cached_session():
992      self.evaluate(variables.local_variables_initializer())
993
994      # Run several updates.
995      for _ in range(10):
996        self.evaluate(update_op)
997
998      # Then verify idempotency.
999      initial_recall = self.evaluate(recall)
1000      for _ in range(10):
1001        self.assertEqual(initial_recall, self.evaluate(recall))
1002
1003  @test_util.run_deprecated_v1
1004  def testAllCorrect(self):
1005    np_inputs = np.random.randint(0, 2, size=(100, 1))
1006
1007    predictions = constant_op.constant(np_inputs)
1008    labels = constant_op.constant(np_inputs)
1009    recall, update_op = metrics.recall(labels, predictions)
1010
1011    with self.cached_session():
1012      self.evaluate(variables.local_variables_initializer())
1013      self.evaluate(update_op)
1014      self.assertAlmostEqual(1.0, self.evaluate(recall), 6)
1015
1016  @test_util.run_deprecated_v1
1017  def testSomeCorrect_multipleInputDtypes(self):
1018    for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
1019      predictions = math_ops.cast(
1020          constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype)
1021      labels = math_ops.cast(
1022          constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
1023      recall, update_op = metrics.recall(labels, predictions)
1024
1025      with self.cached_session():
1026        self.evaluate(variables.local_variables_initializer())
1027        self.assertAlmostEqual(0.5, self.evaluate(update_op))
1028        self.assertAlmostEqual(0.5, self.evaluate(recall))
1029
1030  @test_util.run_deprecated_v1
1031  def testWeighted1d(self):
1032    predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
1033    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
1034    weights = constant_op.constant([[2], [5]])
1035    recall, update_op = metrics.recall(labels, predictions, weights=weights)
1036
1037    with self.cached_session():
1038      self.evaluate(variables.local_variables_initializer())
1039      weighted_tp = 2.0 + 5.0
1040      weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
1041      expected_precision = weighted_tp / weighted_t
1042      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
1043      self.assertAlmostEqual(expected_precision, self.evaluate(recall))
1044
1045  @test_util.run_deprecated_v1
1046  def testWeighted2d(self):
1047    predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
1048    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
1049    weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
1050    recall, update_op = metrics.recall(labels, predictions, weights=weights)
1051
1052    with self.cached_session():
1053      self.evaluate(variables.local_variables_initializer())
1054      weighted_tp = 3.0 + 1.0
1055      weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
1056      expected_precision = weighted_tp / weighted_t
1057      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
1058      self.assertAlmostEqual(expected_precision, self.evaluate(recall))
1059
1060  @test_util.run_deprecated_v1
1061  def testAllIncorrect(self):
1062    np_inputs = np.random.randint(0, 2, size=(100, 1))
1063
1064    predictions = constant_op.constant(np_inputs)
1065    labels = constant_op.constant(1 - np_inputs)
1066    recall, update_op = metrics.recall(labels, predictions)
1067
1068    with self.cached_session():
1069      self.evaluate(variables.local_variables_initializer())
1070      self.evaluate(update_op)
1071      self.assertEqual(0, self.evaluate(recall))
1072
1073  @test_util.run_deprecated_v1
1074  def testZeroTruePositivesAndFalseNegativesGivesZeroRecall(self):
1075    predictions = array_ops.zeros((1, 4))
1076    labels = array_ops.zeros((1, 4))
1077    recall, update_op = metrics.recall(labels, predictions)
1078
1079    with self.cached_session():
1080      self.evaluate(variables.local_variables_initializer())
1081      self.evaluate(update_op)
1082      self.assertEqual(0, self.evaluate(recall))
1083
1084
1085class AUCTest(test.TestCase):
1086
1087  def setUp(self):
1088    np.random.seed(1)
1089    ops.reset_default_graph()
1090
1091  @test_util.run_deprecated_v1
1092  def testVars(self):
1093    metrics.auc(predictions=array_ops.ones((10, 1)),
1094                labels=array_ops.ones((10, 1)))
1095    _assert_metric_variables(self,
1096                             ('auc/true_positives:0', 'auc/false_negatives:0',
1097                              'auc/false_positives:0', 'auc/true_negatives:0'))
1098
1099  @test_util.run_deprecated_v1
1100  def testMetricsCollection(self):
1101    my_collection_name = '__metrics__'
1102    mean, _ = metrics.auc(predictions=array_ops.ones((10, 1)),
1103                          labels=array_ops.ones((10, 1)),
1104                          metrics_collections=[my_collection_name])
1105    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
1106
1107  @test_util.run_deprecated_v1
1108  def testUpdatesCollection(self):
1109    my_collection_name = '__updates__'
1110    _, update_op = metrics.auc(predictions=array_ops.ones((10, 1)),
1111                               labels=array_ops.ones((10, 1)),
1112                               updates_collections=[my_collection_name])
1113    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
1114
1115  @test_util.run_deprecated_v1
1116  def testValueTensorIsIdempotent(self):
1117    predictions = random_ops.random_uniform(
1118        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
1119    labels = random_ops.random_uniform(
1120        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
1121    auc, update_op = metrics.auc(labels, predictions)
1122
1123    with self.cached_session():
1124      self.evaluate(variables.local_variables_initializer())
1125
1126      # Run several updates.
1127      for _ in range(10):
1128        self.evaluate(update_op)
1129
1130      # Then verify idempotency.
1131      initial_auc = self.evaluate(auc)
1132      for _ in range(10):
1133        self.assertAlmostEqual(initial_auc, self.evaluate(auc), 5)
1134
1135  @test_util.run_deprecated_v1
1136  def testAllCorrect(self):
1137    self.allCorrectAsExpected('ROC')
1138
1139  def allCorrectAsExpected(self, curve):
1140    inputs = np.random.randint(0, 2, size=(100, 1))
1141
1142    with self.cached_session():
1143      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
1144      labels = constant_op.constant(inputs)
1145      auc, update_op = metrics.auc(labels, predictions, curve=curve)
1146
1147      self.evaluate(variables.local_variables_initializer())
1148      self.assertEqual(1, self.evaluate(update_op))
1149
1150      self.assertEqual(1, self.evaluate(auc))
1151
1152  @test_util.run_deprecated_v1
1153  def testSomeCorrect_multipleLabelDtypes(self):
1154    with self.cached_session():
1155      for label_dtype in (
1156          dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
1157        predictions = constant_op.constant(
1158            [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
1159        labels = math_ops.cast(
1160            constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype)
1161        auc, update_op = metrics.auc(labels, predictions)
1162
1163        self.evaluate(variables.local_variables_initializer())
1164        self.assertAlmostEqual(0.5, self.evaluate(update_op))
1165
1166        self.assertAlmostEqual(0.5, self.evaluate(auc))
1167
1168  @test_util.run_deprecated_v1
1169  def testWeighted1d(self):
1170    with self.cached_session():
1171      predictions = constant_op.constant(
1172          [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
1173      labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
1174      weights = constant_op.constant([2], shape=(1, 1))
1175      auc, update_op = metrics.auc(labels, predictions, weights=weights)
1176
1177      self.evaluate(variables.local_variables_initializer())
1178      self.assertAlmostEqual(0.5, self.evaluate(update_op), 5)
1179
1180      self.assertAlmostEqual(0.5, self.evaluate(auc), 5)
1181
1182  @test_util.run_deprecated_v1
1183  def testWeighted2d(self):
1184    with self.cached_session():
1185      predictions = constant_op.constant(
1186          [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
1187      labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
1188      weights = constant_op.constant([1, 2, 3, 4], shape=(1, 4))
1189      auc, update_op = metrics.auc(labels, predictions, weights=weights)
1190
1191      self.evaluate(variables.local_variables_initializer())
1192      self.assertAlmostEqual(0.7, self.evaluate(update_op), 5)
1193
1194      self.assertAlmostEqual(0.7, self.evaluate(auc), 5)
1195
1196  @test_util.run_deprecated_v1
1197  def testManualThresholds(self):
1198    with self.cached_session():
1199      # Verifies that thresholds passed in to the `thresholds` parameter are
1200      # used correctly.
1201      # The default thresholds do not split the second and third predictions.
1202      # Thus, when we provide manual thresholds which correctly split it, we get
1203      # an accurate AUC value.
1204      predictions = constant_op.constant(
1205          [0.12, 0.3001, 0.3003, 0.72], shape=(1, 4), dtype=dtypes_lib.float32)
1206      labels = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
1207      weights = constant_op.constant([1, 1, 1, 1], shape=(1, 4))
1208      thresholds = [0.0, 0.2, 0.3002, 0.6, 1.0]
1209      default_auc, default_update_op = metrics.auc(labels,
1210                                                   predictions,
1211                                                   weights=weights)
1212      manual_auc, manual_update_op = metrics.auc(labels,
1213                                                 predictions,
1214                                                 weights=weights,
1215                                                 thresholds=thresholds)
1216
1217      self.evaluate(variables.local_variables_initializer())
1218      self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3)
1219      self.assertAlmostEqual(0.875, self.evaluate(default_auc), 3)
1220
1221      self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3)
1222      self.assertAlmostEqual(0.75, self.evaluate(manual_auc), 3)
1223
1224  # Regarding the AUC-PR tests: note that the preferred method when
1225  # calculating AUC-PR is summation_method='careful_interpolation'.
1226  @test_util.run_deprecated_v1
1227  def testCorrectAUCPRSpecialCase(self):
1228    with self.cached_session():
1229      predictions = constant_op.constant(
1230          [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
1231      labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
1232      auc, update_op = metrics.auc(labels, predictions, curve='PR',
1233                                   summation_method='careful_interpolation')
1234
1235      self.evaluate(variables.local_variables_initializer())
1236      # expected ~= 0.79726744594
1237      expected = 1 - math.log(1.5) / 2
1238      self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
1239      self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)
1240
1241  @test_util.run_deprecated_v1
1242  def testCorrectAnotherAUCPRSpecialCase(self):
1243    with self.cached_session():
1244      predictions = constant_op.constant(
1245          [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
1246          shape=(1, 7),
1247          dtype=dtypes_lib.float32)
1248      labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
1249      auc, update_op = metrics.auc(labels, predictions, curve='PR',
1250                                   summation_method='careful_interpolation')
1251
1252      self.evaluate(variables.local_variables_initializer())
1253      # expected ~= 0.61350593198
1254      expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3
1255      self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
1256      self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)
1257
1258  @test_util.run_deprecated_v1
1259  def testThirdCorrectAUCPRSpecialCase(self):
1260    with self.cached_session():
1261      predictions = constant_op.constant(
1262          [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
1263          shape=(1, 7),
1264          dtype=dtypes_lib.float32)
1265      labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
1266      auc, update_op = metrics.auc(labels, predictions, curve='PR',
1267                                   summation_method='careful_interpolation')
1268
1269      self.evaluate(variables.local_variables_initializer())
1270      # expected ~= 0.90410597584
1271      expected = 1 - math.log(4./3) / 3
1272      self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
1273      self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)
1274
1275  @test_util.run_deprecated_v1
1276  def testIncorrectAUCPRSpecialCase(self):
1277    with self.cached_session():
1278      predictions = constant_op.constant(
1279          [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
1280      labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
1281      auc, update_op = metrics.auc(labels, predictions, curve='PR',
1282                                   summation_method='trapezoidal')
1283
1284      self.evaluate(variables.local_variables_initializer())
1285      self.assertAlmostEqual(0.79166, self.evaluate(update_op), delta=1e-3)
1286
1287      self.assertAlmostEqual(0.79166, self.evaluate(auc), delta=1e-3)
1288
1289  @test_util.run_deprecated_v1
1290  def testAnotherIncorrectAUCPRSpecialCase(self):
1291    with self.cached_session():
1292      predictions = constant_op.constant(
1293          [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
1294          shape=(1, 7),
1295          dtype=dtypes_lib.float32)
1296      labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
1297      auc, update_op = metrics.auc(labels, predictions, curve='PR',
1298                                   summation_method='trapezoidal')
1299
1300      self.evaluate(variables.local_variables_initializer())
1301      self.assertAlmostEqual(0.610317, self.evaluate(update_op), delta=1e-3)
1302
1303      self.assertAlmostEqual(0.610317, self.evaluate(auc), delta=1e-3)
1304
1305  @test_util.run_deprecated_v1
1306  def testThirdIncorrectAUCPRSpecialCase(self):
1307    with self.cached_session():
1308      predictions = constant_op.constant(
1309          [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
1310          shape=(1, 7),
1311          dtype=dtypes_lib.float32)
1312      labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
1313      auc, update_op = metrics.auc(labels, predictions, curve='PR',
1314                                   summation_method='trapezoidal')
1315
1316      self.evaluate(variables.local_variables_initializer())
1317      self.assertAlmostEqual(0.90277, self.evaluate(update_op), delta=1e-3)
1318
1319      self.assertAlmostEqual(0.90277, self.evaluate(auc), delta=1e-3)
1320
1321  @test_util.run_deprecated_v1
1322  def testAllIncorrect(self):
1323    inputs = np.random.randint(0, 2, size=(100, 1))
1324
1325    with self.cached_session():
1326      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
1327      labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
1328      auc, update_op = metrics.auc(labels, predictions)
1329
1330      self.evaluate(variables.local_variables_initializer())
1331      self.assertAlmostEqual(0, self.evaluate(update_op))
1332
1333      self.assertAlmostEqual(0, self.evaluate(auc))
1334
1335  @test_util.run_deprecated_v1
1336  def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
1337    with self.cached_session():
1338      predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
1339      labels = array_ops.zeros([4])
1340      auc, update_op = metrics.auc(labels, predictions)
1341
1342      self.evaluate(variables.local_variables_initializer())
1343      self.assertAlmostEqual(1, self.evaluate(update_op), 6)
1344
1345      self.assertAlmostEqual(1, self.evaluate(auc), 6)
1346
1347  @test_util.run_deprecated_v1
1348  def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
1349    with self.cached_session():
1350      predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
1351      labels = array_ops.ones([4])
1352      auc, update_op = metrics.auc(labels, predictions, curve='PR')
1353
1354      self.evaluate(variables.local_variables_initializer())
1355      self.assertAlmostEqual(1, self.evaluate(update_op), 6)
1356
1357      self.assertAlmostEqual(1, self.evaluate(auc), 6)
1358
1359  def np_auc(self, predictions, labels, weights):
1360    """Computes the AUC explicitly using Numpy.
1361
1362    Args:
1363      predictions: an ndarray with shape [N].
1364      labels: an ndarray with shape [N].
1365      weights: an ndarray with shape [N].
1366
1367    Returns:
1368      the area under the ROC curve.
1369    """
1370    if weights is None:
1371      weights = np.ones(np.size(predictions))
1372    is_positive = labels > 0
1373    num_positives = np.sum(weights[is_positive])
1374    num_negatives = np.sum(weights[~is_positive])
1375
1376    # Sort descending:
1377    inds = np.argsort(-predictions)
1378
1379    sorted_labels = labels[inds]
1380    sorted_weights = weights[inds]
1381    is_positive = sorted_labels > 0
1382
1383    tp = np.cumsum(sorted_weights * is_positive) / num_positives
1384    return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
1385
1386  @test_util.run_deprecated_v1
1387  def testWithMultipleUpdates(self):
1388    num_samples = 1000
1389    batch_size = 10
1390    num_batches = int(num_samples / batch_size)
1391
1392    # Create the labels and data.
1393    labels = np.random.randint(0, 2, size=num_samples)
1394    noise = np.random.normal(0.0, scale=0.2, size=num_samples)
1395    predictions = 0.4 + 0.2 * labels + noise
1396    predictions[predictions > 1] = 1
1397    predictions[predictions < 0] = 0
1398
1399    def _enqueue_as_batches(x, enqueue_ops):
1400      x_batches = x.astype(np.float32).reshape((num_batches, batch_size))
1401      x_queue = data_flow_ops.FIFOQueue(
1402          num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
1403      for i in range(num_batches):
1404        enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :]))
1405      return x_queue.dequeue()
1406
1407    for weights in (None, np.ones(num_samples), np.random.exponential(
1408        scale=1.0, size=num_samples)):
1409      expected_auc = self.np_auc(predictions, labels, weights)
1410
1411      with self.cached_session() as sess:
1412        enqueue_ops = [[] for i in range(num_batches)]
1413        tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
1414        tf_labels = _enqueue_as_batches(labels, enqueue_ops)
1415        tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if
1416                      weights is not None else None)
1417
1418        for i in range(num_batches):
1419          sess.run(enqueue_ops[i])
1420
1421        auc, update_op = metrics.auc(tf_labels,
1422                                     tf_predictions,
1423                                     curve='ROC',
1424                                     num_thresholds=500,
1425                                     weights=tf_weights)
1426
1427        self.evaluate(variables.local_variables_initializer())
1428        for i in range(num_batches):
1429          self.evaluate(update_op)
1430
1431        # Since this is only approximate, we can't expect a 6 digits match.
1432        # Although with higher number of samples/thresholds we should see the
1433        # accuracy improving
1434        self.assertAlmostEqual(expected_auc, self.evaluate(auc), 2)
1435
1436
1437class SpecificityAtSensitivityTest(test.TestCase):
1438
1439  def setUp(self):
1440    np.random.seed(1)
1441    ops.reset_default_graph()
1442
1443  @test_util.run_deprecated_v1
1444  def testVars(self):
1445    metrics.specificity_at_sensitivity(
1446        predictions=array_ops.ones((10, 1)),
1447        labels=array_ops.ones((10, 1)),
1448        sensitivity=0.7)
1449    _assert_metric_variables(self,
1450                             ('specificity_at_sensitivity/true_positives:0',
1451                              'specificity_at_sensitivity/false_negatives:0',
1452                              'specificity_at_sensitivity/false_positives:0',
1453                              'specificity_at_sensitivity/true_negatives:0'))
1454
1455  @test_util.run_deprecated_v1
1456  def testMetricsCollection(self):
1457    my_collection_name = '__metrics__'
1458    mean, _ = metrics.specificity_at_sensitivity(
1459        predictions=array_ops.ones((10, 1)),
1460        labels=array_ops.ones((10, 1)),
1461        sensitivity=0.7,
1462        metrics_collections=[my_collection_name])
1463    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
1464
1465  @test_util.run_deprecated_v1
1466  def testUpdatesCollection(self):
1467    my_collection_name = '__updates__'
1468    _, update_op = metrics.specificity_at_sensitivity(
1469        predictions=array_ops.ones((10, 1)),
1470        labels=array_ops.ones((10, 1)),
1471        sensitivity=0.7,
1472        updates_collections=[my_collection_name])
1473    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
1474
1475  @test_util.run_deprecated_v1
1476  def testValueTensorIsIdempotent(self):
1477    predictions = random_ops.random_uniform(
1478        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
1479    labels = random_ops.random_uniform(
1480        (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
1481    specificity, update_op = metrics.specificity_at_sensitivity(
1482        labels, predictions, sensitivity=0.7)
1483
1484    with self.cached_session():
1485      self.evaluate(variables.local_variables_initializer())
1486
1487      # Run several updates.
1488      for _ in range(10):
1489        self.evaluate(update_op)
1490
1491      # Then verify idempotency.
1492      initial_specificity = self.evaluate(specificity)
1493      for _ in range(10):
1494        self.assertAlmostEqual(initial_specificity, self.evaluate(specificity),
1495                               5)
1496
1497  @test_util.run_deprecated_v1
1498  def testAllCorrect(self):
1499    inputs = np.random.randint(0, 2, size=(100, 1))
1500
1501    predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
1502    labels = constant_op.constant(inputs)
1503    specificity, update_op = metrics.specificity_at_sensitivity(
1504        labels, predictions, sensitivity=0.7)
1505
1506    with self.cached_session():
1507      self.evaluate(variables.local_variables_initializer())
1508      self.assertEqual(1, self.evaluate(update_op))
1509      self.assertEqual(1, self.evaluate(specificity))
1510
1511  @test_util.run_deprecated_v1
1512  def testSomeCorrectHighSensitivity(self):
1513    predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9]
1514    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1515
1516    predictions = constant_op.constant(
1517        predictions_values, dtype=dtypes_lib.float32)
1518    labels = constant_op.constant(labels_values)
1519    specificity, update_op = metrics.specificity_at_sensitivity(
1520        labels, predictions, sensitivity=0.8)
1521
1522    with self.cached_session():
1523      self.evaluate(variables.local_variables_initializer())
1524      self.assertAlmostEqual(1.0, self.evaluate(update_op))
1525      self.assertAlmostEqual(1.0, self.evaluate(specificity))
1526
1527  @test_util.run_deprecated_v1
1528  def testSomeCorrectLowSensitivity(self):
1529    predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
1530    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1531
1532    predictions = constant_op.constant(
1533        predictions_values, dtype=dtypes_lib.float32)
1534    labels = constant_op.constant(labels_values)
1535    specificity, update_op = metrics.specificity_at_sensitivity(
1536        labels, predictions, sensitivity=0.4)
1537
1538    with self.cached_session():
1539      self.evaluate(variables.local_variables_initializer())
1540
1541      self.assertAlmostEqual(0.6, self.evaluate(update_op))
1542      self.assertAlmostEqual(0.6, self.evaluate(specificity))
1543
1544  @test_util.run_deprecated_v1
1545  def testWeighted1d_multipleLabelDtypes(self):
1546    for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
1547      predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
1548      labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1549      weights_values = [3]
1550
1551      predictions = constant_op.constant(
1552          predictions_values, dtype=dtypes_lib.float32)
1553      labels = math_ops.cast(labels_values, dtype=label_dtype)
1554      weights = constant_op.constant(weights_values)
1555      specificity, update_op = metrics.specificity_at_sensitivity(
1556          labels, predictions, weights=weights, sensitivity=0.4)
1557
1558      with self.cached_session():
1559        self.evaluate(variables.local_variables_initializer())
1560
1561        self.assertAlmostEqual(0.6, self.evaluate(update_op))
1562        self.assertAlmostEqual(0.6, self.evaluate(specificity))
1563
1564  @test_util.run_deprecated_v1
1565  def testWeighted2d(self):
1566    predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
1567    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1568    weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
1569
1570    predictions = constant_op.constant(
1571        predictions_values, dtype=dtypes_lib.float32)
1572    labels = constant_op.constant(labels_values)
1573    weights = constant_op.constant(weights_values)
1574    specificity, update_op = metrics.specificity_at_sensitivity(
1575        labels, predictions, weights=weights, sensitivity=0.4)
1576
1577    with self.cached_session():
1578      self.evaluate(variables.local_variables_initializer())
1579
1580      self.assertAlmostEqual(8.0 / 15.0, self.evaluate(update_op))
1581      self.assertAlmostEqual(8.0 / 15.0, self.evaluate(specificity))
1582
1583
1584class SensitivityAtSpecificityTest(test.TestCase):
1585
1586  def setUp(self):
1587    np.random.seed(1)
1588    ops.reset_default_graph()
1589
1590  @test_util.run_deprecated_v1
1591  def testVars(self):
1592    metrics.sensitivity_at_specificity(
1593        predictions=array_ops.ones((10, 1)),
1594        labels=array_ops.ones((10, 1)),
1595        specificity=0.7)
1596    _assert_metric_variables(self,
1597                             ('sensitivity_at_specificity/true_positives:0',
1598                              'sensitivity_at_specificity/false_negatives:0',
1599                              'sensitivity_at_specificity/false_positives:0',
1600                              'sensitivity_at_specificity/true_negatives:0'))
1601
1602  @test_util.run_deprecated_v1
1603  def testMetricsCollection(self):
1604    my_collection_name = '__metrics__'
1605    mean, _ = metrics.sensitivity_at_specificity(
1606        predictions=array_ops.ones((10, 1)),
1607        labels=array_ops.ones((10, 1)),
1608        specificity=0.7,
1609        metrics_collections=[my_collection_name])
1610    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
1611
1612  @test_util.run_deprecated_v1
1613  def testUpdatesCollection(self):
1614    my_collection_name = '__updates__'
1615    _, update_op = metrics.sensitivity_at_specificity(
1616        predictions=array_ops.ones((10, 1)),
1617        labels=array_ops.ones((10, 1)),
1618        specificity=0.7,
1619        updates_collections=[my_collection_name])
1620    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
1621
1622  @test_util.run_deprecated_v1
1623  def testValueTensorIsIdempotent(self):
1624    predictions = random_ops.random_uniform(
1625        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
1626    labels = random_ops.random_uniform(
1627        (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
1628    sensitivity, update_op = metrics.sensitivity_at_specificity(
1629        labels, predictions, specificity=0.7)
1630
1631    with self.cached_session():
1632      self.evaluate(variables.local_variables_initializer())
1633
1634      # Run several updates.
1635      for _ in range(10):
1636        self.evaluate(update_op)
1637
1638      # Then verify idempotency.
1639      initial_sensitivity = self.evaluate(sensitivity)
1640      for _ in range(10):
1641        self.assertAlmostEqual(initial_sensitivity, self.evaluate(sensitivity),
1642                               5)
1643
1644  @test_util.run_deprecated_v1
1645  def testAllCorrect(self):
1646    inputs = np.random.randint(0, 2, size=(100, 1))
1647
1648    predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
1649    labels = constant_op.constant(inputs)
1650    specificity, update_op = metrics.sensitivity_at_specificity(
1651        labels, predictions, specificity=0.7)
1652
1653    with self.cached_session():
1654      self.evaluate(variables.local_variables_initializer())
1655      self.assertAlmostEqual(1.0, self.evaluate(update_op), 6)
1656      self.assertAlmostEqual(1.0, self.evaluate(specificity), 6)
1657
1658  @test_util.run_deprecated_v1
1659  def testSomeCorrectHighSpecificity(self):
1660    predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
1661    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1662
1663    predictions = constant_op.constant(
1664        predictions_values, dtype=dtypes_lib.float32)
1665    labels = constant_op.constant(labels_values)
1666    specificity, update_op = metrics.sensitivity_at_specificity(
1667        labels, predictions, specificity=0.8)
1668
1669    with self.cached_session():
1670      self.evaluate(variables.local_variables_initializer())
1671      self.assertAlmostEqual(0.8, self.evaluate(update_op))
1672      self.assertAlmostEqual(0.8, self.evaluate(specificity))
1673
1674  @test_util.run_deprecated_v1
1675  def testSomeCorrectLowSpecificity(self):
1676    predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
1677    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1678
1679    predictions = constant_op.constant(
1680        predictions_values, dtype=dtypes_lib.float32)
1681    labels = constant_op.constant(labels_values)
1682    specificity, update_op = metrics.sensitivity_at_specificity(
1683        labels, predictions, specificity=0.4)
1684
1685    with self.cached_session():
1686      self.evaluate(variables.local_variables_initializer())
1687      self.assertAlmostEqual(0.6, self.evaluate(update_op))
1688      self.assertAlmostEqual(0.6, self.evaluate(specificity))
1689
1690  @test_util.run_deprecated_v1
1691  def testWeighted_multipleLabelDtypes(self):
1692    for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
1693      predictions_values = [
1694          0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
1695      labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1696      weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
1697
1698      predictions = constant_op.constant(
1699          predictions_values, dtype=dtypes_lib.float32)
1700      labels = math_ops.cast(labels_values, dtype=label_dtype)
1701      weights = constant_op.constant(weights_values)
1702      specificity, update_op = metrics.sensitivity_at_specificity(
1703          labels, predictions, weights=weights, specificity=0.4)
1704
1705      with self.cached_session():
1706        self.evaluate(variables.local_variables_initializer())
1707        self.assertAlmostEqual(0.675, self.evaluate(update_op))
1708        self.assertAlmostEqual(0.675, self.evaluate(specificity))
1709
1710
1711# TODO(nsilberman): Break this up into two sets of tests.
1712class PrecisionRecallThresholdsTest(test.TestCase):
1713
1714  def setUp(self):
1715    np.random.seed(1)
1716    ops.reset_default_graph()
1717
1718  @test_util.run_deprecated_v1
1719  def testVars(self):
1720    metrics.precision_at_thresholds(
1721        predictions=array_ops.ones((10, 1)),
1722        labels=array_ops.ones((10, 1)),
1723        thresholds=[0, 0.5, 1.0])
1724    _assert_metric_variables(self, (
1725        'precision_at_thresholds/true_positives:0',
1726        'precision_at_thresholds/false_positives:0',
1727    ))
1728
1729  @test_util.run_deprecated_v1
1730  def testMetricsCollection(self):
1731    my_collection_name = '__metrics__'
1732    prec, _ = metrics.precision_at_thresholds(
1733        predictions=array_ops.ones((10, 1)),
1734        labels=array_ops.ones((10, 1)),
1735        thresholds=[0, 0.5, 1.0],
1736        metrics_collections=[my_collection_name])
1737    rec, _ = metrics.recall_at_thresholds(
1738        predictions=array_ops.ones((10, 1)),
1739        labels=array_ops.ones((10, 1)),
1740        thresholds=[0, 0.5, 1.0],
1741        metrics_collections=[my_collection_name])
1742    self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec])
1743
1744  @test_util.run_deprecated_v1
1745  def testUpdatesCollection(self):
1746    my_collection_name = '__updates__'
1747    _, precision_op = metrics.precision_at_thresholds(
1748        predictions=array_ops.ones((10, 1)),
1749        labels=array_ops.ones((10, 1)),
1750        thresholds=[0, 0.5, 1.0],
1751        updates_collections=[my_collection_name])
1752    _, recall_op = metrics.recall_at_thresholds(
1753        predictions=array_ops.ones((10, 1)),
1754        labels=array_ops.ones((10, 1)),
1755        thresholds=[0, 0.5, 1.0],
1756        updates_collections=[my_collection_name])
1757    self.assertListEqual(
1758        ops.get_collection(my_collection_name), [precision_op, recall_op])
1759
1760  @test_util.run_deprecated_v1
1761  def testValueTensorIsIdempotent(self):
1762    predictions = random_ops.random_uniform(
1763        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
1764    labels = random_ops.random_uniform(
1765        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
1766    thresholds = [0, 0.5, 1.0]
1767    prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
1768                                                    thresholds)
1769    rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds)
1770
1771    with self.cached_session():
1772      self.evaluate(variables.local_variables_initializer())
1773
1774      # Run several updates, then verify idempotency.
1775      self.evaluate([prec_op, rec_op])
1776      initial_prec = self.evaluate(prec)
1777      initial_rec = self.evaluate(rec)
1778      for _ in range(10):
1779        self.evaluate([prec_op, rec_op])
1780        self.assertAllClose(initial_prec, prec)
1781        self.assertAllClose(initial_rec, rec)
1782
1783  # TODO(nsilberman): fix tests (passing but incorrect).
1784  @test_util.run_deprecated_v1
1785  def testAllCorrect(self):
1786    inputs = np.random.randint(0, 2, size=(100, 1))
1787
1788    with self.cached_session():
1789      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
1790      labels = constant_op.constant(inputs)
1791      thresholds = [0.5]
1792      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
1793                                                      thresholds)
1794      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
1795                                                 thresholds)
1796
1797      self.evaluate(variables.local_variables_initializer())
1798      self.evaluate([prec_op, rec_op])
1799
1800      self.assertEqual(1, self.evaluate(prec))
1801      self.assertEqual(1, self.evaluate(rec))
1802
1803  @test_util.run_deprecated_v1
1804  def testSomeCorrect_multipleLabelDtypes(self):
1805    with self.cached_session():
1806      for label_dtype in (
1807          dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
1808        predictions = constant_op.constant(
1809            [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
1810        labels = math_ops.cast(
1811            constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype)
1812        thresholds = [0.5]
1813        prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
1814                                                        thresholds)
1815        rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
1816                                                   thresholds)
1817
1818        self.evaluate(variables.local_variables_initializer())
1819        self.evaluate([prec_op, rec_op])
1820
1821        self.assertAlmostEqual(0.5, self.evaluate(prec))
1822        self.assertAlmostEqual(0.5, self.evaluate(rec))
1823
1824  @test_util.run_deprecated_v1
1825  def testAllIncorrect(self):
1826    inputs = np.random.randint(0, 2, size=(100, 1))
1827
1828    with self.cached_session():
1829      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
1830      labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
1831      thresholds = [0.5]
1832      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
1833                                                      thresholds)
1834      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
1835                                                 thresholds)
1836
1837      self.evaluate(variables.local_variables_initializer())
1838      self.evaluate([prec_op, rec_op])
1839
1840      self.assertAlmostEqual(0, self.evaluate(prec))
1841      self.assertAlmostEqual(0, self.evaluate(rec))
1842
1843  @test_util.run_deprecated_v1
1844  def testWeights1d(self):
1845    with self.cached_session():
1846      predictions = constant_op.constant(
1847          [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
1848      labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
1849      weights = constant_op.constant(
1850          [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32)
1851      thresholds = [0.5, 1.1]
1852      prec, prec_op = metrics.precision_at_thresholds(
1853          labels, predictions, thresholds, weights=weights)
1854      rec, rec_op = metrics.recall_at_thresholds(
1855          labels, predictions, thresholds, weights=weights)
1856
1857      [prec_low, prec_high] = array_ops.split(
1858          value=prec, num_or_size_splits=2, axis=0)
1859      prec_low = array_ops.reshape(prec_low, shape=())
1860      prec_high = array_ops.reshape(prec_high, shape=())
1861      [rec_low, rec_high] = array_ops.split(
1862          value=rec, num_or_size_splits=2, axis=0)
1863      rec_low = array_ops.reshape(rec_low, shape=())
1864      rec_high = array_ops.reshape(rec_high, shape=())
1865
1866      self.evaluate(variables.local_variables_initializer())
1867      self.evaluate([prec_op, rec_op])
1868
1869      self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5)
1870      self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5)
1871      self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5)
1872      self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5)
1873
1874  @test_util.run_deprecated_v1
1875  def testWeights2d(self):
1876    with self.cached_session():
1877      predictions = constant_op.constant(
1878          [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
1879      labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
1880      weights = constant_op.constant(
1881          [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32)
1882      thresholds = [0.5, 1.1]
1883      prec, prec_op = metrics.precision_at_thresholds(
1884          labels, predictions, thresholds, weights=weights)
1885      rec, rec_op = metrics.recall_at_thresholds(
1886          labels, predictions, thresholds, weights=weights)
1887
1888      [prec_low, prec_high] = array_ops.split(
1889          value=prec, num_or_size_splits=2, axis=0)
1890      prec_low = array_ops.reshape(prec_low, shape=())
1891      prec_high = array_ops.reshape(prec_high, shape=())
1892      [rec_low, rec_high] = array_ops.split(
1893          value=rec, num_or_size_splits=2, axis=0)
1894      rec_low = array_ops.reshape(rec_low, shape=())
1895      rec_high = array_ops.reshape(rec_high, shape=())
1896
1897      self.evaluate(variables.local_variables_initializer())
1898      self.evaluate([prec_op, rec_op])
1899
1900      self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5)
1901      self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5)
1902      self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5)
1903      self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5)
1904
1905  @test_util.run_deprecated_v1
1906  def testExtremeThresholds(self):
1907    with self.cached_session():
1908      predictions = constant_op.constant(
1909          [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
1910      labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
1911      thresholds = [-1.0, 2.0]  # lower/higher than any values
1912      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
1913                                                      thresholds)
1914      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
1915                                                 thresholds)
1916
1917      [prec_low, prec_high] = array_ops.split(
1918          value=prec, num_or_size_splits=2, axis=0)
1919      [rec_low, rec_high] = array_ops.split(
1920          value=rec, num_or_size_splits=2, axis=0)
1921
1922      self.evaluate(variables.local_variables_initializer())
1923      self.evaluate([prec_op, rec_op])
1924
1925      self.assertAlmostEqual(0.75, self.evaluate(prec_low))
1926      self.assertAlmostEqual(0.0, self.evaluate(prec_high))
1927      self.assertAlmostEqual(1.0, self.evaluate(rec_low))
1928      self.assertAlmostEqual(0.0, self.evaluate(rec_high))
1929
1930  @test_util.run_deprecated_v1
1931  def testZeroLabelsPredictions(self):
1932    with self.cached_session():
1933      predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
1934      labels = array_ops.zeros([4])
1935      thresholds = [0.5]
1936      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
1937                                                      thresholds)
1938      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
1939                                                 thresholds)
1940
1941      self.evaluate(variables.local_variables_initializer())
1942      self.evaluate([prec_op, rec_op])
1943
1944      self.assertAlmostEqual(0, self.evaluate(prec), 6)
1945      self.assertAlmostEqual(0, self.evaluate(rec), 6)
1946
1947  @test_util.run_deprecated_v1
1948  def testWithMultipleUpdates(self):
1949    num_samples = 1000
1950    batch_size = 10
1951    num_batches = int(num_samples / batch_size)
1952
1953    # Create the labels and data.
1954    labels = np.random.randint(0, 2, size=(num_samples, 1))
1955    noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
1956    predictions = 0.4 + 0.2 * labels + noise
1957    predictions[predictions > 1] = 1
1958    predictions[predictions < 0] = 0
1959    thresholds = [0.3]
1960
1961    tp = 0
1962    fp = 0
1963    fn = 0
1964    tn = 0
1965    for i in range(num_samples):
1966      if predictions[i] > thresholds[0]:
1967        if labels[i] == 1:
1968          tp += 1
1969        else:
1970          fp += 1
1971      else:
1972        if labels[i] == 1:
1973          fn += 1
1974        else:
1975          tn += 1
1976    epsilon = 1e-7
1977    expected_prec = tp / (epsilon + tp + fp)
1978    expected_rec = tp / (epsilon + tp + fn)
1979
1980    labels = labels.astype(np.float32)
1981    predictions = predictions.astype(np.float32)
1982
1983    with self.cached_session() as sess:
1984      # Reshape the data so its easy to queue up:
1985      predictions_batches = predictions.reshape((batch_size, num_batches))
1986      labels_batches = labels.reshape((batch_size, num_batches))
1987
1988      # Enqueue the data:
1989      predictions_queue = data_flow_ops.FIFOQueue(
1990          num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
1991      labels_queue = data_flow_ops.FIFOQueue(
1992          num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
1993
1994      for i in range(int(num_batches)):
1995        tf_prediction = constant_op.constant(predictions_batches[:, i])
1996        tf_label = constant_op.constant(labels_batches[:, i])
1997        sess.run([
1998            predictions_queue.enqueue(tf_prediction),
1999            labels_queue.enqueue(tf_label)
2000        ])
2001
2002      tf_predictions = predictions_queue.dequeue()
2003      tf_labels = labels_queue.dequeue()
2004
2005      prec, prec_op = metrics.precision_at_thresholds(tf_labels, tf_predictions,
2006                                                      thresholds)
2007      rec, rec_op = metrics.recall_at_thresholds(tf_labels, tf_predictions,
2008                                                 thresholds)
2009
2010      self.evaluate(variables.local_variables_initializer())
2011      for _ in range(int(num_samples / batch_size)):
2012        self.evaluate([prec_op, rec_op])
2013      # Since this is only approximate, we can't expect a 6 digits match.
2014      # Although with higher number of samples/thresholds we should see the
2015      # accuracy improving
2016      self.assertAlmostEqual(expected_prec, self.evaluate(prec), 2)
2017      self.assertAlmostEqual(expected_rec, self.evaluate(rec), 2)
2018
2019
2020def _test_precision_at_k(predictions,
2021                         labels,
2022                         k,
2023                         expected,
2024                         class_id=None,
2025                         weights=None,
2026                         test_case=None):
2027  with ops.Graph().as_default() as g, test_case.test_session(g):
2028    if weights is not None:
2029      weights = constant_op.constant(weights, dtypes_lib.float32)
2030    metric, update = metrics.precision_at_k(
2031        predictions=constant_op.constant(predictions, dtypes_lib.float32),
2032        labels=labels,
2033        k=k,
2034        class_id=class_id,
2035        weights=weights)
2036
2037    # Fails without initialized vars.
2038    test_case.assertRaises(errors_impl.OpError, metric.eval)
2039    test_case.assertRaises(errors_impl.OpError, update.eval)
2040    variables.variables_initializer(variables.local_variables()).run()
2041
2042    # Run per-step op and assert expected values.
2043    if math.isnan(expected):
2044      _assert_nan(test_case, update.eval())
2045      _assert_nan(test_case, metric.eval())
2046    else:
2047      test_case.assertEqual(expected, update.eval())
2048      test_case.assertEqual(expected, metric.eval())
2049
2050
2051def _test_precision_at_top_k(
2052    predictions_idx,
2053    labels,
2054    expected,
2055    k=None,
2056    class_id=None,
2057    weights=None,
2058    test_case=None):
2059  with ops.Graph().as_default() as g, test_case.test_session(g):
2060    if weights is not None:
2061      weights = constant_op.constant(weights, dtypes_lib.float32)
2062    metric, update = metrics.precision_at_top_k(
2063        predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
2064        labels=labels,
2065        k=k,
2066        class_id=class_id,
2067        weights=weights)
2068
2069    # Fails without initialized vars.
2070    test_case.assertRaises(errors_impl.OpError, metric.eval)
2071    test_case.assertRaises(errors_impl.OpError, update.eval)
2072    variables.variables_initializer(variables.local_variables()).run()
2073
2074    # Run per-step op and assert expected values.
2075    if math.isnan(expected):
2076      test_case.assertTrue(math.isnan(update.eval()))
2077      test_case.assertTrue(math.isnan(metric.eval()))
2078    else:
2079      test_case.assertEqual(expected, update.eval())
2080      test_case.assertEqual(expected, metric.eval())
2081
2082
2083def _test_average_precision_at_k(predictions,
2084                                 labels,
2085                                 k,
2086                                 expected,
2087                                 weights=None,
2088                                 test_case=None):
2089  with ops.Graph().as_default() as g, test_case.test_session(g):
2090    if weights is not None:
2091      weights = constant_op.constant(weights, dtypes_lib.float32)
2092    predictions = constant_op.constant(predictions, dtypes_lib.float32)
2093    metric, update = metrics.average_precision_at_k(
2094        labels, predictions, k, weights=weights)
2095
2096    # Fails without initialized vars.
2097    test_case.assertRaises(errors_impl.OpError, metric.eval)
2098    test_case.assertRaises(errors_impl.OpError, update.eval)
2099    variables.variables_initializer(variables.local_variables()).run()
2100
2101    # Run per-step op and assert expected values.
2102    if math.isnan(expected):
2103      _assert_nan(test_case, update.eval())
2104      _assert_nan(test_case, metric.eval())
2105    else:
2106      test_case.assertAlmostEqual(expected, update.eval())
2107      test_case.assertAlmostEqual(expected, metric.eval())
2108
2109
2110class SingleLabelPrecisionAtKTest(test.TestCase):
2111
2112  def setUp(self):
2113    self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
2114    self._predictions_idx = [[3], [3]]
2115    indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
2116    class_labels = (3, 2)
2117    # Sparse vs dense, and 1d vs 2d labels should all be handled the same.
2118    self._labels = (
2119        _binary_2d_label_to_1d_sparse_value(indicator_labels),
2120        _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array(
2121            class_labels, dtype=np.int64), np.array(
2122                [[class_id] for class_id in class_labels], dtype=np.int64))
2123    self._test_precision_at_k = functools.partial(
2124        _test_precision_at_k, test_case=self)
2125    self._test_precision_at_top_k = functools.partial(
2126        _test_precision_at_top_k, test_case=self)
2127    self._test_average_precision_at_k = functools.partial(
2128        _test_average_precision_at_k, test_case=self)
2129
2130  @test_util.run_deprecated_v1
2131  def test_at_k1_nan(self):
2132    for labels in self._labels:
2133      # Classes 0,1,2 have 0 predictions, classes -1 and 4 are out of range.
2134      for class_id in (-1, 0, 1, 2, 4):
2135        self._test_precision_at_k(
2136            self._predictions, labels, k=1, expected=NAN, class_id=class_id)
2137        self._test_precision_at_top_k(
2138            self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)
2139
2140  @test_util.run_deprecated_v1
2141  def test_at_k1(self):
2142    for labels in self._labels:
2143      # Class 3: 1 label, 2 predictions, 1 correct.
2144      self._test_precision_at_k(
2145          self._predictions, labels, k=1, expected=1.0 / 2, class_id=3)
2146      self._test_precision_at_top_k(
2147          self._predictions_idx, labels, k=1, expected=1.0 / 2, class_id=3)
2148
2149      # All classes: 2 labels, 2 predictions, 1 correct.
2150      self._test_precision_at_k(
2151          self._predictions, labels, k=1, expected=1.0 / 2)
2152      self._test_precision_at_top_k(
2153          self._predictions_idx, labels, k=1, expected=1.0 / 2)
2154      self._test_average_precision_at_k(
2155          self._predictions, labels, k=1, expected=1.0 / 2)
2156
2157
2158class MultiLabelPrecisionAtKTest(test.TestCase):
2159
2160  def setUp(self):
2161    self._test_precision_at_k = functools.partial(
2162        _test_precision_at_k, test_case=self)
2163    self._test_precision_at_top_k = functools.partial(
2164        _test_precision_at_top_k, test_case=self)
2165    self._test_average_precision_at_k = functools.partial(
2166        _test_average_precision_at_k, test_case=self)
2167
2168  @test_util.run_deprecated_v1
2169  def test_average_precision(self):
2170    # Example 1.
2171    # Matches example here:
2172    # fastml.com/what-you-wanted-to-know-about-mean-average-precision
2173    labels_ex1 = (0, 1, 2, 3, 4)
2174    labels = np.array([labels_ex1], dtype=np.int64)
2175    predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
2176    predictions = (predictions_ex1,)
2177    predictions_idx_ex1 = (5, 3, 6, 0, 1)
2178    precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4)
2179    avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3,
2180                         (precision_ex1[1] + precision_ex1[3]) / 4)
2181    for i in xrange(4):
2182      k = i + 1
2183      self._test_precision_at_k(
2184          predictions, labels, k, expected=precision_ex1[i])
2185      self._test_precision_at_top_k(
2186          (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i])
2187      self._test_average_precision_at_k(
2188          predictions, labels, k, expected=avg_precision_ex1[i])
2189
2190    # Example 2.
2191    labels_ex2 = (0, 2, 4, 5, 6)
2192    labels = np.array([labels_ex2], dtype=np.int64)
2193    predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2)
2194    predictions = (predictions_ex2,)
2195    predictions_idx_ex2 = (1, 3, 0, 6, 5)
2196    precision_ex2 = (0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4)
2197    avg_precision_ex2 = (0.0 / 1, 0.0 / 2, precision_ex2[2] / 3,
2198                         (precision_ex2[2] + precision_ex2[3]) / 4)
2199    for i in xrange(4):
2200      k = i + 1
2201      self._test_precision_at_k(
2202          predictions, labels, k, expected=precision_ex2[i])
2203      self._test_precision_at_top_k(
2204          (predictions_idx_ex2[:k],), labels, k=k, expected=precision_ex2[i])
2205      self._test_average_precision_at_k(
2206          predictions, labels, k, expected=avg_precision_ex2[i])
2207
2208    # Both examples, we expect both precision and average precision to be the
2209    # average of the 2 examples.
2210    labels = np.array([labels_ex1, labels_ex2], dtype=np.int64)
2211    predictions = (predictions_ex1, predictions_ex2)
2212    streaming_precision = [(ex1 + ex2) / 2
2213                           for ex1, ex2 in zip(precision_ex1, precision_ex2)]
2214    streaming_average_precision = [
2215        (ex1 + ex2) / 2
2216        for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
2217    ]
2218    for i in xrange(4):
2219      k = i + 1
2220      predictions_idx = (predictions_idx_ex1[:k], predictions_idx_ex2[:k])
2221      self._test_precision_at_k(
2222          predictions, labels, k, expected=streaming_precision[i])
2223      self._test_precision_at_top_k(
2224          predictions_idx, labels, k=k, expected=streaming_precision[i])
2225      self._test_average_precision_at_k(
2226          predictions, labels, k, expected=streaming_average_precision[i])
2227
2228    # Weighted examples, we expect streaming average precision to be the
2229    # weighted average of the 2 examples.
2230    weights = (0.3, 0.6)
2231    streaming_average_precision = [
2232        (weights[0] * ex1 + weights[1] * ex2) / (weights[0] + weights[1])
2233        for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
2234    ]
2235    for i in xrange(4):
2236      k = i + 1
2237      self._test_average_precision_at_k(
2238          predictions,
2239          labels,
2240          k,
2241          expected=streaming_average_precision[i],
2242          weights=weights)
2243
2244  @test_util.run_deprecated_v1
2245  def test_average_precision_some_labels_out_of_range(self):
2246    """Tests that labels outside the [0, n_classes) range are ignored."""
2247    labels_ex1 = (-1, 0, 1, 2, 3, 4, 7)
2248    labels = np.array([labels_ex1], dtype=np.int64)
2249    predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
2250    predictions = (predictions_ex1,)
2251    predictions_idx_ex1 = (5, 3, 6, 0, 1)
2252    precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4)
2253    avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3,
2254                         (precision_ex1[1] + precision_ex1[3]) / 4)
2255    for i in xrange(4):
2256      k = i + 1
2257      self._test_precision_at_k(
2258          predictions, labels, k, expected=precision_ex1[i])
2259      self._test_precision_at_top_k(
2260          (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i])
2261      self._test_average_precision_at_k(
2262          predictions, labels, k, expected=avg_precision_ex1[i])
2263
2264  @test_util.run_deprecated_v1
2265  def test_average_precision_different_num_labels(self):
2266    """Tests the case where the numbers of labels differ across examples."""
2267    predictions = [[0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4]]
2268    sparse_labels = _binary_2d_label_to_2d_sparse_value(
2269        [[0, 0, 1, 1], [0, 0, 0, 1]])
2270    dense_labels = np.array([[2, 3], [3, -1]], dtype=np.int64)
2271    predictions_idx_ex1 = np.array(((0, 1, 2, 3), (3, 2, 1, 0)))
2272    precision_ex1 = ((0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4),
2273                     (1.0 / 1, 1.0 / 2, 1.0 / 3, 1.0 / 4))
2274    mean_precision_ex1 = np.mean(precision_ex1, axis=0)
2275    avg_precision_ex1 = (
2276        (0.0 / 1, 0.0 / 2, 1.0 / 3 / 2, (1.0 / 3 + 2.0 / 4) / 2),
2277        (1.0 / 1, 1.0 / 1, 1.0 / 1, 1.0 / 1))
2278    mean_avg_precision_ex1 = np.mean(avg_precision_ex1, axis=0)
2279    for labels in (sparse_labels, dense_labels):
2280      for i in xrange(4):
2281        k = i + 1
2282        self._test_precision_at_k(
2283            predictions, labels, k, expected=mean_precision_ex1[i])
2284        self._test_precision_at_top_k(
2285            predictions_idx_ex1[:, :k], labels, k=k,
2286            expected=mean_precision_ex1[i])
2287        self._test_average_precision_at_k(
2288            predictions, labels, k, expected=mean_avg_precision_ex1[i])
2289
2290  @test_util.run_deprecated_v1
2291  def test_three_labels_at_k5_no_predictions(self):
2292    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2293                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
2294    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
2295    sparse_labels = _binary_2d_label_to_2d_sparse_value(
2296        [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
2297    dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
2298
2299    for labels in (sparse_labels, dense_labels):
2300      # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range.
2301      for class_id in (-1, 1, 3, 8, 10):
2302        self._test_precision_at_k(
2303            predictions, labels, k=5, expected=NAN, class_id=class_id)
2304        self._test_precision_at_top_k(
2305            predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
2306
2307  @test_util.run_deprecated_v1
2308  def test_three_labels_at_k5_no_labels(self):
2309    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2310                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
2311    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
2312    sparse_labels = _binary_2d_label_to_2d_sparse_value(
2313        [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
2314    dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
2315
2316    for labels in (sparse_labels, dense_labels):
2317      # Classes 0,4,6,9: 0 labels, >=1 prediction.
2318      for class_id in (0, 4, 6, 9):
2319        self._test_precision_at_k(
2320            predictions, labels, k=5, expected=0.0, class_id=class_id)
2321        self._test_precision_at_top_k(
2322            predictions_idx, labels, k=5, expected=0.0, class_id=class_id)
2323
2324  @test_util.run_deprecated_v1
2325  def test_three_labels_at_k5(self):
2326    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2327                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
2328    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
2329    sparse_labels = _binary_2d_label_to_2d_sparse_value(
2330        [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
2331    dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
2332
2333    for labels in (sparse_labels, dense_labels):
2334      # Class 2: 2 labels, 2 correct predictions.
2335      self._test_precision_at_k(
2336          predictions, labels, k=5, expected=2.0 / 2, class_id=2)
2337      self._test_precision_at_top_k(
2338          predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
2339
2340      # Class 5: 1 label, 1 correct prediction.
2341      self._test_precision_at_k(
2342          predictions, labels, k=5, expected=1.0 / 1, class_id=5)
2343      self._test_precision_at_top_k(
2344          predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
2345
2346      # Class 7: 1 label, 1 incorrect prediction.
2347      self._test_precision_at_k(
2348          predictions, labels, k=5, expected=0.0 / 1, class_id=7)
2349      self._test_precision_at_top_k(
2350          predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
2351
2352      # All classes: 10 predictions, 3 correct.
2353      self._test_precision_at_k(
2354          predictions, labels, k=5, expected=3.0 / 10)
2355      self._test_precision_at_top_k(
2356          predictions_idx, labels, k=5, expected=3.0 / 10)
2357
2358  @test_util.run_deprecated_v1
2359  def test_three_labels_at_k5_some_out_of_range(self):
2360    """Tests that labels outside the [0, n_classes) range are ignored."""
2361    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2362                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
2363    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
2364    sp_labels = sparse_tensor.SparseTensorValue(
2365        indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
2366                 [1, 3]],
2367        # values -1 and 10 are outside the [0, n_classes) range and are ignored.
2368        values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
2369        dense_shape=[2, 4])
2370
2371    # Class 2: 2 labels, 2 correct predictions.
2372    self._test_precision_at_k(
2373        predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2)
2374    self._test_precision_at_top_k(
2375        predictions_idx, sp_labels, k=5, expected=2.0 / 2, class_id=2)
2376
2377    # Class 5: 1 label, 1 correct prediction.
2378    self._test_precision_at_k(
2379        predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5)
2380    self._test_precision_at_top_k(
2381        predictions_idx, sp_labels, k=5, expected=1.0 / 1, class_id=5)
2382
2383    # Class 7: 1 label, 1 incorrect prediction.
2384    self._test_precision_at_k(
2385        predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7)
2386    self._test_precision_at_top_k(
2387        predictions_idx, sp_labels, k=5, expected=0.0 / 1, class_id=7)
2388
2389    # All classes: 10 predictions, 3 correct.
2390    self._test_precision_at_k(
2391        predictions, sp_labels, k=5, expected=3.0 / 10)
2392    self._test_precision_at_top_k(
2393        predictions_idx, sp_labels, k=5, expected=3.0 / 10)
2394
2395  @test_util.run_deprecated_v1
2396  def test_3d_nan(self):
2397    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2398                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
2399                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
2400                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
2401    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
2402                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
2403    labels = _binary_3d_label_to_sparse_value(
2404        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
2405         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
2406
2407    # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range.
2408    for class_id in (-1, 1, 3, 8, 10):
2409      self._test_precision_at_k(
2410          predictions, labels, k=5, expected=NAN, class_id=class_id)
2411      self._test_precision_at_top_k(
2412          predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
2413
2414  @test_util.run_deprecated_v1
2415  def test_3d_no_labels(self):
2416    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2417                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
2418                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
2419                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
2420    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
2421                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
2422    labels = _binary_3d_label_to_sparse_value(
2423        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
2424         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
2425
2426    # Classes 0,4,6,9: 0 labels, >=1 prediction.
2427    for class_id in (0, 4, 6, 9):
2428      self._test_precision_at_k(
2429          predictions, labels, k=5, expected=0.0, class_id=class_id)
2430      self._test_precision_at_top_k(
2431          predictions_idx, labels, k=5, expected=0.0, class_id=class_id)
2432
2433  @test_util.run_deprecated_v1
2434  def test_3d(self):
2435    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2436                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
2437                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
2438                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
2439    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
2440                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
2441    labels = _binary_3d_label_to_sparse_value(
2442        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
2443         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
2444
2445    # Class 2: 4 predictions, all correct.
2446    self._test_precision_at_k(
2447        predictions, labels, k=5, expected=4.0 / 4, class_id=2)
2448    self._test_precision_at_top_k(
2449        predictions_idx, labels, k=5, expected=4.0 / 4, class_id=2)
2450
2451    # Class 5: 2 predictions, both correct.
2452    self._test_precision_at_k(
2453        predictions, labels, k=5, expected=2.0 / 2, class_id=5)
2454    self._test_precision_at_top_k(
2455        predictions_idx, labels, k=5, expected=2.0 / 2, class_id=5)
2456
2457    # Class 7: 2 predictions, 1 correct.
2458    self._test_precision_at_k(
2459        predictions, labels, k=5, expected=1.0 / 2, class_id=7)
2460    self._test_precision_at_top_k(
2461        predictions_idx, labels, k=5, expected=1.0 / 2, class_id=7)
2462
2463    # All classes: 20 predictions, 7 correct.
2464    self._test_precision_at_k(
2465        predictions, labels, k=5, expected=7.0 / 20)
2466    self._test_precision_at_top_k(
2467        predictions_idx, labels, k=5, expected=7.0 / 20)
2468
2469  @test_util.run_deprecated_v1
2470  def test_3d_ignore_some(self):
2471    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
2472                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
2473                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
2474                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
2475    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
2476                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
2477    labels = _binary_3d_label_to_sparse_value(
2478        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
2479         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
2480
2481    # Class 2: 2 predictions, both correct.
2482    self._test_precision_at_k(
2483        predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
2484        weights=[[1], [0]])
2485    self._test_precision_at_top_k(
2486        predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2,
2487        weights=[[1], [0]])
2488
2489    # Class 2: 2 predictions, both correct.
2490    self._test_precision_at_k(
2491        predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
2492        weights=[[0], [1]])
2493    self._test_precision_at_top_k(
2494        predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2,
2495        weights=[[0], [1]])
2496
2497    # Class 7: 1 incorrect prediction.
2498    self._test_precision_at_k(
2499        predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
2500        weights=[[1], [0]])
2501    self._test_precision_at_top_k(
2502        predictions_idx, labels, k=5, expected=0.0 / 1.0, class_id=7,
2503        weights=[[1], [0]])
2504
2505    # Class 7: 1 correct prediction.
2506    self._test_precision_at_k(
2507        predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
2508        weights=[[0], [1]])
2509    self._test_precision_at_top_k(
2510        predictions_idx, labels, k=5, expected=1.0 / 1.0, class_id=7,
2511        weights=[[0], [1]])
2512
2513    # Class 7: no predictions.
2514    self._test_precision_at_k(
2515        predictions, labels, k=5, expected=NAN, class_id=7,
2516        weights=[[1, 0], [0, 1]])
2517    self._test_precision_at_top_k(
2518        predictions_idx, labels, k=5, expected=NAN, class_id=7,
2519        weights=[[1, 0], [0, 1]])
2520
2521    # Class 7: 2 predictions, 1 correct.
2522    self._test_precision_at_k(
2523        predictions, labels, k=5, expected=1.0 / 2.0, class_id=7,
2524        weights=[[0, 1], [1, 0]])
2525    self._test_precision_at_top_k(
2526        predictions_idx, labels, k=5, expected=1.0 / 2.0, class_id=7,
2527        weights=[[0, 1], [1, 0]])
2528
2529
2530def _test_recall_at_k(predictions,
2531                      labels,
2532                      k,
2533                      expected,
2534                      class_id=None,
2535                      weights=None,
2536                      test_case=None):
2537  with ops.Graph().as_default() as g, test_case.test_session(g):
2538    if weights is not None:
2539      weights = constant_op.constant(weights, dtypes_lib.float32)
2540    metric, update = metrics.recall_at_k(
2541        predictions=constant_op.constant(predictions, dtypes_lib.float32),
2542        labels=labels,
2543        k=k,
2544        class_id=class_id,
2545        weights=weights)
2546
2547    # Fails without initialized vars.
2548    test_case.assertRaises(errors_impl.OpError, metric.eval)
2549    test_case.assertRaises(errors_impl.OpError, update.eval)
2550    variables.variables_initializer(variables.local_variables()).run()
2551
2552    # Run per-step op and assert expected values.
2553    if math.isnan(expected):
2554      _assert_nan(test_case, update.eval())
2555      _assert_nan(test_case, metric.eval())
2556    else:
2557      test_case.assertEqual(expected, update.eval())
2558      test_case.assertEqual(expected, metric.eval())
2559
2560
2561def _test_recall_at_top_k(
2562    predictions_idx,
2563    labels,
2564    expected,
2565    k=None,
2566    class_id=None,
2567    weights=None,
2568    test_case=None):
2569  with ops.Graph().as_default() as g, test_case.test_session(g):
2570    if weights is not None:
2571      weights = constant_op.constant(weights, dtypes_lib.float32)
2572    metric, update = metrics.recall_at_top_k(
2573        predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
2574        labels=labels,
2575        k=k,
2576        class_id=class_id,
2577        weights=weights)
2578
2579    # Fails without initialized vars.
2580    test_case.assertRaises(errors_impl.OpError, metric.eval)
2581    test_case.assertRaises(errors_impl.OpError, update.eval)
2582    variables.variables_initializer(variables.local_variables()).run()
2583
2584    # Run per-step op and assert expected values.
2585    if math.isnan(expected):
2586      _assert_nan(test_case, update.eval())
2587      _assert_nan(test_case, metric.eval())
2588    else:
2589      test_case.assertEqual(expected, update.eval())
2590      test_case.assertEqual(expected, metric.eval())
2591
2592
2593class SingleLabelRecallAtKTest(test.TestCase):
2594
2595  def setUp(self):
2596    self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
2597    self._predictions_idx = [[3], [3]]
2598    indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
2599    class_labels = (3, 2)
2600    # Sparse vs dense, and 1d vs 2d labels should all be handled the same.
2601    self._labels = (
2602        _binary_2d_label_to_1d_sparse_value(indicator_labels),
2603        _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array(
2604            class_labels, dtype=np.int64), np.array(
2605                [[class_id] for class_id in class_labels], dtype=np.int64))
2606    self._test_recall_at_k = functools.partial(
2607        _test_recall_at_k, test_case=self)
2608    self._test_recall_at_top_k = functools.partial(
2609        _test_recall_at_top_k, test_case=self)
2610
2611  @test_util.run_deprecated_v1
2612  def test_at_k1_nan(self):
2613    # Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of
2614    # range.
2615    for labels in self._labels:
2616      for class_id in (-1, 0, 1, 4):
2617        self._test_recall_at_k(
2618            self._predictions, labels, k=1, expected=NAN, class_id=class_id)
2619        self._test_recall_at_top_k(
2620            self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)
2621
2622  @test_util.run_deprecated_v1
2623  def test_at_k1_no_predictions(self):
2624    for labels in self._labels:
2625      # Class 2: 0 predictions.
2626      self._test_recall_at_k(
2627          self._predictions, labels, k=1, expected=0.0, class_id=2)
2628      self._test_recall_at_top_k(
2629          self._predictions_idx, labels, k=1, expected=0.0, class_id=2)
2630
2631  @test_util.run_deprecated_v1
2632  def test_one_label_at_k1(self):
2633    for labels in self._labels:
2634      # Class 3: 1 label, 2 predictions, 1 correct.
2635      self._test_recall_at_k(
2636          self._predictions, labels, k=1, expected=1.0 / 1, class_id=3)
2637      self._test_recall_at_top_k(
2638          self._predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3)
2639
2640      # All classes: 2 labels, 2 predictions, 1 correct.
2641      self._test_recall_at_k(self._predictions, labels, k=1, expected=1.0 / 2)
2642      self._test_recall_at_top_k(
2643          self._predictions_idx, labels, k=1, expected=1.0 / 2)
2644
2645  @test_util.run_deprecated_v1
2646  def test_one_label_at_k1_weighted_class_id3(self):
2647    predictions = self._predictions
2648    predictions_idx = self._predictions_idx
2649    for labels in self._labels:
2650      # Class 3: 1 label, 2 predictions, 1 correct.
2651      self._test_recall_at_k(
2652          predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
2653      self._test_recall_at_top_k(
2654          predictions_idx, labels, k=1, expected=NAN, class_id=3,
2655          weights=(0.0,))
2656      self._test_recall_at_k(
2657          predictions, labels, k=1, expected=1.0 / 1, class_id=3,
2658          weights=(1.0,))
2659      self._test_recall_at_top_k(
2660          predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
2661          weights=(1.0,))
2662      self._test_recall_at_k(
2663          predictions, labels, k=1, expected=1.0 / 1, class_id=3,
2664          weights=(2.0,))
2665      self._test_recall_at_top_k(
2666          predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
2667          weights=(2.0,))
2668      self._test_recall_at_k(
2669          predictions, labels, k=1, expected=NAN, class_id=3,
2670          weights=(0.0, 1.0))
2671      self._test_recall_at_top_k(
2672          predictions_idx, labels, k=1, expected=NAN, class_id=3,
2673          weights=(0.0, 1.0))
2674      self._test_recall_at_k(
2675          predictions, labels, k=1, expected=1.0 / 1, class_id=3,
2676          weights=(1.0, 0.0))
2677      self._test_recall_at_top_k(
2678          predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
2679          weights=(1.0, 0.0))
2680      self._test_recall_at_k(
2681          predictions, labels, k=1, expected=2.0 / 2, class_id=3,
2682          weights=(2.0, 3.0))
2683      self._test_recall_at_top_k(
2684          predictions_idx, labels, k=1, expected=2.0 / 2, class_id=3,
2685          weights=(2.0, 3.0))
2686
2687  @test_util.run_deprecated_v1
2688  def test_one_label_at_k1_weighted(self):
2689    predictions = self._predictions
2690    predictions_idx = self._predictions_idx
2691    for labels in self._labels:
2692      # All classes: 2 labels, 2 predictions, 1 correct.
2693      self._test_recall_at_k(
2694          predictions, labels, k=1, expected=NAN, weights=(0.0,))
2695      self._test_recall_at_top_k(
2696          predictions_idx, labels, k=1, expected=NAN, weights=(0.0,))
2697      self._test_recall_at_k(
2698          predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
2699      self._test_recall_at_top_k(
2700          predictions_idx, labels, k=1, expected=1.0 / 2, weights=(1.0,))
2701      self._test_recall_at_k(
2702          predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
2703      self._test_recall_at_top_k(
2704          predictions_idx, labels, k=1, expected=1.0 / 2, weights=(2.0,))
2705      self._test_recall_at_k(
2706          predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
2707      self._test_recall_at_top_k(
2708          predictions_idx, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
2709      self._test_recall_at_k(
2710          predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
2711      self._test_recall_at_top_k(
2712          predictions_idx, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
2713      self._test_recall_at_k(
2714          predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
2715      self._test_recall_at_top_k(
2716          predictions_idx, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
2717
2718
2719class MultiLabel2dRecallAtKTest(test.TestCase):
2720
2721  def setUp(self):
2722    self._predictions = ((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
2723                         (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6))
2724    self._predictions_idx = ((9, 4, 6, 2, 0), (5, 7, 2, 9, 6))
2725    indicator_labels = ((0, 0, 1, 0, 0, 0, 0, 1, 1, 0),
2726                        (0, 1, 1, 0, 0, 1, 0, 0, 0, 0))
2727    class_labels = ((2, 7, 8), (1, 2, 5))
2728    # Sparse vs dense labels should be handled the same.
2729    self._labels = (_binary_2d_label_to_2d_sparse_value(indicator_labels),
2730                    np.array(
2731                        class_labels, dtype=np.int64))
2732    self._test_recall_at_k = functools.partial(
2733        _test_recall_at_k, test_case=self)
2734    self._test_recall_at_top_k = functools.partial(
2735        _test_recall_at_top_k, test_case=self)
2736
2737  @test_util.run_deprecated_v1
2738  def test_at_k5_nan(self):
2739    for labels in self._labels:
2740      # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
2741      for class_id in (0, 3, 4, 6, 9, 10):
2742        self._test_recall_at_k(
2743            self._predictions, labels, k=5, expected=NAN, class_id=class_id)
2744        self._test_recall_at_top_k(
2745            self._predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
2746
2747  @test_util.run_deprecated_v1
2748  def test_at_k5_no_predictions(self):
2749    for labels in self._labels:
2750      # Class 8: 1 label, no predictions.
2751      self._test_recall_at_k(
2752          self._predictions, labels, k=5, expected=0.0 / 1, class_id=8)
2753      self._test_recall_at_top_k(
2754          self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=8)
2755
2756  @test_util.run_deprecated_v1
2757  def test_at_k5(self):
2758    for labels in self._labels:
2759      # Class 2: 2 labels, both correct.
2760      self._test_recall_at_k(
2761          self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
2762      self._test_recall_at_top_k(
2763          self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
2764
2765      # Class 5: 1 label, incorrect.
2766      self._test_recall_at_k(
2767          self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
2768      self._test_recall_at_top_k(
2769          self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
2770
2771      # Class 7: 1 label, incorrect.
2772      self._test_recall_at_k(
2773          self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
2774      self._test_recall_at_top_k(
2775          self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
2776
2777      # All classes: 6 labels, 3 correct.
2778      self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 6)
2779      self._test_recall_at_top_k(
2780          self._predictions_idx, labels, k=5, expected=3.0 / 6)
2781
2782  @test_util.run_deprecated_v1
2783  def test_at_k5_some_out_of_range(self):
2784    """Tests that labels outside the [0, n_classes) count in denominator."""
2785    labels = sparse_tensor.SparseTensorValue(
2786        indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
2787                 [1, 3]],
2788        # values -1 and 10 are outside the [0, n_classes) range.
2789        values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
2790        dense_shape=[2, 4])
2791
2792    # Class 2: 2 labels, both correct.
2793    self._test_recall_at_k(
2794        self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
2795    self._test_recall_at_top_k(
2796        self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
2797
2798    # Class 5: 1 label, incorrect.
2799    self._test_recall_at_k(
2800        self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
2801    self._test_recall_at_top_k(
2802        self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
2803
2804    # Class 7: 1 label, incorrect.
2805    self._test_recall_at_k(
2806        self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
2807    self._test_recall_at_top_k(
2808        self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
2809
2810    # All classes: 8 labels, 3 correct.
2811    self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 8)
2812    self._test_recall_at_top_k(
2813        self._predictions_idx, labels, k=5, expected=3.0 / 8)
2814
2815
2816class MultiLabel3dRecallAtKTest(test.TestCase):
2817
2818  def setUp(self):
2819    self._predictions = (((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
2820                          (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)),
2821                         ((0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6),
2822                          (0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9)))
2823    self._predictions_idx = (((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)),
2824                             ((5, 7, 2, 9, 6), (9, 4, 6, 2, 0)))
2825    # Note: We don't test dense labels here, since examples have different
2826    # numbers of labels.
2827    self._labels = _binary_3d_label_to_sparse_value(((
2828        (0, 0, 1, 0, 0, 0, 0, 1, 1, 0), (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)), (
2829            (0, 1, 1, 0, 0, 1, 0, 1, 0, 0), (0, 0, 1, 0, 0, 0, 0, 0, 1, 0))))
2830    self._test_recall_at_k = functools.partial(
2831        _test_recall_at_k, test_case=self)
2832    self._test_recall_at_top_k = functools.partial(
2833        _test_recall_at_top_k, test_case=self)
2834
2835  @test_util.run_deprecated_v1
2836  def test_3d_nan(self):
2837    # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
2838    for class_id in (0, 3, 4, 6, 9, 10):
2839      self._test_recall_at_k(
2840          self._predictions, self._labels, k=5, expected=NAN, class_id=class_id)
2841      self._test_recall_at_top_k(
2842          self._predictions_idx, self._labels, k=5, expected=NAN,
2843          class_id=class_id)
2844
2845  @test_util.run_deprecated_v1
2846  def test_3d_no_predictions(self):
2847    # Classes 1,8 have 0 predictions, >=1 label.
2848    for class_id in (1, 8):
2849      self._test_recall_at_k(
2850          self._predictions, self._labels, k=5, expected=0.0, class_id=class_id)
2851      self._test_recall_at_top_k(
2852          self._predictions_idx, self._labels, k=5, expected=0.0,
2853          class_id=class_id)
2854
2855  @test_util.run_deprecated_v1
2856  def test_3d(self):
2857    # Class 2: 4 labels, all correct.
2858    self._test_recall_at_k(
2859        self._predictions, self._labels, k=5, expected=4.0 / 4, class_id=2)
2860    self._test_recall_at_top_k(
2861        self._predictions_idx, self._labels, k=5, expected=4.0 / 4,
2862        class_id=2)
2863
2864    # Class 5: 2 labels, both correct.
2865    self._test_recall_at_k(
2866        self._predictions, self._labels, k=5, expected=2.0 / 2, class_id=5)
2867    self._test_recall_at_top_k(
2868        self._predictions_idx, self._labels, k=5, expected=2.0 / 2,
2869        class_id=5)
2870
2871    # Class 7: 2 labels, 1 incorrect.
2872    self._test_recall_at_k(
2873        self._predictions, self._labels, k=5, expected=1.0 / 2, class_id=7)
2874    self._test_recall_at_top_k(
2875        self._predictions_idx, self._labels, k=5, expected=1.0 / 2,
2876        class_id=7)
2877
2878    # All classes: 12 labels, 7 correct.
2879    self._test_recall_at_k(
2880        self._predictions, self._labels, k=5, expected=7.0 / 12)
2881    self._test_recall_at_top_k(
2882        self._predictions_idx, self._labels, k=5, expected=7.0 / 12)
2883
2884  @test_util.run_deprecated_v1
2885  def test_3d_ignore_all(self):
2886    for class_id in xrange(10):
2887      self._test_recall_at_k(
2888          self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
2889          weights=[[0], [0]])
2890      self._test_recall_at_top_k(
2891          self._predictions_idx, self._labels, k=5, expected=NAN,
2892          class_id=class_id, weights=[[0], [0]])
2893      self._test_recall_at_k(
2894          self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
2895          weights=[[0, 0], [0, 0]])
2896      self._test_recall_at_top_k(
2897          self._predictions_idx, self._labels, k=5, expected=NAN,
2898          class_id=class_id, weights=[[0, 0], [0, 0]])
2899    self._test_recall_at_k(
2900        self._predictions, self._labels, k=5, expected=NAN, weights=[[0], [0]])
2901    self._test_recall_at_top_k(
2902        self._predictions_idx, self._labels, k=5, expected=NAN,
2903        weights=[[0], [0]])
2904    self._test_recall_at_k(
2905        self._predictions, self._labels, k=5, expected=NAN,
2906        weights=[[0, 0], [0, 0]])
2907    self._test_recall_at_top_k(
2908        self._predictions_idx, self._labels, k=5, expected=NAN,
2909        weights=[[0, 0], [0, 0]])
2910
2911  @test_util.run_deprecated_v1
2912  def test_3d_ignore_some(self):
2913    # Class 2: 2 labels, both correct.
2914    self._test_recall_at_k(
2915        self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
2916        weights=[[1], [0]])
2917    self._test_recall_at_top_k(
2918        self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
2919        class_id=2, weights=[[1], [0]])
2920
2921    # Class 2: 2 labels, both correct.
2922    self._test_recall_at_k(
2923        self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
2924        weights=[[0], [1]])
2925    self._test_recall_at_top_k(
2926        self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
2927        class_id=2, weights=[[0], [1]])
2928
2929    # Class 7: 1 label, correct.
2930    self._test_recall_at_k(
2931        self._predictions, self._labels, k=5, expected=1.0 / 1.0, class_id=7,
2932        weights=[[0], [1]])
2933    self._test_recall_at_top_k(
2934        self._predictions_idx, self._labels, k=5, expected=1.0 / 1.0,
2935        class_id=7, weights=[[0], [1]])
2936
2937    # Class 7: 1 label, incorrect.
2938    self._test_recall_at_k(
2939        self._predictions, self._labels, k=5, expected=0.0 / 1.0, class_id=7,
2940        weights=[[1], [0]])
2941    self._test_recall_at_top_k(
2942        self._predictions_idx, self._labels, k=5, expected=0.0 / 1.0,
2943        class_id=7, weights=[[1], [0]])
2944
2945    # Class 7: 2 labels, 1 correct.
2946    self._test_recall_at_k(
2947        self._predictions, self._labels, k=5, expected=1.0 / 2.0, class_id=7,
2948        weights=[[1, 0], [1, 0]])
2949    self._test_recall_at_top_k(
2950        self._predictions_idx, self._labels, k=5, expected=1.0 / 2.0,
2951        class_id=7, weights=[[1, 0], [1, 0]])
2952
2953    # Class 7: No labels.
2954    self._test_recall_at_k(
2955        self._predictions, self._labels, k=5, expected=NAN, class_id=7,
2956        weights=[[0, 1], [0, 1]])
2957    self._test_recall_at_top_k(
2958        self._predictions_idx, self._labels, k=5, expected=NAN, class_id=7,
2959        weights=[[0, 1], [0, 1]])
2960
2961
2962class MeanAbsoluteErrorTest(test.TestCase):
2963
2964  def setUp(self):
2965    ops.reset_default_graph()
2966
2967  @test_util.run_deprecated_v1
2968  def testVars(self):
2969    metrics.mean_absolute_error(
2970        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
2971    _assert_metric_variables(
2972        self, ('mean_absolute_error/count:0', 'mean_absolute_error/total:0'))
2973
2974  @test_util.run_deprecated_v1
2975  def testMetricsCollection(self):
2976    my_collection_name = '__metrics__'
2977    mean, _ = metrics.mean_absolute_error(
2978        predictions=array_ops.ones((10, 1)),
2979        labels=array_ops.ones((10, 1)),
2980        metrics_collections=[my_collection_name])
2981    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
2982
2983  @test_util.run_deprecated_v1
2984  def testUpdatesCollection(self):
2985    my_collection_name = '__updates__'
2986    _, update_op = metrics.mean_absolute_error(
2987        predictions=array_ops.ones((10, 1)),
2988        labels=array_ops.ones((10, 1)),
2989        updates_collections=[my_collection_name])
2990    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
2991
2992  @test_util.run_deprecated_v1
2993  def testValueTensorIsIdempotent(self):
2994    predictions = random_ops.random_normal((10, 3), seed=1)
2995    labels = random_ops.random_normal((10, 3), seed=2)
2996    error, update_op = metrics.mean_absolute_error(labels, predictions)
2997
2998    with self.cached_session():
2999      self.evaluate(variables.local_variables_initializer())
3000
3001      # Run several updates.
3002      for _ in range(10):
3003        self.evaluate(update_op)
3004
3005      # Then verify idempotency.
3006      initial_error = self.evaluate(error)
3007      for _ in range(10):
3008        self.assertEqual(initial_error, self.evaluate(error))
3009
3010  @test_util.run_deprecated_v1
3011  def testSingleUpdateWithErrorAndWeights(self):
3012    predictions = constant_op.constant(
3013        [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
3014    labels = constant_op.constant(
3015        [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
3016    weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
3017
3018    error, update_op = metrics.mean_absolute_error(labels, predictions, weights)
3019
3020    with self.cached_session():
3021      self.evaluate(variables.local_variables_initializer())
3022      self.assertEqual(3, self.evaluate(update_op))
3023      self.assertEqual(3, self.evaluate(error))
3024
3025
3026class MeanRelativeErrorTest(test.TestCase):
3027
3028  def setUp(self):
3029    ops.reset_default_graph()
3030
3031  @test_util.run_deprecated_v1
3032  def testVars(self):
3033    metrics.mean_relative_error(
3034        predictions=array_ops.ones((10, 1)),
3035        labels=array_ops.ones((10, 1)),
3036        normalizer=array_ops.ones((10, 1)))
3037    _assert_metric_variables(
3038        self, ('mean_relative_error/count:0', 'mean_relative_error/total:0'))
3039
3040  @test_util.run_deprecated_v1
3041  def testMetricsCollection(self):
3042    my_collection_name = '__metrics__'
3043    mean, _ = metrics.mean_relative_error(
3044        predictions=array_ops.ones((10, 1)),
3045        labels=array_ops.ones((10, 1)),
3046        normalizer=array_ops.ones((10, 1)),
3047        metrics_collections=[my_collection_name])
3048    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
3049
3050  @test_util.run_deprecated_v1
3051  def testUpdatesCollection(self):
3052    my_collection_name = '__updates__'
3053    _, update_op = metrics.mean_relative_error(
3054        predictions=array_ops.ones((10, 1)),
3055        labels=array_ops.ones((10, 1)),
3056        normalizer=array_ops.ones((10, 1)),
3057        updates_collections=[my_collection_name])
3058    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3059
3060  @test_util.run_deprecated_v1
3061  def testValueTensorIsIdempotent(self):
3062    predictions = random_ops.random_normal((10, 3), seed=1)
3063    labels = random_ops.random_normal((10, 3), seed=2)
3064    normalizer = random_ops.random_normal((10, 3), seed=3)
3065    error, update_op = metrics.mean_relative_error(labels, predictions,
3066                                                   normalizer)
3067
3068    with self.cached_session():
3069      self.evaluate(variables.local_variables_initializer())
3070
3071      # Run several updates.
3072      for _ in range(10):
3073        self.evaluate(update_op)
3074
3075      # Then verify idempotency.
3076      initial_error = self.evaluate(error)
3077      for _ in range(10):
3078        self.assertEqual(initial_error, self.evaluate(error))
3079
3080  @test_util.run_deprecated_v1
3081  def testSingleUpdateNormalizedByLabels(self):
3082    np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)
3083    np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32)
3084    expected_error = np.mean(
3085        np.divide(np.absolute(np_predictions - np_labels), np_labels))
3086
3087    predictions = constant_op.constant(
3088        np_predictions, shape=(1, 4), dtype=dtypes_lib.float32)
3089    labels = constant_op.constant(np_labels, shape=(1, 4))
3090
3091    error, update_op = metrics.mean_relative_error(
3092        labels, predictions, normalizer=labels)
3093
3094    with self.cached_session():
3095      self.evaluate(variables.local_variables_initializer())
3096      self.assertEqual(expected_error, self.evaluate(update_op))
3097      self.assertEqual(expected_error, self.evaluate(error))
3098
3099  @test_util.run_deprecated_v1
3100  def testSingleUpdateNormalizedByZeros(self):
3101    np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)
3102
3103    predictions = constant_op.constant(
3104        np_predictions, shape=(1, 4), dtype=dtypes_lib.float32)
3105    labels = constant_op.constant(
3106        [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
3107
3108    error, update_op = metrics.mean_relative_error(
3109        labels, predictions, normalizer=array_ops.zeros_like(labels))
3110
3111    with self.cached_session():
3112      self.evaluate(variables.local_variables_initializer())
3113      self.assertEqual(0.0, self.evaluate(update_op))
3114      self.assertEqual(0.0, self.evaluate(error))
3115
3116
3117class MeanSquaredErrorTest(test.TestCase):
3118
3119  def setUp(self):
3120    ops.reset_default_graph()
3121
3122  @test_util.run_deprecated_v1
3123  def testVars(self):
3124    metrics.mean_squared_error(
3125        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
3126    _assert_metric_variables(
3127        self, ('mean_squared_error/count:0', 'mean_squared_error/total:0'))
3128
3129  @test_util.run_deprecated_v1
3130  def testMetricsCollection(self):
3131    my_collection_name = '__metrics__'
3132    mean, _ = metrics.mean_squared_error(
3133        predictions=array_ops.ones((10, 1)),
3134        labels=array_ops.ones((10, 1)),
3135        metrics_collections=[my_collection_name])
3136    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
3137
3138  @test_util.run_deprecated_v1
3139  def testUpdatesCollection(self):
3140    my_collection_name = '__updates__'
3141    _, update_op = metrics.mean_squared_error(
3142        predictions=array_ops.ones((10, 1)),
3143        labels=array_ops.ones((10, 1)),
3144        updates_collections=[my_collection_name])
3145    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3146
3147  @test_util.run_deprecated_v1
3148  def testValueTensorIsIdempotent(self):
3149    predictions = random_ops.random_normal((10, 3), seed=1)
3150    labels = random_ops.random_normal((10, 3), seed=2)
3151    error, update_op = metrics.mean_squared_error(labels, predictions)
3152
3153    with self.cached_session():
3154      self.evaluate(variables.local_variables_initializer())
3155
3156      # Run several updates.
3157      for _ in range(10):
3158        self.evaluate(update_op)
3159
3160      # Then verify idempotency.
3161      initial_error = self.evaluate(error)
3162      for _ in range(10):
3163        self.assertEqual(initial_error, self.evaluate(error))
3164
3165  @test_util.run_deprecated_v1
3166  def testSingleUpdateZeroError(self):
3167    predictions = array_ops.zeros((1, 3), dtype=dtypes_lib.float32)
3168    labels = array_ops.zeros((1, 3), dtype=dtypes_lib.float32)
3169
3170    error, update_op = metrics.mean_squared_error(labels, predictions)
3171
3172    with self.cached_session():
3173      self.evaluate(variables.local_variables_initializer())
3174      self.assertEqual(0, self.evaluate(update_op))
3175      self.assertEqual(0, self.evaluate(error))
3176
3177  @test_util.run_deprecated_v1
3178  def testSingleUpdateWithError(self):
3179    predictions = constant_op.constant(
3180        [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
3181    labels = constant_op.constant(
3182        [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
3183
3184    error, update_op = metrics.mean_squared_error(labels, predictions)
3185
3186    with self.cached_session():
3187      self.evaluate(variables.local_variables_initializer())
3188      self.assertEqual(6, self.evaluate(update_op))
3189      self.assertEqual(6, self.evaluate(error))
3190
3191  @test_util.run_deprecated_v1
3192  def testSingleUpdateWithErrorAndWeights(self):
3193    predictions = constant_op.constant(
3194        [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
3195    labels = constant_op.constant(
3196        [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
3197    weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
3198
3199    error, update_op = metrics.mean_squared_error(labels, predictions, weights)
3200
3201    with self.cached_session():
3202      self.evaluate(variables.local_variables_initializer())
3203      self.assertEqual(13, self.evaluate(update_op))
3204      self.assertEqual(13, self.evaluate(error))
3205
3206  @test_util.run_deprecated_v1
3207  def testMultipleBatchesOfSizeOne(self):
3208    with self.cached_session() as sess:
3209      # Create the queue that populates the predictions.
3210      preds_queue = data_flow_ops.FIFOQueue(
3211          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3212      _enqueue_vector(sess, preds_queue, [10, 8, 6])
3213      _enqueue_vector(sess, preds_queue, [-4, 3, -1])
3214      predictions = preds_queue.dequeue()
3215
3216      # Create the queue that populates the labels.
3217      labels_queue = data_flow_ops.FIFOQueue(
3218          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3219      _enqueue_vector(sess, labels_queue, [1, 3, 2])
3220      _enqueue_vector(sess, labels_queue, [2, 4, 6])
3221      labels = labels_queue.dequeue()
3222
3223      error, update_op = metrics.mean_squared_error(labels, predictions)
3224
3225      self.evaluate(variables.local_variables_initializer())
3226      self.evaluate(update_op)
3227      self.assertAlmostEqual(208.0 / 6, self.evaluate(update_op), 5)
3228
3229      self.assertAlmostEqual(208.0 / 6, self.evaluate(error), 5)
3230
3231  @test_util.run_deprecated_v1
3232  def testMetricsComputedConcurrently(self):
3233    with self.cached_session() as sess:
3234      # Create the queue that populates one set of predictions.
3235      preds_queue0 = data_flow_ops.FIFOQueue(
3236          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3237      _enqueue_vector(sess, preds_queue0, [10, 8, 6])
3238      _enqueue_vector(sess, preds_queue0, [-4, 3, -1])
3239      predictions0 = preds_queue0.dequeue()
3240
3241      # Create the queue that populates one set of predictions.
3242      preds_queue1 = data_flow_ops.FIFOQueue(
3243          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3244      _enqueue_vector(sess, preds_queue1, [0, 1, 1])
3245      _enqueue_vector(sess, preds_queue1, [1, 1, 0])
3246      predictions1 = preds_queue1.dequeue()
3247
3248      # Create the queue that populates one set of labels.
3249      labels_queue0 = data_flow_ops.FIFOQueue(
3250          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3251      _enqueue_vector(sess, labels_queue0, [1, 3, 2])
3252      _enqueue_vector(sess, labels_queue0, [2, 4, 6])
3253      labels0 = labels_queue0.dequeue()
3254
3255      # Create the queue that populates another set of labels.
3256      labels_queue1 = data_flow_ops.FIFOQueue(
3257          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3258      _enqueue_vector(sess, labels_queue1, [-5, -3, -1])
3259      _enqueue_vector(sess, labels_queue1, [5, 4, 3])
3260      labels1 = labels_queue1.dequeue()
3261
3262      mse0, update_op0 = metrics.mean_squared_error(
3263          labels0, predictions0, name='msd0')
3264      mse1, update_op1 = metrics.mean_squared_error(
3265          labels1, predictions1, name='msd1')
3266
3267      self.evaluate(variables.local_variables_initializer())
3268      self.evaluate([update_op0, update_op1])
3269      self.evaluate([update_op0, update_op1])
3270
3271      mse0, mse1 = self.evaluate([mse0, mse1])
3272      self.assertAlmostEqual(208.0 / 6, mse0, 5)
3273      self.assertAlmostEqual(79.0 / 6, mse1, 5)
3274
3275  @test_util.run_deprecated_v1
3276  def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
3277    with self.cached_session() as sess:
3278      # Create the queue that populates the predictions.
3279      preds_queue = data_flow_ops.FIFOQueue(
3280          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3281      _enqueue_vector(sess, preds_queue, [10, 8, 6])
3282      _enqueue_vector(sess, preds_queue, [-4, 3, -1])
3283      predictions = preds_queue.dequeue()
3284
3285      # Create the queue that populates the labels.
3286      labels_queue = data_flow_ops.FIFOQueue(
3287          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
3288      _enqueue_vector(sess, labels_queue, [1, 3, 2])
3289      _enqueue_vector(sess, labels_queue, [2, 4, 6])
3290      labels = labels_queue.dequeue()
3291
3292      mae, ma_update_op = metrics.mean_absolute_error(labels, predictions)
3293      mse, ms_update_op = metrics.mean_squared_error(labels, predictions)
3294
3295      self.evaluate(variables.local_variables_initializer())
3296      self.evaluate([ma_update_op, ms_update_op])
3297      self.evaluate([ma_update_op, ms_update_op])
3298
3299      self.assertAlmostEqual(32.0 / 6, self.evaluate(mae), 5)
3300      self.assertAlmostEqual(208.0 / 6, self.evaluate(mse), 5)
3301
3302
3303class RootMeanSquaredErrorTest(test.TestCase):
3304
3305  def setUp(self):
3306    ops.reset_default_graph()
3307
3308  @test_util.run_deprecated_v1
3309  def testVars(self):
3310    metrics.root_mean_squared_error(
3311        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
3312    _assert_metric_variables(
3313        self,
3314        ('root_mean_squared_error/count:0', 'root_mean_squared_error/total:0'))
3315
3316  @test_util.run_deprecated_v1
3317  def testMetricsCollection(self):
3318    my_collection_name = '__metrics__'
3319    mean, _ = metrics.root_mean_squared_error(
3320        predictions=array_ops.ones((10, 1)),
3321        labels=array_ops.ones((10, 1)),
3322        metrics_collections=[my_collection_name])
3323    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
3324
3325  @test_util.run_deprecated_v1
3326  def testUpdatesCollection(self):
3327    my_collection_name = '__updates__'
3328    _, update_op = metrics.root_mean_squared_error(
3329        predictions=array_ops.ones((10, 1)),
3330        labels=array_ops.ones((10, 1)),
3331        updates_collections=[my_collection_name])
3332    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3333
3334  @test_util.run_deprecated_v1
3335  def testValueTensorIsIdempotent(self):
3336    predictions = random_ops.random_normal((10, 3), seed=1)
3337    labels = random_ops.random_normal((10, 3), seed=2)
3338    error, update_op = metrics.root_mean_squared_error(labels, predictions)
3339
3340    with self.cached_session():
3341      self.evaluate(variables.local_variables_initializer())
3342
3343      # Run several updates.
3344      for _ in range(10):
3345        self.evaluate(update_op)
3346
3347      # Then verify idempotency.
3348      initial_error = self.evaluate(error)
3349      for _ in range(10):
3350        self.assertEqual(initial_error, self.evaluate(error))
3351
3352  @test_util.run_deprecated_v1
3353  def testSingleUpdateZeroError(self):
3354    with self.cached_session():
3355      predictions = constant_op.constant(
3356          0.0, shape=(1, 3), dtype=dtypes_lib.float32)
3357      labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
3358
3359      rmse, update_op = metrics.root_mean_squared_error(labels, predictions)
3360
3361      self.evaluate(variables.local_variables_initializer())
3362      self.assertEqual(0, self.evaluate(update_op))
3363
3364      self.assertEqual(0, self.evaluate(rmse))
3365
3366  @test_util.run_deprecated_v1
3367  def testSingleUpdateWithError(self):
3368    with self.cached_session():
3369      predictions = constant_op.constant(
3370          [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
3371      labels = constant_op.constant(
3372          [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
3373
3374      rmse, update_op = metrics.root_mean_squared_error(labels, predictions)
3375
3376      self.evaluate(variables.local_variables_initializer())
3377      self.assertAlmostEqual(math.sqrt(6), self.evaluate(update_op), 5)
3378      self.assertAlmostEqual(math.sqrt(6), self.evaluate(rmse), 5)
3379
3380  @test_util.run_deprecated_v1
3381  def testSingleUpdateWithErrorAndWeights(self):
3382    with self.cached_session():
3383      predictions = constant_op.constant(
3384          [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
3385      labels = constant_op.constant(
3386          [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
3387      weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
3388
3389      rmse, update_op = metrics.root_mean_squared_error(labels, predictions,
3390                                                        weights)
3391
3392      self.evaluate(variables.local_variables_initializer())
3393      self.assertAlmostEqual(math.sqrt(13), self.evaluate(update_op))
3394
3395      self.assertAlmostEqual(math.sqrt(13), self.evaluate(rmse), 5)
3396
3397
3398def _reweight(predictions, labels, weights):
3399  return (np.concatenate([[p] * int(w) for p, w in zip(predictions, weights)]),
3400          np.concatenate([[l] * int(w) for l, w in zip(labels, weights)]))
3401
3402
3403class MeanCosineDistanceTest(test.TestCase):
3404
3405  def setUp(self):
3406    ops.reset_default_graph()
3407
3408  @test_util.run_deprecated_v1
3409  def testVars(self):
3410    metrics.mean_cosine_distance(
3411        predictions=array_ops.ones((10, 3)),
3412        labels=array_ops.ones((10, 3)),
3413        dim=1)
3414    _assert_metric_variables(self, (
3415        'mean_cosine_distance/count:0',
3416        'mean_cosine_distance/total:0',
3417    ))
3418
3419  @test_util.run_deprecated_v1
3420  def testMetricsCollection(self):
3421    my_collection_name = '__metrics__'
3422    mean, _ = metrics.mean_cosine_distance(
3423        predictions=array_ops.ones((10, 3)),
3424        labels=array_ops.ones((10, 3)),
3425        dim=1,
3426        metrics_collections=[my_collection_name])
3427    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
3428
3429  @test_util.run_deprecated_v1
3430  def testUpdatesCollection(self):
3431    my_collection_name = '__updates__'
3432    _, update_op = metrics.mean_cosine_distance(
3433        predictions=array_ops.ones((10, 3)),
3434        labels=array_ops.ones((10, 3)),
3435        dim=1,
3436        updates_collections=[my_collection_name])
3437    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3438
3439  @test_util.run_deprecated_v1
3440  def testValueTensorIsIdempotent(self):
3441    predictions = random_ops.random_normal((10, 3), seed=1)
3442    labels = random_ops.random_normal((10, 3), seed=2)
3443    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1)
3444
3445    with self.cached_session():
3446      self.evaluate(variables.local_variables_initializer())
3447
3448      # Run several updates.
3449      for _ in range(10):
3450        self.evaluate(update_op)
3451
3452      # Then verify idempotency.
3453      initial_error = self.evaluate(error)
3454      for _ in range(10):
3455        self.assertEqual(initial_error, self.evaluate(error))
3456
3457  @test_util.run_deprecated_v1
3458  def testSingleUpdateZeroError(self):
3459    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
3460
3461    predictions = constant_op.constant(
3462        np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32)
3463    labels = constant_op.constant(
3464        np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32)
3465
3466    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
3467
3468    with self.cached_session():
3469      self.evaluate(variables.local_variables_initializer())
3470      self.assertEqual(0, self.evaluate(update_op))
3471      self.assertEqual(0, self.evaluate(error))
3472
3473  @test_util.run_deprecated_v1
3474  def testSingleUpdateWithError1(self):
3475    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
3476    np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
3477
3478    predictions = constant_op.constant(
3479        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3480    labels = constant_op.constant(
3481        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3482
3483    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
3484
3485    with self.cached_session():
3486      self.evaluate(variables.local_variables_initializer())
3487      self.assertAlmostEqual(1, self.evaluate(update_op), 5)
3488      self.assertAlmostEqual(1, self.evaluate(error), 5)
3489
3490  @test_util.run_deprecated_v1
3491  def testSingleUpdateWithError2(self):
3492    np_predictions = np.matrix(
3493        ('0.819031913261206 0.567041924552012 0.087465312324590;'
3494         '-0.665139432070255 -0.739487441769973 -0.103671883216994;'
3495         '0.707106781186548 -0.707106781186548 0'))
3496    np_labels = np.matrix(
3497        ('0.819031913261206 0.567041924552012 0.087465312324590;'
3498         '0.665139432070255 0.739487441769973 0.103671883216994;'
3499         '0.707106781186548 0.707106781186548 0'))
3500
3501    predictions = constant_op.constant(
3502        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3503    labels = constant_op.constant(
3504        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3505    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
3506
3507    with self.cached_session():
3508      self.evaluate(variables.local_variables_initializer())
3509      self.assertAlmostEqual(1.0, self.evaluate(update_op), 5)
3510      self.assertAlmostEqual(1.0, self.evaluate(error), 5)
3511
3512  @test_util.run_deprecated_v1
3513  def testSingleUpdateWithErrorAndWeights1(self):
3514    np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
3515    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
3516
3517    predictions = constant_op.constant(
3518        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3519    labels = constant_op.constant(
3520        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3521    weights = constant_op.constant(
3522        [1, 0, 0], shape=(3, 1, 1), dtype=dtypes_lib.float32)
3523
3524    error, update_op = metrics.mean_cosine_distance(
3525        labels, predictions, dim=2, weights=weights)
3526
3527    with self.cached_session():
3528      self.evaluate(variables.local_variables_initializer())
3529      self.assertEqual(0, self.evaluate(update_op))
3530      self.assertEqual(0, self.evaluate(error))
3531
3532  @test_util.run_deprecated_v1
3533  def testSingleUpdateWithErrorAndWeights2(self):
3534    np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
3535    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
3536
3537    predictions = constant_op.constant(
3538        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3539    labels = constant_op.constant(
3540        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
3541    weights = constant_op.constant(
3542        [0, 1, 1], shape=(3, 1, 1), dtype=dtypes_lib.float32)
3543
3544    error, update_op = metrics.mean_cosine_distance(
3545        labels, predictions, dim=2, weights=weights)
3546
3547    with self.cached_session():
3548      self.evaluate(variables.local_variables_initializer())
3549      self.assertEqual(1.5, self.evaluate(update_op))
3550      self.assertEqual(1.5, self.evaluate(error))
3551
3552
3553class PcntBelowThreshTest(test.TestCase):
3554
3555  def setUp(self):
3556    ops.reset_default_graph()
3557
3558  @test_util.run_deprecated_v1
3559  def testVars(self):
3560    metrics.percentage_below(values=array_ops.ones((10,)), threshold=2)
3561    _assert_metric_variables(self, (
3562        'percentage_below_threshold/count:0',
3563        'percentage_below_threshold/total:0',
3564    ))
3565
3566  @test_util.run_deprecated_v1
3567  def testMetricsCollection(self):
3568    my_collection_name = '__metrics__'
3569    mean, _ = metrics.percentage_below(
3570        values=array_ops.ones((10,)),
3571        threshold=2,
3572        metrics_collections=[my_collection_name])
3573    self.assertListEqual(ops.get_collection(my_collection_name), [mean])
3574
3575  @test_util.run_deprecated_v1
3576  def testUpdatesCollection(self):
3577    my_collection_name = '__updates__'
3578    _, update_op = metrics.percentage_below(
3579        values=array_ops.ones((10,)),
3580        threshold=2,
3581        updates_collections=[my_collection_name])
3582    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3583
3584  @test_util.run_deprecated_v1
3585  def testOneUpdate(self):
3586    with self.cached_session():
3587      values = constant_op.constant(
3588          [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
3589
3590      pcnt0, update_op0 = metrics.percentage_below(values, 100, name='high')
3591      pcnt1, update_op1 = metrics.percentage_below(values, 7, name='medium')
3592      pcnt2, update_op2 = metrics.percentage_below(values, 1, name='low')
3593
3594      self.evaluate(variables.local_variables_initializer())
3595      self.evaluate([update_op0, update_op1, update_op2])
3596
3597      pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2])
3598      self.assertAlmostEqual(1.0, pcnt0, 5)
3599      self.assertAlmostEqual(0.75, pcnt1, 5)
3600      self.assertAlmostEqual(0.0, pcnt2, 5)
3601
3602  @test_util.run_deprecated_v1
3603  def testSomePresentOneUpdate(self):
3604    with self.cached_session():
3605      values = constant_op.constant(
3606          [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
3607      weights = constant_op.constant(
3608          [1, 0, 0, 1], shape=(1, 4), dtype=dtypes_lib.float32)
3609
3610      pcnt0, update_op0 = metrics.percentage_below(
3611          values, 100, weights=weights, name='high')
3612      pcnt1, update_op1 = metrics.percentage_below(
3613          values, 7, weights=weights, name='medium')
3614      pcnt2, update_op2 = metrics.percentage_below(
3615          values, 1, weights=weights, name='low')
3616
3617      self.evaluate(variables.local_variables_initializer())
3618      self.assertListEqual([1.0, 0.5, 0.0],
3619                           self.evaluate([update_op0, update_op1, update_op2]))
3620
3621      pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2])
3622      self.assertAlmostEqual(1.0, pcnt0, 5)
3623      self.assertAlmostEqual(0.5, pcnt1, 5)
3624      self.assertAlmostEqual(0.0, pcnt2, 5)
3625
3626
3627class MeanIOUTest(test.TestCase):
3628
3629  def setUp(self):
3630    np.random.seed(1)
3631    ops.reset_default_graph()
3632
3633  @test_util.run_deprecated_v1
3634  def testVars(self):
3635    metrics.mean_iou(
3636        predictions=array_ops.ones([10, 1]),
3637        labels=array_ops.ones([10, 1]),
3638        num_classes=2)
3639    _assert_metric_variables(self, ('mean_iou/total_confusion_matrix:0',))
3640
3641  @test_util.run_deprecated_v1
3642  def testMetricsCollections(self):
3643    my_collection_name = '__metrics__'
3644    mean_iou, _ = metrics.mean_iou(
3645        predictions=array_ops.ones([10, 1]),
3646        labels=array_ops.ones([10, 1]),
3647        num_classes=2,
3648        metrics_collections=[my_collection_name])
3649    self.assertListEqual(ops.get_collection(my_collection_name), [mean_iou])
3650
3651  @test_util.run_deprecated_v1
3652  def testUpdatesCollection(self):
3653    my_collection_name = '__updates__'
3654    _, update_op = metrics.mean_iou(
3655        predictions=array_ops.ones([10, 1]),
3656        labels=array_ops.ones([10, 1]),
3657        num_classes=2,
3658        updates_collections=[my_collection_name])
3659    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3660
3661  @test_util.run_deprecated_v1
3662  def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
3663    predictions = array_ops.ones([10, 3])
3664    labels = array_ops.ones([10, 4])
3665    with self.assertRaises(ValueError):
3666      metrics.mean_iou(labels, predictions, num_classes=2)
3667
3668  @test_util.run_deprecated_v1
3669  def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
3670    predictions = array_ops.ones([10])
3671    labels = array_ops.ones([10])
3672    weights = array_ops.zeros([9])
3673    with self.assertRaises(ValueError):
3674      metrics.mean_iou(labels, predictions, num_classes=2, weights=weights)
3675
3676  @test_util.run_deprecated_v1
3677  def testValueTensorIsIdempotent(self):
3678    num_classes = 3
3679    predictions = random_ops.random_uniform(
3680        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
3681    labels = random_ops.random_uniform(
3682        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
3683    mean_iou, update_op = metrics.mean_iou(
3684        labels, predictions, num_classes=num_classes)
3685
3686    with self.cached_session():
3687      self.evaluate(variables.local_variables_initializer())
3688
3689      # Run several updates.
3690      for _ in range(10):
3691        self.evaluate(update_op)
3692
3693      # Then verify idempotency.
3694      initial_mean_iou = self.evaluate(mean_iou)
3695      for _ in range(10):
3696        self.assertEqual(initial_mean_iou, self.evaluate(mean_iou))
3697
3698  @test_util.run_deprecated_v1
3699  def testMultipleUpdates(self):
3700    num_classes = 3
3701    with self.cached_session() as sess:
3702      # Create the queue that populates the predictions.
3703      preds_queue = data_flow_ops.FIFOQueue(
3704          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
3705      _enqueue_vector(sess, preds_queue, [0])
3706      _enqueue_vector(sess, preds_queue, [1])
3707      _enqueue_vector(sess, preds_queue, [2])
3708      _enqueue_vector(sess, preds_queue, [1])
3709      _enqueue_vector(sess, preds_queue, [0])
3710      predictions = preds_queue.dequeue()
3711
3712      # Create the queue that populates the labels.
3713      labels_queue = data_flow_ops.FIFOQueue(
3714          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
3715      _enqueue_vector(sess, labels_queue, [0])
3716      _enqueue_vector(sess, labels_queue, [1])
3717      _enqueue_vector(sess, labels_queue, [1])
3718      _enqueue_vector(sess, labels_queue, [2])
3719      _enqueue_vector(sess, labels_queue, [1])
3720      labels = labels_queue.dequeue()
3721
3722      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3723
3724      self.evaluate(variables.local_variables_initializer())
3725      for _ in range(5):
3726        self.evaluate(update_op)
3727      desired_output = np.mean([1.0 / 2.0, 1.0 / 4.0, 0.])
3728      self.assertEqual(desired_output, self.evaluate(miou))
3729
3730  @test_util.run_deprecated_v1
3731  def testMultipleUpdatesWithWeights(self):
3732    num_classes = 2
3733    with self.cached_session() as sess:
3734      # Create the queue that populates the predictions.
3735      preds_queue = data_flow_ops.FIFOQueue(
3736          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
3737      _enqueue_vector(sess, preds_queue, [0])
3738      _enqueue_vector(sess, preds_queue, [1])
3739      _enqueue_vector(sess, preds_queue, [0])
3740      _enqueue_vector(sess, preds_queue, [1])
3741      _enqueue_vector(sess, preds_queue, [0])
3742      _enqueue_vector(sess, preds_queue, [1])
3743      predictions = preds_queue.dequeue()
3744
3745      # Create the queue that populates the labels.
3746      labels_queue = data_flow_ops.FIFOQueue(
3747          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
3748      _enqueue_vector(sess, labels_queue, [0])
3749      _enqueue_vector(sess, labels_queue, [1])
3750      _enqueue_vector(sess, labels_queue, [1])
3751      _enqueue_vector(sess, labels_queue, [0])
3752      _enqueue_vector(sess, labels_queue, [0])
3753      _enqueue_vector(sess, labels_queue, [1])
3754      labels = labels_queue.dequeue()
3755
3756      # Create the queue that populates the weights.
3757      weights_queue = data_flow_ops.FIFOQueue(
3758          6, dtypes=dtypes_lib.float32, shapes=(1, 1))
3759      _enqueue_vector(sess, weights_queue, [1.0])
3760      _enqueue_vector(sess, weights_queue, [1.0])
3761      _enqueue_vector(sess, weights_queue, [1.0])
3762      _enqueue_vector(sess, weights_queue, [0.0])
3763      _enqueue_vector(sess, weights_queue, [1.0])
3764      _enqueue_vector(sess, weights_queue, [0.0])
3765      weights = weights_queue.dequeue()
3766
3767      mean_iou, update_op = metrics.mean_iou(
3768          labels, predictions, num_classes, weights=weights)
3769
3770      variables.local_variables_initializer().run()
3771      for _ in range(6):
3772        self.evaluate(update_op)
3773      desired_output = np.mean([2.0 / 3.0, 1.0 / 2.0])
3774      self.assertAlmostEqual(desired_output, self.evaluate(mean_iou))
3775
3776  @test_util.run_deprecated_v1
3777  def testMultipleUpdatesWithMissingClass(self):
3778    # Test the case where there are no predictions and labels for
3779    # one class, and thus there is one row and one column with
3780    # zero entries in the confusion matrix.
3781    num_classes = 3
3782    with self.cached_session() as sess:
3783      # Create the queue that populates the predictions.
3784      # There is no prediction for class 2.
3785      preds_queue = data_flow_ops.FIFOQueue(
3786          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
3787      _enqueue_vector(sess, preds_queue, [0])
3788      _enqueue_vector(sess, preds_queue, [1])
3789      _enqueue_vector(sess, preds_queue, [1])
3790      _enqueue_vector(sess, preds_queue, [1])
3791      _enqueue_vector(sess, preds_queue, [0])
3792      predictions = preds_queue.dequeue()
3793
3794      # Create the queue that populates the labels.
3795      # There is label for class 2.
3796      labels_queue = data_flow_ops.FIFOQueue(
3797          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
3798      _enqueue_vector(sess, labels_queue, [0])
3799      _enqueue_vector(sess, labels_queue, [1])
3800      _enqueue_vector(sess, labels_queue, [1])
3801      _enqueue_vector(sess, labels_queue, [0])
3802      _enqueue_vector(sess, labels_queue, [1])
3803      labels = labels_queue.dequeue()
3804
3805      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3806
3807      self.evaluate(variables.local_variables_initializer())
3808      for _ in range(5):
3809        self.evaluate(update_op)
3810      desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0])
3811      self.assertAlmostEqual(desired_output, self.evaluate(miou))
3812
3813  @test_util.run_deprecated_v1
3814  def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
3815    predictions = array_ops.concat(
3816        [
3817            constant_op.constant(
3818                0, shape=[5]), constant_op.constant(
3819                    1, shape=[5])
3820        ],
3821        0)
3822    labels = array_ops.concat(
3823        [
3824            constant_op.constant(
3825                0, shape=[3]), constant_op.constant(
3826                    1, shape=[7])
3827        ],
3828        0)
3829    num_classes = 2
3830    with self.cached_session():
3831      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3832      self.evaluate(variables.local_variables_initializer())
3833      confusion_matrix = self.evaluate(update_op)
3834      self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
3835      desired_miou = np.mean([3. / 5., 5. / 7.])
3836      self.assertAlmostEqual(desired_miou, self.evaluate(miou))
3837
3838  @test_util.run_deprecated_v1
3839  def testAllCorrect(self):
3840    predictions = array_ops.zeros([40])
3841    labels = array_ops.zeros([40])
3842    num_classes = 1
3843    with self.cached_session():
3844      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3845      self.evaluate(variables.local_variables_initializer())
3846      self.assertEqual(40, self.evaluate(update_op)[0])
3847      self.assertEqual(1.0, self.evaluate(miou))
3848
3849  @test_util.run_deprecated_v1
3850  def testAllWrong(self):
3851    predictions = array_ops.zeros([40])
3852    labels = array_ops.ones([40])
3853    num_classes = 2
3854    with self.cached_session():
3855      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3856      self.evaluate(variables.local_variables_initializer())
3857      self.assertAllEqual([[0, 0], [40, 0]], update_op)
3858      self.assertEqual(0., self.evaluate(miou))
3859
3860  @test_util.run_deprecated_v1
3861  def testResultsWithSomeMissing(self):
3862    predictions = array_ops.concat(
3863        [
3864            constant_op.constant(
3865                0, shape=[5]), constant_op.constant(
3866                    1, shape=[5])
3867        ],
3868        0)
3869    labels = array_ops.concat(
3870        [
3871            constant_op.constant(
3872                0, shape=[3]), constant_op.constant(
3873                    1, shape=[7])
3874        ],
3875        0)
3876    num_classes = 2
3877    weights = array_ops.concat(
3878        [
3879            constant_op.constant(
3880                0, shape=[1]), constant_op.constant(
3881                    1, shape=[8]), constant_op.constant(
3882                        0, shape=[1])
3883        ],
3884        0)
3885    with self.cached_session():
3886      miou, update_op = metrics.mean_iou(
3887          labels, predictions, num_classes, weights=weights)
3888      self.evaluate(variables.local_variables_initializer())
3889      self.assertAllEqual([[2, 0], [2, 4]], update_op)
3890      desired_miou = np.mean([2. / 4., 4. / 6.])
3891      self.assertAlmostEqual(desired_miou, self.evaluate(miou))
3892
3893  @test_util.run_deprecated_v1
3894  def testMissingClassInLabels(self):
3895    labels = constant_op.constant([
3896        [[0, 0, 1, 1, 0, 0],
3897         [1, 0, 0, 0, 0, 1]],
3898        [[1, 1, 1, 1, 1, 1],
3899         [0, 0, 0, 0, 0, 0]]])
3900    predictions = constant_op.constant([
3901        [[0, 0, 2, 1, 1, 0],
3902         [0, 1, 2, 2, 0, 1]],
3903        [[0, 0, 2, 1, 1, 1],
3904         [1, 1, 2, 0, 0, 0]]])
3905    num_classes = 3
3906    with self.cached_session():
3907      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3908      self.evaluate(variables.local_variables_initializer())
3909      self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op)
3910      self.assertAlmostEqual(
3911          1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
3912          self.evaluate(miou))
3913
3914  @test_util.run_deprecated_v1
3915  def testMissingClassOverallSmall(self):
3916    labels = constant_op.constant([0])
3917    predictions = constant_op.constant([0])
3918    num_classes = 2
3919    with self.cached_session():
3920      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3921      self.evaluate(variables.local_variables_initializer())
3922      self.assertAllEqual([[1, 0], [0, 0]], update_op)
3923      self.assertAlmostEqual(1, self.evaluate(miou))
3924
3925  @test_util.run_deprecated_v1
3926  def testMissingClassOverallLarge(self):
3927    labels = constant_op.constant([
3928        [[0, 0, 1, 1, 0, 0],
3929         [1, 0, 0, 0, 0, 1]],
3930        [[1, 1, 1, 1, 1, 1],
3931         [0, 0, 0, 0, 0, 0]]])
3932    predictions = constant_op.constant([
3933        [[0, 0, 1, 1, 0, 0],
3934         [1, 1, 0, 0, 1, 1]],
3935        [[0, 0, 0, 1, 1, 1],
3936         [1, 1, 1, 0, 0, 0]]])
3937    num_classes = 3
3938    with self.cached_session():
3939      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
3940      self.evaluate(variables.local_variables_initializer())
3941      self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op)
3942      self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)),
3943                             self.evaluate(miou))
3944
3945
3946class MeanPerClassAccuracyTest(test.TestCase):
3947
3948  def setUp(self):
3949    np.random.seed(1)
3950    ops.reset_default_graph()
3951
3952  @test_util.run_deprecated_v1
3953  def testVars(self):
3954    metrics.mean_per_class_accuracy(
3955        predictions=array_ops.ones([10, 1]),
3956        labels=array_ops.ones([10, 1]),
3957        num_classes=2)
3958    _assert_metric_variables(self, ('mean_accuracy/count:0',
3959                                    'mean_accuracy/total:0'))
3960
3961  @test_util.run_deprecated_v1
3962  def testMetricsCollections(self):
3963    my_collection_name = '__metrics__'
3964    mean_accuracy, _ = metrics.mean_per_class_accuracy(
3965        predictions=array_ops.ones([10, 1]),
3966        labels=array_ops.ones([10, 1]),
3967        num_classes=2,
3968        metrics_collections=[my_collection_name])
3969    self.assertListEqual(
3970        ops.get_collection(my_collection_name), [mean_accuracy])
3971
3972  @test_util.run_deprecated_v1
3973  def testUpdatesCollection(self):
3974    my_collection_name = '__updates__'
3975    _, update_op = metrics.mean_per_class_accuracy(
3976        predictions=array_ops.ones([10, 1]),
3977        labels=array_ops.ones([10, 1]),
3978        num_classes=2,
3979        updates_collections=[my_collection_name])
3980    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
3981
3982  @test_util.run_deprecated_v1
3983  def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
3984    predictions = array_ops.ones([10, 3])
3985    labels = array_ops.ones([10, 4])
3986    with self.assertRaises(ValueError):
3987      metrics.mean_per_class_accuracy(labels, predictions, num_classes=2)
3988
3989  @test_util.run_deprecated_v1
3990  def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
3991    predictions = array_ops.ones([10])
3992    labels = array_ops.ones([10])
3993    weights = array_ops.zeros([9])
3994    with self.assertRaises(ValueError):
3995      metrics.mean_per_class_accuracy(
3996          labels, predictions, num_classes=2, weights=weights)
3997
3998  @test_util.run_deprecated_v1
3999  def testValueTensorIsIdempotent(self):
4000    num_classes = 3
4001    predictions = random_ops.random_uniform(
4002        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
4003    labels = random_ops.random_uniform(
4004        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
4005    mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4006        labels, predictions, num_classes=num_classes)
4007
4008    with self.cached_session():
4009      self.evaluate(variables.local_variables_initializer())
4010
4011      # Run several updates.
4012      for _ in range(10):
4013        self.evaluate(update_op)
4014
4015      # Then verify idempotency.
4016      initial_mean_accuracy = self.evaluate(mean_accuracy)
4017      for _ in range(10):
4018        self.assertEqual(initial_mean_accuracy, self.evaluate(mean_accuracy))
4019
4020    num_classes = 3
4021    with self.cached_session() as sess:
4022      # Create the queue that populates the predictions.
4023      preds_queue = data_flow_ops.FIFOQueue(
4024          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
4025      _enqueue_vector(sess, preds_queue, [0])
4026      _enqueue_vector(sess, preds_queue, [1])
4027      _enqueue_vector(sess, preds_queue, [2])
4028      _enqueue_vector(sess, preds_queue, [1])
4029      _enqueue_vector(sess, preds_queue, [0])
4030      predictions = preds_queue.dequeue()
4031
4032      # Create the queue that populates the labels.
4033      labels_queue = data_flow_ops.FIFOQueue(
4034          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
4035      _enqueue_vector(sess, labels_queue, [0])
4036      _enqueue_vector(sess, labels_queue, [1])
4037      _enqueue_vector(sess, labels_queue, [1])
4038      _enqueue_vector(sess, labels_queue, [2])
4039      _enqueue_vector(sess, labels_queue, [1])
4040      labels = labels_queue.dequeue()
4041
4042      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4043          labels, predictions, num_classes)
4044
4045      self.evaluate(variables.local_variables_initializer())
4046      for _ in range(5):
4047        self.evaluate(update_op)
4048      desired_output = np.mean([1.0, 1.0 / 3.0, 0.0])
4049      self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))
4050
4051  @test_util.run_deprecated_v1
4052  def testMultipleUpdatesWithWeights(self):
4053    num_classes = 2
4054    with self.cached_session() as sess:
4055      # Create the queue that populates the predictions.
4056      preds_queue = data_flow_ops.FIFOQueue(
4057          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
4058      _enqueue_vector(sess, preds_queue, [0])
4059      _enqueue_vector(sess, preds_queue, [1])
4060      _enqueue_vector(sess, preds_queue, [0])
4061      _enqueue_vector(sess, preds_queue, [1])
4062      _enqueue_vector(sess, preds_queue, [0])
4063      _enqueue_vector(sess, preds_queue, [1])
4064      predictions = preds_queue.dequeue()
4065
4066      # Create the queue that populates the labels.
4067      labels_queue = data_flow_ops.FIFOQueue(
4068          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
4069      _enqueue_vector(sess, labels_queue, [0])
4070      _enqueue_vector(sess, labels_queue, [1])
4071      _enqueue_vector(sess, labels_queue, [1])
4072      _enqueue_vector(sess, labels_queue, [0])
4073      _enqueue_vector(sess, labels_queue, [0])
4074      _enqueue_vector(sess, labels_queue, [1])
4075      labels = labels_queue.dequeue()
4076
4077      # Create the queue that populates the weights.
4078      weights_queue = data_flow_ops.FIFOQueue(
4079          6, dtypes=dtypes_lib.float32, shapes=(1, 1))
4080      _enqueue_vector(sess, weights_queue, [1.0])
4081      _enqueue_vector(sess, weights_queue, [0.5])
4082      _enqueue_vector(sess, weights_queue, [1.0])
4083      _enqueue_vector(sess, weights_queue, [0.0])
4084      _enqueue_vector(sess, weights_queue, [1.0])
4085      _enqueue_vector(sess, weights_queue, [0.0])
4086      weights = weights_queue.dequeue()
4087
4088      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4089          labels, predictions, num_classes, weights=weights)
4090
4091      variables.local_variables_initializer().run()
4092      for _ in range(6):
4093        self.evaluate(update_op)
4094      desired_output = np.mean([2.0 / 2.0, 0.5 / 1.5])
4095      self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))
4096
4097  @test_util.run_deprecated_v1
4098  def testMultipleUpdatesWithMissingClass(self):
4099    # Test the case where there are no predictions and labels for
4100    # one class, and thus there is one row and one column with
4101    # zero entries in the confusion matrix.
4102    num_classes = 3
4103    with self.cached_session() as sess:
4104      # Create the queue that populates the predictions.
4105      # There is no prediction for class 2.
4106      preds_queue = data_flow_ops.FIFOQueue(
4107          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
4108      _enqueue_vector(sess, preds_queue, [0])
4109      _enqueue_vector(sess, preds_queue, [1])
4110      _enqueue_vector(sess, preds_queue, [1])
4111      _enqueue_vector(sess, preds_queue, [1])
4112      _enqueue_vector(sess, preds_queue, [0])
4113      predictions = preds_queue.dequeue()
4114
4115      # Create the queue that populates the labels.
4116      # There is label for class 2.
4117      labels_queue = data_flow_ops.FIFOQueue(
4118          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
4119      _enqueue_vector(sess, labels_queue, [0])
4120      _enqueue_vector(sess, labels_queue, [1])
4121      _enqueue_vector(sess, labels_queue, [1])
4122      _enqueue_vector(sess, labels_queue, [0])
4123      _enqueue_vector(sess, labels_queue, [1])
4124      labels = labels_queue.dequeue()
4125
4126      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4127          labels, predictions, num_classes)
4128
4129      self.evaluate(variables.local_variables_initializer())
4130      for _ in range(5):
4131        self.evaluate(update_op)
4132      desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.])
4133      self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))
4134
4135  @test_util.run_deprecated_v1
4136  def testAllCorrect(self):
4137    predictions = array_ops.zeros([40])
4138    labels = array_ops.zeros([40])
4139    num_classes = 1
4140    with self.cached_session():
4141      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4142          labels, predictions, num_classes)
4143      self.evaluate(variables.local_variables_initializer())
4144      self.assertEqual(1.0, self.evaluate(update_op)[0])
4145      self.assertEqual(1.0, self.evaluate(mean_accuracy))
4146
4147  @test_util.run_deprecated_v1
4148  def testAllWrong(self):
4149    predictions = array_ops.zeros([40])
4150    labels = array_ops.ones([40])
4151    num_classes = 2
4152    with self.cached_session():
4153      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4154          labels, predictions, num_classes)
4155      self.evaluate(variables.local_variables_initializer())
4156      self.assertAllEqual([0.0, 0.0], update_op)
4157      self.assertEqual(0., self.evaluate(mean_accuracy))
4158
4159  @test_util.run_deprecated_v1
4160  def testResultsWithSomeMissing(self):
4161    predictions = array_ops.concat([
4162        constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5])
4163    ], 0)
4164    labels = array_ops.concat([
4165        constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7])
4166    ], 0)
4167    num_classes = 2
4168    weights = array_ops.concat([
4169        constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]),
4170        constant_op.constant(0, shape=[1])
4171    ], 0)
4172    with self.cached_session():
4173      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
4174          labels, predictions, num_classes, weights=weights)
4175      self.evaluate(variables.local_variables_initializer())
4176      desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32)
4177      self.assertAllEqual(desired_accuracy, update_op)
4178      desired_mean_accuracy = np.mean(desired_accuracy)
4179      self.assertAlmostEqual(desired_mean_accuracy,
4180                             self.evaluate(mean_accuracy))
4181
4182
4183class FalseNegativesTest(test.TestCase):
4184
4185  def setUp(self):
4186    np.random.seed(1)
4187    ops.reset_default_graph()
4188
4189  @test_util.run_deprecated_v1
4190  def testVars(self):
4191    metrics.false_negatives(
4192        labels=(0, 1, 0, 1),
4193        predictions=(0, 0, 1, 1))
4194    _assert_metric_variables(self, ('false_negatives/count:0',))
4195
4196  @test_util.run_deprecated_v1
4197  def testUnweighted(self):
4198    labels = constant_op.constant(((0, 1, 0, 1, 0),
4199                                   (0, 0, 1, 1, 1),
4200                                   (1, 1, 1, 1, 0),
4201                                   (0, 0, 0, 0, 1)))
4202    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4203                                        (1, 1, 1, 1, 1),
4204                                        (0, 1, 0, 1, 0),
4205                                        (1, 1, 1, 1, 1)))
4206    tn, tn_update_op = metrics.false_negatives(
4207        labels=labels, predictions=predictions)
4208
4209    with self.cached_session():
4210      self.evaluate(variables.local_variables_initializer())
4211      self.assertAllClose(0., tn)
4212      self.assertAllClose(3., tn_update_op)
4213      self.assertAllClose(3., tn)
4214
4215  @test_util.run_deprecated_v1
4216  def testWeighted(self):
4217    labels = constant_op.constant(((0, 1, 0, 1, 0),
4218                                   (0, 0, 1, 1, 1),
4219                                   (1, 1, 1, 1, 0),
4220                                   (0, 0, 0, 0, 1)))
4221    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4222                                        (1, 1, 1, 1, 1),
4223                                        (0, 1, 0, 1, 0),
4224                                        (1, 1, 1, 1, 1)))
4225    weights = constant_op.constant((1., 1.5, 2., 2.5))
4226    tn, tn_update_op = metrics.false_negatives(
4227        labels=labels, predictions=predictions, weights=weights)
4228
4229    with self.cached_session():
4230      self.evaluate(variables.local_variables_initializer())
4231      self.assertAllClose(0., tn)
4232      self.assertAllClose(5., tn_update_op)
4233      self.assertAllClose(5., tn)
4234
4235
4236class FalseNegativesAtThresholdsTest(test.TestCase):
4237
4238  def setUp(self):
4239    np.random.seed(1)
4240    ops.reset_default_graph()
4241
4242  @test_util.run_deprecated_v1
4243  def testVars(self):
4244    metrics.false_negatives_at_thresholds(
4245        predictions=array_ops.ones((10, 1)),
4246        labels=array_ops.ones((10, 1)),
4247        thresholds=[0.15, 0.5, 0.85])
4248    _assert_metric_variables(self, ('false_negatives/false_negatives:0',))
4249
4250  @test_util.run_deprecated_v1
4251  def testUnweighted(self):
4252    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4253                                        (0.2, 0.9, 0.7, 0.6),
4254                                        (0.1, 0.2, 0.4, 0.3)))
4255    labels = constant_op.constant(((0, 1, 1, 0),
4256                                   (1, 0, 0, 0),
4257                                   (0, 0, 0, 0)))
4258    fn, fn_update_op = metrics.false_negatives_at_thresholds(
4259        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
4260
4261    with self.cached_session():
4262      self.evaluate(variables.local_variables_initializer())
4263      self.assertAllEqual((0, 0, 0), fn)
4264      self.assertAllEqual((0, 2, 3), fn_update_op)
4265      self.assertAllEqual((0, 2, 3), fn)
4266
4267  @test_util.run_deprecated_v1
4268  def testWeighted(self):
4269    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4270                                        (0.2, 0.9, 0.7, 0.6),
4271                                        (0.1, 0.2, 0.4, 0.3)))
4272    labels = constant_op.constant(((0, 1, 1, 0),
4273                                   (1, 0, 0, 0),
4274                                   (0, 0, 0, 0)))
4275    fn, fn_update_op = metrics.false_negatives_at_thresholds(
4276        predictions=predictions,
4277        labels=labels,
4278        weights=((3.0,), (5.0,), (7.0,)),
4279        thresholds=[0.15, 0.5, 0.85])
4280
4281    with self.cached_session():
4282      self.evaluate(variables.local_variables_initializer())
4283      self.assertAllEqual((0.0, 0.0, 0.0), fn)
4284      self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op)
4285      self.assertAllEqual((0.0, 8.0, 11.0), fn)
4286
4287
4288class FalsePositivesTest(test.TestCase):
4289
4290  def setUp(self):
4291    np.random.seed(1)
4292    ops.reset_default_graph()
4293
4294  @test_util.run_deprecated_v1
4295  def testVars(self):
4296    metrics.false_positives(
4297        labels=(0, 1, 0, 1),
4298        predictions=(0, 0, 1, 1))
4299    _assert_metric_variables(self, ('false_positives/count:0',))
4300
4301  @test_util.run_deprecated_v1
4302  def testUnweighted(self):
4303    labels = constant_op.constant(((0, 1, 0, 1, 0),
4304                                   (0, 0, 1, 1, 1),
4305                                   (1, 1, 1, 1, 0),
4306                                   (0, 0, 0, 0, 1)))
4307    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4308                                        (1, 1, 1, 1, 1),
4309                                        (0, 1, 0, 1, 0),
4310                                        (1, 1, 1, 1, 1)))
4311    tn, tn_update_op = metrics.false_positives(
4312        labels=labels, predictions=predictions)
4313
4314    with self.cached_session():
4315      self.evaluate(variables.local_variables_initializer())
4316      self.assertAllClose(0., tn)
4317      self.assertAllClose(7., tn_update_op)
4318      self.assertAllClose(7., tn)
4319
4320  @test_util.run_deprecated_v1
4321  def testWeighted(self):
4322    labels = constant_op.constant(((0, 1, 0, 1, 0),
4323                                   (0, 0, 1, 1, 1),
4324                                   (1, 1, 1, 1, 0),
4325                                   (0, 0, 0, 0, 1)))
4326    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4327                                        (1, 1, 1, 1, 1),
4328                                        (0, 1, 0, 1, 0),
4329                                        (1, 1, 1, 1, 1)))
4330    weights = constant_op.constant((1., 1.5, 2., 2.5))
4331    tn, tn_update_op = metrics.false_positives(
4332        labels=labels, predictions=predictions, weights=weights)
4333
4334    with self.cached_session():
4335      self.evaluate(variables.local_variables_initializer())
4336      self.assertAllClose(0., tn)
4337      self.assertAllClose(14., tn_update_op)
4338      self.assertAllClose(14., tn)
4339
4340
4341class FalsePositivesAtThresholdsTest(test.TestCase):
4342
4343  def setUp(self):
4344    np.random.seed(1)
4345    ops.reset_default_graph()
4346
4347  @test_util.run_deprecated_v1
4348  def testVars(self):
4349    metrics.false_positives_at_thresholds(
4350        predictions=array_ops.ones((10, 1)),
4351        labels=array_ops.ones((10, 1)),
4352        thresholds=[0.15, 0.5, 0.85])
4353    _assert_metric_variables(self, ('false_positives/false_positives:0',))
4354
4355  @test_util.run_deprecated_v1
4356  def testUnweighted(self):
4357    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4358                                        (0.2, 0.9, 0.7, 0.6),
4359                                        (0.1, 0.2, 0.4, 0.3)))
4360    labels = constant_op.constant(((0, 1, 1, 0),
4361                                   (1, 0, 0, 0),
4362                                   (0, 0, 0, 0)))
4363    fp, fp_update_op = metrics.false_positives_at_thresholds(
4364        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
4365
4366    with self.cached_session():
4367      self.evaluate(variables.local_variables_initializer())
4368      self.assertAllEqual((0, 0, 0), fp)
4369      self.assertAllEqual((7, 4, 2), fp_update_op)
4370      self.assertAllEqual((7, 4, 2), fp)
4371
4372  @test_util.run_deprecated_v1
4373  def testWeighted(self):
4374    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4375                                        (0.2, 0.9, 0.7, 0.6),
4376                                        (0.1, 0.2, 0.4, 0.3)))
4377    labels = constant_op.constant(((0, 1, 1, 0),
4378                                   (1, 0, 0, 0),
4379                                   (0, 0, 0, 0)))
4380    fp, fp_update_op = metrics.false_positives_at_thresholds(
4381        predictions=predictions,
4382        labels=labels,
4383        weights=((1.0, 2.0, 3.0, 5.0),
4384                 (7.0, 11.0, 13.0, 17.0),
4385                 (19.0, 23.0, 29.0, 31.0)),
4386        thresholds=[0.15, 0.5, 0.85])
4387
4388    with self.cached_session():
4389      self.evaluate(variables.local_variables_initializer())
4390      self.assertAllEqual((0.0, 0.0, 0.0), fp)
4391      self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op)
4392      self.assertAllEqual((125.0, 42.0, 12.0), fp)
4393
4394
4395class TrueNegativesTest(test.TestCase):
4396
4397  def setUp(self):
4398    np.random.seed(1)
4399    ops.reset_default_graph()
4400
4401  @test_util.run_deprecated_v1
4402  def testVars(self):
4403    metrics.true_negatives(
4404        labels=(0, 1, 0, 1),
4405        predictions=(0, 0, 1, 1))
4406    _assert_metric_variables(self, ('true_negatives/count:0',))
4407
4408  @test_util.run_deprecated_v1
4409  def testUnweighted(self):
4410    labels = constant_op.constant(((0, 1, 0, 1, 0),
4411                                   (0, 0, 1, 1, 1),
4412                                   (1, 1, 1, 1, 0),
4413                                   (0, 0, 0, 0, 1)))
4414    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4415                                        (1, 1, 1, 1, 1),
4416                                        (0, 1, 0, 1, 0),
4417                                        (1, 1, 1, 1, 1)))
4418    tn, tn_update_op = metrics.true_negatives(
4419        labels=labels, predictions=predictions)
4420
4421    with self.cached_session():
4422      self.evaluate(variables.local_variables_initializer())
4423      self.assertAllClose(0., tn)
4424      self.assertAllClose(3., tn_update_op)
4425      self.assertAllClose(3., tn)
4426
4427  @test_util.run_deprecated_v1
4428  def testWeighted(self):
4429    labels = constant_op.constant(((0, 1, 0, 1, 0),
4430                                   (0, 0, 1, 1, 1),
4431                                   (1, 1, 1, 1, 0),
4432                                   (0, 0, 0, 0, 1)))
4433    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4434                                        (1, 1, 1, 1, 1),
4435                                        (0, 1, 0, 1, 0),
4436                                        (1, 1, 1, 1, 1)))
4437    weights = constant_op.constant((1., 1.5, 2., 2.5))
4438    tn, tn_update_op = metrics.true_negatives(
4439        labels=labels, predictions=predictions, weights=weights)
4440
4441    with self.cached_session():
4442      self.evaluate(variables.local_variables_initializer())
4443      self.assertAllClose(0., tn)
4444      self.assertAllClose(4., tn_update_op)
4445      self.assertAllClose(4., tn)
4446
4447
4448class TrueNegativesAtThresholdsTest(test.TestCase):
4449
4450  def setUp(self):
4451    np.random.seed(1)
4452    ops.reset_default_graph()
4453
4454  @test_util.run_deprecated_v1
4455  def testVars(self):
4456    metrics.true_negatives_at_thresholds(
4457        predictions=array_ops.ones((10, 1)),
4458        labels=array_ops.ones((10, 1)),
4459        thresholds=[0.15, 0.5, 0.85])
4460    _assert_metric_variables(self, ('true_negatives/true_negatives:0',))
4461
4462  @test_util.run_deprecated_v1
4463  def testUnweighted(self):
4464    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4465                                        (0.2, 0.9, 0.7, 0.6),
4466                                        (0.1, 0.2, 0.4, 0.3)))
4467    labels = constant_op.constant(((0, 1, 1, 0),
4468                                   (1, 0, 0, 0),
4469                                   (0, 0, 0, 0)))
4470    tn, tn_update_op = metrics.true_negatives_at_thresholds(
4471        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
4472
4473    with self.cached_session():
4474      self.evaluate(variables.local_variables_initializer())
4475      self.assertAllEqual((0, 0, 0), tn)
4476      self.assertAllEqual((2, 5, 7), tn_update_op)
4477      self.assertAllEqual((2, 5, 7), tn)
4478
4479  @test_util.run_deprecated_v1
4480  def testWeighted(self):
4481    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4482                                        (0.2, 0.9, 0.7, 0.6),
4483                                        (0.1, 0.2, 0.4, 0.3)))
4484    labels = constant_op.constant(((0, 1, 1, 0),
4485                                   (1, 0, 0, 0),
4486                                   (0, 0, 0, 0)))
4487    tn, tn_update_op = metrics.true_negatives_at_thresholds(
4488        predictions=predictions,
4489        labels=labels,
4490        weights=((0.0, 2.0, 3.0, 5.0),),
4491        thresholds=[0.15, 0.5, 0.85])
4492
4493    with self.cached_session():
4494      self.evaluate(variables.local_variables_initializer())
4495      self.assertAllEqual((0.0, 0.0, 0.0), tn)
4496      self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op)
4497      self.assertAllEqual((5.0, 15.0, 23.0), tn)
4498
4499
4500class TruePositivesTest(test.TestCase):
4501
4502  def setUp(self):
4503    np.random.seed(1)
4504    ops.reset_default_graph()
4505
4506  @test_util.run_deprecated_v1
4507  def testVars(self):
4508    metrics.true_positives(
4509        labels=(0, 1, 0, 1),
4510        predictions=(0, 0, 1, 1))
4511    _assert_metric_variables(self, ('true_positives/count:0',))
4512
4513  @test_util.run_deprecated_v1
4514  def testUnweighted(self):
4515    labels = constant_op.constant(((0, 1, 0, 1, 0),
4516                                   (0, 0, 1, 1, 1),
4517                                   (1, 1, 1, 1, 0),
4518                                   (0, 0, 0, 0, 1)))
4519    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4520                                        (1, 1, 1, 1, 1),
4521                                        (0, 1, 0, 1, 0),
4522                                        (1, 1, 1, 1, 1)))
4523    tn, tn_update_op = metrics.true_positives(
4524        labels=labels, predictions=predictions)
4525
4526    with self.cached_session():
4527      self.evaluate(variables.local_variables_initializer())
4528      self.assertAllClose(0., tn)
4529      self.assertAllClose(7., tn_update_op)
4530      self.assertAllClose(7., tn)
4531
4532  @test_util.run_deprecated_v1
4533  def testWeighted(self):
4534    labels = constant_op.constant(((0, 1, 0, 1, 0),
4535                                   (0, 0, 1, 1, 1),
4536                                   (1, 1, 1, 1, 0),
4537                                   (0, 0, 0, 0, 1)))
4538    predictions = constant_op.constant(((0, 0, 1, 1, 0),
4539                                        (1, 1, 1, 1, 1),
4540                                        (0, 1, 0, 1, 0),
4541                                        (1, 1, 1, 1, 1)))
4542    weights = constant_op.constant((1., 1.5, 2., 2.5))
4543    tn, tn_update_op = metrics.true_positives(
4544        labels=labels, predictions=predictions, weights=weights)
4545
4546    with self.cached_session():
4547      self.evaluate(variables.local_variables_initializer())
4548      self.assertAllClose(0., tn)
4549      self.assertAllClose(12., tn_update_op)
4550      self.assertAllClose(12., tn)
4551
4552
4553class TruePositivesAtThresholdsTest(test.TestCase):
4554
4555  def setUp(self):
4556    np.random.seed(1)
4557    ops.reset_default_graph()
4558
4559  @test_util.run_deprecated_v1
4560  def testVars(self):
4561    metrics.true_positives_at_thresholds(
4562        predictions=array_ops.ones((10, 1)),
4563        labels=array_ops.ones((10, 1)),
4564        thresholds=[0.15, 0.5, 0.85])
4565    _assert_metric_variables(self, ('true_positives/true_positives:0',))
4566
4567  @test_util.run_deprecated_v1
4568  def testUnweighted(self):
4569    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4570                                        (0.2, 0.9, 0.7, 0.6),
4571                                        (0.1, 0.2, 0.4, 0.3)))
4572    labels = constant_op.constant(((0, 1, 1, 0),
4573                                   (1, 0, 0, 0),
4574                                   (0, 0, 0, 0)))
4575    tp, tp_update_op = metrics.true_positives_at_thresholds(
4576        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
4577
4578    with self.cached_session():
4579      self.evaluate(variables.local_variables_initializer())
4580      self.assertAllEqual((0, 0, 0), tp)
4581      self.assertAllEqual((3, 1, 0), tp_update_op)
4582      self.assertAllEqual((3, 1, 0), tp)
4583
4584  @test_util.run_deprecated_v1
4585  def testWeighted(self):
4586    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
4587                                        (0.2, 0.9, 0.7, 0.6),
4588                                        (0.1, 0.2, 0.4, 0.3)))
4589    labels = constant_op.constant(((0, 1, 1, 0),
4590                                   (1, 0, 0, 0),
4591                                   (0, 0, 0, 0)))
4592    tp, tp_update_op = metrics.true_positives_at_thresholds(
4593        predictions=predictions, labels=labels, weights=37.0,
4594        thresholds=[0.15, 0.5, 0.85])
4595
4596    with self.cached_session():
4597      self.evaluate(variables.local_variables_initializer())
4598      self.assertAllEqual((0.0, 0.0, 0.0), tp)
4599      self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op)
4600      self.assertAllEqual((111.0, 37.0, 0.0), tp)
4601
4602
4603if __name__ == '__main__':
4604  test.main()
4605