1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Synchronize replicas for training.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import types_pb2 22from tensorflow.python.distribute import distribution_strategy_context 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import data_flow_ops 27from tensorflow.python.ops import state_ops 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.training import optimizer 32from tensorflow.python.training import queue_runner 33from tensorflow.python.training import session_manager 34from tensorflow.python.training import session_run_hook 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.tf_export import tf_export 37 38 39# Please note that the gradients from replicas are averaged instead of summed 40# (as in the old sync_replicas_optimizer) so you need to increase the learning 41# rate according to the number of replicas. This change is introduced to be 42# consistent with how gradients are aggregated (averaged) within a batch in a 43# replica. 44@tf_export(v1=["train.SyncReplicasOptimizer"]) 45class SyncReplicasOptimizer(optimizer.Optimizer): 46 """Class to synchronize, aggregate gradients and pass them to the optimizer. 47 48 This class is deprecated. For synchrononous training, please use [Distribution 49 Strategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). 50 51 In a typical asynchronous training environment, it's common to have some 52 stale gradients. For example, with a N-replica asynchronous training, 53 gradients will be applied to the variables N times independently. Depending 54 on each replica's training speed, some gradients might be calculated from 55 copies of the variable from several steps back (N-1 steps on average). This 56 optimizer avoids stale gradients by collecting gradients from all replicas, 57 averaging them, then applying them to the variables in one shot, after 58 which replicas can fetch the new variables and continue. 59 60 The following accumulators/queue are created: 61 62 * N `gradient accumulators`, one per variable to train. Gradients are pushed 63 to them and the chief worker will wait until enough gradients are collected 64 and then average them before applying to variables. The accumulator will 65 drop all stale gradients (more details in the accumulator op). 66 * 1 `token` queue where the optimizer pushes the new global_step value after 67 all variables are updated. 68 69 The following local variable is created: 70 * `sync_rep_local_step`, one per replica. Compared against the global_step in 71 each accumulator to check for staleness of the gradients. 72 73 The optimizer adds nodes to the graph to collect gradients and pause the 74 trainers until variables are updated. 75 For the Parameter Server job: 76 77 1. An accumulator is created for each variable, and each replica pushes the 78 gradients into the accumulators instead of directly applying them to the 79 variables. 80 2. Each accumulator averages once enough gradients (replicas_to_aggregate) 81 have been accumulated. 82 3. Apply the averaged gradients to the variables. 83 4. Only after all variables have been updated, increment the global step. 84 5. Only after step 4, pushes `global_step` in the `token_queue`, once for 85 each worker replica. The workers can now fetch the global step, use it to 86 update its local_step variable and start the next batch. Please note that 87 some workers can consume multiple minibatches, while some may not consume 88 even one. This is because each worker fetches minibatches as long as 89 a token exists. If one worker is stuck for some reason and does not 90 consume a token, another worker can use it. 91 92 For the replicas: 93 94 1. Start a step: fetch variables and compute gradients. 95 2. Once the gradients have been computed, push them into gradient 96 accumulators. Each accumulator will check the staleness and drop the stale. 97 3. After pushing all the gradients, dequeue an updated value of global_step 98 from the token queue and record that step to its local_step variable. Note 99 that this is effectively a barrier. 100 4. Start the next batch. 101 102 ### Usage 103 104 ```python 105 # Create any optimizer to update the variables, say a simple SGD: 106 opt = GradientDescentOptimizer(learning_rate=0.1) 107 108 # Wrap the optimizer with sync_replicas_optimizer with 50 replicas: at each 109 # step the optimizer collects 50 gradients before applying to variables. 110 # Note that if you want to have 2 backup replicas, you can change 111 # total_num_replicas=52 and make sure this number matches how many physical 112 # replicas you started in your job. 113 opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50, 114 total_num_replicas=50) 115 116 # Some models have startup_delays to help stabilize the model but when using 117 # sync_replicas training, set it to 0. 118 119 # Now you can call `minimize()` or `compute_gradients()` and 120 # `apply_gradients()` normally 121 training_op = opt.minimize(total_loss, global_step=self.global_step) 122 123 124 # You can create the hook which handles initialization and queues. 125 sync_replicas_hook = opt.make_session_run_hook(is_chief) 126 ``` 127 128 In the training program, every worker will run the train_op as if not 129 synchronized. 130 131 ```python 132 with training.MonitoredTrainingSession( 133 master=workers[worker_id].target, is_chief=is_chief, 134 hooks=[sync_replicas_hook]) as mon_sess: 135 while not mon_sess.should_stop(): 136 mon_sess.run(training_op) 137 ``` 138 139 To use SyncReplicasOptimizer with an `Estimator`, you need to send 140 sync_replicas_hook while calling the fit. 141 ```python 142 my_estimator = DNNClassifier(..., optimizer=opt) 143 my_estimator.fit(..., hooks=[sync_replicas_hook]) 144 ``` 145 """ 146 147 @deprecation.deprecated( 148 None, 149 "The `SyncReplicaOptimizer` class is deprecated. For synchrononous " 150 "training, please use [Distribution Strategies](https://github.com/" 151 "tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute).", 152 warn_once=True) 153 def __init__(self, 154 opt, 155 replicas_to_aggregate, 156 total_num_replicas=None, 157 variable_averages=None, 158 variables_to_average=None, 159 use_locking=False, 160 name="sync_replicas"): 161 """Construct a sync_replicas optimizer. 162 163 Args: 164 opt: The actual optimizer that will be used to compute and apply the 165 gradients. Must be one of the Optimizer classes. 166 replicas_to_aggregate: number of replicas to aggregate for each variable 167 update. 168 total_num_replicas: Total number of tasks/workers/replicas, could be 169 different from replicas_to_aggregate. 170 If total_num_replicas > replicas_to_aggregate: it is backup_replicas + 171 replicas_to_aggregate. 172 If total_num_replicas < replicas_to_aggregate: Replicas compute 173 multiple batches per update to variables. 174 variable_averages: Optional `ExponentialMovingAverage` object, used to 175 maintain moving averages for the variables passed in 176 `variables_to_average`. 177 variables_to_average: a list of variables that need to be averaged. Only 178 needed if variable_averages is passed in. 179 use_locking: If True use locks for update operation. 180 name: string. Optional name of the returned operation. 181 """ 182 if total_num_replicas is None: 183 total_num_replicas = replicas_to_aggregate 184 185 super(SyncReplicasOptimizer, self).__init__(use_locking, name) 186 logging.info( 187 "SyncReplicasV2: replicas_to_aggregate=%s; total_num_replicas=%s", 188 replicas_to_aggregate, total_num_replicas) 189 self._opt = opt 190 self._replicas_to_aggregate = replicas_to_aggregate 191 self._gradients_applied = False 192 self._variable_averages = variable_averages 193 self._variables_to_average = variables_to_average 194 self._total_num_replicas = total_num_replicas 195 self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate) 196 self._global_step = None 197 self._sync_token_queue = None 198 199 # The synchronization op will be executed in a queue runner which should 200 # only be executed by one of the replicas (usually the chief). 201 self._chief_queue_runner = None 202 203 # Remember which accumulator is on which device to set the initial step in 204 # the accumulator to be global step. This list contains list of the 205 # following format: (accumulator, device). 206 self._accumulator_list = [] 207 208 def compute_gradients(self, *args, **kwargs): 209 """Compute gradients of "loss" for the variables in "var_list". 210 211 This simply wraps the compute_gradients() from the real optimizer. The 212 gradients will be aggregated in the apply_gradients() so that user can 213 modify the gradients like clipping with per replica global norm if needed. 214 The global norm with aggregated gradients can be bad as one replica's huge 215 gradients can hurt the gradients from other replicas. 216 217 Args: 218 *args: Arguments for compute_gradients(). 219 **kwargs: Keyword arguments for compute_gradients(). 220 221 Returns: 222 A list of (gradient, variable) pairs. 223 """ 224 return self._opt.compute_gradients(*args, **kwargs) 225 226 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 227 """Apply gradients to variables. 228 229 This contains most of the synchronization implementation and also wraps the 230 apply_gradients() from the real optimizer. 231 232 Args: 233 grads_and_vars: List of (gradient, variable) pairs as returned by 234 compute_gradients(). 235 global_step: Optional Variable to increment by one after the 236 variables have been updated. 237 name: Optional name for the returned operation. Default to the 238 name passed to the Optimizer constructor. 239 240 Returns: 241 train_op: The op to dequeue a token so the replicas can exit this batch 242 and start the next one. This is executed by each replica. 243 244 Raises: 245 ValueError: If the grads_and_vars is empty. 246 ValueError: If global step is not provided, the staleness cannot be 247 checked. 248 """ 249 if not grads_and_vars: 250 raise ValueError("Must supply at least one variable") 251 252 if global_step is None: 253 raise ValueError("Global step is required to check staleness") 254 255 self._global_step = global_step 256 train_ops = [] 257 aggregated_grad = [] 258 var_list = [] 259 260 # local_anchor op will be placed on this worker task by default. 261 local_anchor = control_flow_ops.no_op() 262 # Colocating local_step variable prevents it being placed on the PS. 263 distribution_strategy = distribution_strategy_context.get_strategy() 264 with distribution_strategy.extended.colocate_vars_with(local_anchor): 265 self._local_step = variable_scope.variable( 266 initial_value=0, 267 trainable=False, 268 collections=[ops.GraphKeys.LOCAL_VARIABLES], 269 dtype=global_step.dtype.base_dtype, 270 name="sync_rep_local_step") 271 272 self.local_step_init_op = state_ops.assign(self._local_step, global_step) 273 chief_init_ops = [self.local_step_init_op] 274 self.ready_for_local_init_op = variables.report_uninitialized_variables( 275 variables.global_variables()) 276 277 with ops.name_scope(None, self._name): 278 for grad, var in grads_and_vars: 279 var_list.append(var) 280 with ops.device(var.device): 281 # Dense gradients. 282 if grad is None: 283 aggregated_grad.append(None) # pass-through. 284 continue 285 elif isinstance(grad, ops.Tensor): 286 grad_accum = data_flow_ops.ConditionalAccumulator( 287 grad.dtype, 288 shape=var.get_shape(), 289 shared_name=var.name + "/grad_accum") 290 train_ops.append(grad_accum.apply_grad( 291 grad, local_step=self._local_step)) 292 aggregated_grad.append(grad_accum.take_grad( 293 self._replicas_to_aggregate)) 294 else: 295 if not isinstance(grad, ops.IndexedSlices): 296 raise ValueError("Unknown grad type!") 297 grad_accum = data_flow_ops.SparseConditionalAccumulator( 298 grad.dtype, shape=(), shared_name=var.name + "/grad_accum") 299 train_ops.append(grad_accum.apply_indexed_slices_grad( 300 grad, local_step=self._local_step)) 301 aggregated_grad.append(grad_accum.take_indexed_slices_grad( 302 self._replicas_to_aggregate)) 303 304 self._accumulator_list.append((grad_accum, var.device)) 305 306 aggregated_grads_and_vars = zip(aggregated_grad, var_list) 307 308 # sync_op will be assigned to the same device as the global step. 309 with ops.device(global_step.device), ops.name_scope(""): 310 update_op = self._opt.apply_gradients(aggregated_grads_and_vars, 311 global_step) 312 313 # Create token queue. 314 with ops.device(global_step.device), ops.name_scope(""): 315 sync_token_queue = ( 316 data_flow_ops.FIFOQueue(-1, 317 global_step.dtype.base_dtype, 318 shapes=(), 319 name="sync_token_q", 320 shared_name="sync_token_q")) 321 self._sync_token_queue = sync_token_queue 322 323 # dummy_queue is passed to the queue runner. Don't use the real queues 324 # because the queue runner doesn't automatically reopen it once it 325 # closed queues in PS devices. 326 dummy_queue = ( 327 data_flow_ops.FIFOQueue(1, 328 types_pb2.DT_INT32, 329 shapes=(), 330 name="dummy_queue", 331 shared_name="dummy_queue")) 332 333 with ops.device(global_step.device), ops.name_scope(""): 334 # Replicas have to wait until they can get a token from the token queue. 335 with ops.control_dependencies(train_ops): 336 token = sync_token_queue.dequeue() 337 train_op = state_ops.assign(self._local_step, token) 338 339 with ops.control_dependencies([update_op]): 340 # Sync_op needs to insert tokens to the token queue at the end of the 341 # step so the replicas can fetch them to start the next step. 342 tokens = array_ops.fill([self._tokens_per_step], global_step) 343 sync_op = sync_token_queue.enqueue_many((tokens,)) 344 345 if self._variable_averages is not None: 346 with ops.control_dependencies([sync_op]), ops.name_scope(""): 347 sync_op = self._variable_averages.apply( 348 self._variables_to_average) 349 350 self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, 351 [sync_op]) 352 for accum, dev in self._accumulator_list: 353 with ops.device(dev): 354 chief_init_ops.append( 355 accum.set_global_step( 356 global_step, name="SetGlobalStep")) 357 self.chief_init_op = control_flow_ops.group(*(chief_init_ops)) 358 self._gradients_applied = True 359 return train_op 360 361 def get_chief_queue_runner(self): 362 """Returns the QueueRunner for the chief to execute. 363 364 This includes the operations to synchronize replicas: aggregate gradients, 365 apply to variables, increment global step, insert tokens to token queue. 366 367 Note that this can only be called after calling apply_gradients() which 368 actually generates this queuerunner. 369 370 Returns: 371 A `QueueRunner` for chief to execute. 372 373 Raises: 374 ValueError: If this is called before apply_gradients(). 375 """ 376 if self._gradients_applied is False: 377 raise ValueError("Should be called after apply_gradients().") 378 379 return self._chief_queue_runner 380 381 def get_slot(self, *args, **kwargs): 382 """Return a slot named "name" created for "var" by the Optimizer. 383 384 This simply wraps the get_slot() from the actual optimizer. 385 386 Args: 387 *args: Arguments for get_slot(). 388 **kwargs: Keyword arguments for get_slot(). 389 390 Returns: 391 The `Variable` for the slot if it was created, `None` otherwise. 392 """ 393 return self._opt.get_slot(*args, **kwargs) 394 395 def variables(self): 396 """Fetches a list of optimizer variables in the default graph. 397 398 This wraps `variables()` from the actual optimizer. It does not include 399 the `SyncReplicasOptimizer`'s local step. 400 401 Returns: 402 A list of variables. 403 """ 404 return self._opt.variables() 405 406 def get_slot_names(self, *args, **kwargs): 407 """Return a list of the names of slots created by the `Optimizer`. 408 409 This simply wraps the get_slot_names() from the actual optimizer. 410 411 Args: 412 *args: Arguments for get_slot(). 413 **kwargs: Keyword arguments for get_slot(). 414 415 Returns: 416 A list of strings. 417 """ 418 return self._opt.get_slot_names(*args, **kwargs) 419 420 def get_init_tokens_op(self, num_tokens=-1): 421 """Returns the op to fill the sync_token_queue with the tokens. 422 423 This is supposed to be executed in the beginning of the chief/sync thread 424 so that even if the total_num_replicas is less than replicas_to_aggregate, 425 the model can still proceed as the replicas can compute multiple steps per 426 variable update. Make sure: 427 `num_tokens >= replicas_to_aggregate - total_num_replicas`. 428 429 Args: 430 num_tokens: Number of tokens to add to the queue. 431 432 Returns: 433 An op for the chief/sync replica to fill the token queue. 434 435 Raises: 436 ValueError: If this is called before apply_gradients(). 437 ValueError: If num_tokens are smaller than replicas_to_aggregate - 438 total_num_replicas. 439 """ 440 if self._gradients_applied is False: 441 raise ValueError( 442 "get_init_tokens_op() should be called after apply_gradients().") 443 444 tokens_needed = self._replicas_to_aggregate - self._total_num_replicas 445 if num_tokens == -1: 446 num_tokens = self._replicas_to_aggregate 447 elif num_tokens < tokens_needed: 448 raise ValueError( 449 "Too few tokens to finish the first step: %d (given) vs %d (needed)" % 450 (num_tokens, tokens_needed)) 451 452 if num_tokens > 0: 453 with ops.device(self._global_step.device), ops.name_scope(""): 454 tokens = array_ops.fill([num_tokens], self._global_step) 455 init_tokens = self._sync_token_queue.enqueue_many((tokens,)) 456 else: 457 init_tokens = control_flow_ops.no_op(name="no_init_tokens") 458 459 return init_tokens 460 461 def make_session_run_hook(self, is_chief, num_tokens=-1): 462 """Creates a hook to handle SyncReplicasHook ops such as initialization.""" 463 return _SyncReplicasOptimizerHook(self, is_chief, num_tokens) 464 465 466class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook): 467 """A SessionRunHook handles ops related to SyncReplicasOptimizer.""" 468 469 def __init__(self, sync_optimizer, is_chief, num_tokens): 470 """Creates hook to handle SyncReplicasOptimizer initialization ops. 471 472 Args: 473 sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize. 474 is_chief: `Bool`, whether is this a chief replica or not. 475 num_tokens: Number of tokens to add to the queue. 476 """ 477 self._sync_optimizer = sync_optimizer 478 self._is_chief = is_chief 479 self._num_tokens = num_tokens 480 481 def begin(self): 482 if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access 483 raise ValueError( 484 "SyncReplicasOptimizer.apply_gradient should be called before using " 485 "the hook.") 486 if self._is_chief: 487 self._local_init_op = self._sync_optimizer.chief_init_op 488 self._ready_for_local_init_op = ( 489 self._sync_optimizer.ready_for_local_init_op) 490 self._q_runner = self._sync_optimizer.get_chief_queue_runner() 491 self._init_tokens_op = self._sync_optimizer.get_init_tokens_op( 492 self._num_tokens) 493 else: 494 self._local_init_op = self._sync_optimizer.local_step_init_op 495 self._ready_for_local_init_op = ( 496 self._sync_optimizer.ready_for_local_init_op) 497 self._q_runner = None 498 self._init_tokens_op = None 499 500 def after_create_session(self, session, coord): 501 """Runs SyncReplicasOptimizer initialization ops.""" 502 local_init_success, msg = session_manager._ready( # pylint: disable=protected-access 503 self._ready_for_local_init_op, session, 504 "Model is not ready for SyncReplicasOptimizer local init.") 505 if not local_init_success: 506 raise RuntimeError( 507 "Init operations did not make model ready for SyncReplicasOptimizer " 508 "local_init. Init op: %s, error: %s" % 509 (self._local_init_op.name, msg)) 510 session.run(self._local_init_op) 511 if self._init_tokens_op is not None: 512 session.run(self._init_tokens_op) 513 if self._q_runner is not None: 514 self._q_runner.create_threads( 515 session, coord=coord, daemon=True, start=True) 516