1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of k-means clustering on top of `Estimator` API (deprecated). 16 17This module is deprecated. Please use 18`tf.contrib.factorization.KMeansClustering` instead of 19`tf.contrib.learn.KMeansClustering`. It has a similar interface, but uses the 20`tf.estimator.Estimator` API instead of `tf.contrib.learn.Estimator`. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import time 28import numpy as np 29 30from tensorflow.contrib.factorization.python.ops import clustering_ops 31from tensorflow.python.training import training_util 32from tensorflow.contrib.learn.python.learn.estimators import estimator 33from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops.control_flow_ops import with_dependencies 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.summary import summary 41from tensorflow.python.training import session_run_hook 42from tensorflow.python.training.session_run_hook import SessionRunArgs 43from tensorflow.python.util.deprecation import deprecated 44 45_USE_TF_CONTRIB_FACTORIZATION = ( 46 'Please use tf.contrib.factorization.KMeansClustering instead of' 47 ' tf.contrib.learn.KMeansClustering. It has a similar interface, but uses' 48 ' the tf.estimator.Estimator API instead of tf.contrib.learn.Estimator.') 49 50 51class _LossRelativeChangeHook(session_run_hook.SessionRunHook): 52 """Stops when the change in loss goes below a tolerance.""" 53 54 def __init__(self, tolerance): 55 """Initializes _LossRelativeChangeHook. 56 57 Args: 58 tolerance: A relative tolerance of change between iterations. 59 """ 60 self._tolerance = tolerance 61 self._prev_loss = None 62 63 def begin(self): 64 self._loss_tensor = ops.get_default_graph().get_tensor_by_name( 65 KMeansClustering.LOSS_OP_NAME + ':0') 66 assert self._loss_tensor is not None 67 68 def before_run(self, run_context): 69 del run_context 70 return SessionRunArgs( 71 fetches={KMeansClustering.LOSS_OP_NAME: self._loss_tensor}) 72 73 def after_run(self, run_context, run_values): 74 loss = run_values.results[KMeansClustering.LOSS_OP_NAME] 75 assert loss is not None 76 if self._prev_loss is not None: 77 relative_change = (abs(loss - self._prev_loss) / 78 (1 + abs(self._prev_loss))) 79 if relative_change < self._tolerance: 80 run_context.request_stop() 81 self._prev_loss = loss 82 83 84class _InitializeClustersHook(session_run_hook.SessionRunHook): 85 """Initializes clusters or waits for cluster initialization.""" 86 87 def __init__(self, init_op, is_initialized_op, is_chief): 88 self._init_op = init_op 89 self._is_chief = is_chief 90 self._is_initialized_op = is_initialized_op 91 92 def after_create_session(self, session, _): 93 assert self._init_op.graph == ops.get_default_graph() 94 assert self._is_initialized_op.graph == self._init_op.graph 95 while True: 96 try: 97 if session.run(self._is_initialized_op): 98 break 99 elif self._is_chief: 100 session.run(self._init_op) 101 else: 102 time.sleep(1) 103 except RuntimeError as e: 104 logging.info(e) 105 106 107def _parse_tensor_or_dict(features): 108 """Helper function to parse features.""" 109 if isinstance(features, dict): 110 keys = sorted(features.keys()) 111 with ops.colocate_with(features[keys[0]]): 112 features = array_ops.concat([features[k] for k in keys], 1) 113 return features 114 115 116def _kmeans_clustering_model_fn(features, labels, mode, params, config): 117 """Model function for KMeansClustering estimator.""" 118 assert labels is None, labels 119 (all_scores, model_predictions, losses, 120 is_initialized, init_op, training_op) = clustering_ops.KMeans( 121 _parse_tensor_or_dict(features), 122 params.get('num_clusters'), 123 initial_clusters=params.get('training_initial_clusters'), 124 distance_metric=params.get('distance_metric'), 125 use_mini_batch=params.get('use_mini_batch'), 126 mini_batch_steps_per_iteration=params.get( 127 'mini_batch_steps_per_iteration'), 128 random_seed=params.get('random_seed'), 129 kmeans_plus_plus_num_retries=params.get( 130 'kmeans_plus_plus_num_retries')).training_graph() 131 incr_step = state_ops.assign_add(training_util.get_global_step(), 1) 132 loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME) 133 summary.scalar('loss/raw', loss) 134 training_op = with_dependencies([training_op, incr_step], loss) 135 predictions = { 136 KMeansClustering.ALL_SCORES: all_scores[0], 137 KMeansClustering.CLUSTER_IDX: model_predictions[0], 138 } 139 eval_metric_ops = {KMeansClustering.SCORES: loss} 140 training_hooks = [_InitializeClustersHook( 141 init_op, is_initialized, config.is_chief)] 142 relative_tolerance = params.get('relative_tolerance') 143 if relative_tolerance is not None: 144 training_hooks.append(_LossRelativeChangeHook(relative_tolerance)) 145 return ModelFnOps( 146 mode=mode, 147 predictions=predictions, 148 eval_metric_ops=eval_metric_ops, 149 loss=loss, 150 train_op=training_op, 151 training_hooks=training_hooks) 152 153 154# TODO(agarwal,ands): support sharded input. 155class KMeansClustering(estimator.Estimator): 156 """An Estimator for K-Means clustering. 157 158 THIS CLASS IS DEPRECATED. See 159 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 160 for general migration instructions. 161 """ 162 SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE 163 COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE 164 RANDOM_INIT = clustering_ops.RANDOM_INIT 165 KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT 166 SCORES = 'scores' 167 CLUSTER_IDX = 'cluster_idx' 168 CLUSTERS = 'clusters' 169 ALL_SCORES = 'all_scores' 170 LOSS_OP_NAME = 'kmeans_loss' 171 172 @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) 173 def __init__(self, 174 num_clusters, 175 model_dir=None, 176 initial_clusters=RANDOM_INIT, 177 distance_metric=SQUARED_EUCLIDEAN_DISTANCE, 178 random_seed=0, 179 use_mini_batch=True, 180 mini_batch_steps_per_iteration=1, 181 kmeans_plus_plus_num_retries=2, 182 relative_tolerance=None, 183 config=None): 184 """Creates a model for running KMeans training and inference. 185 186 Args: 187 num_clusters: number of clusters to train. 188 model_dir: the directory to save the model results and log files. 189 initial_clusters: specifies how to initialize the clusters for training. 190 See clustering_ops.kmeans for the possible values. 191 distance_metric: the distance metric used for clustering. 192 See clustering_ops.kmeans for the possible values. 193 random_seed: Python integer. Seed for PRNG used to initialize centers. 194 use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume 195 full batch. 196 mini_batch_steps_per_iteration: number of steps after which the updated 197 cluster centers are synced back to a master copy. See clustering_ops.py 198 for more details. 199 kmeans_plus_plus_num_retries: For each point that is sampled during 200 kmeans++ initialization, this parameter specifies the number of 201 additional points to draw from the current distribution before selecting 202 the best. If a negative value is specified, a heuristic is used to 203 sample O(log(num_to_sample)) additional points. 204 relative_tolerance: A relative tolerance of change in the loss between 205 iterations. Stops learning if the loss changes less than this amount. 206 Note that this may not work correctly if use_mini_batch=True. 207 config: See Estimator 208 """ 209 params = {} 210 params['num_clusters'] = num_clusters 211 params['training_initial_clusters'] = initial_clusters 212 params['distance_metric'] = distance_metric 213 params['random_seed'] = random_seed 214 params['use_mini_batch'] = use_mini_batch 215 params['mini_batch_steps_per_iteration'] = mini_batch_steps_per_iteration 216 params['kmeans_plus_plus_num_retries'] = kmeans_plus_plus_num_retries 217 params['relative_tolerance'] = relative_tolerance 218 super(KMeansClustering, self).__init__( 219 model_fn=_kmeans_clustering_model_fn, 220 params=params, 221 model_dir=model_dir, 222 config=config) 223 224 @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) 225 def predict_cluster_idx(self, input_fn=None): 226 """Yields predicted cluster indices.""" 227 key = KMeansClustering.CLUSTER_IDX 228 results = super(KMeansClustering, self).predict( 229 input_fn=input_fn, outputs=[key]) 230 for result in results: 231 yield result[key] 232 233 @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) 234 def score(self, input_fn=None, steps=None): 235 """Predict total sum of distances to nearest clusters. 236 237 Note that this function is different from the corresponding one in sklearn 238 which returns the negative of the sum of distances. 239 240 Args: 241 input_fn: see predict. 242 steps: see predict. 243 244 Returns: 245 Total sum of distances to nearest clusters. 246 """ 247 return np.sum( 248 self.evaluate( 249 input_fn=input_fn, steps=steps)[KMeansClustering.SCORES]) 250 251 @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) 252 def transform(self, input_fn=None, as_iterable=False): 253 """Transforms each element to distances to cluster centers. 254 255 Note that this function is different from the corresponding one in sklearn. 256 For SQUARED_EUCLIDEAN distance metric, sklearn transform returns the 257 EUCLIDEAN distance, while this function returns the SQUARED_EUCLIDEAN 258 distance. 259 260 Args: 261 input_fn: see predict. 262 as_iterable: see predict 263 264 Returns: 265 Array with same number of rows as x, and num_clusters columns, containing 266 distances to the cluster centers. 267 """ 268 key = KMeansClustering.ALL_SCORES 269 results = super(KMeansClustering, self).predict( 270 input_fn=input_fn, 271 outputs=[key], 272 as_iterable=as_iterable) 273 if not as_iterable: 274 return results[key] 275 else: 276 return results 277 278 @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) 279 def clusters(self): 280 """Returns cluster centers.""" 281 return super(KMeansClustering, self).get_variable_value(self.CLUSTERS) 282