• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Optimizer utilities."""
16
17from tensorflow.python.distribute import central_storage_strategy
18from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
19from tensorflow.python.distribute import reduce_util as ds_reduce_util
20from tensorflow.python.ops import clip_ops
21from tensorflow.python.platform import tf_logging as logging
22
23
24def all_reduce_sum_gradients(grads_and_vars):
25  """Returns all-reduced gradients aggregated via summation.
26
27  Args:
28    grads_and_vars: List of (gradient, variable) pairs.
29
30  Returns:
31    List of (gradient, variable) pairs where gradients have been all-reduced.
32  """
33  grads_and_vars = list(grads_and_vars)
34  filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
35  if filtered_grads_and_vars:
36    if strategy_supports_no_merge_call():
37      grads = [pair[0] for pair in filtered_grads_and_vars]
38      reduced = distribute_ctx.get_strategy().extended._replica_ctx_all_reduce(  # pylint: disable=protected-access
39          ds_reduce_util.ReduceOp.SUM, grads)
40    else:
41      # TODO(b/183257003): Remove this branch
42      reduced = distribute_ctx.get_replica_context().merge_call(
43          _all_reduce_sum_fn, args=(filtered_grads_and_vars,))
44  else:
45    reduced = []
46  # Copy 'reduced' but add None gradients back in
47  reduced_with_nones = []
48  reduced_pos = 0
49  for g, v in grads_and_vars:
50    if g is None:
51      reduced_with_nones.append((None, v))
52    else:
53      reduced_with_nones.append((reduced[reduced_pos], v))
54      reduced_pos += 1
55  assert reduced_pos == len(reduced), "Failed to add all gradients"
56  return reduced_with_nones
57
58
59def filter_empty_gradients(grads_and_vars):
60  """Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
61  grads_and_vars = tuple(grads_and_vars)
62  if not grads_and_vars:
63    return grads_and_vars
64
65  filtered = []
66  vars_with_empty_grads = []
67  for grad, var in grads_and_vars:
68    if grad is None:
69      vars_with_empty_grads.append(var)
70    else:
71      filtered.append((grad, var))
72  filtered = tuple(filtered)
73
74  if not filtered:
75    raise ValueError("No gradients provided for any variable: %s." %
76                     ([v.name for _, v in grads_and_vars],))
77  if vars_with_empty_grads:
78    logging.warning(
79        ("Gradients do not exist for variables %s when minimizing the loss."),
80        ([v.name for v in vars_with_empty_grads]))
81  return filtered
82
83
84def make_gradient_clipnorm_fn(clipnorm):
85  """Creates a gradient transformation function for clipping by norm."""
86  if clipnorm is None:
87    return lambda grads_and_vars: grads_and_vars
88
89  def gradient_clipnorm_fn(grads_and_vars):
90
91    if isinstance(distribute_ctx.get_strategy(),
92                  (central_storage_strategy.CentralStorageStrategy,
93                   central_storage_strategy.CentralStorageStrategyV1)):
94      raise ValueError(
95          "`clipnorm` is not supported with `CenteralStorageStrategy`")
96
97    clipped_grads_and_vars = [
98        (clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars
99    ]
100    return clipped_grads_and_vars
101
102  return gradient_clipnorm_fn
103
104
105def make_global_gradient_clipnorm_fn(clipnorm):
106  """Creates a gradient transformation function for clipping by norm."""
107  if clipnorm is None:
108    return lambda grads_and_vars: grads_and_vars
109
110  def gradient_clipnorm_fn(grads_and_vars):
111
112    if isinstance(distribute_ctx.get_strategy(),
113                  (central_storage_strategy.CentralStorageStrategy,
114                   central_storage_strategy.CentralStorageStrategyV1)):
115      raise ValueError(
116          "`global_clipnorm` is not supported with `CenteralStorageStrategy`")
117
118    grads, variables = zip(*grads_and_vars)
119    clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm)
120    clipped_grads_and_vars = list(zip(clipped_grads, variables))
121    return clipped_grads_and_vars
122
123  return gradient_clipnorm_fn
124
125
126def make_gradient_clipvalue_fn(clipvalue):
127  """Creates a gradient transformation function for clipping by value."""
128  if clipvalue is None:
129    return lambda grads_and_vars: grads_and_vars
130
131  def gradient_clipvalue_fn(grads_and_vars):
132
133    if isinstance(distribute_ctx.get_strategy(),
134                  (central_storage_strategy.CentralStorageStrategy,
135                   central_storage_strategy.CentralStorageStrategyV1)):
136      raise ValueError(
137          "`clipvalue` is not supported with `CenteralStorageStrategy`")
138
139    clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue,
140                                                      clipvalue), v)
141                              for g, v in grads_and_vars]
142    return clipped_grads_and_vars
143
144  return gradient_clipvalue_fn
145
146
147def _all_reduce_sum_fn(distribution, grads_and_vars):
148  return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM,
149                                               grads_and_vars)
150
151
152def strategy_supports_no_merge_call():
153  """Returns if the current Strategy can operate in pure replica context."""
154  if not distribute_ctx.has_strategy():
155    return True
156  strategy = distribute_ctx.get_strategy()
157  return not strategy.extended._use_merge_call()  # pylint: disable=protected-access
158