• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Operations for clipping (gradient, weight) tensors to min/max values."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import six
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_array_ops
30from tensorflow.python.ops import gen_nn_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import numerics
33from tensorflow.python.util import deprecation
34from tensorflow.python.util import dispatch
35from tensorflow.python.util.tf_export import tf_export
36
37
38@tf_export("clip_by_value")
39@dispatch.add_dispatch_support
40def clip_by_value(t, clip_value_min, clip_value_max,
41                  name=None):
42  """Clips tensor values to a specified min and max.
43
44  Given a tensor `t`, this operation returns a tensor of the same type and
45  shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
46  Any values less than `clip_value_min` are set to `clip_value_min`. Any values
47  greater than `clip_value_max` are set to `clip_value_max`.
48
49  Note: `clip_value_min` needs to be smaller or equal to `clip_value_max` for
50  correct results.
51
52  Args:
53    t: A `Tensor`.
54    clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
55      as `t`. The minimum value to clip by.
56    clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
57      as `t`. The maximum value to clip by.
58    name: A name for the operation (optional).
59
60  Returns:
61    A clipped `Tensor`.
62
63  Raises:
64    ValueError: If the clip tensors would trigger array broadcasting
65      that would make the returned tensor larger than the input.
66  """
67  with ops.name_scope(name, "clip_by_value",
68                      [t, clip_value_min, clip_value_max]) as name:
69    t = ops.convert_to_tensor(t, name="t")
70
71    # Go through list of tensors, for each value in each tensor clip
72    t_min = math_ops.minimum(t, clip_value_max)
73    # Assert that the shape is compatible with the initial shape,
74    # to prevent unintentional broadcasting.
75    _ = t.shape.merge_with(t_min.shape)
76
77    t_max = math_ops.maximum(t_min, clip_value_min, name=name)
78    _ = t.shape.merge_with(t_max.shape)
79
80  return t_max
81  # TODO(scottzhu): switch to use new implmentation in 2 weeks.
82  # return gen_math_ops.clip_by_value(
83  #     t, clip_value_min, clip_value_max, name=name)
84
85
86# TODO(scottzhu): switch to use new implmentation in 2 weeks.
87# @ops.RegisterGradient("ClipByValue")
88def _clip_by_value_grad(op, grad):
89  """Returns grad of clip_by_value."""
90  x = op.inputs[0]
91  y = op.inputs[1]
92  z = op.inputs[2]
93  gdtype = grad.dtype
94  sx = array_ops.shape(x)
95  sy = array_ops.shape(y)
96  sz = array_ops.shape(z)
97  gradshape = array_ops.shape(grad)
98  zeros = array_ops.zeros(gradshape, gdtype)
99  xymask = math_ops.less(x, y)
100  xzmask = math_ops.greater(x, z)
101  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
102  rx, rz = gen_array_ops.broadcast_gradient_args(sx, sz)
103  xgrad = array_ops.where(math_ops.logical_or(xymask, xzmask), zeros, grad)
104  ygrad = array_ops.where(xymask, grad, zeros)
105  zgrad = array_ops.where(xzmask, grad, zeros)
106  gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
107  gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
108  gz = array_ops.reshape(math_ops.reduce_sum(zgrad, rz), sz)
109  return (gx, gy, gz)
110
111
112@tf_export("clip_by_norm")
113def clip_by_norm(t, clip_norm, axes=None, name=None):
114  """Clips tensor values to a maximum L2-norm.
115
116  Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
117  normalizes `t` so that its L2-norm is less than or equal to `clip_norm`,
118  along the dimensions given in `axes`. Specifically, in the default case
119  where all dimensions are used for calculation, if the L2-norm of `t` is
120  already less than or equal to `clip_norm`, then `t` is not modified. If
121  the L2-norm is greater than `clip_norm`, then this operation returns a
122  tensor of the same type and shape as `t` with its values set to:
123
124  `t * clip_norm / l2norm(t)`
125
126  In this case, the L2-norm of the output tensor is `clip_norm`.
127
128  As another example, if `t` is a matrix and `axes == [1]`, then each row
129  of the output will have L2-norm less than or equal to `clip_norm`. If
130  `axes == [0]` instead, each column of the output will be clipped.
131
132  This operation is typically used to clip gradients before applying them with
133  an optimizer.
134
135  Args:
136    t: A `Tensor` or `IndexedSlices`.
137    clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
138    axes: A 1-D (vector) `Tensor` of type int32 containing the dimensions
139      to use for computing the L2-norm. If `None` (the default), uses all
140      dimensions.
141    name: A name for the operation (optional).
142
143  Returns:
144    A clipped `Tensor` or `IndexedSlices`.
145  """
146  with ops.name_scope(name, "clip_by_norm", [t, clip_norm]) as name:
147    values = ops.convert_to_tensor(
148        t.values if isinstance(t, ops.IndexedSlices) else t, name="t")
149
150    # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
151    l2sum = math_ops.reduce_sum(values * values, axes, keepdims=True)
152    pred = l2sum > 0
153    # Two-tap tf.where trick to bypass NaN gradients
154    l2sum_safe = array_ops.where(pred, l2sum, array_ops.ones_like(l2sum))
155    l2norm = array_ops.where(pred, math_ops.sqrt(l2sum_safe), l2sum)
156    intermediate = values * clip_norm
157    # Assert that the shape is compatible with the initial shape,
158    # to prevent unintentional broadcasting.
159    _ = values.shape.merge_with(intermediate.shape)
160    values_clip = array_ops.identity(
161        intermediate / math_ops.maximum(l2norm, clip_norm), name=name)
162
163    if isinstance(t, ops.IndexedSlices):
164      return ops.IndexedSlices(values_clip, t.indices, t.dense_shape)
165
166    return values_clip
167
168
169@tf_export("linalg.global_norm", v1=["linalg.global_norm", "global_norm"])
170@deprecation.deprecated_endpoints("global_norm")
171def global_norm(t_list, name=None):
172  """Computes the global norm of multiple tensors.
173
174  Given a tuple or list of tensors `t_list`, this operation returns the
175  global norm of the elements in all tensors in `t_list`. The global norm is
176  computed as:
177
178  `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
179
180  Any entries in `t_list` that are of type None are ignored.
181
182  Args:
183    t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
184    name: A name for the operation (optional).
185
186  Returns:
187    A 0-D (scalar) `Tensor` of type `float`.
188
189  Raises:
190    TypeError: If `t_list` is not a sequence.
191  """
192  if (not isinstance(t_list, collections.Sequence)
193      or isinstance(t_list, six.string_types)):
194    raise TypeError("t_list should be a sequence")
195  t_list = list(t_list)
196  with ops.name_scope(name, "global_norm", t_list) as name:
197    values = [
198        ops.convert_to_tensor(
199            t.values if isinstance(t, ops.IndexedSlices) else t,
200            name="t_%d" % i)
201        if t is not None else t
202        for i, t in enumerate(t_list)]
203    half_squared_norms = []
204    for v in values:
205      if v is not None:
206        with ops.colocate_with(v):
207          half_squared_norms.append(gen_nn_ops.l2_loss(v))
208
209    half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms))
210
211    norm = math_ops.sqrt(
212        half_squared_norm *
213        constant_op.constant(2.0, dtype=half_squared_norm.dtype),
214        name="global_norm")
215
216  return norm
217
218
219@tf_export("clip_by_global_norm")
220def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
221  """Clips values of multiple tensors by the ratio of the sum of their norms.
222
223  Given a tuple or list of tensors `t_list`, and a clipping ratio `clip_norm`,
224  this operation returns a list of clipped tensors `list_clipped`
225  and the global norm (`global_norm`) of all tensors in `t_list`. Optionally,
226  if you've already computed the global norm for `t_list`, you can specify
227  the global norm with `use_norm`.
228
229  To perform the clipping, the values `t_list[i]` are set to:
230
231      t_list[i] * clip_norm / max(global_norm, clip_norm)
232
233  where:
234
235      global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))
236
237  If `clip_norm > global_norm` then the entries in `t_list` remain as they are,
238  otherwise they're all shrunk by the global ratio.
239
240  Any of the entries of `t_list` that are of type `None` are ignored.
241
242  This is the correct way to perform gradient clipping (for example, see
243  [Pascanu et al., 2012](http://arxiv.org/abs/1211.5063)
244  ([pdf](http://arxiv.org/pdf/1211.5063.pdf))).
245
246  However, it is slower than `clip_by_norm()` because all the parameters must be
247  ready before the clipping operation can be performed.
248
249  Args:
250    t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
251    clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio.
252    use_norm: A 0-D (scalar) `Tensor` of type `float` (optional). The global
253      norm to use. If not provided, `global_norm()` is used to compute the norm.
254    name: A name for the operation (optional).
255
256  Returns:
257    list_clipped: A list of `Tensors` of the same type as `list_t`.
258    global_norm: A 0-D (scalar) `Tensor` representing the global norm.
259
260  Raises:
261    TypeError: If `t_list` is not a sequence.
262    InvalidArgumentError: If global norm is not finite.
263  """
264  if (not isinstance(t_list, collections.Sequence)
265      or isinstance(t_list, six.string_types)):
266    raise TypeError("t_list should be a sequence")
267  t_list = list(t_list)
268  if use_norm is None:
269    use_norm = global_norm(t_list, name)
270  use_norm = numerics.verify_tensor_all_finite(use_norm,
271                                               "Found Inf or NaN global norm.")
272
273  with ops.name_scope(name, "clip_by_global_norm",
274                      t_list + [clip_norm]) as name:
275    # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
276    scale = clip_norm * math_ops.minimum(
277        1.0 / use_norm,
278        constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)
279
280    values = [
281        ops.convert_to_tensor(
282            t.values if isinstance(t, ops.IndexedSlices) else t,
283            name="t_%d" % i)
284        if t is not None else t
285        for i, t in enumerate(t_list)]
286
287    values_clipped = []
288    for i, v in enumerate(values):
289      if v is None:
290        values_clipped.append(None)
291      else:
292        with ops.colocate_with(v):
293          values_clipped.append(
294              array_ops.identity(v * scale, name="%s_%d" % (name, i)))
295
296    list_clipped = [
297        ops.IndexedSlices(c_v, t.indices, t.dense_shape)
298        if isinstance(t, ops.IndexedSlices)
299        else c_v
300        for (c_v, t) in zip(values_clipped, t_list)]
301
302  return list_clipped, use_norm
303
304
305@deprecation.deprecated(
306    date=None,
307    instructions="clip_by_average_norm is deprecated in TensorFlow 2.0. Please "
308    "use clip_by_norm(t, clip_norm * tf.cast(tf.size(t), tf.float32), name) "
309    "instead.")
310@tf_export(v1=["clip_by_average_norm"])
311def clip_by_average_norm(t, clip_norm, name=None):
312  """Clips tensor values to a maximum average L2-norm.
313
314  Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
315  normalizes `t` so that its average L2-norm is less than or equal to
316  `clip_norm`. Specifically, if the average L2-norm is already less than or
317  equal to `clip_norm`, then `t` is not modified. If the average L2-norm is
318  greater than `clip_norm`, then this operation returns a tensor of the same
319  type and shape as `t` with its values set to:
320
321  `t * clip_norm / l2norm_avg(t)`
322
323  In this case, the average L2-norm of the output tensor is `clip_norm`.
324
325  This operation is typically used to clip gradients before applying them with
326  an optimizer.
327
328  Args:
329    t: A `Tensor`.
330    clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
331    name: A name for the operation (optional).
332
333  Returns:
334    A clipped `Tensor`.
335  """
336  with ops.name_scope(name, "clip_by_average_norm", [t, clip_norm]) as name:
337    t = ops.convert_to_tensor(t, name="t")
338
339    # Calculate L2-norm per element, clip elements by ratio of clip_norm to
340    # L2-norm per element
341    n_element = math_ops.cast(array_ops.size(t), dtypes.float32)
342    l2norm_inv = math_ops.rsqrt(
343        math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
344    tclip = array_ops.identity(
345        t * clip_norm * math_ops.minimum(
346            l2norm_inv * n_element, constant_op.constant(1.0) / clip_norm),
347        name=name)
348
349  return tclip
350