1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Optimizer utilities.""" 16 17from tensorflow.python.distribute import central_storage_strategy 18from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 19from tensorflow.python.distribute import reduce_util as ds_reduce_util 20from tensorflow.python.ops import clip_ops 21from tensorflow.python.platform import tf_logging as logging 22 23 24def all_reduce_sum_gradients(grads_and_vars): 25 """Returns all-reduced gradients aggregated via summation. 26 27 Args: 28 grads_and_vars: List of (gradient, variable) pairs. 29 30 Returns: 31 List of (gradient, variable) pairs where gradients have been all-reduced. 32 """ 33 grads_and_vars = list(grads_and_vars) 34 filtered_grads_and_vars = filter_empty_gradients(grads_and_vars) 35 if filtered_grads_and_vars: 36 if strategy_supports_no_merge_call(): 37 grads = [pair[0] for pair in filtered_grads_and_vars] 38 reduced = distribute_ctx.get_strategy().extended._replica_ctx_all_reduce( # pylint: disable=protected-access 39 ds_reduce_util.ReduceOp.SUM, grads) 40 else: 41 # TODO(b/183257003): Remove this branch 42 reduced = distribute_ctx.get_replica_context().merge_call( 43 _all_reduce_sum_fn, args=(filtered_grads_and_vars,)) 44 else: 45 reduced = [] 46 # Copy 'reduced' but add None gradients back in 47 reduced_with_nones = [] 48 reduced_pos = 0 49 for g, v in grads_and_vars: 50 if g is None: 51 reduced_with_nones.append((None, v)) 52 else: 53 reduced_with_nones.append((reduced[reduced_pos], v)) 54 reduced_pos += 1 55 assert reduced_pos == len(reduced), "Failed to add all gradients" 56 return reduced_with_nones 57 58 59def filter_empty_gradients(grads_and_vars): 60 """Filter out `(grad, var)` pairs that have a gradient equal to `None`.""" 61 grads_and_vars = tuple(grads_and_vars) 62 if not grads_and_vars: 63 return grads_and_vars 64 65 filtered = [] 66 vars_with_empty_grads = [] 67 for grad, var in grads_and_vars: 68 if grad is None: 69 vars_with_empty_grads.append(var) 70 else: 71 filtered.append((grad, var)) 72 filtered = tuple(filtered) 73 74 if not filtered: 75 raise ValueError("No gradients provided for any variable: %s." % 76 ([v.name for _, v in grads_and_vars],)) 77 if vars_with_empty_grads: 78 logging.warning( 79 ("Gradients do not exist for variables %s when minimizing the loss."), 80 ([v.name for v in vars_with_empty_grads])) 81 return filtered 82 83 84def make_gradient_clipnorm_fn(clipnorm): 85 """Creates a gradient transformation function for clipping by norm.""" 86 if clipnorm is None: 87 return lambda grads_and_vars: grads_and_vars 88 89 def gradient_clipnorm_fn(grads_and_vars): 90 91 if isinstance(distribute_ctx.get_strategy(), 92 (central_storage_strategy.CentralStorageStrategy, 93 central_storage_strategy.CentralStorageStrategyV1)): 94 raise ValueError( 95 "`clipnorm` is not supported with `CenteralStorageStrategy`") 96 97 clipped_grads_and_vars = [ 98 (clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars 99 ] 100 return clipped_grads_and_vars 101 102 return gradient_clipnorm_fn 103 104 105def make_global_gradient_clipnorm_fn(clipnorm): 106 """Creates a gradient transformation function for clipping by norm.""" 107 if clipnorm is None: 108 return lambda grads_and_vars: grads_and_vars 109 110 def gradient_clipnorm_fn(grads_and_vars): 111 112 if isinstance(distribute_ctx.get_strategy(), 113 (central_storage_strategy.CentralStorageStrategy, 114 central_storage_strategy.CentralStorageStrategyV1)): 115 raise ValueError( 116 "`global_clipnorm` is not supported with `CenteralStorageStrategy`") 117 118 grads, variables = zip(*grads_and_vars) 119 clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm) 120 clipped_grads_and_vars = list(zip(clipped_grads, variables)) 121 return clipped_grads_and_vars 122 123 return gradient_clipnorm_fn 124 125 126def make_gradient_clipvalue_fn(clipvalue): 127 """Creates a gradient transformation function for clipping by value.""" 128 if clipvalue is None: 129 return lambda grads_and_vars: grads_and_vars 130 131 def gradient_clipvalue_fn(grads_and_vars): 132 133 if isinstance(distribute_ctx.get_strategy(), 134 (central_storage_strategy.CentralStorageStrategy, 135 central_storage_strategy.CentralStorageStrategyV1)): 136 raise ValueError( 137 "`clipvalue` is not supported with `CenteralStorageStrategy`") 138 139 clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue, 140 clipvalue), v) 141 for g, v in grads_and_vars] 142 return clipped_grads_and_vars 143 144 return gradient_clipvalue_fn 145 146 147def _all_reduce_sum_fn(distribution, grads_and_vars): 148 return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM, 149 grads_and_vars) 150 151 152def strategy_supports_no_merge_call(): 153 """Returns if the current Strategy can operate in pure replica context.""" 154 if not distribute_ctx.has_strategy(): 155 return True 156 strategy = distribute_ctx.get_strategy() 157 return not strategy.extended._use_merge_call() # pylint: disable=protected-access 158