• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for connected component analysis."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import logging
22
23import numpy as np
24
25from tensorflow.contrib.image.python.ops import image_ops
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import googletest
31
32# Image for testing connected_components, with a single, winding component.
33SNAKE = np.asarray(
34    [[0, 0, 0, 0, 0, 0, 0, 0, 0],
35     [0, 1, 1, 1, 1, 0, 0, 0, 0],
36     [0, 0, 0, 0, 1, 1, 1, 1, 0],
37     [0, 0, 0, 0, 0, 0, 0, 1, 0],
38     [0, 1, 1, 1, 1, 1, 1, 1, 0],
39     [0, 1, 0, 0, 0, 0, 0, 0, 0],
40     [0, 1, 0, 1, 1, 1, 1, 1, 0],
41     [0, 1, 0, 0, 0, 0, 0, 1, 0],
42     [0, 1, 1, 1, 1, 1, 1, 1, 0],
43     [0, 0, 0, 0, 0, 0, 0, 0, 0]])  # pyformat: disable
44
45
46class SegmentationTest(test_util.TensorFlowTestCase):
47
48  def testDisconnected(self):
49    arr = math_ops.cast(
50        [[1, 0, 0, 1, 0, 0, 0, 0, 1],
51         [0, 1, 0, 0, 0, 1, 0, 1, 0],
52         [1, 0, 1, 0, 0, 0, 1, 0, 0],
53         [0, 0, 0, 0, 1, 0, 0, 0, 0],
54         [0, 0, 1, 0, 0, 0, 0, 0, 0]],
55        dtypes.bool)  # pyformat: disable
56    expected = (
57        [[1, 0, 0, 2, 0, 0, 0, 0, 3],
58         [0, 4, 0, 0, 0, 5, 0, 6, 0],
59         [7, 0, 8, 0, 0, 0, 9, 0, 0],
60         [0, 0, 0, 0, 10, 0, 0, 0, 0],
61         [0, 0, 11, 0, 0, 0, 0, 0, 0]])  # pyformat: disable
62    with self.cached_session():
63      self.assertAllEqual(image_ops.connected_components(arr).eval(), expected)
64
65  def testSimple(self):
66    arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
67    with self.cached_session():
68      # Single component with id 1.
69      self.assertAllEqual(
70          image_ops.connected_components(math_ops.cast(
71              arr, dtypes.bool)).eval(), arr)
72
73  def testSnake(self):
74    with self.cached_session():
75      # Single component with id 1.
76      self.assertAllEqual(
77          image_ops.connected_components(math_ops.cast(
78              SNAKE, dtypes.bool)).eval(), SNAKE)
79
80  def testSnake_disconnected(self):
81    for i in range(SNAKE.shape[0]):
82      for j in range(SNAKE.shape[1]):
83        with self.cached_session():
84          # If we disconnect any part of the snake except for the endpoints,
85          # there will be 2 components.
86          if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]:
87            disconnected_snake = SNAKE.copy()
88            disconnected_snake[i, j] = 0
89            components = image_ops.connected_components(
90                math_ops.cast(disconnected_snake, dtypes.bool)).eval()
91            self.assertEqual(components.max(), 2, 'disconnect (%d, %d)' % (i,
92                                                                           j))
93            bins = np.bincount(components.ravel())
94            # Nonzero number of pixels labeled 0, 1, or 2.
95            self.assertGreater(bins[0], 0)
96            self.assertGreater(bins[1], 0)
97            self.assertGreater(bins[2], 0)
98
99  def testMultipleImages(self):
100    images = [[[1, 1, 1, 1],
101               [1, 0, 0, 1],
102               [1, 0, 0, 1],
103               [1, 1, 1, 1]],
104              [[1, 0, 0, 1],
105               [0, 0, 0, 0],
106               [0, 0, 0, 0],
107               [1, 0, 0, 1]],
108              [[1, 1, 0, 1],
109               [0, 1, 1, 0],
110               [1, 0, 1, 0],
111               [0, 0, 1, 1]]]  # pyformat: disable
112    expected = [[[1, 1, 1, 1],
113                 [1, 0, 0, 1],
114                 [1, 0, 0, 1],
115                 [1, 1, 1, 1]],
116                [[2, 0, 0, 3],
117                 [0, 0, 0, 0],
118                 [0, 0, 0, 0],
119                 [4, 0, 0, 5]],
120                [[6, 6, 0, 7],
121                 [0, 6, 6, 0],
122                 [8, 0, 6, 0],
123                 [0, 0, 6, 6]]]  # pyformat: disable
124    with self.cached_session():
125      self.assertAllEqual(
126          image_ops.connected_components(math_ops.cast(
127              images, dtypes.bool)).eval(), expected)
128
129  def testZeros(self):
130    with self.cached_session():
131      self.assertAllEqual(
132          image_ops.connected_components(
133              array_ops.zeros((100, 20, 50), dtypes.bool)).eval(),
134          np.zeros((100, 20, 50)))
135
136  def testOnes(self):
137    with self.cached_session():
138      self.assertAllEqual(
139          image_ops.connected_components(
140              array_ops.ones((100, 20, 50), dtypes.bool)).eval(),
141          np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50]))
142
143  def testOnes_small(self):
144    with self.cached_session():
145      self.assertAllEqual(
146          image_ops.connected_components(array_ops.ones((3, 5),
147                                                        dtypes.bool)).eval(),
148          np.ones((3, 5)))
149
150  def testRandom_scipy(self):
151    np.random.seed(42)
152    images = np.random.randint(0, 2, size=(10, 100, 200)).astype(np.bool)
153    expected = connected_components_reference_implementation(images)
154    if expected is None:
155      return
156    with self.cached_session():
157      self.assertAllEqual(
158          image_ops.connected_components(images).eval(), expected)
159
160
161def connected_components_reference_implementation(images):
162  try:
163    # pylint: disable=g-import-not-at-top
164    from scipy.ndimage import measurements
165  except ImportError:
166    logging.exception('Skipping test method because scipy could not be loaded')
167    return
168  image_or_images = np.asarray(images)
169  if len(image_or_images.shape) == 2:
170    images = image_or_images[None, :, :]
171  elif len(image_or_images.shape) == 3:
172    images = image_or_images
173  components = np.asarray([measurements.label(image)[0] for image in images])
174  # Get the count of nonzero ids for each image, and offset each image's nonzero
175  # ids using the cumulative sum.
176  num_ids_per_image = components.reshape(
177      [-1, components.shape[1] * components.shape[2]]).max(axis=-1)
178  positive_id_start_per_image = np.cumsum(num_ids_per_image)
179  for i in range(components.shape[0]):
180    new_id_start = positive_id_start_per_image[i - 1] if i > 0 else 0
181    components[i, components[i] > 0] += new_id_start
182  if len(image_or_images.shape) == 2:
183    return components[0, :, :]
184  else:
185    return components
186
187
188if __name__ == '__main__':
189  googletest.main()
190