1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14# ============================================================================== 15"""Tests for clustering_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import clustering_ops 25from tensorflow.python.platform import test 26 27 28@test_util.run_all_in_graph_and_eager_modes 29class KmeansPlusPlusInitializationTest(test.TestCase): 30 31 # All but one input point are close to (101, 1). With uniform random sampling, 32 # it is highly improbable for (-1, -1) to be selected. 33 def setUp(self): 34 self._points = np.array([[100., 0.], 35 [101., 2.], 36 [102., 0.], 37 [100., 1.], 38 [100., 2.], 39 [101., 0.], 40 [101., 0.], 41 [101., 1.], 42 [102., 0.], 43 [-1., -1.]]).astype(np.float32) 44 45 def runTestWithSeed(self, seed): 46 with self.cached_session(): 47 sampled_points = clustering_ops.kmeans_plus_plus_initialization( 48 self._points, 3, seed, (seed % 5) - 1) 49 self.assertAllClose( 50 sorted(self.evaluate(sampled_points).tolist()), 51 [[-1., -1.], [101., 1.], [101., 1.]], 52 atol=1.0) 53 54 def testBasic(self): 55 for seed in range(100): 56 self.runTestWithSeed(seed) 57 58 59@test_util.run_all_in_graph_and_eager_modes 60class KMC2InitializationTest(test.TestCase): 61 62 def runTestWithSeed(self, seed): 63 with self.cached_session(): 64 distances = np.zeros(1000).astype(np.float32) 65 distances[6] = 10e7 66 distances[4] = 10e3 67 68 sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) 69 self.assertAllEqual(sampled_point, 6) 70 distances[6] = 0.0 71 sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) 72 self.assertAllEqual(sampled_point, 4) 73 74 def testBasic(self): 75 for seed in range(100): 76 self.runTestWithSeed(seed) 77 78 79@test_util.run_all_in_graph_and_eager_modes 80class KMC2InitializationLargeTest(test.TestCase): 81 82 def setUp(self): 83 self._distances = np.zeros(1001) 84 self._distances[500] = 100.0 85 self._distances[1000] = 50.0 86 87 def testBasic(self): 88 with self.cached_session(): 89 counts = {} 90 seed = 0 91 for i in range(50): 92 sample = self.evaluate( 93 clustering_ops.kmc2_chain_initialization(self._distances, seed + i)) 94 counts[sample] = counts.get(sample, 0) + 1 95 self.assertEquals(len(counts), 2) 96 self.assertTrue(500 in counts) 97 self.assertTrue(1000 in counts) 98 self.assertGreaterEqual(counts[500], 5) 99 self.assertGreaterEqual(counts[1000], 5) 100 101 102@test_util.run_all_in_graph_and_eager_modes 103class KMC2InitializationCornercaseTest(test.TestCase): 104 105 def setUp(self): 106 self._distances = np.zeros(10) 107 108 def runTestWithSeed(self, seed): 109 with self.cached_session(): 110 sampled_point = clustering_ops.kmc2_chain_initialization( 111 self._distances, seed) 112 self.assertAllEqual(sampled_point, 0) 113 114 def testBasic(self): 115 for seed in range(100): 116 self.runTestWithSeed(seed) 117 118 119@test_util.run_all_in_graph_and_eager_modes 120# A simple test that can be verified by hand. 121class NearestCentersTest(test.TestCase): 122 123 def setUp(self): 124 self._points = np.array([[100., 0.], 125 [101., 2.], 126 [99., 2.], 127 [1., 1.]]).astype(np.float32) 128 129 self._centers = np.array([[100., 0.], 130 [99., 1.], 131 [50., 50.], 132 [0., 0.], 133 [1., 1.]]).astype(np.float32) 134 135 def testNearest1(self): 136 with self.cached_session(): 137 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 138 self._centers, 1) 139 self.assertAllClose(indices, [[0], [0], [1], [4]]) 140 self.assertAllClose(distances, [[0.], [5.], [1.], [0.]]) 141 142 def testNearest2(self): 143 with self.cached_session(): 144 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 145 self._centers, 2) 146 self.assertAllClose(indices, [[0, 1], [0, 1], [1, 0], [4, 3]]) 147 self.assertAllClose(distances, [[0., 2.], [5., 5.], [1., 5.], [0., 2.]]) 148 149 150@test_util.run_all_in_graph_and_eager_modes 151# A test with large inputs. 152class NearestCentersLargeTest(test.TestCase): 153 154 def setUp(self): 155 num_points = 1000 156 num_centers = 2000 157 num_dim = 100 158 max_k = 5 159 # Construct a small number of random points and later tile them. 160 points_per_tile = 10 161 assert num_points % points_per_tile == 0 162 points = np.random.standard_normal( 163 [points_per_tile, num_dim]).astype(np.float32) 164 # Construct random centers. 165 self._centers = np.random.standard_normal( 166 [num_centers, num_dim]).astype(np.float32) 167 168 # Exhaustively compute expected nearest neighbors. 169 def squared_distance(x, y): 170 return np.linalg.norm(x - y, ord=2)**2 171 172 nearest_neighbors = [ 173 sorted([(squared_distance(point, self._centers[j]), j) 174 for j in range(num_centers)])[:max_k] for point in points 175 ] 176 expected_nearest_neighbor_indices = np.array( 177 [[i for _, i in nn] for nn in nearest_neighbors]) 178 expected_nearest_neighbor_squared_distances = np.array( 179 [[dist for dist, _ in nn] for nn in nearest_neighbors]) 180 # Tile points and expected results to reach requested size (num_points) 181 (self._points, self._expected_nearest_neighbor_indices, 182 self._expected_nearest_neighbor_squared_distances) = ( 183 np.tile(x, (int(num_points / points_per_tile), 1)) 184 for x in (points, expected_nearest_neighbor_indices, 185 expected_nearest_neighbor_squared_distances)) 186 187 def testNearest1(self): 188 with self.cached_session(): 189 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 190 self._centers, 1) 191 self.assertAllClose( 192 indices, 193 self._expected_nearest_neighbor_indices[:, [0]]) 194 self.assertAllClose( 195 distances, 196 self._expected_nearest_neighbor_squared_distances[:, [0]]) 197 198 def testNearest5(self): 199 with self.cached_session(): 200 [indices, distances] = clustering_ops.nearest_neighbors(self._points, 201 self._centers, 5) 202 self.assertAllClose( 203 indices, 204 self._expected_nearest_neighbor_indices[:, 0:5]) 205 self.assertAllClose( 206 distances, 207 self._expected_nearest_neighbor_squared_distances[:, 0:5]) 208 209 210if __name__ == "__main__": 211 np.random.seed(0) 212 test.main() 213