• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Synchronize replicas for training."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import types_pb2
22from tensorflow.python.distribute import distribution_strategy_context
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import data_flow_ops
27from tensorflow.python.ops import state_ops
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.training import optimizer
32from tensorflow.python.training import queue_runner
33from tensorflow.python.training import session_manager
34from tensorflow.python.training import session_run_hook
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.tf_export import tf_export
37
38
39# Please note that the gradients from replicas are averaged instead of summed
40# (as in the old sync_replicas_optimizer) so you need to increase the learning
41# rate according to the number of replicas. This change is introduced to be
42# consistent with how gradients are aggregated (averaged) within a batch in a
43# replica.
44@tf_export(v1=["train.SyncReplicasOptimizer"])
45class SyncReplicasOptimizer(optimizer.Optimizer):
46  """Class to synchronize, aggregate gradients and pass them to the optimizer.
47
48  This class is deprecated. For synchrononous training, please use [Distribution
49  Strategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute).
50
51  In a typical asynchronous training environment, it's common to have some
52  stale gradients. For example, with a N-replica asynchronous training,
53  gradients will be applied to the variables N times independently. Depending
54  on each replica's training speed, some gradients might be calculated from
55  copies of the variable from several steps back (N-1 steps on average). This
56  optimizer avoids stale gradients by collecting gradients from all replicas,
57  averaging them, then applying them to the variables in one shot, after
58  which replicas can fetch the new variables and continue.
59
60  The following accumulators/queue are created:
61
62  * N `gradient accumulators`, one per variable to train. Gradients are pushed
63    to them and the chief worker will wait until enough gradients are collected
64    and then average them before applying to variables. The accumulator will
65    drop all stale gradients (more details in the accumulator op).
66  * 1 `token` queue where the optimizer pushes the new global_step value after
67    all variables are updated.
68
69  The following local variable is created:
70  * `sync_rep_local_step`, one per replica. Compared against the global_step in
71    each accumulator to check for staleness of the gradients.
72
73  The optimizer adds nodes to the graph to collect gradients and pause the
74  trainers until variables are updated.
75  For the Parameter Server job:
76
77  1. An accumulator is created for each variable, and each replica pushes the
78     gradients into the accumulators instead of directly applying them to the
79     variables.
80  2. Each accumulator averages once enough gradients (replicas_to_aggregate)
81     have been accumulated.
82  3. Apply the averaged gradients to the variables.
83  4. Only after all variables have been updated, increment the global step.
84  5. Only after step 4, pushes `global_step` in the `token_queue`, once for
85     each worker replica. The workers can now fetch the global step, use it to
86     update its local_step variable and start the next batch. Please note that
87     some workers can consume multiple minibatches, while some may not consume
88     even one. This is because each worker fetches minibatches as long as
89     a token exists. If one worker is stuck for some reason and does not
90     consume a token, another worker can use it.
91
92  For the replicas:
93
94  1. Start a step: fetch variables and compute gradients.
95  2. Once the gradients have been computed, push them into gradient
96     accumulators. Each accumulator will check the staleness and drop the stale.
97  3. After pushing all the gradients, dequeue an updated value of global_step
98     from the token queue and record that step to its local_step variable. Note
99     that this is effectively a barrier.
100  4. Start the next batch.
101
102  ### Usage
103
104  ```python
105  # Create any optimizer to update the variables, say a simple SGD:
106  opt = GradientDescentOptimizer(learning_rate=0.1)
107
108  # Wrap the optimizer with sync_replicas_optimizer with 50 replicas: at each
109  # step the optimizer collects 50 gradients before applying to variables.
110  # Note that if you want to have 2 backup replicas, you can change
111  # total_num_replicas=52 and make sure this number matches how many physical
112  # replicas you started in your job.
113  opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
114                                 total_num_replicas=50)
115
116  # Some models have startup_delays to help stabilize the model but when using
117  # sync_replicas training, set it to 0.
118
119  # Now you can call `minimize()` or `compute_gradients()` and
120  # `apply_gradients()` normally
121  training_op = opt.minimize(total_loss, global_step=self.global_step)
122
123
124  # You can create the hook which handles initialization and queues.
125  sync_replicas_hook = opt.make_session_run_hook(is_chief)
126  ```
127
128  In the training program, every worker will run the train_op as if not
129  synchronized.
130
131  ```python
132  with training.MonitoredTrainingSession(
133      master=workers[worker_id].target, is_chief=is_chief,
134      hooks=[sync_replicas_hook]) as mon_sess:
135    while not mon_sess.should_stop():
136      mon_sess.run(training_op)
137  ```
138
139  To use SyncReplicasOptimizer with an `Estimator`, you need to send
140  sync_replicas_hook while calling the fit.
141  ```python
142  my_estimator = DNNClassifier(..., optimizer=opt)
143  my_estimator.fit(..., hooks=[sync_replicas_hook])
144  ```
145  """
146
147  @deprecation.deprecated(
148      None,
149      "The `SyncReplicaOptimizer` class is deprecated. For synchrononous "
150      "training, please use [Distribution Strategies](https://github.com/"
151      "tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute).",
152      warn_once=True)
153  def __init__(self,
154               opt,
155               replicas_to_aggregate,
156               total_num_replicas=None,
157               variable_averages=None,
158               variables_to_average=None,
159               use_locking=False,
160               name="sync_replicas"):
161    """Construct a sync_replicas optimizer.
162
163    Args:
164      opt: The actual optimizer that will be used to compute and apply the
165        gradients. Must be one of the Optimizer classes.
166      replicas_to_aggregate: number of replicas to aggregate for each variable
167        update.
168      total_num_replicas: Total number of tasks/workers/replicas, could be
169        different from replicas_to_aggregate.
170        If total_num_replicas > replicas_to_aggregate: it is backup_replicas +
171        replicas_to_aggregate.
172        If total_num_replicas < replicas_to_aggregate: Replicas compute
173        multiple batches per update to variables.
174      variable_averages: Optional `ExponentialMovingAverage` object, used to
175        maintain moving averages for the variables passed in
176        `variables_to_average`.
177      variables_to_average: a list of variables that need to be averaged. Only
178        needed if variable_averages is passed in.
179      use_locking: If True use locks for update operation.
180      name: string. Optional name of the returned operation.
181    """
182    if total_num_replicas is None:
183      total_num_replicas = replicas_to_aggregate
184
185    super(SyncReplicasOptimizer, self).__init__(use_locking, name)
186    logging.info(
187        "SyncReplicasV2: replicas_to_aggregate=%s; total_num_replicas=%s",
188        replicas_to_aggregate, total_num_replicas)
189    self._opt = opt
190    self._replicas_to_aggregate = replicas_to_aggregate
191    self._gradients_applied = False
192    self._variable_averages = variable_averages
193    self._variables_to_average = variables_to_average
194    self._total_num_replicas = total_num_replicas
195    self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate)
196    self._global_step = None
197    self._sync_token_queue = None
198
199    # The synchronization op will be executed in a queue runner which should
200    # only be executed by one of the replicas (usually the chief).
201    self._chief_queue_runner = None
202
203    # Remember which accumulator is on which device to set the initial step in
204    # the accumulator to be global step. This list contains list of the
205    # following format: (accumulator, device).
206    self._accumulator_list = []
207
208  def compute_gradients(self, *args, **kwargs):
209    """Compute gradients of "loss" for the variables in "var_list".
210
211    This simply wraps the compute_gradients() from the real optimizer. The
212    gradients will be aggregated in the apply_gradients() so that user can
213    modify the gradients like clipping with per replica global norm if needed.
214    The global norm with aggregated gradients can be bad as one replica's huge
215    gradients can hurt the gradients from other replicas.
216
217    Args:
218      *args: Arguments for compute_gradients().
219      **kwargs: Keyword arguments for compute_gradients().
220
221    Returns:
222      A list of (gradient, variable) pairs.
223    """
224    return self._opt.compute_gradients(*args, **kwargs)
225
226  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
227    """Apply gradients to variables.
228
229    This contains most of the synchronization implementation and also wraps the
230    apply_gradients() from the real optimizer.
231
232    Args:
233      grads_and_vars: List of (gradient, variable) pairs as returned by
234        compute_gradients().
235      global_step: Optional Variable to increment by one after the
236        variables have been updated.
237      name: Optional name for the returned operation.  Default to the
238        name passed to the Optimizer constructor.
239
240    Returns:
241      train_op: The op to dequeue a token so the replicas can exit this batch
242      and start the next one. This is executed by each replica.
243
244    Raises:
245      ValueError: If the grads_and_vars is empty.
246      ValueError: If global step is not provided, the staleness cannot be
247        checked.
248    """
249    if not grads_and_vars:
250      raise ValueError("Must supply at least one variable")
251
252    if global_step is None:
253      raise ValueError("Global step is required to check staleness")
254
255    self._global_step = global_step
256    train_ops = []
257    aggregated_grad = []
258    var_list = []
259
260    # local_anchor op will be placed on this worker task by default.
261    local_anchor = control_flow_ops.no_op()
262    # Colocating local_step variable prevents it being placed on the PS.
263    distribution_strategy = distribution_strategy_context.get_strategy()
264    with distribution_strategy.extended.colocate_vars_with(local_anchor):
265      self._local_step = variable_scope.variable(
266          initial_value=0,
267          trainable=False,
268          collections=[ops.GraphKeys.LOCAL_VARIABLES],
269          dtype=global_step.dtype.base_dtype,
270          name="sync_rep_local_step")
271
272    self.local_step_init_op = state_ops.assign(self._local_step, global_step)
273    chief_init_ops = [self.local_step_init_op]
274    self.ready_for_local_init_op = variables.report_uninitialized_variables(
275        variables.global_variables())
276
277    with ops.name_scope(None, self._name):
278      for grad, var in grads_and_vars:
279        var_list.append(var)
280        with ops.device(var.device):
281          # Dense gradients.
282          if grad is None:
283            aggregated_grad.append(None)  # pass-through.
284            continue
285          elif isinstance(grad, ops.Tensor):
286            grad_accum = data_flow_ops.ConditionalAccumulator(
287                grad.dtype,
288                shape=var.get_shape(),
289                shared_name=var.name + "/grad_accum")
290            train_ops.append(grad_accum.apply_grad(
291                grad, local_step=self._local_step))
292            aggregated_grad.append(grad_accum.take_grad(
293                self._replicas_to_aggregate))
294          else:
295            if not isinstance(grad, ops.IndexedSlices):
296              raise ValueError("Unknown grad type!")
297            grad_accum = data_flow_ops.SparseConditionalAccumulator(
298                grad.dtype, shape=(), shared_name=var.name + "/grad_accum")
299            train_ops.append(grad_accum.apply_indexed_slices_grad(
300                grad, local_step=self._local_step))
301            aggregated_grad.append(grad_accum.take_indexed_slices_grad(
302                self._replicas_to_aggregate))
303
304          self._accumulator_list.append((grad_accum, var.device))
305
306      aggregated_grads_and_vars = zip(aggregated_grad, var_list)
307
308      # sync_op will be assigned to the same device as the global step.
309      with ops.device(global_step.device), ops.name_scope(""):
310        update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
311                                              global_step)
312
313      # Create token queue.
314      with ops.device(global_step.device), ops.name_scope(""):
315        sync_token_queue = (
316            data_flow_ops.FIFOQueue(-1,
317                                    global_step.dtype.base_dtype,
318                                    shapes=(),
319                                    name="sync_token_q",
320                                    shared_name="sync_token_q"))
321        self._sync_token_queue = sync_token_queue
322
323        # dummy_queue is passed to the queue runner. Don't use the real queues
324        # because the queue runner doesn't automatically reopen it once it
325        # closed queues in PS devices.
326        dummy_queue = (
327            data_flow_ops.FIFOQueue(1,
328                                    types_pb2.DT_INT32,
329                                    shapes=(),
330                                    name="dummy_queue",
331                                    shared_name="dummy_queue"))
332
333      with ops.device(global_step.device), ops.name_scope(""):
334        # Replicas have to wait until they can get a token from the token queue.
335        with ops.control_dependencies(train_ops):
336          token = sync_token_queue.dequeue()
337        train_op = state_ops.assign(self._local_step, token)
338
339        with ops.control_dependencies([update_op]):
340          # Sync_op needs to insert tokens to the token queue at the end of the
341          # step so the replicas can fetch them to start the next step.
342          tokens = array_ops.fill([self._tokens_per_step], global_step)
343          sync_op = sync_token_queue.enqueue_many((tokens,))
344
345        if self._variable_averages is not None:
346          with ops.control_dependencies([sync_op]), ops.name_scope(""):
347            sync_op = self._variable_averages.apply(
348                self._variables_to_average)
349
350        self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue,
351                                                            [sync_op])
352      for accum, dev in self._accumulator_list:
353        with ops.device(dev):
354          chief_init_ops.append(
355              accum.set_global_step(
356                  global_step, name="SetGlobalStep"))
357      self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
358      self._gradients_applied = True
359      return train_op
360
361  def get_chief_queue_runner(self):
362    """Returns the QueueRunner for the chief to execute.
363
364    This includes the operations to synchronize replicas: aggregate gradients,
365    apply to variables, increment global step, insert tokens to token queue.
366
367    Note that this can only be called after calling apply_gradients() which
368    actually generates this queuerunner.
369
370    Returns:
371      A `QueueRunner` for chief to execute.
372
373    Raises:
374      ValueError: If this is called before apply_gradients().
375    """
376    if self._gradients_applied is False:
377      raise ValueError("Should be called after apply_gradients().")
378
379    return self._chief_queue_runner
380
381  def get_slot(self, *args, **kwargs):
382    """Return a slot named "name" created for "var" by the Optimizer.
383
384    This simply wraps the get_slot() from the actual optimizer.
385
386    Args:
387      *args: Arguments for get_slot().
388      **kwargs: Keyword arguments for get_slot().
389
390    Returns:
391      The `Variable` for the slot if it was created, `None` otherwise.
392    """
393    return self._opt.get_slot(*args, **kwargs)
394
395  def variables(self):
396    """Fetches a list of optimizer variables in the default graph.
397
398    This wraps `variables()` from the actual optimizer. It does not include
399    the `SyncReplicasOptimizer`'s local step.
400
401    Returns:
402      A list of variables.
403    """
404    return self._opt.variables()
405
406  def get_slot_names(self, *args, **kwargs):
407    """Return a list of the names of slots created by the `Optimizer`.
408
409    This simply wraps the get_slot_names() from the actual optimizer.
410
411    Args:
412      *args: Arguments for get_slot().
413      **kwargs: Keyword arguments for get_slot().
414
415    Returns:
416      A list of strings.
417    """
418    return self._opt.get_slot_names(*args, **kwargs)
419
420  def get_init_tokens_op(self, num_tokens=-1):
421    """Returns the op to fill the sync_token_queue with the tokens.
422
423    This is supposed to be executed in the beginning of the chief/sync thread
424    so that even if the total_num_replicas is less than replicas_to_aggregate,
425    the model can still proceed as the replicas can compute multiple steps per
426    variable update. Make sure:
427    `num_tokens >= replicas_to_aggregate - total_num_replicas`.
428
429    Args:
430      num_tokens: Number of tokens to add to the queue.
431
432    Returns:
433      An op for the chief/sync replica to fill the token queue.
434
435    Raises:
436      ValueError: If this is called before apply_gradients().
437      ValueError: If num_tokens are smaller than replicas_to_aggregate -
438        total_num_replicas.
439    """
440    if self._gradients_applied is False:
441      raise ValueError(
442          "get_init_tokens_op() should be called after apply_gradients().")
443
444    tokens_needed = self._replicas_to_aggregate - self._total_num_replicas
445    if num_tokens == -1:
446      num_tokens = self._replicas_to_aggregate
447    elif num_tokens < tokens_needed:
448      raise ValueError(
449          "Too few tokens to finish the first step: %d (given) vs %d (needed)" %
450          (num_tokens, tokens_needed))
451
452    if num_tokens > 0:
453      with ops.device(self._global_step.device), ops.name_scope(""):
454        tokens = array_ops.fill([num_tokens], self._global_step)
455        init_tokens = self._sync_token_queue.enqueue_many((tokens,))
456    else:
457      init_tokens = control_flow_ops.no_op(name="no_init_tokens")
458
459    return init_tokens
460
461  def make_session_run_hook(self, is_chief, num_tokens=-1):
462    """Creates a hook to handle SyncReplicasHook ops such as initialization."""
463    return _SyncReplicasOptimizerHook(self, is_chief, num_tokens)
464
465
466class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
467  """A SessionRunHook handles ops related to SyncReplicasOptimizer."""
468
469  def __init__(self, sync_optimizer, is_chief, num_tokens):
470    """Creates hook to handle SyncReplicasOptimizer initialization ops.
471
472    Args:
473      sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
474      is_chief: `Bool`, whether is this a chief replica or not.
475      num_tokens: Number of tokens to add to the queue.
476    """
477    self._sync_optimizer = sync_optimizer
478    self._is_chief = is_chief
479    self._num_tokens = num_tokens
480
481  def begin(self):
482    if self._sync_optimizer._gradients_applied is False:  # pylint: disable=protected-access
483      raise ValueError(
484          "SyncReplicasOptimizer.apply_gradient should be called before using "
485          "the hook.")
486    if self._is_chief:
487      self._local_init_op = self._sync_optimizer.chief_init_op
488      self._ready_for_local_init_op = (
489          self._sync_optimizer.ready_for_local_init_op)
490      self._q_runner = self._sync_optimizer.get_chief_queue_runner()
491      self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
492          self._num_tokens)
493    else:
494      self._local_init_op = self._sync_optimizer.local_step_init_op
495      self._ready_for_local_init_op = (
496          self._sync_optimizer.ready_for_local_init_op)
497      self._q_runner = None
498      self._init_tokens_op = None
499
500  def after_create_session(self, session, coord):
501    """Runs SyncReplicasOptimizer initialization ops."""
502    local_init_success, msg = session_manager._ready(  # pylint: disable=protected-access
503        self._ready_for_local_init_op, session,
504        "Model is not ready for SyncReplicasOptimizer local init.")
505    if not local_init_success:
506      raise RuntimeError(
507          "Init operations did not make model ready for SyncReplicasOptimizer "
508          "local_init. Init op: %s, error: %s" %
509          (self._local_init_op.name, msg))
510    session.run(self._local_init_op)
511    if self._init_tokens_op is not None:
512      session.run(self._init_tokens_op)
513    if self._q_runner is not None:
514      self._q_runner.create_threads(
515          session, coord=coord, daemon=True, start=True)
516