• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License.  You may obtain a copy of
5# the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12# License for the specific language governing permissions and limitations under
13# the License.
14# ==============================================================================
15"""Tests for clustering_ops."""
16
17import numpy as np
18
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import clustering_ops
21from tensorflow.python.platform import test
22
23
24@test_util.run_all_in_graph_and_eager_modes
25class KmeansPlusPlusInitializationTest(test.TestCase):
26
27  # All but one input point are close to (101, 1). With uniform random sampling,
28  # it is highly improbable for (-1, -1) to be selected.
29  def setUp(self):
30    self._points = np.array([[100., 0.],
31                             [101., 2.],
32                             [102., 0.],
33                             [100., 1.],
34                             [100., 2.],
35                             [101., 0.],
36                             [101., 0.],
37                             [101., 1.],
38                             [102., 0.],
39                             [-1., -1.]]).astype(np.float32)
40
41  def runTestWithSeed(self, seed):
42    with self.cached_session():
43      sampled_points = clustering_ops.kmeans_plus_plus_initialization(
44          self._points, 3, seed, (seed % 5) - 1)
45      self.assertAllClose(
46          sorted(self.evaluate(sampled_points).tolist()),
47          [[-1., -1.], [101., 1.], [101., 1.]],
48          atol=1.0)
49
50  def testBasic(self):
51    for seed in range(100):
52      self.runTestWithSeed(seed)
53
54
55@test_util.run_all_in_graph_and_eager_modes
56class KMC2InitializationTest(test.TestCase):
57
58  def runTestWithSeed(self, seed):
59    with self.cached_session():
60      distances = np.zeros(1000).astype(np.float32)
61      distances[6] = 10e7
62      distances[4] = 10e3
63
64      sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
65      self.assertAllEqual(sampled_point, 6)
66      distances[6] = 0.0
67      sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
68      self.assertAllEqual(sampled_point, 4)
69
70  def testBasic(self):
71    for seed in range(100):
72      self.runTestWithSeed(seed)
73
74
75@test_util.run_all_in_graph_and_eager_modes
76class KMC2InitializationLargeTest(test.TestCase):
77
78  def setUp(self):
79    self._distances = np.zeros(1001)
80    self._distances[500] = 100.0
81    self._distances[1000] = 50.0
82
83  def testBasic(self):
84    with self.cached_session():
85      counts = {}
86      seed = 0
87      for i in range(50):
88        sample = self.evaluate(
89            clustering_ops.kmc2_chain_initialization(self._distances, seed + i))
90        counts[sample] = counts.get(sample, 0) + 1
91      self.assertEqual(len(counts), 2)
92      self.assertTrue(500 in counts)
93      self.assertTrue(1000 in counts)
94      self.assertGreaterEqual(counts[500], 5)
95      self.assertGreaterEqual(counts[1000], 5)
96
97
98@test_util.run_all_in_graph_and_eager_modes
99class KMC2InitializationCornercaseTest(test.TestCase):
100
101  def setUp(self):
102    self._distances = np.zeros(10)
103
104  def runTestWithSeed(self, seed):
105    with self.cached_session():
106      sampled_point = clustering_ops.kmc2_chain_initialization(
107          self._distances, seed)
108      self.assertAllEqual(sampled_point, 0)
109
110  def testBasic(self):
111    for seed in range(100):
112      self.runTestWithSeed(seed)
113
114
115@test_util.run_all_in_graph_and_eager_modes
116# A simple test that can be verified by hand.
117class NearestCentersTest(test.TestCase):
118
119  def setUp(self):
120    self._points = np.array([[100., 0.],
121                             [101., 2.],
122                             [99., 2.],
123                             [1., 1.]]).astype(np.float32)
124
125    self._centers = np.array([[100., 0.],
126                              [99., 1.],
127                              [50., 50.],
128                              [0., 0.],
129                              [1., 1.]]).astype(np.float32)
130
131  def testNearest1(self):
132    with self.cached_session():
133      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
134                                                              self._centers, 1)
135      self.assertAllClose(indices, [[0], [0], [1], [4]])
136      self.assertAllClose(distances, [[0.], [5.], [1.], [0.]])
137
138  def testNearest2(self):
139    with self.cached_session():
140      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
141                                                              self._centers, 2)
142      self.assertAllClose(indices, [[0, 1], [0, 1], [1, 0], [4, 3]])
143      self.assertAllClose(distances, [[0., 2.], [5., 5.], [1., 5.], [0., 2.]])
144
145
146@test_util.run_all_in_graph_and_eager_modes
147# A test with large inputs.
148class NearestCentersLargeTest(test.TestCase):
149
150  def setUp(self):
151    num_points = 1000
152    num_centers = 2000
153    num_dim = 100
154    max_k = 5
155    # Construct a small number of random points and later tile them.
156    points_per_tile = 10
157    assert num_points % points_per_tile == 0
158    points = np.random.standard_normal(
159        [points_per_tile, num_dim]).astype(np.float32)
160    # Construct random centers.
161    self._centers = np.random.standard_normal(
162        [num_centers, num_dim]).astype(np.float32)
163
164    # Exhaustively compute expected nearest neighbors.
165    def squared_distance(x, y):
166      return np.linalg.norm(x - y, ord=2)**2
167
168    nearest_neighbors = [
169        sorted([(squared_distance(point, self._centers[j]), j)
170                for j in range(num_centers)])[:max_k] for point in points
171    ]
172    expected_nearest_neighbor_indices = np.array(
173        [[i for _, i in nn] for nn in nearest_neighbors])
174    expected_nearest_neighbor_squared_distances = np.array(
175        [[dist for dist, _ in nn] for nn in nearest_neighbors])
176    # Tile points and expected results to reach requested size (num_points)
177    (self._points, self._expected_nearest_neighbor_indices,
178     self._expected_nearest_neighbor_squared_distances) = (
179         np.tile(x, (int(num_points / points_per_tile), 1))
180         for x in (points, expected_nearest_neighbor_indices,
181                   expected_nearest_neighbor_squared_distances))
182
183  def testNearest1(self):
184    with self.cached_session():
185      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
186                                                              self._centers, 1)
187      self.assertAllClose(
188          indices,
189          self._expected_nearest_neighbor_indices[:, [0]])
190      self.assertAllClose(
191          distances,
192          self._expected_nearest_neighbor_squared_distances[:, [0]])
193
194  def testNearest5(self):
195    with self.cached_session():
196      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
197                                                              self._centers, 5)
198      self.assertAllClose(
199          indices,
200          self._expected_nearest_neighbor_indices[:, 0:5])
201      self.assertAllClose(
202          distances,
203          self._expected_nearest_neighbor_squared_distances[:, 0:5])
204
205
206if __name__ == "__main__":
207  np.random.seed(0)
208  test.main()
209