1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Optimizer that implements cross-shard gradient reduction for TPU.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23from tensorflow.python.framework import ops 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.ops.losses import losses 26from tensorflow.python.platform import tf_logging as logging 27from tensorflow.python.tpu import tpu_function 28from tensorflow.python.tpu.ops import tpu_ops 29from tensorflow.python.training import optimizer 30from tensorflow.python.util.tf_export import tf_export 31 32 33@tf_export(v1=["tpu.CrossShardOptimizer"]) 34class CrossShardOptimizer(optimizer.Optimizer): 35 """An optimizer that averages gradients across TPU shards.""" 36 37 def __init__(self, 38 opt, 39 reduction=losses.Reduction.MEAN, 40 name="CrossShardOptimizer", 41 group_assignment=None): 42 """Construct a new cross-shard optimizer. 43 44 Args: 45 opt: An existing `Optimizer` to encapsulate. 46 reduction: The reduction to apply to the shard losses. 47 name: Optional name prefix for the operations created when applying 48 gradients. Defaults to "CrossShardOptimizer". 49 group_assignment: Optional 2d int32 lists with shape 50 [num_groups, num_replicas_per_group] which describles how to apply 51 optimizer to subgroups. 52 53 Raises: 54 ValueError: If reduction is not a valid cross-shard reduction. 55 """ 56 if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN): 57 raise ValueError("Unsupported reduction: %s." % reduction) 58 if isinstance(opt, optimizer_v2.OptimizerV2): 59 raise TypeError( 60 "CrossShardOptimizer does not work with OptimizerV2. If you are " 61 "using TPUStrategy, OptimizerV2 will sum gradients across replicas." 62 "If you are using TPUEstimator, you may instead sum your gradients " 63 "with: grads = [tf.compat.v1.tpu.cross_replica_sum(g) for g in grads]" 64 ". If you want to average your gradients, rescale your loss with: " 65 "loss /= global_batch_size") 66 67 super(CrossShardOptimizer, self).__init__(False, name) 68 self._opt = opt 69 self._reduction = reduction 70 self._group_assignment = group_assignment 71 72 def _verify_and_get_subgroup_size(self, group_assignment, num_shards): 73 """Verify group_assignment and get the subgroup size". 74 75 Args: 76 group_assignment: list of group ids for applying the optimizer 77 to subgroups. 78 num_shards: The number of TPU shards. 79 80 Returns: 81 The size of one subgroup in group_assignment. 82 83 Raises: 84 ValueError: If group_assignment is invalid. 85 """ 86 if not group_assignment: 87 return None 88 if not (isinstance(group_assignment, list) and 89 all(isinstance(i, list) for i in group_assignment)): 90 raise ValueError("group_assignment must be a list of list. Got {}".format( 91 group_assignment)) 92 93 replica_ids = set() 94 for g in group_assignment: 95 for i in g: 96 replica_ids.add(i) 97 98 if set(range(num_shards)) != replica_ids: 99 raise ValueError("group_assignment must be a permutation of range({0})." 100 " Got group_assignment={1}".format( 101 num_shards, group_assignment)) 102 103 subgroup_size_list = [len(group) for group in group_assignment] 104 if all(subgroup_size_list[0] == size for size in subgroup_size_list): 105 return subgroup_size_list[0] 106 else: 107 raise ValueError("The size of each subgroup in group_assignment must " 108 "be equal. Got group_assignment={}".format( 109 self._group_assignment)) 110 111 def compute_gradients(self, loss, var_list=None, **kwargs): 112 """Compute gradients of "loss" for the variables in "var_list". 113 114 This simply wraps `compute_gradients()` from the real optimizer. The 115 gradients will be aggregated in `apply_gradients()` so that user can 116 modify the gradients like clipping with per replica global norm if needed. 117 The global norm with aggregated gradients can be bad as one replica's huge 118 gradients can hurt the gradients from other replicas. 119 120 When the CrossShardOptimizer is constructed with 121 `reduction == losses.Reduction.MEAN` (default), this function scales the 122 loss by `1.0 / num_shards` before computing the gradients. Assuming the 123 optimizer uses the default implementation of `compute_gradients()`, the 124 gradients of the scaled loss are scaled by `1.0 / num_shards` compared to 125 the gradients of the original loss. This scaling factor is important because 126 `apply_gradients()` sums gradients across shards, rather than averaging 127 them. However, the scaling factor must be taken into account when clipping 128 the norm of the gradients or performing other postprocessing. 129 130 Args: 131 loss: A Tensor containing the value to minimize. 132 var_list: Optional list or tuple of `tf.Variable` to update to minimize 133 `loss`. Defaults to the list of variables collected in the graph 134 under the key `GraphKey.TRAINABLE_VARIABLES`. 135 **kwargs: Keyword arguments for compute_gradients(). 136 137 Returns: 138 A list of (gradient, variable) pairs. 139 140 Raises: 141 ValueError: If not within a tpu_shard_context or group_assignment is 142 invalid. 143 """ 144 num_shards = tpu_function.get_tpu_context().number_of_shards 145 if num_shards is None: 146 logging.warning( 147 "CrossShardOptimizer should be used within a tpu_shard_context, but " 148 "got unset number_of_shards. Assuming 1.") 149 num_shards = 1 150 151 subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, 152 num_shards) 153 154 if num_shards > 1 and self._reduction == losses.Reduction.MEAN: 155 if self._group_assignment: 156 scale = 1.0 / subgroup_size 157 else: 158 scale = 1.0 / num_shards 159 loss *= scale 160 161 return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) 162 163 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 164 """Apply gradients to variables. 165 166 Calls tpu_ops.cross_replica_sum() to sum gradient contributions across 167 replicas, and then applies the real optimizer. 168 169 Args: 170 grads_and_vars: List of (gradient, variable) pairs as returned by 171 compute_gradients(). 172 global_step: Optional Variable to increment by one after the 173 variables have been updated. 174 name: Optional name for the returned operation. Default to the 175 name passed to the Optimizer constructor. 176 177 Returns: 178 An `Operation` that applies the gradients. If `global_step` was not None, 179 that operation also increments `global_step`. 180 181 Raises: 182 ValueError: If the grads_and_vars is malformed. 183 """ 184 summed_grads_and_vars = [] 185 for (grad, var) in grads_and_vars: 186 if grad is None: 187 summed_grads_and_vars.append((grad, var)) 188 else: 189 with ops.colocate_with(grad): 190 summed_grads_and_vars.append((tpu_ops.cross_replica_sum( 191 grad, self._group_assignment), var)) 192 return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) 193 194 def get_slot(self, *args, **kwargs): 195 """Return a slot named "name" created for "var" by the Optimizer. 196 197 This simply wraps the get_slot() from the actual optimizer. 198 199 Args: 200 *args: Arguments for get_slot(). 201 **kwargs: Keyword arguments for get_slot(). 202 203 Returns: 204 The `Variable` for the slot if it was created, `None` otherwise. 205 """ 206 return self._opt.get_slot(*args, **kwargs) 207 208 def get_slot_names(self, *args, **kwargs): 209 """Return a list of the names of slots created by the `Optimizer`. 210 211 This simply wraps the get_slot_names() from the actual optimizer. 212 213 Args: 214 *args: Arguments for get_slot(). 215 **kwargs: Keyword arguments for get_slot(). 216 217 Returns: 218 A list of strings. 219 """ 220 return self._opt.get_slot_names(*args, **kwargs) 221 222 def variables(self): 223 """Forwarding the variables from the underlying optimizer.""" 224 return self._opt.variables() 225