• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Clustering Operations."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import random_seed as random_seed_ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import gen_clustering_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn_impl
27from tensorflow.python.ops import random_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.ops.embedding_ops import embedding_lookup
31# go/tf-wildcard-import
32# pylint: disable=wildcard-import
33from tensorflow.python.ops.gen_clustering_ops import *
34# pylint: enable=wildcard-import
35
36# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\)
37# which is the square root of the sum of the absolute squares of the elements
38# difference.
39SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean'
40# Cosine distance between vectors U and V is defined as
41# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\)
42COSINE_DISTANCE = 'cosine'
43
44RANDOM_INIT = 'random'
45KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
46KMC2_INIT = 'kmc2'
47
48# The name of the variable holding the cluster centers. Used by the Estimator.
49CLUSTERS_VAR_NAME = 'clusters'
50
51
52class KMeans:
53  """Creates the graph for k-means clustering."""
54
55  def __init__(self,
56               inputs,
57               num_clusters,
58               initial_clusters=RANDOM_INIT,
59               distance_metric=SQUARED_EUCLIDEAN_DISTANCE,
60               use_mini_batch=False,
61               mini_batch_steps_per_iteration=1,
62               random_seed=0,
63               kmeans_plus_plus_num_retries=2,
64               kmc2_chain_length=200):
65    """Creates an object for generating KMeans clustering graph.
66
67    This class implements the following variants of K-means algorithm:
68
69    If use_mini_batch is False, it runs standard full batch K-means. Each step
70    runs a single iteration of K-Means. This step can be run sharded across
71    multiple workers by passing a list of sharded inputs to this class. Note
72    however that a single step needs to process the full input at once.
73
74    If use_mini_batch is True, it runs a generalization of the mini-batch
75    K-means algorithm. It runs multiple iterations, where each iteration is
76    composed of mini_batch_steps_per_iteration steps. Two copies of cluster
77    centers are maintained: one that is updated at the end of each iteration,
78    and one that is updated every step. The first copy is used to compute
79    cluster allocations for each step, and for inference, while the second copy
80    is the one updated each step using the mini-batch update rule. After each
81    iteration is complete, this second copy is copied back the first copy.
82
83    Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1,
84    the algorithm reduces to the standard mini-batch algorithm. Also by setting
85    mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm
86    becomes an asynchronous version of the full-batch algorithm. Note however
87    that there is no guarantee by this implementation that each input is seen
88    exactly once per iteration. Also, different updates are applied
89    asynchronously without locking. So this asynchronous version may not behave
90    exactly like a full-batch version.
91
92    Args:
93      inputs: An input tensor or list of input tensors. It is assumed that the
94        data points have been previously randomly permuted.
95      num_clusters: An integer tensor specifying the number of clusters. This
96        argument is ignored if initial_clusters is a tensor or numpy array.
97      initial_clusters: Specifies the clusters used during initialization. One
98        of the following: - a tensor or numpy array with the initial cluster
99          centers. - a function f(inputs, k) that returns up to k centers from
100          `inputs`.
101        - "random": Choose centers randomly from `inputs`.
102        - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
103        - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
104          In the last three cases, one batch of `inputs` may not yield
105          `num_clusters` centers, in which case initialization will require
106          multiple batches until enough centers are chosen. In the case of
107          "random" or "kmeans_plus_plus", if the input size is <= `num_clusters`
108          then the entire batch is chosen to be cluster centers.
109      distance_metric: Distance metric used for clustering. Supported options:
110        "squared_euclidean", "cosine".
111      use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
112        full batch.
113      mini_batch_steps_per_iteration: Number of steps after which the updated
114        cluster centers are synced back to a master copy.
115      random_seed: Seed for PRNG used to initialize seeds.
116      kmeans_plus_plus_num_retries: For each point that is sampled during
117        kmeans++ initialization, this parameter specifies the number of
118        additional points to draw from the current distribution before selecting
119        the best. If a negative value is specified, a heuristic is used to
120        sample O(log(num_to_sample)) additional points.
121      kmc2_chain_length: Determines how many candidate points are used by the
122        k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
123        contains less points, one new cluster center is generated from the
124        (mini-)batch.
125
126    Raises:
127      ValueError: An invalid argument was passed to initial_clusters or
128        distance_metric.
129    """
130    initialization_algorithms = [RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT]
131    if isinstance(initial_clusters,
132                  str) and initial_clusters not in initialization_algorithms:
133      raise ValueError(
134          f'Unsupported initialization algorithm `{initial_clusters}`,'
135          f'must be one of `{initialization_algorithms}`.')
136
137    distance_metrics = [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE]
138    if distance_metric not in distance_metrics:
139      raise ValueError(f'Unsupported distance metric `{distance_metric}`,'
140                       f'must be one of `{distance_metrics}`.')
141    self._inputs = inputs if isinstance(inputs, list) else [inputs]
142    self._num_clusters = num_clusters
143    self._initial_clusters = initial_clusters
144    self._distance_metric = distance_metric
145    self._use_mini_batch = use_mini_batch
146    self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
147    self._seed = random_seed_ops.get_seed(random_seed)[0]
148    self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
149    self._kmc2_chain_length = kmc2_chain_length
150
151  @classmethod
152  def _distance_graph(cls, inputs, clusters, distance_metric):
153    """Computes distance between each input and each cluster center.
154
155    Args:
156      inputs: list of input Tensors.
157      clusters: cluster Tensor.
158      distance_metric: distance metric used for clustering
159
160    Returns:
161      list of Tensors, where each element corresponds to each element in inputs.
162      The value is the distance of each row to all the cluster centers.
163      Currently only Euclidean distance and cosine distance are supported.
164    """
165    assert isinstance(inputs, list)
166    if distance_metric == SQUARED_EUCLIDEAN_DISTANCE:
167      return cls._compute_euclidean_distance(inputs, clusters)
168    elif distance_metric == COSINE_DISTANCE:
169      return cls._compute_cosine_distance(
170          inputs, clusters, inputs_normalized=True)
171    else:
172      assert False, str(distance_metric)
173
174  @classmethod
175  def _compute_euclidean_distance(cls, inputs, clusters):
176    """Computes Euclidean distance between each input and each cluster center.
177
178    Args:
179      inputs: list of input Tensors.
180      clusters: cluster Tensor.
181
182    Returns:
183      list of Tensors, where each element corresponds to each element in inputs.
184      The value is the distance of each row to all the cluster centers.
185    """
186    output = []
187    for inp in inputs:
188      with ops.colocate_with(inp, ignore_existing=True):
189        # Computes Euclidean distance. Note the first and third terms are
190        # broadcast additions.
191        squared_distance = (
192            math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) -
193            2 * math_ops.matmul(inp, clusters, transpose_b=True) +
194            array_ops.transpose(
195                math_ops.reduce_sum(
196                    math_ops.square(clusters), 1, keepdims=True)))
197        output.append(squared_distance)
198
199    return output
200
201  @classmethod
202  def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True):
203    """Computes cosine distance between each input and each cluster center.
204
205    Args:
206      inputs: list of input Tensor.
207      clusters: cluster Tensor
208      inputs_normalized: if True, it assumes that inp and clusters are
209        normalized and computes the dot product which is equivalent to the
210        cosine distance. Else it L2 normalizes the inputs first.
211
212    Returns:
213      list of Tensors, where each element corresponds to each element in inp.
214      The value is the distance of each row to all the cluster centers.
215    """
216    output = []
217    if not inputs_normalized:
218      with ops.colocate_with(clusters, ignore_existing=True):
219        clusters = nn_impl.l2_normalize(clusters, axis=1)
220    for inp in inputs:
221      with ops.colocate_with(inp, ignore_existing=True):
222        if not inputs_normalized:
223          inp = nn_impl.l2_normalize(inp, axis=1)
224        output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True))
225    return output
226
227  def _infer_graph(self, inputs, clusters):
228    """Maps input to closest cluster and the score.
229
230    Args:
231      inputs: list of input Tensors.
232      clusters: Tensor of cluster centers.
233
234    Returns:
235      List of tuple, where each value in tuple corresponds to a value in inp.
236      The tuple has following three elements:
237      all_scores: distance of each input to each cluster center.
238      score: distance of each input to closest cluster center.
239      cluster_idx: index of cluster center closest to the corresponding input.
240    """
241    assert isinstance(inputs, list)
242    # Pairwise distances are used only by transform(). In all other cases, this
243    # sub-graph is not evaluated.
244    scores = self._distance_graph(inputs, clusters, self._distance_metric)
245    output = []
246    if (self._distance_metric == COSINE_DISTANCE and
247        not self._clusters_l2_normalized()):
248      # The cosine distance between normalized vectors x and y is the same as
249      # 2 * squared_euclidean_distance. We are using this fact and reusing the
250      # nearest_neighbors op.
251      # TODO(ands): Support COSINE distance in nearest_neighbors and remove
252      # this.
253      with ops.colocate_with(clusters, ignore_existing=True):
254        clusters = nn_impl.l2_normalize(clusters, axis=1)
255    for inp, score in zip(inputs, scores):
256      with ops.colocate_with(inp, ignore_existing=True):
257        (indices,
258         distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
259        if self._distance_metric == COSINE_DISTANCE:
260          distances *= 0.5
261        output.append(
262            (score, array_ops.squeeze(distances,
263                                      [-1]), array_ops.squeeze(indices, [-1])))
264    return zip(*output)
265
266  def _clusters_l2_normalized(self):
267    """Returns True if clusters centers are kept normalized."""
268    return (self._distance_metric == COSINE_DISTANCE and
269            (not self._use_mini_batch or
270             self._mini_batch_steps_per_iteration > 1))
271
272  def _create_variables(self, num_clusters):
273    """Creates variables.
274
275    Args:
276      num_clusters: an integer Tensor providing the number of clusters.
277
278    Returns:
279      Tuple with following elements:
280      - cluster_centers: a Tensor for storing cluster centers
281      - cluster_centers_initialized: bool Variable indicating whether clusters
282            are initialized.
283      - cluster_counts: a Tensor for storing counts of points assigned to this
284            cluster. This is used by mini-batch training.
285      - cluster_centers_updated: Tensor representing copy of cluster centers
286            that are updated every step.
287      - update_in_steps: numbers of steps left before we sync
288            cluster_centers_updated back to cluster_centers.
289    """
290    init_value = array_ops.placeholder_with_default([], shape=None)
291    cluster_centers = variable_scope.variable(
292        init_value, name=CLUSTERS_VAR_NAME, validate_shape=False)
293    cluster_centers_initialized = variable_scope.variable(
294        False, dtype=dtypes.bool, name='initialized')
295
296    if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
297      # Copy of cluster centers actively updated each step according to
298      # mini-batch update rule.
299      cluster_centers_updated = variable_scope.variable(
300          init_value, name='clusters_updated', validate_shape=False)
301      # How many steps till we copy the updated clusters to cluster_centers.
302      update_in_steps = variable_scope.variable(
303          self._mini_batch_steps_per_iteration,
304          dtype=dtypes.int64,
305          name='update_in_steps')
306      # Count of points assigned to cluster_centers_updated.
307      cluster_counts = variable_scope.variable(
308          array_ops.zeros([num_clusters], dtype=dtypes.int64))
309    else:
310      cluster_centers_updated = cluster_centers
311      update_in_steps = None
312      cluster_counts = (
313          variable_scope.variable(
314              array_ops.ones([num_clusters], dtype=dtypes.int64))
315          if self._use_mini_batch else None)
316    return (cluster_centers, cluster_centers_initialized, cluster_counts,
317            cluster_centers_updated, update_in_steps)
318
319  @classmethod
320  def _l2_normalize_data(cls, inputs):
321    """Normalized the input data."""
322    output = []
323    for inp in inputs:
324      with ops.colocate_with(inp, ignore_existing=True):
325        output.append(nn_impl.l2_normalize(inp, dim=1))
326    return output
327
328  def training_graph(self):
329    """Generate a training graph for kmeans algorithm.
330
331    This returns, among other things, an op that chooses initial centers
332    (init_op), a boolean variable that is set to True when the initial centers
333    are chosen (cluster_centers_initialized), and an op to perform either an
334    entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op).
335    The caller should use these components as follows. A single worker should
336    execute init_op multiple times until cluster_centers_initialized becomes
337    True. Then multiple workers may execute training_op any number of times.
338
339    Returns:
340      A tuple consisting of:
341      all_scores: A matrix (or list of matrices) of dimensions (num_input,
342        num_clusters) where the value is the distance of an input vector and a
343        cluster center.
344      cluster_idx: A vector (or list of vectors). Each element in the vector
345        corresponds to an input row in 'inp' and specifies the cluster id
346        corresponding to the input.
347      scores: Similar to cluster_idx but specifies the distance to the
348        assigned cluster instead.
349      cluster_centers_initialized: scalar indicating whether clusters have been
350        initialized.
351      init_op: an op to initialize the clusters.
352      training_op: an op that runs an iteration of training.
353    """
354    # Implementation of kmeans.
355    if (isinstance(self._initial_clusters, str) or
356        callable(self._initial_clusters)):
357      initial_clusters = self._initial_clusters
358      num_clusters = ops.convert_to_tensor(self._num_clusters)
359    else:
360      initial_clusters = ops.convert_to_tensor(self._initial_clusters)
361      num_clusters = array_ops.shape(initial_clusters)[0]
362
363    inputs = self._inputs
364    (cluster_centers_var, cluster_centers_initialized, total_counts,
365     cluster_centers_updated,
366     update_in_steps) = self._create_variables(num_clusters)
367    init_op = _InitializeClustersOpFactory(
368        self._inputs, num_clusters, initial_clusters, self._distance_metric,
369        self._seed, self._kmeans_plus_plus_num_retries, self._kmc2_chain_length,
370        cluster_centers_var, cluster_centers_updated,
371        cluster_centers_initialized).op()
372    cluster_centers = cluster_centers_var
373
374    if self._distance_metric == COSINE_DISTANCE:
375      inputs = self._l2_normalize_data(inputs)
376      if not self._clusters_l2_normalized():
377        cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1)
378
379    all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
380    if self._use_mini_batch:
381      sync_updates_op = self._mini_batch_sync_updates_op(
382          update_in_steps, cluster_centers_var, cluster_centers_updated,
383          total_counts)
384      assert sync_updates_op is not None
385      with ops.control_dependencies([sync_updates_op]):
386        training_op = self._mini_batch_training_op(inputs, cluster_idx,
387                                                   cluster_centers_updated,
388                                                   total_counts)
389    else:
390      assert cluster_centers == cluster_centers_var
391      training_op = self._full_batch_training_op(inputs, num_clusters,
392                                                 cluster_idx,
393                                                 cluster_centers_var)
394
395    return (all_scores, cluster_idx, scores, cluster_centers_initialized,
396            init_op, training_op)
397
398  def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
399                                  cluster_centers_updated, total_counts):
400    if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
401      assert update_in_steps is not None
402      with ops.colocate_with(update_in_steps, ignore_existing=True):
403
404        def _f():
405          # Note that there is a race condition here, so we do a best effort
406          # updates here. We reset update_in_steps first so that other workers
407          # don't duplicate the updates. Also we update cluster_center_vars
408          # before resetting total_counts to avoid large updates to
409          # cluster_centers_updated based on partially updated
410          # cluster_center_vars.
411          with ops.control_dependencies([
412              state_ops.assign(update_in_steps,
413                               self._mini_batch_steps_per_iteration - 1)
414          ]):
415            with ops.colocate_with(
416                cluster_centers_updated, ignore_existing=True):
417              if self._distance_metric == COSINE_DISTANCE:
418                cluster_centers = nn_impl.l2_normalize(
419                    cluster_centers_updated, dim=1)
420              else:
421                cluster_centers = cluster_centers_updated
422            with ops.colocate_with(cluster_centers_var, ignore_existing=True):
423              with ops.control_dependencies(
424                  [state_ops.assign(cluster_centers_var, cluster_centers)]):
425                with ops.colocate_with(None, ignore_existing=True):
426                  with ops.control_dependencies([
427                      state_ops.assign(total_counts,
428                                       array_ops.zeros_like(total_counts))
429                  ]):
430                    return array_ops.identity(update_in_steps)
431
432        return control_flow_ops.cond(
433            update_in_steps <= 0, _f,
434            lambda: state_ops.assign_sub(update_in_steps, 1))
435    else:
436      return control_flow_ops.no_op()
437
438  def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
439                              total_counts):
440    """Creates an op for training for mini batch case.
441
442    Args:
443      inputs: list of input Tensors.
444      cluster_idx_list: A vector (or list of vectors). Each element in the
445        vector corresponds to an input row in 'inp' and specifies the cluster id
446        corresponding to the input.
447      cluster_centers: Tensor Ref of cluster centers.
448      total_counts: Tensor Ref of cluster counts.
449
450    Returns:
451      An op for doing an update of mini-batch k-means.
452    """
453    update_ops = []
454    for inp, cluster_idx in zip(inputs, cluster_idx_list):
455      with ops.colocate_with(inp, ignore_existing=True):
456        assert total_counts is not None
457        cluster_idx = array_ops.reshape(cluster_idx, [-1])
458        # Dedupe the unique ids of cluster_centers being updated so that updates
459        # can be locally aggregated.
460        unique_ids, unique_idx = array_ops.unique(cluster_idx)
461        num_unique_cluster_idx = array_ops.size(unique_ids)
462        # Fetch the old values of counts and cluster_centers.
463        with ops.colocate_with(total_counts, ignore_existing=True):
464          old_counts = array_ops.gather(total_counts, unique_ids)
465        # TODO(agarwal): This colocation seems to run into problems. Fix it.
466        with ops.colocate_with(cluster_centers, ignore_existing=True):
467          old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
468        # Locally aggregate the increment to counts.
469        count_updates = math_ops.unsorted_segment_sum(
470            array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
471            unique_idx, num_unique_cluster_idx)
472        # Locally compute the sum of inputs mapped to each id.
473        # For a cluster with old cluster value x, old count n, and with data
474        # d_1,...d_k newly assigned to it, we recompute the new value as
475        # \\(x += (sum_i(d_i) - k * x) / (n + k)\\).
476        # Compute \\(sum_i(d_i)\\), see comment above.
477        cluster_center_updates = math_ops.unsorted_segment_sum(
478            inp, unique_idx, num_unique_cluster_idx)
479        # Shape to enable broadcasting count_updates and learning_rate to inp.
480        # It extends the shape with 1's to match the rank of inp.
481        broadcast_shape = array_ops.concat([
482            array_ops.reshape(num_unique_cluster_idx, [1]),
483            array_ops.ones(
484                array_ops.reshape(array_ops.rank(inp) - 1, [1]),
485                dtype=dtypes.int32)
486        ], 0)
487        # Subtract k * x, see comment above.
488        cluster_center_updates -= math_ops.cast(
489            array_ops.reshape(count_updates, broadcast_shape),
490            inp.dtype) * old_cluster_centers
491        learning_rate = math_ops.reciprocal(
492            math_ops.cast(old_counts + count_updates, inp.dtype))
493        learning_rate = array_ops.reshape(learning_rate, broadcast_shape)
494        # scale by 1 / (n + k), see comment above.
495        cluster_center_updates *= learning_rate
496        # Apply the updates.
497      update_counts = state_ops.scatter_add(total_counts, unique_ids,
498                                            count_updates)
499      update_cluster_centers = state_ops.scatter_add(cluster_centers,
500                                                     unique_ids,
501                                                     cluster_center_updates)
502      update_ops.extend([update_counts, update_cluster_centers])
503    return control_flow_ops.group(*update_ops)
504
505  def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list,
506                              cluster_centers):
507    """Creates an op for training for full batch case.
508
509    Args:
510      inputs: list of input Tensors.
511      num_clusters: an integer Tensor providing the number of clusters.
512      cluster_idx_list: A vector (or list of vectors). Each element in the
513        vector corresponds to an input row in 'inp' and specifies the cluster id
514        corresponding to the input.
515      cluster_centers: Tensor Ref of cluster centers.
516
517    Returns:
518      An op for doing an update of mini-batch k-means.
519    """
520    cluster_sums = []
521    cluster_counts = []
522    epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
523    for inp, cluster_idx in zip(inputs, cluster_idx_list):
524      with ops.colocate_with(inp, ignore_existing=True):
525        cluster_sums.append(
526            math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters))
527        cluster_counts.append(
528            math_ops.unsorted_segment_sum(
529                array_ops.reshape(
530                    array_ops.ones(
531                        array_ops.reshape(array_ops.shape(inp)[0], [-1])),
532                    [-1, 1]), cluster_idx, num_clusters))
533    with ops.colocate_with(cluster_centers, ignore_existing=True):
534      new_clusters_centers = math_ops.add_n(cluster_sums) / (
535          math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
536          epsilon)
537      if self._clusters_l2_normalized():
538        new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
539    return state_ops.assign(cluster_centers, new_clusters_centers)
540
541
542class _InitializeClustersOpFactory:
543  """Internal class to create the op to initialize the clusters.
544
545    The op performs this algorithm (see constructor args):
546
547    num_remaining = num_clusters - length(cluster_centers)
548    if num_remaining == 0:
549      assert that cluster_centers_initialized is true
550    else:
551      assert that num_remaining > 0
552      new_centers = choose up to num_remaining initial centers
553      l2-normalize new_centers if using cosine distance
554      all_centers = concat(cluster_centers, new_centers)
555      cluster_centers := all_centers
556      if there is a cluster_centers_updated variable:
557        cluster_centers_updated := cluster_centers
558      num_now_remaining = num_clusters - length(cluster_centers)
559      if num_now_remaining == 0:
560        cluster_centers_initialized := true
561  """
562
563  # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
564
565  def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
566               random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
567               cluster_centers, cluster_centers_updated,
568               cluster_centers_initialized):
569    """Creates an op factory.
570
571    Args:
572      inputs: See KMeans constructor.
573      num_clusters: An integer Tensor providing the number of clusters.
574      initial_clusters: See KMeans constructor.
575      distance_metric: See KMeans constructor.
576      random_seed: See KMeans constructor.
577      kmeans_plus_plus_num_retries: See KMeans constructor.
578      kmc2_chain_length: See KMeans constructor.
579      cluster_centers: The TF variable holding the initial centers. It may
580        already contain some centers when the op is executed.
581      cluster_centers_updated: A second TF variable to hold a copy of the
582        initial centers, used for full-batch mode. In mini-batch mode,
583        cluster_centers_updated is the same variable as cluster_centers.
584      cluster_centers_initialized: A boolean TF variable that will be set to
585        true when all the initial centers have been chosen.
586    """
587    # All of these instance variables are constants.
588    self._inputs = inputs
589    self._num_clusters = num_clusters
590    self._initial_clusters = initial_clusters
591    self._distance_metric = distance_metric
592    self._seed = random_seed
593    self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
594    self._kmc2_chain_length = kmc2_chain_length
595    self._cluster_centers = cluster_centers
596    self._cluster_centers_updated = cluster_centers_updated
597    self._cluster_centers_initialized = cluster_centers_initialized
598
599    self._num_selected = array_ops.shape(self._cluster_centers)[0]
600    self._num_remaining = self._num_clusters - self._num_selected
601    self._num_data = math_ops.add_n(
602        [array_ops.shape(i)[0] for i in self._inputs])
603
604  def _random(self):
605    indices = random_ops.random_uniform(
606        array_ops.reshape(self._num_remaining, [-1]),
607        minval=0,
608        maxval=math_ops.cast(self._num_data, dtypes.int64),
609        seed=self._seed,
610        dtype=dtypes.int64)
611    return embedding_lookup(self._inputs, indices, partition_strategy='div')
612
613  def _kmeans_plus_plus(self):
614    # Points from only the first shard are used for initializing centers.
615    # TODO(ands): Use all points.
616    inp = self._inputs[0]
617    if self._distance_metric == COSINE_DISTANCE:
618      inp = nn_impl.l2_normalize(inp, dim=1)
619    return gen_clustering_ops.kmeans_plus_plus_initialization(
620        inp, math_ops.cast(self._num_remaining, dtypes.int64), self._seed,
621        self._kmeans_plus_plus_num_retries)
622
623  def _kmc2_multiple_centers(self):
624    """Adds new initial cluster centers using the k-MC2 algorithm.
625
626    In each call to the op, the provided batch is split into subsets based on
627    the specified `kmc2_chain_length`. On each subset, a single Markov chain of
628    the k-MC2 algorithm is used to add *one* new center cluster center. If there
629    are less than `kmc2_chain_length` points in the subset, a single center is
630    added using one Markov chain on the full input. It is assumed that the
631    provided batch has previously been randomly permuted. Otherwise, k-MC2 may
632    return suboptimal centers.
633
634    Returns:
635      An op that adds new cluster centers.
636    """
637    # The op only operates on the first shard of data.
638    first_shard = self._inputs[0]
639    # Number of points in the input that can be used.
640    batch_size = array_ops.shape(first_shard)[0]
641    # Maximum number of subsets such that the size of each subset is at least
642    # `kmc2_chain_length`. Final subsets may be larger.
643    max_to_sample = math_ops.cast(
644        batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
645    # We sample at least one new center and at most all remaining centers.
646    num_to_sample = math_ops.maximum(
647        math_ops.minimum(self._num_remaining, max_to_sample), 1)
648
649    def _cond(i, _):
650      """Stopping condition for the while loop."""
651      return math_ops.less(i, num_to_sample)
652
653    def _body(i, _):
654      """Body that adds a single new center based on a subset."""
655
656      def _sample_random():
657        """Returns a random point as a cluster center."""
658        # By assumption the batch is reshuffled and _sample_random is always
659        # called for i=0. Hence, we simply return the first point.
660        new_center = array_ops.reshape(first_shard[0], [1, -1])
661        if self._distance_metric == COSINE_DISTANCE:
662          new_center = nn_impl.l2_normalize(new_center, dim=1)
663        return new_center
664
665      def _sample_kmc2_chain():
666        """Returns previous centers as well as a new center sampled using k-MC2."""
667        # Extract the subset from the underlying batch.
668        start = i * self._kmc2_chain_length
669        end = start + self._kmc2_chain_length
670        subset = first_shard[start:end]
671        # Compute the distances from points in the subset to previous centers.
672        _, distances = gen_clustering_ops.nearest_neighbors(
673            subset, self._cluster_centers, 1)
674        # Sample index of new center using k-MC2 Markov chain.
675        new_center_index = gen_clustering_ops.kmc2_chain_initialization(
676            array_ops.squeeze(distances), self._seed)
677        # Extract actual new center.
678        newly_sampled_center = array_ops.reshape(subset[new_center_index],
679                                                 [1, -1])
680        # Return concatenation with previously sampled centers.
681        if self._distance_metric == COSINE_DISTANCE:
682          newly_sampled_center = nn_impl.l2_normalize(
683              newly_sampled_center, dim=1)
684        return array_ops.concat([self._cluster_centers, newly_sampled_center],
685                                0)
686
687      # Obtain a random point if there are no previously sampled centers.
688      # Otherwise, construct a k-MC2 Markov chain.
689      new_centers = control_flow_ops.cond(
690          math_ops.equal(self._num_selected, 0), _sample_random,
691          _sample_kmc2_chain)
692      # Assign new cluster centers to underlying variable.
693      assigned_centers = state_ops.assign(
694          self._cluster_centers, new_centers, validate_shape=False)
695      if self._cluster_centers_updated is not self._cluster_centers:
696        assigned_centers = state_ops.assign(
697            self._cluster_centers_updated,
698            assigned_centers,
699            validate_shape=False)
700      return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
701
702    # Add num_to_sample new data points.
703    _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
704    return num_remaining
705
706  def _greedy_batch_sampler(self, sampler):
707    # If the input dataset size is smaller than the number of centers
708    # remaining, choose the entire input dataset as centers. This can happen
709    # with mini-batch. Otherwise, sample the batch according to the provided
710    # sampler.
711    return control_flow_ops.cond(self._num_data <= self._num_remaining,
712                                 lambda: array_ops.concat(self._inputs, 0),
713                                 sampler)
714
715  def _single_batch_sampler(self, sampler):
716    # Enforce that there are at least as many data points as centers
717    # remaining. This gives the provided sampler the chance to select all
718    # remaining centers from a single batch.
719    with ops.control_dependencies(
720        [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]):
721      return sampler()
722
723  def _choose_initial_centers(self):
724    if isinstance(self._initial_clusters, str):
725      if self._initial_clusters == RANDOM_INIT:
726        return self._greedy_batch_sampler(self._random)
727      else:  # self._initial_clusters == KMEANS_PLUS_PLUS_INIT
728        return self._single_batch_sampler(self._kmeans_plus_plus)
729    elif callable(self._initial_clusters):
730      return self._initial_clusters(self._inputs, self._num_remaining)
731    else:
732      with ops.control_dependencies([
733          check_ops.assert_equal(self._num_remaining,
734                                 array_ops.shape(self._initial_clusters)[0])
735      ]):
736        return self._initial_clusters
737
738  def _add_new_centers(self):
739    """Adds some centers and returns the number of centers remaining."""
740    new_centers = self._choose_initial_centers()
741    if self._distance_metric == COSINE_DISTANCE:
742      new_centers = nn_impl.l2_normalize(new_centers, dim=1)
743    # If cluster_centers is empty, it doesn't have the right shape for concat.
744    all_centers = control_flow_ops.cond(
745        math_ops.equal(self._num_selected, 0), lambda: new_centers,
746        lambda: array_ops.concat([self._cluster_centers, new_centers], 0))
747    # TODO(ccolby): De-dupe all_centers?
748    a = state_ops.assign(
749        self._cluster_centers, all_centers, validate_shape=False)
750    if self._cluster_centers_updated is not self._cluster_centers:
751      a = state_ops.assign(
752          self._cluster_centers_updated, a, validate_shape=False)
753    return self._num_clusters - array_ops.shape(a)[0]
754
755  def _initialize(self):
756    with ops.control_dependencies([
757        check_ops.assert_positive(self._num_remaining),
758    ]):
759      if self._initial_clusters == KMC2_INIT:
760        num_now_remaining = self._kmc2_multiple_centers()
761      else:
762        num_now_remaining = self._add_new_centers()
763      return control_flow_ops.cond(
764          math_ops.equal(num_now_remaining, 0),
765          lambda: state_ops.assign(self._cluster_centers_initialized, True),
766          control_flow_ops.no_op)
767
768  def op(self):
769    """Returns the cluster initializer op."""
770    return control_flow_ops.cond(
771        math_ops.equal(self._num_remaining, 0),
772        lambda: check_ops.assert_equal(self._cluster_centers_initialized, True),
773        self._initialize)
774