• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License.  You may obtain a copy of
5# the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12# License for the specific language governing permissions and limitations under
13# the License.
14# ==============================================================================
15"""Tests for clustering_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import clustering_ops
25from tensorflow.python.platform import test
26
27
28@test_util.run_all_in_graph_and_eager_modes
29class KmeansPlusPlusInitializationTest(test.TestCase):
30
31  # All but one input point are close to (101, 1). With uniform random sampling,
32  # it is highly improbable for (-1, -1) to be selected.
33  def setUp(self):
34    self._points = np.array([[100., 0.],
35                             [101., 2.],
36                             [102., 0.],
37                             [100., 1.],
38                             [100., 2.],
39                             [101., 0.],
40                             [101., 0.],
41                             [101., 1.],
42                             [102., 0.],
43                             [-1., -1.]]).astype(np.float32)
44
45  def runTestWithSeed(self, seed):
46    with self.cached_session():
47      sampled_points = clustering_ops.kmeans_plus_plus_initialization(
48          self._points, 3, seed, (seed % 5) - 1)
49      self.assertAllClose(
50          sorted(self.evaluate(sampled_points).tolist()),
51          [[-1., -1.], [101., 1.], [101., 1.]],
52          atol=1.0)
53
54  def testBasic(self):
55    for seed in range(100):
56      self.runTestWithSeed(seed)
57
58
59@test_util.run_all_in_graph_and_eager_modes
60class KMC2InitializationTest(test.TestCase):
61
62  def runTestWithSeed(self, seed):
63    with self.cached_session():
64      distances = np.zeros(1000).astype(np.float32)
65      distances[6] = 10e7
66      distances[4] = 10e3
67
68      sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
69      self.assertAllEqual(sampled_point, 6)
70      distances[6] = 0.0
71      sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
72      self.assertAllEqual(sampled_point, 4)
73
74  def testBasic(self):
75    for seed in range(100):
76      self.runTestWithSeed(seed)
77
78
79@test_util.run_all_in_graph_and_eager_modes
80class KMC2InitializationLargeTest(test.TestCase):
81
82  def setUp(self):
83    self._distances = np.zeros(1001)
84    self._distances[500] = 100.0
85    self._distances[1000] = 50.0
86
87  def testBasic(self):
88    with self.cached_session():
89      counts = {}
90      seed = 0
91      for i in range(50):
92        sample = self.evaluate(
93            clustering_ops.kmc2_chain_initialization(self._distances, seed + i))
94        counts[sample] = counts.get(sample, 0) + 1
95      self.assertEquals(len(counts), 2)
96      self.assertTrue(500 in counts)
97      self.assertTrue(1000 in counts)
98      self.assertGreaterEqual(counts[500], 5)
99      self.assertGreaterEqual(counts[1000], 5)
100
101
102@test_util.run_all_in_graph_and_eager_modes
103class KMC2InitializationCornercaseTest(test.TestCase):
104
105  def setUp(self):
106    self._distances = np.zeros(10)
107
108  def runTestWithSeed(self, seed):
109    with self.cached_session():
110      sampled_point = clustering_ops.kmc2_chain_initialization(
111          self._distances, seed)
112      self.assertAllEqual(sampled_point, 0)
113
114  def testBasic(self):
115    for seed in range(100):
116      self.runTestWithSeed(seed)
117
118
119@test_util.run_all_in_graph_and_eager_modes
120# A simple test that can be verified by hand.
121class NearestCentersTest(test.TestCase):
122
123  def setUp(self):
124    self._points = np.array([[100., 0.],
125                             [101., 2.],
126                             [99., 2.],
127                             [1., 1.]]).astype(np.float32)
128
129    self._centers = np.array([[100., 0.],
130                              [99., 1.],
131                              [50., 50.],
132                              [0., 0.],
133                              [1., 1.]]).astype(np.float32)
134
135  def testNearest1(self):
136    with self.cached_session():
137      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
138                                                              self._centers, 1)
139      self.assertAllClose(indices, [[0], [0], [1], [4]])
140      self.assertAllClose(distances, [[0.], [5.], [1.], [0.]])
141
142  def testNearest2(self):
143    with self.cached_session():
144      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
145                                                              self._centers, 2)
146      self.assertAllClose(indices, [[0, 1], [0, 1], [1, 0], [4, 3]])
147      self.assertAllClose(distances, [[0., 2.], [5., 5.], [1., 5.], [0., 2.]])
148
149
150@test_util.run_all_in_graph_and_eager_modes
151# A test with large inputs.
152class NearestCentersLargeTest(test.TestCase):
153
154  def setUp(self):
155    num_points = 1000
156    num_centers = 2000
157    num_dim = 100
158    max_k = 5
159    # Construct a small number of random points and later tile them.
160    points_per_tile = 10
161    assert num_points % points_per_tile == 0
162    points = np.random.standard_normal(
163        [points_per_tile, num_dim]).astype(np.float32)
164    # Construct random centers.
165    self._centers = np.random.standard_normal(
166        [num_centers, num_dim]).astype(np.float32)
167
168    # Exhaustively compute expected nearest neighbors.
169    def squared_distance(x, y):
170      return np.linalg.norm(x - y, ord=2)**2
171
172    nearest_neighbors = [
173        sorted([(squared_distance(point, self._centers[j]), j)
174                for j in range(num_centers)])[:max_k] for point in points
175    ]
176    expected_nearest_neighbor_indices = np.array(
177        [[i for _, i in nn] for nn in nearest_neighbors])
178    expected_nearest_neighbor_squared_distances = np.array(
179        [[dist for dist, _ in nn] for nn in nearest_neighbors])
180    # Tile points and expected results to reach requested size (num_points)
181    (self._points, self._expected_nearest_neighbor_indices,
182     self._expected_nearest_neighbor_squared_distances) = (
183         np.tile(x, (int(num_points / points_per_tile), 1))
184         for x in (points, expected_nearest_neighbor_indices,
185                   expected_nearest_neighbor_squared_distances))
186
187  def testNearest1(self):
188    with self.cached_session():
189      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
190                                                              self._centers, 1)
191      self.assertAllClose(
192          indices,
193          self._expected_nearest_neighbor_indices[:, [0]])
194      self.assertAllClose(
195          distances,
196          self._expected_nearest_neighbor_squared_distances[:, [0]])
197
198  def testNearest5(self):
199    with self.cached_session():
200      [indices, distances] = clustering_ops.nearest_neighbors(self._points,
201                                                              self._centers, 5)
202      self.assertAllClose(
203          indices,
204          self._expected_nearest_neighbor_indices[:, 0:5])
205      self.assertAllClose(
206          distances,
207          self._expected_nearest_neighbor_squared_distances[:, 0:5])
208
209
210if __name__ == "__main__":
211  np.random.seed(0)
212  test.main()
213