1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Optimizer utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import central_storage_strategy 22from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 23from tensorflow.python.distribute import reduce_util as ds_reduce_util 24from tensorflow.python.ops import clip_ops 25from tensorflow.python.platform import tf_logging as logging 26 27 28def all_reduce_sum_gradients(grads_and_vars): 29 """Returns all-reduced gradients aggregated via summation. 30 31 Args: 32 grads_and_vars: List of (gradient, variable) pairs. 33 34 Returns: 35 List of (gradient, variable) pairs where gradients have been all-reduced. 36 """ 37 grads_and_vars = list(grads_and_vars) 38 filtered_grads_and_vars = filter_empty_gradients(grads_and_vars) 39 # We switch to a cross-replica context since there is a bug which causes 40 # IndexedSlices to be converted to dense tensors when all-reduced in a 41 # replica context. 42 # TODO(b/150507409): Do not switch to a cross-replica context once the bug 43 # is fixed. 44 if filtered_grads_and_vars: 45 reduced = distribute_ctx.get_replica_context().merge_call( 46 _all_reduce_sum_fn, args=(filtered_grads_and_vars,)) 47 else: 48 reduced = [] 49 # Copy 'reduced' but add None gradients back in 50 reduced_with_nones = [] 51 reduced_pos = 0 52 for g, v in grads_and_vars: 53 if g is None: 54 reduced_with_nones.append((None, v)) 55 else: 56 reduced_with_nones.append((reduced[reduced_pos], v)) 57 reduced_pos += 1 58 assert reduced_pos == len(reduced), "Failed to add all gradients" 59 return reduced_with_nones 60 61 62def filter_empty_gradients(grads_and_vars): 63 """Filter out `(grad, var)` pairs that have a gradient equal to `None`.""" 64 grads_and_vars = tuple(grads_and_vars) 65 if not grads_and_vars: 66 return grads_and_vars 67 68 filtered = [] 69 vars_with_empty_grads = [] 70 for grad, var in grads_and_vars: 71 if grad is None: 72 vars_with_empty_grads.append(var) 73 else: 74 filtered.append((grad, var)) 75 filtered = tuple(filtered) 76 77 if not filtered: 78 raise ValueError("No gradients provided for any variable: %s." % 79 ([v.name for _, v in grads_and_vars],)) 80 if vars_with_empty_grads: 81 logging.warning( 82 ("Gradients do not exist for variables %s when minimizing the loss."), 83 ([v.name for v in vars_with_empty_grads])) 84 return filtered 85 86 87def make_gradient_clipnorm_fn(clipnorm): 88 """Creates a gradient transformation function for clipping by norm.""" 89 if clipnorm is None: 90 return lambda grads_and_vars: grads_and_vars 91 92 def gradient_clipnorm_fn(grads_and_vars): 93 94 if isinstance(distribute_ctx.get_strategy(), 95 (central_storage_strategy.CentralStorageStrategy, 96 central_storage_strategy.CentralStorageStrategyV1)): 97 raise ValueError( 98 "`clipnorm` is not supported with `CenteralStorageStrategy`") 99 100 clipped_grads_and_vars = [ 101 (clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars 102 ] 103 return clipped_grads_and_vars 104 105 return gradient_clipnorm_fn 106 107 108def make_global_gradient_clipnorm_fn(clipnorm): 109 """Creates a gradient transformation function for clipping by norm.""" 110 if clipnorm is None: 111 return lambda grads_and_vars: grads_and_vars 112 113 def gradient_clipnorm_fn(grads_and_vars): 114 115 if isinstance(distribute_ctx.get_strategy(), 116 (central_storage_strategy.CentralStorageStrategy, 117 central_storage_strategy.CentralStorageStrategyV1)): 118 raise ValueError( 119 "`global_clipnorm` is not supported with `CenteralStorageStrategy`") 120 121 grads, variables = zip(*grads_and_vars) 122 clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm) 123 clipped_grads_and_vars = list(zip(clipped_grads, variables)) 124 return clipped_grads_and_vars 125 126 return gradient_clipnorm_fn 127 128 129def make_gradient_clipvalue_fn(clipvalue): 130 """Creates a gradient transformation function for clipping by value.""" 131 if clipvalue is None: 132 return lambda grads_and_vars: grads_and_vars 133 134 def gradient_clipvalue_fn(grads_and_vars): 135 136 if isinstance(distribute_ctx.get_strategy(), 137 (central_storage_strategy.CentralStorageStrategy, 138 central_storage_strategy.CentralStorageStrategyV1)): 139 raise ValueError( 140 "`clipvalue` is not supported with `CenteralStorageStrategy`") 141 142 clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue, 143 clipvalue), v) 144 for g, v in grads_and_vars] 145 return clipped_grads_and_vars 146 147 return gradient_clipvalue_fn 148 149 150def _all_reduce_sum_fn(distribution, grads_and_vars): 151 return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM, 152 grads_and_vars) 153