1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for connected component analysis.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import logging 22 23import numpy as np 24 25from tensorflow.contrib.image.python.ops import image_ops 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import googletest 31 32# Image for testing connected_components, with a single, winding component. 33SNAKE = np.asarray( 34 [[0, 0, 0, 0, 0, 0, 0, 0, 0], 35 [0, 1, 1, 1, 1, 0, 0, 0, 0], 36 [0, 0, 0, 0, 1, 1, 1, 1, 0], 37 [0, 0, 0, 0, 0, 0, 0, 1, 0], 38 [0, 1, 1, 1, 1, 1, 1, 1, 0], 39 [0, 1, 0, 0, 0, 0, 0, 0, 0], 40 [0, 1, 0, 1, 1, 1, 1, 1, 0], 41 [0, 1, 0, 0, 0, 0, 0, 1, 0], 42 [0, 1, 1, 1, 1, 1, 1, 1, 0], 43 [0, 0, 0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable 44 45 46class SegmentationTest(test_util.TensorFlowTestCase): 47 48 def testDisconnected(self): 49 arr = math_ops.cast( 50 [[1, 0, 0, 1, 0, 0, 0, 0, 1], 51 [0, 1, 0, 0, 0, 1, 0, 1, 0], 52 [1, 0, 1, 0, 0, 0, 1, 0, 0], 53 [0, 0, 0, 0, 1, 0, 0, 0, 0], 54 [0, 0, 1, 0, 0, 0, 0, 0, 0]], 55 dtypes.bool) # pyformat: disable 56 expected = ( 57 [[1, 0, 0, 2, 0, 0, 0, 0, 3], 58 [0, 4, 0, 0, 0, 5, 0, 6, 0], 59 [7, 0, 8, 0, 0, 0, 9, 0, 0], 60 [0, 0, 0, 0, 10, 0, 0, 0, 0], 61 [0, 0, 11, 0, 0, 0, 0, 0, 0]]) # pyformat: disable 62 with self.cached_session(): 63 self.assertAllEqual(image_ops.connected_components(arr).eval(), expected) 64 65 def testSimple(self): 66 arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] 67 with self.cached_session(): 68 # Single component with id 1. 69 self.assertAllEqual( 70 image_ops.connected_components(math_ops.cast( 71 arr, dtypes.bool)).eval(), arr) 72 73 def testSnake(self): 74 with self.cached_session(): 75 # Single component with id 1. 76 self.assertAllEqual( 77 image_ops.connected_components(math_ops.cast( 78 SNAKE, dtypes.bool)).eval(), SNAKE) 79 80 def testSnake_disconnected(self): 81 for i in range(SNAKE.shape[0]): 82 for j in range(SNAKE.shape[1]): 83 with self.cached_session(): 84 # If we disconnect any part of the snake except for the endpoints, 85 # there will be 2 components. 86 if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]: 87 disconnected_snake = SNAKE.copy() 88 disconnected_snake[i, j] = 0 89 components = image_ops.connected_components( 90 math_ops.cast(disconnected_snake, dtypes.bool)).eval() 91 self.assertEqual(components.max(), 2, 'disconnect (%d, %d)' % (i, 92 j)) 93 bins = np.bincount(components.ravel()) 94 # Nonzero number of pixels labeled 0, 1, or 2. 95 self.assertGreater(bins[0], 0) 96 self.assertGreater(bins[1], 0) 97 self.assertGreater(bins[2], 0) 98 99 def testMultipleImages(self): 100 images = [[[1, 1, 1, 1], 101 [1, 0, 0, 1], 102 [1, 0, 0, 1], 103 [1, 1, 1, 1]], 104 [[1, 0, 0, 1], 105 [0, 0, 0, 0], 106 [0, 0, 0, 0], 107 [1, 0, 0, 1]], 108 [[1, 1, 0, 1], 109 [0, 1, 1, 0], 110 [1, 0, 1, 0], 111 [0, 0, 1, 1]]] # pyformat: disable 112 expected = [[[1, 1, 1, 1], 113 [1, 0, 0, 1], 114 [1, 0, 0, 1], 115 [1, 1, 1, 1]], 116 [[2, 0, 0, 3], 117 [0, 0, 0, 0], 118 [0, 0, 0, 0], 119 [4, 0, 0, 5]], 120 [[6, 6, 0, 7], 121 [0, 6, 6, 0], 122 [8, 0, 6, 0], 123 [0, 0, 6, 6]]] # pyformat: disable 124 with self.cached_session(): 125 self.assertAllEqual( 126 image_ops.connected_components(math_ops.cast( 127 images, dtypes.bool)).eval(), expected) 128 129 def testZeros(self): 130 with self.cached_session(): 131 self.assertAllEqual( 132 image_ops.connected_components( 133 array_ops.zeros((100, 20, 50), dtypes.bool)).eval(), 134 np.zeros((100, 20, 50))) 135 136 def testOnes(self): 137 with self.cached_session(): 138 self.assertAllEqual( 139 image_ops.connected_components( 140 array_ops.ones((100, 20, 50), dtypes.bool)).eval(), 141 np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50])) 142 143 def testOnes_small(self): 144 with self.cached_session(): 145 self.assertAllEqual( 146 image_ops.connected_components(array_ops.ones((3, 5), 147 dtypes.bool)).eval(), 148 np.ones((3, 5))) 149 150 def testRandom_scipy(self): 151 np.random.seed(42) 152 images = np.random.randint(0, 2, size=(10, 100, 200)).astype(np.bool) 153 expected = connected_components_reference_implementation(images) 154 if expected is None: 155 return 156 with self.cached_session(): 157 self.assertAllEqual( 158 image_ops.connected_components(images).eval(), expected) 159 160 161def connected_components_reference_implementation(images): 162 try: 163 # pylint: disable=g-import-not-at-top 164 from scipy.ndimage import measurements 165 except ImportError: 166 logging.exception('Skipping test method because scipy could not be loaded') 167 return 168 image_or_images = np.asarray(images) 169 if len(image_or_images.shape) == 2: 170 images = image_or_images[None, :, :] 171 elif len(image_or_images.shape) == 3: 172 images = image_or_images 173 components = np.asarray([measurements.label(image)[0] for image in images]) 174 # Get the count of nonzero ids for each image, and offset each image's nonzero 175 # ids using the cumulative sum. 176 num_ids_per_image = components.reshape( 177 [-1, components.shape[1] * components.shape[2]]).max(axis=-1) 178 positive_id_start_per_image = np.cumsum(num_ids_per_image) 179 for i in range(components.shape[0]): 180 new_id_start = positive_id_start_per_image[i - 1] if i > 0 else 0 181 components[i, components[i] > 0] += new_id_start 182 if len(image_or_images.shape) == 2: 183 return components[0, :, :] 184 else: 185 return components 186 187 188if __name__ == '__main__': 189 googletest.main() 190