• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of k-means clustering on top of `Estimator` API (deprecated).
16
17This module is deprecated. Please use
18`tf.contrib.factorization.KMeansClustering` instead of
19`tf.contrib.learn.KMeansClustering`. It has a similar interface, but uses the
20`tf.estimator.Estimator` API instead of `tf.contrib.learn.Estimator`.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import time
28import numpy as np
29
30from tensorflow.contrib.factorization.python.ops import clustering_ops
31from tensorflow.python.training import training_util
32from tensorflow.contrib.learn.python.learn.estimators import estimator
33from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops.control_flow_ops import with_dependencies
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.summary import summary
41from tensorflow.python.training import session_run_hook
42from tensorflow.python.training.session_run_hook import SessionRunArgs
43from tensorflow.python.util.deprecation import deprecated
44
45_USE_TF_CONTRIB_FACTORIZATION = (
46    'Please use tf.contrib.factorization.KMeansClustering instead of'
47    ' tf.contrib.learn.KMeansClustering. It has a similar interface, but uses'
48    ' the tf.estimator.Estimator API instead of tf.contrib.learn.Estimator.')
49
50
51class _LossRelativeChangeHook(session_run_hook.SessionRunHook):
52  """Stops when the change in loss goes below a tolerance."""
53
54  def __init__(self, tolerance):
55    """Initializes _LossRelativeChangeHook.
56
57    Args:
58      tolerance: A relative tolerance of change between iterations.
59    """
60    self._tolerance = tolerance
61    self._prev_loss = None
62
63  def begin(self):
64    self._loss_tensor = ops.get_default_graph().get_tensor_by_name(
65        KMeansClustering.LOSS_OP_NAME + ':0')
66    assert self._loss_tensor is not None
67
68  def before_run(self, run_context):
69    del run_context
70    return SessionRunArgs(
71        fetches={KMeansClustering.LOSS_OP_NAME: self._loss_tensor})
72
73  def after_run(self, run_context, run_values):
74    loss = run_values.results[KMeansClustering.LOSS_OP_NAME]
75    assert loss is not None
76    if self._prev_loss is not None:
77      relative_change = (abs(loss - self._prev_loss) /
78                         (1 + abs(self._prev_loss)))
79      if relative_change < self._tolerance:
80        run_context.request_stop()
81    self._prev_loss = loss
82
83
84class _InitializeClustersHook(session_run_hook.SessionRunHook):
85  """Initializes clusters or waits for cluster initialization."""
86
87  def __init__(self, init_op, is_initialized_op, is_chief):
88    self._init_op = init_op
89    self._is_chief = is_chief
90    self._is_initialized_op = is_initialized_op
91
92  def after_create_session(self, session, _):
93    assert self._init_op.graph == ops.get_default_graph()
94    assert self._is_initialized_op.graph == self._init_op.graph
95    while True:
96      try:
97        if session.run(self._is_initialized_op):
98          break
99        elif self._is_chief:
100          session.run(self._init_op)
101        else:
102          time.sleep(1)
103      except RuntimeError as e:
104        logging.info(e)
105
106
107def _parse_tensor_or_dict(features):
108  """Helper function to parse features."""
109  if isinstance(features, dict):
110    keys = sorted(features.keys())
111    with ops.colocate_with(features[keys[0]]):
112      features = array_ops.concat([features[k] for k in keys], 1)
113  return features
114
115
116def _kmeans_clustering_model_fn(features, labels, mode, params, config):
117  """Model function for KMeansClustering estimator."""
118  assert labels is None, labels
119  (all_scores, model_predictions, losses,
120   is_initialized, init_op, training_op) = clustering_ops.KMeans(
121       _parse_tensor_or_dict(features),
122       params.get('num_clusters'),
123       initial_clusters=params.get('training_initial_clusters'),
124       distance_metric=params.get('distance_metric'),
125       use_mini_batch=params.get('use_mini_batch'),
126       mini_batch_steps_per_iteration=params.get(
127           'mini_batch_steps_per_iteration'),
128       random_seed=params.get('random_seed'),
129       kmeans_plus_plus_num_retries=params.get(
130           'kmeans_plus_plus_num_retries')).training_graph()
131  incr_step = state_ops.assign_add(training_util.get_global_step(), 1)
132  loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME)
133  summary.scalar('loss/raw', loss)
134  training_op = with_dependencies([training_op, incr_step], loss)
135  predictions = {
136      KMeansClustering.ALL_SCORES: all_scores[0],
137      KMeansClustering.CLUSTER_IDX: model_predictions[0],
138  }
139  eval_metric_ops = {KMeansClustering.SCORES: loss}
140  training_hooks = [_InitializeClustersHook(
141      init_op, is_initialized, config.is_chief)]
142  relative_tolerance = params.get('relative_tolerance')
143  if relative_tolerance is not None:
144    training_hooks.append(_LossRelativeChangeHook(relative_tolerance))
145  return ModelFnOps(
146      mode=mode,
147      predictions=predictions,
148      eval_metric_ops=eval_metric_ops,
149      loss=loss,
150      train_op=training_op,
151      training_hooks=training_hooks)
152
153
154# TODO(agarwal,ands): support sharded input.
155class KMeansClustering(estimator.Estimator):
156  """An Estimator for K-Means clustering.
157
158  THIS CLASS IS DEPRECATED. See
159  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
160  for general migration instructions.
161  """
162  SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
163  COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE
164  RANDOM_INIT = clustering_ops.RANDOM_INIT
165  KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT
166  SCORES = 'scores'
167  CLUSTER_IDX = 'cluster_idx'
168  CLUSTERS = 'clusters'
169  ALL_SCORES = 'all_scores'
170  LOSS_OP_NAME = 'kmeans_loss'
171
172  @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION)
173  def __init__(self,
174               num_clusters,
175               model_dir=None,
176               initial_clusters=RANDOM_INIT,
177               distance_metric=SQUARED_EUCLIDEAN_DISTANCE,
178               random_seed=0,
179               use_mini_batch=True,
180               mini_batch_steps_per_iteration=1,
181               kmeans_plus_plus_num_retries=2,
182               relative_tolerance=None,
183               config=None):
184    """Creates a model for running KMeans training and inference.
185
186    Args:
187      num_clusters: number of clusters to train.
188      model_dir: the directory to save the model results and log files.
189      initial_clusters: specifies how to initialize the clusters for training.
190        See clustering_ops.kmeans for the possible values.
191      distance_metric: the distance metric used for clustering.
192        See clustering_ops.kmeans for the possible values.
193      random_seed: Python integer. Seed for PRNG used to initialize centers.
194      use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
195        full batch.
196      mini_batch_steps_per_iteration: number of steps after which the updated
197        cluster centers are synced back to a master copy. See clustering_ops.py
198        for more details.
199      kmeans_plus_plus_num_retries: For each point that is sampled during
200        kmeans++ initialization, this parameter specifies the number of
201        additional points to draw from the current distribution before selecting
202        the best. If a negative value is specified, a heuristic is used to
203        sample O(log(num_to_sample)) additional points.
204      relative_tolerance: A relative tolerance of change in the loss between
205        iterations.  Stops learning if the loss changes less than this amount.
206        Note that this may not work correctly if use_mini_batch=True.
207      config: See Estimator
208    """
209    params = {}
210    params['num_clusters'] = num_clusters
211    params['training_initial_clusters'] = initial_clusters
212    params['distance_metric'] = distance_metric
213    params['random_seed'] = random_seed
214    params['use_mini_batch'] = use_mini_batch
215    params['mini_batch_steps_per_iteration'] = mini_batch_steps_per_iteration
216    params['kmeans_plus_plus_num_retries'] = kmeans_plus_plus_num_retries
217    params['relative_tolerance'] = relative_tolerance
218    super(KMeansClustering, self).__init__(
219        model_fn=_kmeans_clustering_model_fn,
220        params=params,
221        model_dir=model_dir,
222        config=config)
223
224  @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION)
225  def predict_cluster_idx(self, input_fn=None):
226    """Yields predicted cluster indices."""
227    key = KMeansClustering.CLUSTER_IDX
228    results = super(KMeansClustering, self).predict(
229        input_fn=input_fn, outputs=[key])
230    for result in results:
231      yield result[key]
232
233  @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION)
234  def score(self, input_fn=None, steps=None):
235    """Predict total sum of distances to nearest clusters.
236
237    Note that this function is different from the corresponding one in sklearn
238    which returns the negative of the sum of distances.
239
240    Args:
241      input_fn: see predict.
242      steps: see predict.
243
244    Returns:
245      Total sum of distances to nearest clusters.
246    """
247    return np.sum(
248        self.evaluate(
249            input_fn=input_fn, steps=steps)[KMeansClustering.SCORES])
250
251  @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION)
252  def transform(self, input_fn=None, as_iterable=False):
253    """Transforms each element to distances to cluster centers.
254
255    Note that this function is different from the corresponding one in sklearn.
256    For SQUARED_EUCLIDEAN distance metric, sklearn transform returns the
257    EUCLIDEAN distance, while this function returns the SQUARED_EUCLIDEAN
258    distance.
259
260    Args:
261      input_fn: see predict.
262      as_iterable: see predict
263
264    Returns:
265      Array with same number of rows as x, and num_clusters columns, containing
266      distances to the cluster centers.
267    """
268    key = KMeansClustering.ALL_SCORES
269    results = super(KMeansClustering, self).predict(
270        input_fn=input_fn,
271        outputs=[key],
272        as_iterable=as_iterable)
273    if not as_iterable:
274      return results[key]
275    else:
276      return results
277
278  @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION)
279  def clusters(self):
280    """Returns cluster centers."""
281    return super(KMeansClustering, self).get_variable_value(self.CLUSTERS)
282