• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Optimizer that implements cross-shard gradient reduction for TPU."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23from tensorflow.python.framework import ops
24from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25from tensorflow.python.ops.losses import losses
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.tpu import tpu_function
28from tensorflow.python.tpu.ops import tpu_ops
29from tensorflow.python.training import optimizer
30from tensorflow.python.util.tf_export import tf_export
31
32
33@tf_export(v1=["tpu.CrossShardOptimizer"])
34class CrossShardOptimizer(optimizer.Optimizer):
35  """An optimizer that averages gradients across TPU shards."""
36
37  def __init__(self,
38               opt,
39               reduction=losses.Reduction.MEAN,
40               name="CrossShardOptimizer",
41               group_assignment=None):
42    """Construct a new cross-shard optimizer.
43
44    Args:
45      opt: An existing `Optimizer` to encapsulate.
46      reduction: The reduction to apply to the shard losses.
47      name: Optional name prefix for the operations created when applying
48        gradients. Defaults to "CrossShardOptimizer".
49      group_assignment: Optional 2d int32 lists with shape
50        [num_groups, num_replicas_per_group] which describles how to apply
51        optimizer to subgroups.
52
53    Raises:
54      ValueError: If reduction is not a valid cross-shard reduction.
55    """
56    if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
57      raise ValueError("Unsupported reduction: %s." % reduction)
58    if isinstance(opt, optimizer_v2.OptimizerV2):
59      raise TypeError(
60          "CrossShardOptimizer does not work with OptimizerV2. If you are "
61          "using TPUStrategy, OptimizerV2 will sum gradients across replicas."
62          "If you are using TPUEstimator, you may instead sum your gradients "
63          "with: grads = [tf.compat.v1.tpu.cross_replica_sum(g) for g in grads]"
64          ". If you want to average your gradients, rescale your loss with: "
65          "loss /= global_batch_size")
66
67    super(CrossShardOptimizer, self).__init__(False, name)
68    self._opt = opt
69    self._reduction = reduction
70    self._group_assignment = group_assignment
71
72  def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
73    """Verify group_assignment and get the subgroup size".
74
75    Args:
76      group_assignment: list of group ids for applying the optimizer
77        to subgroups.
78      num_shards: The number of TPU shards.
79
80    Returns:
81      The size of one subgroup in group_assignment.
82
83    Raises:
84      ValueError: If group_assignment is invalid.
85    """
86    if not group_assignment:
87      return None
88    if not (isinstance(group_assignment, list) and
89            all(isinstance(i, list) for i in group_assignment)):
90      raise ValueError("group_assignment must be a list of list. Got {}".format(
91          group_assignment))
92
93    replica_ids = set()
94    for g in group_assignment:
95      for i in g:
96        replica_ids.add(i)
97
98    if set(range(num_shards)) != replica_ids:
99      raise ValueError("group_assignment must be a permutation of range({0})."
100                       " Got group_assignment={1}".format(
101                           num_shards, group_assignment))
102
103    subgroup_size_list = [len(group) for group in group_assignment]
104    if all(subgroup_size_list[0] == size for size in subgroup_size_list):
105      return subgroup_size_list[0]
106    else:
107      raise ValueError("The size of each subgroup in group_assignment must "
108                       "be equal. Got group_assignment={}".format(
109                           self._group_assignment))
110
111  def compute_gradients(self, loss, var_list=None, **kwargs):
112    """Compute gradients of "loss" for the variables in "var_list".
113
114    This simply wraps `compute_gradients()` from the real optimizer. The
115    gradients will be aggregated in `apply_gradients()` so that user can
116    modify the gradients like clipping with per replica global norm if needed.
117    The global norm with aggregated gradients can be bad as one replica's huge
118    gradients can hurt the gradients from other replicas.
119
120    When the CrossShardOptimizer is constructed with
121    `reduction == losses.Reduction.MEAN` (default), this function scales the
122    loss by `1.0 / num_shards` before computing the gradients. Assuming the
123    optimizer uses the default implementation of `compute_gradients()`, the
124    gradients of the scaled loss are scaled by `1.0 / num_shards` compared to
125    the gradients of the original loss. This scaling factor is important because
126    `apply_gradients()` sums gradients across shards, rather than averaging
127    them. However, the scaling factor must be taken into account when clipping
128    the norm of the gradients or performing other postprocessing.
129
130    Args:
131      loss: A Tensor containing the value to minimize.
132      var_list: Optional list or tuple of `tf.Variable` to update to minimize
133        `loss`.  Defaults to the list of variables collected in the graph
134        under the key `GraphKey.TRAINABLE_VARIABLES`.
135      **kwargs: Keyword arguments for compute_gradients().
136
137    Returns:
138      A list of (gradient, variable) pairs.
139
140    Raises:
141      ValueError: If not within a tpu_shard_context or group_assignment is
142        invalid.
143    """
144    num_shards = tpu_function.get_tpu_context().number_of_shards
145    if num_shards is None:
146      logging.warning(
147          "CrossShardOptimizer should be used within a tpu_shard_context, but "
148          "got unset number_of_shards. Assuming 1.")
149      num_shards = 1
150
151    subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
152                                                       num_shards)
153
154    if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
155      if self._group_assignment:
156        scale = 1.0 / subgroup_size
157      else:
158        scale = 1.0 / num_shards
159      loss *= scale
160
161    return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
162
163  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
164    """Apply gradients to variables.
165
166    Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
167    replicas, and then applies the real optimizer.
168
169    Args:
170      grads_and_vars: List of (gradient, variable) pairs as returned by
171        compute_gradients().
172      global_step: Optional Variable to increment by one after the
173        variables have been updated.
174      name: Optional name for the returned operation.  Default to the
175        name passed to the Optimizer constructor.
176
177    Returns:
178      An `Operation` that applies the gradients. If `global_step` was not None,
179      that operation also increments `global_step`.
180
181    Raises:
182      ValueError: If the grads_and_vars is malformed.
183    """
184    summed_grads_and_vars = []
185    for (grad, var) in grads_and_vars:
186      if grad is None:
187        summed_grads_and_vars.append((grad, var))
188      else:
189        with ops.colocate_with(grad):
190          summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
191              grad, self._group_assignment), var))
192    return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
193
194  def get_slot(self, *args, **kwargs):
195    """Return a slot named "name" created for "var" by the Optimizer.
196
197    This simply wraps the get_slot() from the actual optimizer.
198
199    Args:
200      *args: Arguments for get_slot().
201      **kwargs: Keyword arguments for get_slot().
202
203    Returns:
204      The `Variable` for the slot if it was created, `None` otherwise.
205    """
206    return self._opt.get_slot(*args, **kwargs)
207
208  def get_slot_names(self, *args, **kwargs):
209    """Return a list of the names of slots created by the `Optimizer`.
210
211    This simply wraps the get_slot_names() from the actual optimizer.
212
213    Args:
214      *args: Arguments for get_slot().
215      **kwargs: Keyword arguments for get_slot().
216
217    Returns:
218      A list of strings.
219    """
220    return self._opt.get_slot_names(*args, **kwargs)
221
222  def variables(self):
223    """Forwarding the variables from the underlying optimizer."""
224    return self._opt.variables()
225