1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14# ============================================================================== 15"""Tests for clustering_ops.""" 16 17import numpy as np 18 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import clustering_ops 21from tensorflow.python.platform import test 22 23 24@test_util.run_all_in_graph_and_eager_modes 25class KmeansPlusPlusInitializationTest(test.TestCase): 26 27 # All but one input point are close to (101, 1). With uniform random sampling, 28 # it is highly improbable for (-1, -1) to be selected. 29 def setUp(self): 30 self._points = np.array([[100., 0.], 31 [101., 2.], 32 [102., 0.], 33 [100., 1.], 34 [100., 2.], 35 [101., 0.], 36 [101., 0.], 37 [101., 1.], 38 [102., 0.], 39 [-1., -1.]]).astype(np.float32) 40 41 def runTestWithSeed(self, seed): 42 with self.cached_session(): 43 sampled_points = clustering_ops.kmeans_plus_plus_initialization( 44 self._points, 3, seed, (seed % 5) - 1) 45 self.assertAllClose( 46 sorted(self.evaluate(sampled_points).tolist()), 47 [[-1., -1.], [101., 1.], [101., 1.]], 48 atol=1.0) 49 50 def testBasic(self): 51 for seed in range(100): 52 self.runTestWithSeed(seed) 53 54 55@test_util.run_all_in_graph_and_eager_modes 56class KMC2InitializationTest(test.TestCase): 57 58 def runTestWithSeed(self, seed): 59 with self.cached_session(): 60 distances = np.zeros(1000).astype(np.float32) 61 distances[6] = 10e7 62 distances[4] = 10e3 63 64 sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) 65 self.assertAllEqual(sampled_point, 6) 66 distances[6] = 0.0 67 sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) 68 self.assertAllEqual(sampled_point, 4) 69 70 def testBasic(self): 71 for seed in range(100): 72 self.runTestWithSeed(seed) 73 74 75@test_util.run_all_in_graph_and_eager_modes 76class KMC2InitializationLargeTest(test.TestCase): 77 78 def setUp(self): 79 self._distances = np.zeros(1001) 80 self._distances[500] = 100.0 81 self._distances[1000] = 50.0 82 83 def testBasic(self): 84 with self.cached_session(): 85 counts = {} 86 seed = 0 87 for i in range(50): 88 sample = self.evaluate( 89 clustering_ops.kmc2_chain_initialization(self._distances, seed + i)) 90 counts[sample] = counts.get(sample, 0) + 1 91 self.assertEqual(len(counts), 2) 92 self.assertTrue(500 in counts) 93 self.assertTrue(1000 in counts) 94 self.assertGreaterEqual(counts[500], 5) 95 self.assertGreaterEqual(counts[1000], 5) 96 97 98@test_util.run_all_in_graph_and_eager_modes 99class KMC2InitializationCornercaseTest(test.TestCase): 100 101 def setUp(self): 102 self._distances = np.zeros(10) 103 104 def runTestWithSeed(self, seed): 105 with self.cached_session(): 106 sampled_point = clustering_ops.kmc2_chain_initialization( 107 self._distances, seed) 108 self.assertAllEqual(sampled_point, 0) 109 110 def testBasic(self): 111 for seed in range(100): 112 self.runTestWithSeed(seed) 113 114 115@test_util.run_all_in_graph_and_eager_modes 116# A simple test that can be verified by hand. 117class NearestCentersTest(test.TestCase): 118 119 def setUp(self): 120 self._points = np.array([[100., 0.], 121 [101., 2.], 122 [99., 2.], 123 [1., 1.]]).astype(np.float32) 124 125 self._centers = np.array([[100., 0.], 126 [99., 1.], 127 [50., 50.], 128 [0., 0.], 129 [1., 1.]]).astype(np.float32) 130 131 def testNearest1(self): 132 with self.cached_session(): 133 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 134 self._centers, 1) 135 self.assertAllClose(indices, [[0], [0], [1], [4]]) 136 self.assertAllClose(distances, [[0.], [5.], [1.], [0.]]) 137 138 def testNearest2(self): 139 with self.cached_session(): 140 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 141 self._centers, 2) 142 self.assertAllClose(indices, [[0, 1], [0, 1], [1, 0], [4, 3]]) 143 self.assertAllClose(distances, [[0., 2.], [5., 5.], [1., 5.], [0., 2.]]) 144 145 146@test_util.run_all_in_graph_and_eager_modes 147# A test with large inputs. 148class NearestCentersLargeTest(test.TestCase): 149 150 def setUp(self): 151 num_points = 1000 152 num_centers = 2000 153 num_dim = 100 154 max_k = 5 155 # Construct a small number of random points and later tile them. 156 points_per_tile = 10 157 assert num_points % points_per_tile == 0 158 points = np.random.standard_normal( 159 [points_per_tile, num_dim]).astype(np.float32) 160 # Construct random centers. 161 self._centers = np.random.standard_normal( 162 [num_centers, num_dim]).astype(np.float32) 163 164 # Exhaustively compute expected nearest neighbors. 165 def squared_distance(x, y): 166 return np.linalg.norm(x - y, ord=2)**2 167 168 nearest_neighbors = [ 169 sorted([(squared_distance(point, self._centers[j]), j) 170 for j in range(num_centers)])[:max_k] for point in points 171 ] 172 expected_nearest_neighbor_indices = np.array( 173 [[i for _, i in nn] for nn in nearest_neighbors]) 174 expected_nearest_neighbor_squared_distances = np.array( 175 [[dist for dist, _ in nn] for nn in nearest_neighbors]) 176 # Tile points and expected results to reach requested size (num_points) 177 (self._points, self._expected_nearest_neighbor_indices, 178 self._expected_nearest_neighbor_squared_distances) = ( 179 np.tile(x, (int(num_points / points_per_tile), 1)) 180 for x in (points, expected_nearest_neighbor_indices, 181 expected_nearest_neighbor_squared_distances)) 182 183 def testNearest1(self): 184 with self.cached_session(): 185 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 186 self._centers, 1) 187 self.assertAllClose( 188 indices, 189 self._expected_nearest_neighbor_indices[:, [0]]) 190 self.assertAllClose( 191 distances, 192 self._expected_nearest_neighbor_squared_distances[:, [0]]) 193 194 def testNearest5(self): 195 with self.cached_session(): 196 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 197 self._centers, 5) 198 self.assertAllClose( 199 indices, 200 self._expected_nearest_neighbor_indices[:, 0:5]) 201 self.assertAllClose( 202 distances, 203 self._expected_nearest_neighbor_squared_distances[:, 0:5]) 204 205 206if __name__ == "__main__": 207 np.random.seed(0) 208 test.main() 209