1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Clustering Operations.""" 16 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import random_seed as random_seed_ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import gen_clustering_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn_impl 27from tensorflow.python.ops import random_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.ops.embedding_ops import embedding_lookup 31# go/tf-wildcard-import 32# pylint: disable=wildcard-import 33from tensorflow.python.ops.gen_clustering_ops import * 34# pylint: enable=wildcard-import 35 36# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) 37# which is the square root of the sum of the absolute squares of the elements 38# difference. 39SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' 40# Cosine distance between vectors U and V is defined as 41# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\) 42COSINE_DISTANCE = 'cosine' 43 44RANDOM_INIT = 'random' 45KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' 46KMC2_INIT = 'kmc2' 47 48# The name of the variable holding the cluster centers. Used by the Estimator. 49CLUSTERS_VAR_NAME = 'clusters' 50 51 52class KMeans: 53 """Creates the graph for k-means clustering.""" 54 55 def __init__(self, 56 inputs, 57 num_clusters, 58 initial_clusters=RANDOM_INIT, 59 distance_metric=SQUARED_EUCLIDEAN_DISTANCE, 60 use_mini_batch=False, 61 mini_batch_steps_per_iteration=1, 62 random_seed=0, 63 kmeans_plus_plus_num_retries=2, 64 kmc2_chain_length=200): 65 """Creates an object for generating KMeans clustering graph. 66 67 This class implements the following variants of K-means algorithm: 68 69 If use_mini_batch is False, it runs standard full batch K-means. Each step 70 runs a single iteration of K-Means. This step can be run sharded across 71 multiple workers by passing a list of sharded inputs to this class. Note 72 however that a single step needs to process the full input at once. 73 74 If use_mini_batch is True, it runs a generalization of the mini-batch 75 K-means algorithm. It runs multiple iterations, where each iteration is 76 composed of mini_batch_steps_per_iteration steps. Two copies of cluster 77 centers are maintained: one that is updated at the end of each iteration, 78 and one that is updated every step. The first copy is used to compute 79 cluster allocations for each step, and for inference, while the second copy 80 is the one updated each step using the mini-batch update rule. After each 81 iteration is complete, this second copy is copied back the first copy. 82 83 Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1, 84 the algorithm reduces to the standard mini-batch algorithm. Also by setting 85 mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm 86 becomes an asynchronous version of the full-batch algorithm. Note however 87 that there is no guarantee by this implementation that each input is seen 88 exactly once per iteration. Also, different updates are applied 89 asynchronously without locking. So this asynchronous version may not behave 90 exactly like a full-batch version. 91 92 Args: 93 inputs: An input tensor or list of input tensors. It is assumed that the 94 data points have been previously randomly permuted. 95 num_clusters: An integer tensor specifying the number of clusters. This 96 argument is ignored if initial_clusters is a tensor or numpy array. 97 initial_clusters: Specifies the clusters used during initialization. One 98 of the following: - a tensor or numpy array with the initial cluster 99 centers. - a function f(inputs, k) that returns up to k centers from 100 `inputs`. 101 - "random": Choose centers randomly from `inputs`. 102 - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. 103 - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. 104 In the last three cases, one batch of `inputs` may not yield 105 `num_clusters` centers, in which case initialization will require 106 multiple batches until enough centers are chosen. In the case of 107 "random" or "kmeans_plus_plus", if the input size is <= `num_clusters` 108 then the entire batch is chosen to be cluster centers. 109 distance_metric: Distance metric used for clustering. Supported options: 110 "squared_euclidean", "cosine". 111 use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume 112 full batch. 113 mini_batch_steps_per_iteration: Number of steps after which the updated 114 cluster centers are synced back to a master copy. 115 random_seed: Seed for PRNG used to initialize seeds. 116 kmeans_plus_plus_num_retries: For each point that is sampled during 117 kmeans++ initialization, this parameter specifies the number of 118 additional points to draw from the current distribution before selecting 119 the best. If a negative value is specified, a heuristic is used to 120 sample O(log(num_to_sample)) additional points. 121 kmc2_chain_length: Determines how many candidate points are used by the 122 k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch 123 contains less points, one new cluster center is generated from the 124 (mini-)batch. 125 126 Raises: 127 ValueError: An invalid argument was passed to initial_clusters or 128 distance_metric. 129 """ 130 initialization_algorithms = [RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT] 131 if isinstance(initial_clusters, 132 str) and initial_clusters not in initialization_algorithms: 133 raise ValueError( 134 f'Unsupported initialization algorithm `{initial_clusters}`,' 135 f'must be one of `{initialization_algorithms}`.') 136 137 distance_metrics = [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE] 138 if distance_metric not in distance_metrics: 139 raise ValueError(f'Unsupported distance metric `{distance_metric}`,' 140 f'must be one of `{distance_metrics}`.') 141 self._inputs = inputs if isinstance(inputs, list) else [inputs] 142 self._num_clusters = num_clusters 143 self._initial_clusters = initial_clusters 144 self._distance_metric = distance_metric 145 self._use_mini_batch = use_mini_batch 146 self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) 147 self._seed = random_seed_ops.get_seed(random_seed)[0] 148 self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries 149 self._kmc2_chain_length = kmc2_chain_length 150 151 @classmethod 152 def _distance_graph(cls, inputs, clusters, distance_metric): 153 """Computes distance between each input and each cluster center. 154 155 Args: 156 inputs: list of input Tensors. 157 clusters: cluster Tensor. 158 distance_metric: distance metric used for clustering 159 160 Returns: 161 list of Tensors, where each element corresponds to each element in inputs. 162 The value is the distance of each row to all the cluster centers. 163 Currently only Euclidean distance and cosine distance are supported. 164 """ 165 assert isinstance(inputs, list) 166 if distance_metric == SQUARED_EUCLIDEAN_DISTANCE: 167 return cls._compute_euclidean_distance(inputs, clusters) 168 elif distance_metric == COSINE_DISTANCE: 169 return cls._compute_cosine_distance( 170 inputs, clusters, inputs_normalized=True) 171 else: 172 assert False, str(distance_metric) 173 174 @classmethod 175 def _compute_euclidean_distance(cls, inputs, clusters): 176 """Computes Euclidean distance between each input and each cluster center. 177 178 Args: 179 inputs: list of input Tensors. 180 clusters: cluster Tensor. 181 182 Returns: 183 list of Tensors, where each element corresponds to each element in inputs. 184 The value is the distance of each row to all the cluster centers. 185 """ 186 output = [] 187 for inp in inputs: 188 with ops.colocate_with(inp, ignore_existing=True): 189 # Computes Euclidean distance. Note the first and third terms are 190 # broadcast additions. 191 squared_distance = ( 192 math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) - 193 2 * math_ops.matmul(inp, clusters, transpose_b=True) + 194 array_ops.transpose( 195 math_ops.reduce_sum( 196 math_ops.square(clusters), 1, keepdims=True))) 197 output.append(squared_distance) 198 199 return output 200 201 @classmethod 202 def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True): 203 """Computes cosine distance between each input and each cluster center. 204 205 Args: 206 inputs: list of input Tensor. 207 clusters: cluster Tensor 208 inputs_normalized: if True, it assumes that inp and clusters are 209 normalized and computes the dot product which is equivalent to the 210 cosine distance. Else it L2 normalizes the inputs first. 211 212 Returns: 213 list of Tensors, where each element corresponds to each element in inp. 214 The value is the distance of each row to all the cluster centers. 215 """ 216 output = [] 217 if not inputs_normalized: 218 with ops.colocate_with(clusters, ignore_existing=True): 219 clusters = nn_impl.l2_normalize(clusters, axis=1) 220 for inp in inputs: 221 with ops.colocate_with(inp, ignore_existing=True): 222 if not inputs_normalized: 223 inp = nn_impl.l2_normalize(inp, axis=1) 224 output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True)) 225 return output 226 227 def _infer_graph(self, inputs, clusters): 228 """Maps input to closest cluster and the score. 229 230 Args: 231 inputs: list of input Tensors. 232 clusters: Tensor of cluster centers. 233 234 Returns: 235 List of tuple, where each value in tuple corresponds to a value in inp. 236 The tuple has following three elements: 237 all_scores: distance of each input to each cluster center. 238 score: distance of each input to closest cluster center. 239 cluster_idx: index of cluster center closest to the corresponding input. 240 """ 241 assert isinstance(inputs, list) 242 # Pairwise distances are used only by transform(). In all other cases, this 243 # sub-graph is not evaluated. 244 scores = self._distance_graph(inputs, clusters, self._distance_metric) 245 output = [] 246 if (self._distance_metric == COSINE_DISTANCE and 247 not self._clusters_l2_normalized()): 248 # The cosine distance between normalized vectors x and y is the same as 249 # 2 * squared_euclidean_distance. We are using this fact and reusing the 250 # nearest_neighbors op. 251 # TODO(ands): Support COSINE distance in nearest_neighbors and remove 252 # this. 253 with ops.colocate_with(clusters, ignore_existing=True): 254 clusters = nn_impl.l2_normalize(clusters, axis=1) 255 for inp, score in zip(inputs, scores): 256 with ops.colocate_with(inp, ignore_existing=True): 257 (indices, 258 distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1) 259 if self._distance_metric == COSINE_DISTANCE: 260 distances *= 0.5 261 output.append( 262 (score, array_ops.squeeze(distances, 263 [-1]), array_ops.squeeze(indices, [-1]))) 264 return zip(*output) 265 266 def _clusters_l2_normalized(self): 267 """Returns True if clusters centers are kept normalized.""" 268 return (self._distance_metric == COSINE_DISTANCE and 269 (not self._use_mini_batch or 270 self._mini_batch_steps_per_iteration > 1)) 271 272 def _create_variables(self, num_clusters): 273 """Creates variables. 274 275 Args: 276 num_clusters: an integer Tensor providing the number of clusters. 277 278 Returns: 279 Tuple with following elements: 280 - cluster_centers: a Tensor for storing cluster centers 281 - cluster_centers_initialized: bool Variable indicating whether clusters 282 are initialized. 283 - cluster_counts: a Tensor for storing counts of points assigned to this 284 cluster. This is used by mini-batch training. 285 - cluster_centers_updated: Tensor representing copy of cluster centers 286 that are updated every step. 287 - update_in_steps: numbers of steps left before we sync 288 cluster_centers_updated back to cluster_centers. 289 """ 290 init_value = array_ops.placeholder_with_default([], shape=None) 291 cluster_centers = variable_scope.variable( 292 init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) 293 cluster_centers_initialized = variable_scope.variable( 294 False, dtype=dtypes.bool, name='initialized') 295 296 if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: 297 # Copy of cluster centers actively updated each step according to 298 # mini-batch update rule. 299 cluster_centers_updated = variable_scope.variable( 300 init_value, name='clusters_updated', validate_shape=False) 301 # How many steps till we copy the updated clusters to cluster_centers. 302 update_in_steps = variable_scope.variable( 303 self._mini_batch_steps_per_iteration, 304 dtype=dtypes.int64, 305 name='update_in_steps') 306 # Count of points assigned to cluster_centers_updated. 307 cluster_counts = variable_scope.variable( 308 array_ops.zeros([num_clusters], dtype=dtypes.int64)) 309 else: 310 cluster_centers_updated = cluster_centers 311 update_in_steps = None 312 cluster_counts = ( 313 variable_scope.variable( 314 array_ops.ones([num_clusters], dtype=dtypes.int64)) 315 if self._use_mini_batch else None) 316 return (cluster_centers, cluster_centers_initialized, cluster_counts, 317 cluster_centers_updated, update_in_steps) 318 319 @classmethod 320 def _l2_normalize_data(cls, inputs): 321 """Normalized the input data.""" 322 output = [] 323 for inp in inputs: 324 with ops.colocate_with(inp, ignore_existing=True): 325 output.append(nn_impl.l2_normalize(inp, dim=1)) 326 return output 327 328 def training_graph(self): 329 """Generate a training graph for kmeans algorithm. 330 331 This returns, among other things, an op that chooses initial centers 332 (init_op), a boolean variable that is set to True when the initial centers 333 are chosen (cluster_centers_initialized), and an op to perform either an 334 entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op). 335 The caller should use these components as follows. A single worker should 336 execute init_op multiple times until cluster_centers_initialized becomes 337 True. Then multiple workers may execute training_op any number of times. 338 339 Returns: 340 A tuple consisting of: 341 all_scores: A matrix (or list of matrices) of dimensions (num_input, 342 num_clusters) where the value is the distance of an input vector and a 343 cluster center. 344 cluster_idx: A vector (or list of vectors). Each element in the vector 345 corresponds to an input row in 'inp' and specifies the cluster id 346 corresponding to the input. 347 scores: Similar to cluster_idx but specifies the distance to the 348 assigned cluster instead. 349 cluster_centers_initialized: scalar indicating whether clusters have been 350 initialized. 351 init_op: an op to initialize the clusters. 352 training_op: an op that runs an iteration of training. 353 """ 354 # Implementation of kmeans. 355 if (isinstance(self._initial_clusters, str) or 356 callable(self._initial_clusters)): 357 initial_clusters = self._initial_clusters 358 num_clusters = ops.convert_to_tensor(self._num_clusters) 359 else: 360 initial_clusters = ops.convert_to_tensor(self._initial_clusters) 361 num_clusters = array_ops.shape(initial_clusters)[0] 362 363 inputs = self._inputs 364 (cluster_centers_var, cluster_centers_initialized, total_counts, 365 cluster_centers_updated, 366 update_in_steps) = self._create_variables(num_clusters) 367 init_op = _InitializeClustersOpFactory( 368 self._inputs, num_clusters, initial_clusters, self._distance_metric, 369 self._seed, self._kmeans_plus_plus_num_retries, self._kmc2_chain_length, 370 cluster_centers_var, cluster_centers_updated, 371 cluster_centers_initialized).op() 372 cluster_centers = cluster_centers_var 373 374 if self._distance_metric == COSINE_DISTANCE: 375 inputs = self._l2_normalize_data(inputs) 376 if not self._clusters_l2_normalized(): 377 cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1) 378 379 all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) 380 if self._use_mini_batch: 381 sync_updates_op = self._mini_batch_sync_updates_op( 382 update_in_steps, cluster_centers_var, cluster_centers_updated, 383 total_counts) 384 assert sync_updates_op is not None 385 with ops.control_dependencies([sync_updates_op]): 386 training_op = self._mini_batch_training_op(inputs, cluster_idx, 387 cluster_centers_updated, 388 total_counts) 389 else: 390 assert cluster_centers == cluster_centers_var 391 training_op = self._full_batch_training_op(inputs, num_clusters, 392 cluster_idx, 393 cluster_centers_var) 394 395 return (all_scores, cluster_idx, scores, cluster_centers_initialized, 396 init_op, training_op) 397 398 def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, 399 cluster_centers_updated, total_counts): 400 if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: 401 assert update_in_steps is not None 402 with ops.colocate_with(update_in_steps, ignore_existing=True): 403 404 def _f(): 405 # Note that there is a race condition here, so we do a best effort 406 # updates here. We reset update_in_steps first so that other workers 407 # don't duplicate the updates. Also we update cluster_center_vars 408 # before resetting total_counts to avoid large updates to 409 # cluster_centers_updated based on partially updated 410 # cluster_center_vars. 411 with ops.control_dependencies([ 412 state_ops.assign(update_in_steps, 413 self._mini_batch_steps_per_iteration - 1) 414 ]): 415 with ops.colocate_with( 416 cluster_centers_updated, ignore_existing=True): 417 if self._distance_metric == COSINE_DISTANCE: 418 cluster_centers = nn_impl.l2_normalize( 419 cluster_centers_updated, dim=1) 420 else: 421 cluster_centers = cluster_centers_updated 422 with ops.colocate_with(cluster_centers_var, ignore_existing=True): 423 with ops.control_dependencies( 424 [state_ops.assign(cluster_centers_var, cluster_centers)]): 425 with ops.colocate_with(None, ignore_existing=True): 426 with ops.control_dependencies([ 427 state_ops.assign(total_counts, 428 array_ops.zeros_like(total_counts)) 429 ]): 430 return array_ops.identity(update_in_steps) 431 432 return control_flow_ops.cond( 433 update_in_steps <= 0, _f, 434 lambda: state_ops.assign_sub(update_in_steps, 1)) 435 else: 436 return control_flow_ops.no_op() 437 438 def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, 439 total_counts): 440 """Creates an op for training for mini batch case. 441 442 Args: 443 inputs: list of input Tensors. 444 cluster_idx_list: A vector (or list of vectors). Each element in the 445 vector corresponds to an input row in 'inp' and specifies the cluster id 446 corresponding to the input. 447 cluster_centers: Tensor Ref of cluster centers. 448 total_counts: Tensor Ref of cluster counts. 449 450 Returns: 451 An op for doing an update of mini-batch k-means. 452 """ 453 update_ops = [] 454 for inp, cluster_idx in zip(inputs, cluster_idx_list): 455 with ops.colocate_with(inp, ignore_existing=True): 456 assert total_counts is not None 457 cluster_idx = array_ops.reshape(cluster_idx, [-1]) 458 # Dedupe the unique ids of cluster_centers being updated so that updates 459 # can be locally aggregated. 460 unique_ids, unique_idx = array_ops.unique(cluster_idx) 461 num_unique_cluster_idx = array_ops.size(unique_ids) 462 # Fetch the old values of counts and cluster_centers. 463 with ops.colocate_with(total_counts, ignore_existing=True): 464 old_counts = array_ops.gather(total_counts, unique_ids) 465 # TODO(agarwal): This colocation seems to run into problems. Fix it. 466 with ops.colocate_with(cluster_centers, ignore_existing=True): 467 old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) 468 # Locally aggregate the increment to counts. 469 count_updates = math_ops.unsorted_segment_sum( 470 array_ops.ones_like(unique_idx, dtype=total_counts.dtype), 471 unique_idx, num_unique_cluster_idx) 472 # Locally compute the sum of inputs mapped to each id. 473 # For a cluster with old cluster value x, old count n, and with data 474 # d_1,...d_k newly assigned to it, we recompute the new value as 475 # \\(x += (sum_i(d_i) - k * x) / (n + k)\\). 476 # Compute \\(sum_i(d_i)\\), see comment above. 477 cluster_center_updates = math_ops.unsorted_segment_sum( 478 inp, unique_idx, num_unique_cluster_idx) 479 # Shape to enable broadcasting count_updates and learning_rate to inp. 480 # It extends the shape with 1's to match the rank of inp. 481 broadcast_shape = array_ops.concat([ 482 array_ops.reshape(num_unique_cluster_idx, [1]), 483 array_ops.ones( 484 array_ops.reshape(array_ops.rank(inp) - 1, [1]), 485 dtype=dtypes.int32) 486 ], 0) 487 # Subtract k * x, see comment above. 488 cluster_center_updates -= math_ops.cast( 489 array_ops.reshape(count_updates, broadcast_shape), 490 inp.dtype) * old_cluster_centers 491 learning_rate = math_ops.reciprocal( 492 math_ops.cast(old_counts + count_updates, inp.dtype)) 493 learning_rate = array_ops.reshape(learning_rate, broadcast_shape) 494 # scale by 1 / (n + k), see comment above. 495 cluster_center_updates *= learning_rate 496 # Apply the updates. 497 update_counts = state_ops.scatter_add(total_counts, unique_ids, 498 count_updates) 499 update_cluster_centers = state_ops.scatter_add(cluster_centers, 500 unique_ids, 501 cluster_center_updates) 502 update_ops.extend([update_counts, update_cluster_centers]) 503 return control_flow_ops.group(*update_ops) 504 505 def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list, 506 cluster_centers): 507 """Creates an op for training for full batch case. 508 509 Args: 510 inputs: list of input Tensors. 511 num_clusters: an integer Tensor providing the number of clusters. 512 cluster_idx_list: A vector (or list of vectors). Each element in the 513 vector corresponds to an input row in 'inp' and specifies the cluster id 514 corresponding to the input. 515 cluster_centers: Tensor Ref of cluster centers. 516 517 Returns: 518 An op for doing an update of mini-batch k-means. 519 """ 520 cluster_sums = [] 521 cluster_counts = [] 522 epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) 523 for inp, cluster_idx in zip(inputs, cluster_idx_list): 524 with ops.colocate_with(inp, ignore_existing=True): 525 cluster_sums.append( 526 math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters)) 527 cluster_counts.append( 528 math_ops.unsorted_segment_sum( 529 array_ops.reshape( 530 array_ops.ones( 531 array_ops.reshape(array_ops.shape(inp)[0], [-1])), 532 [-1, 1]), cluster_idx, num_clusters)) 533 with ops.colocate_with(cluster_centers, ignore_existing=True): 534 new_clusters_centers = math_ops.add_n(cluster_sums) / ( 535 math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + 536 epsilon) 537 if self._clusters_l2_normalized(): 538 new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) 539 return state_ops.assign(cluster_centers, new_clusters_centers) 540 541 542class _InitializeClustersOpFactory: 543 """Internal class to create the op to initialize the clusters. 544 545 The op performs this algorithm (see constructor args): 546 547 num_remaining = num_clusters - length(cluster_centers) 548 if num_remaining == 0: 549 assert that cluster_centers_initialized is true 550 else: 551 assert that num_remaining > 0 552 new_centers = choose up to num_remaining initial centers 553 l2-normalize new_centers if using cosine distance 554 all_centers = concat(cluster_centers, new_centers) 555 cluster_centers := all_centers 556 if there is a cluster_centers_updated variable: 557 cluster_centers_updated := cluster_centers 558 num_now_remaining = num_clusters - length(cluster_centers) 559 if num_now_remaining == 0: 560 cluster_centers_initialized := true 561 """ 562 563 # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. 564 565 def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, 566 random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, 567 cluster_centers, cluster_centers_updated, 568 cluster_centers_initialized): 569 """Creates an op factory. 570 571 Args: 572 inputs: See KMeans constructor. 573 num_clusters: An integer Tensor providing the number of clusters. 574 initial_clusters: See KMeans constructor. 575 distance_metric: See KMeans constructor. 576 random_seed: See KMeans constructor. 577 kmeans_plus_plus_num_retries: See KMeans constructor. 578 kmc2_chain_length: See KMeans constructor. 579 cluster_centers: The TF variable holding the initial centers. It may 580 already contain some centers when the op is executed. 581 cluster_centers_updated: A second TF variable to hold a copy of the 582 initial centers, used for full-batch mode. In mini-batch mode, 583 cluster_centers_updated is the same variable as cluster_centers. 584 cluster_centers_initialized: A boolean TF variable that will be set to 585 true when all the initial centers have been chosen. 586 """ 587 # All of these instance variables are constants. 588 self._inputs = inputs 589 self._num_clusters = num_clusters 590 self._initial_clusters = initial_clusters 591 self._distance_metric = distance_metric 592 self._seed = random_seed 593 self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries 594 self._kmc2_chain_length = kmc2_chain_length 595 self._cluster_centers = cluster_centers 596 self._cluster_centers_updated = cluster_centers_updated 597 self._cluster_centers_initialized = cluster_centers_initialized 598 599 self._num_selected = array_ops.shape(self._cluster_centers)[0] 600 self._num_remaining = self._num_clusters - self._num_selected 601 self._num_data = math_ops.add_n( 602 [array_ops.shape(i)[0] for i in self._inputs]) 603 604 def _random(self): 605 indices = random_ops.random_uniform( 606 array_ops.reshape(self._num_remaining, [-1]), 607 minval=0, 608 maxval=math_ops.cast(self._num_data, dtypes.int64), 609 seed=self._seed, 610 dtype=dtypes.int64) 611 return embedding_lookup(self._inputs, indices, partition_strategy='div') 612 613 def _kmeans_plus_plus(self): 614 # Points from only the first shard are used for initializing centers. 615 # TODO(ands): Use all points. 616 inp = self._inputs[0] 617 if self._distance_metric == COSINE_DISTANCE: 618 inp = nn_impl.l2_normalize(inp, dim=1) 619 return gen_clustering_ops.kmeans_plus_plus_initialization( 620 inp, math_ops.cast(self._num_remaining, dtypes.int64), self._seed, 621 self._kmeans_plus_plus_num_retries) 622 623 def _kmc2_multiple_centers(self): 624 """Adds new initial cluster centers using the k-MC2 algorithm. 625 626 In each call to the op, the provided batch is split into subsets based on 627 the specified `kmc2_chain_length`. On each subset, a single Markov chain of 628 the k-MC2 algorithm is used to add *one* new center cluster center. If there 629 are less than `kmc2_chain_length` points in the subset, a single center is 630 added using one Markov chain on the full input. It is assumed that the 631 provided batch has previously been randomly permuted. Otherwise, k-MC2 may 632 return suboptimal centers. 633 634 Returns: 635 An op that adds new cluster centers. 636 """ 637 # The op only operates on the first shard of data. 638 first_shard = self._inputs[0] 639 # Number of points in the input that can be used. 640 batch_size = array_ops.shape(first_shard)[0] 641 # Maximum number of subsets such that the size of each subset is at least 642 # `kmc2_chain_length`. Final subsets may be larger. 643 max_to_sample = math_ops.cast( 644 batch_size / self._kmc2_chain_length, dtype=dtypes.int32) 645 # We sample at least one new center and at most all remaining centers. 646 num_to_sample = math_ops.maximum( 647 math_ops.minimum(self._num_remaining, max_to_sample), 1) 648 649 def _cond(i, _): 650 """Stopping condition for the while loop.""" 651 return math_ops.less(i, num_to_sample) 652 653 def _body(i, _): 654 """Body that adds a single new center based on a subset.""" 655 656 def _sample_random(): 657 """Returns a random point as a cluster center.""" 658 # By assumption the batch is reshuffled and _sample_random is always 659 # called for i=0. Hence, we simply return the first point. 660 new_center = array_ops.reshape(first_shard[0], [1, -1]) 661 if self._distance_metric == COSINE_DISTANCE: 662 new_center = nn_impl.l2_normalize(new_center, dim=1) 663 return new_center 664 665 def _sample_kmc2_chain(): 666 """Returns previous centers as well as a new center sampled using k-MC2.""" 667 # Extract the subset from the underlying batch. 668 start = i * self._kmc2_chain_length 669 end = start + self._kmc2_chain_length 670 subset = first_shard[start:end] 671 # Compute the distances from points in the subset to previous centers. 672 _, distances = gen_clustering_ops.nearest_neighbors( 673 subset, self._cluster_centers, 1) 674 # Sample index of new center using k-MC2 Markov chain. 675 new_center_index = gen_clustering_ops.kmc2_chain_initialization( 676 array_ops.squeeze(distances), self._seed) 677 # Extract actual new center. 678 newly_sampled_center = array_ops.reshape(subset[new_center_index], 679 [1, -1]) 680 # Return concatenation with previously sampled centers. 681 if self._distance_metric == COSINE_DISTANCE: 682 newly_sampled_center = nn_impl.l2_normalize( 683 newly_sampled_center, dim=1) 684 return array_ops.concat([self._cluster_centers, newly_sampled_center], 685 0) 686 687 # Obtain a random point if there are no previously sampled centers. 688 # Otherwise, construct a k-MC2 Markov chain. 689 new_centers = control_flow_ops.cond( 690 math_ops.equal(self._num_selected, 0), _sample_random, 691 _sample_kmc2_chain) 692 # Assign new cluster centers to underlying variable. 693 assigned_centers = state_ops.assign( 694 self._cluster_centers, new_centers, validate_shape=False) 695 if self._cluster_centers_updated is not self._cluster_centers: 696 assigned_centers = state_ops.assign( 697 self._cluster_centers_updated, 698 assigned_centers, 699 validate_shape=False) 700 return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] 701 702 # Add num_to_sample new data points. 703 _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) 704 return num_remaining 705 706 def _greedy_batch_sampler(self, sampler): 707 # If the input dataset size is smaller than the number of centers 708 # remaining, choose the entire input dataset as centers. This can happen 709 # with mini-batch. Otherwise, sample the batch according to the provided 710 # sampler. 711 return control_flow_ops.cond(self._num_data <= self._num_remaining, 712 lambda: array_ops.concat(self._inputs, 0), 713 sampler) 714 715 def _single_batch_sampler(self, sampler): 716 # Enforce that there are at least as many data points as centers 717 # remaining. This gives the provided sampler the chance to select all 718 # remaining centers from a single batch. 719 with ops.control_dependencies( 720 [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]): 721 return sampler() 722 723 def _choose_initial_centers(self): 724 if isinstance(self._initial_clusters, str): 725 if self._initial_clusters == RANDOM_INIT: 726 return self._greedy_batch_sampler(self._random) 727 else: # self._initial_clusters == KMEANS_PLUS_PLUS_INIT 728 return self._single_batch_sampler(self._kmeans_plus_plus) 729 elif callable(self._initial_clusters): 730 return self._initial_clusters(self._inputs, self._num_remaining) 731 else: 732 with ops.control_dependencies([ 733 check_ops.assert_equal(self._num_remaining, 734 array_ops.shape(self._initial_clusters)[0]) 735 ]): 736 return self._initial_clusters 737 738 def _add_new_centers(self): 739 """Adds some centers and returns the number of centers remaining.""" 740 new_centers = self._choose_initial_centers() 741 if self._distance_metric == COSINE_DISTANCE: 742 new_centers = nn_impl.l2_normalize(new_centers, dim=1) 743 # If cluster_centers is empty, it doesn't have the right shape for concat. 744 all_centers = control_flow_ops.cond( 745 math_ops.equal(self._num_selected, 0), lambda: new_centers, 746 lambda: array_ops.concat([self._cluster_centers, new_centers], 0)) 747 # TODO(ccolby): De-dupe all_centers? 748 a = state_ops.assign( 749 self._cluster_centers, all_centers, validate_shape=False) 750 if self._cluster_centers_updated is not self._cluster_centers: 751 a = state_ops.assign( 752 self._cluster_centers_updated, a, validate_shape=False) 753 return self._num_clusters - array_ops.shape(a)[0] 754 755 def _initialize(self): 756 with ops.control_dependencies([ 757 check_ops.assert_positive(self._num_remaining), 758 ]): 759 if self._initial_clusters == KMC2_INIT: 760 num_now_remaining = self._kmc2_multiple_centers() 761 else: 762 num_now_remaining = self._add_new_centers() 763 return control_flow_ops.cond( 764 math_ops.equal(num_now_remaining, 0), 765 lambda: state_ops.assign(self._cluster_centers_initialized, True), 766 control_flow_ops.no_op) 767 768 def op(self): 769 """Returns the cluster initializer op.""" 770 return control_flow_ops.cond( 771 math_ops.equal(self._num_remaining, 0), 772 lambda: check_ops.assert_equal(self._cluster_centers_initialized, True), 773 self._initialize) 774