1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Operations for clipping (gradient, weight) tensors to min/max values.""" 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import indexed_slices 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import gen_array_ops 23from tensorflow.python.ops import gen_nn_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.util import deprecation 26from tensorflow.python.util import dispatch 27from tensorflow.python.util.compat import collections_abc 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export("clip_by_value") 32@dispatch.register_unary_elementwise_api 33@dispatch.add_dispatch_support 34def clip_by_value(t, clip_value_min, clip_value_max, 35 name=None): 36 """Clips tensor values to a specified min and max. 37 38 Given a tensor `t`, this operation returns a tensor of the same type and 39 shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. 40 Any values less than `clip_value_min` are set to `clip_value_min`. Any values 41 greater than `clip_value_max` are set to `clip_value_max`. 42 43 Note: `clip_value_min` needs to be smaller or equal to `clip_value_max` for 44 correct results. 45 46 For example: 47 48 Basic usage passes a scalar as the min and max value. 49 50 >>> t = tf.constant([[-10., -1., 0.], [0., 2., 10.]]) 51 >>> t2 = tf.clip_by_value(t, clip_value_min=-1, clip_value_max=1) 52 >>> t2.numpy() 53 array([[-1., -1., 0.], 54 [ 0., 1., 1.]], dtype=float32) 55 56 The min and max can be the same size as `t`, or broadcastable to that size. 57 58 >>> t = tf.constant([[-1, 0., 10.], [-1, 0, 10]]) 59 >>> clip_min = [[2],[1]] 60 >>> t3 = tf.clip_by_value(t, clip_value_min=clip_min, clip_value_max=100) 61 >>> t3.numpy() 62 array([[ 2., 2., 10.], 63 [ 1., 1., 10.]], dtype=float32) 64 65 Broadcasting fails, intentionally, if you would expand the dimensions of `t` 66 67 >>> t = tf.constant([[-1, 0., 10.], [-1, 0, 10]]) 68 >>> clip_min = [[[2, 1]]] # Has a third axis 69 >>> t4 = tf.clip_by_value(t, clip_value_min=clip_min, clip_value_max=100) 70 Traceback (most recent call last): 71 ... 72 InvalidArgumentError: Incompatible shapes: [2,3] vs. [1,1,2] 73 74 It throws a `TypeError` if you try to clip an `int` to a `float` value 75 (`tf.cast` the input to `float` first). 76 77 >>> t = tf.constant([[1, 2], [3, 4]], dtype=tf.int32) 78 >>> t5 = tf.clip_by_value(t, clip_value_min=-3.1, clip_value_max=3.1) 79 Traceback (most recent call last): 80 ... 81 TypeError: Cannot convert ... 82 83 84 Args: 85 t: A `Tensor` or `IndexedSlices`. 86 clip_value_min: The minimum value to clip to. A scalar `Tensor` or one that 87 is broadcastable to the shape of `t`. 88 clip_value_max: The maximum value to clip to. A scalar `Tensor` or one that 89 is broadcastable to the shape of `t`. 90 name: A name for the operation (optional). 91 92 Returns: 93 A clipped `Tensor` or `IndexedSlices`. 94 95 Raises: 96 `tf.errors.InvalidArgumentError`: If the clip tensors would trigger array 97 broadcasting that would make the returned tensor larger than the input. 98 TypeError: If dtype of the input is `int32` and dtype of 99 the `clip_value_min` or `clip_value_max` is `float32` 100 """ 101 with ops.name_scope(name, "clip_by_value", 102 [t, clip_value_min, clip_value_max]) as name: 103 values = ops.convert_to_tensor( 104 t.values if isinstance(t, indexed_slices.IndexedSlices) else t, 105 name="t") 106 107 # Go through list of tensors, for each value in each tensor clip 108 t_min = math_ops.minimum(values, clip_value_max) 109 # Assert that the shape is compatible with the initial shape, 110 # to prevent unintentional broadcasting. 111 values.shape.assert_is_compatible_with(t_min.shape) 112 113 t_max = math_ops.maximum(t_min, clip_value_min, name=name) 114 values.shape.assert_is_compatible_with(t_max.shape) 115 116 if isinstance(t, indexed_slices.IndexedSlices): 117 t_max = indexed_slices.IndexedSlices(t_max, t.indices, t.dense_shape) 118 119 return t_max 120 # TODO(scottzhu): switch to use new implementation in 2 weeks. 121 # return gen_math_ops.clip_by_value( 122 # t, clip_value_min, clip_value_max, name=name) 123 124 125# TODO(scottzhu): switch to use new implementation in 2 weeks. 126# @ops.RegisterGradient("ClipByValue") 127def _clip_by_value_grad(op, grad): 128 """Returns grad of clip_by_value.""" 129 x = op.inputs[0] 130 y = op.inputs[1] 131 z = op.inputs[2] 132 gdtype = grad.dtype 133 sx = array_ops.shape(x) 134 sy = array_ops.shape(y) 135 sz = array_ops.shape(z) 136 gradshape = array_ops.shape(grad) 137 zeros = array_ops.zeros(gradshape, gdtype) 138 xymask = math_ops.less(x, y) 139 xzmask = math_ops.greater(x, z) 140 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 141 rx, rz = gen_array_ops.broadcast_gradient_args(sx, sz) 142 xgrad = array_ops.where(math_ops.logical_or(xymask, xzmask), zeros, grad) 143 ygrad = array_ops.where(xymask, grad, zeros) 144 zgrad = array_ops.where(xzmask, grad, zeros) 145 gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) 146 gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) 147 gz = array_ops.reshape(math_ops.reduce_sum(zgrad, rz), sz) 148 return (gx, gy, gz) 149 150 151@tf_export("clip_by_norm") 152@dispatch.add_dispatch_support 153def clip_by_norm(t, clip_norm, axes=None, name=None): 154 """Clips tensor values to a maximum L2-norm. 155 156 Given a tensor `t`, and a maximum clip value `clip_norm`, this operation 157 normalizes `t` so that its L2-norm is less than or equal to `clip_norm`, 158 along the dimensions given in `axes`. Specifically, in the default case 159 where all dimensions are used for calculation, if the L2-norm of `t` is 160 already less than or equal to `clip_norm`, then `t` is not modified. If 161 the L2-norm is greater than `clip_norm`, then this operation returns a 162 tensor of the same type and shape as `t` with its values set to: 163 164 `t * clip_norm / l2norm(t)` 165 166 In this case, the L2-norm of the output tensor is `clip_norm`. 167 168 As another example, if `t` is a matrix and `axes == [1]`, then each row 169 of the output will have L2-norm less than or equal to `clip_norm`. If 170 `axes == [0]` instead, each column of the output will be clipped. 171 172 Code example: 173 174 >>> some_nums = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.float32) 175 >>> tf.clip_by_norm(some_nums, 2.0).numpy() 176 array([[0.26967996, 0.5393599 , 0.80903983, 1.0787199 , 1.3483998 ]], 177 dtype=float32) 178 179 This operation is typically used to clip gradients before applying them with 180 an optimizer. Most gradient data is a collection of different shaped tensors 181 for different parts of the model. Thus, this is a common usage: 182 183 ``` 184 # Get your gradients after training 185 loss_value, grads = grad(model, features, labels) 186 187 # Apply some clipping 188 grads = [tf.clip_by_norm(g, norm) 189 for g in grads] 190 191 # Continue on with training 192 optimizer.apply_gradients(grads) 193 ``` 194 195 Args: 196 t: A `Tensor` or `IndexedSlices`. This must be a floating point type. 197 clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value, also 198 floating point 199 axes: A 1-D (vector) `Tensor` of type int32 containing the dimensions 200 to use for computing the L2-norm. If `None` (the default), uses all 201 dimensions. 202 name: A name for the operation (optional). 203 204 Returns: 205 A clipped `Tensor` or `IndexedSlices`. 206 207 Raises: 208 ValueError: If the clip_norm tensor is not a 0-D scalar tensor. 209 TypeError: If dtype of the input is not a floating point or 210 complex type. 211 """ 212 with ops.name_scope(name, "clip_by_norm", [t, clip_norm]) as name: 213 values = ops.convert_to_tensor( 214 t.values if isinstance(t, indexed_slices.IndexedSlices) else t, 215 name="t") 216 217 # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm 218 l2sum = math_ops.reduce_sum(values * values, axes, keepdims=True) 219 pred = l2sum > 0 220 # Two-tap tf.where trick to bypass NaN gradients 221 l2sum_safe = array_ops.where(pred, l2sum, array_ops.ones_like(l2sum)) 222 l2norm = array_ops.where(pred, math_ops.sqrt(l2sum_safe), l2sum) 223 intermediate = values * clip_norm 224 # Assert that the shape is compatible with the initial shape, 225 # to prevent unintentional broadcasting. 226 values.shape.assert_is_compatible_with(intermediate.shape) 227 values_clip = array_ops.identity( 228 intermediate / math_ops.maximum(l2norm, clip_norm), name=name) 229 230 if isinstance(t, indexed_slices.IndexedSlices): 231 return indexed_slices.IndexedSlices(values_clip, t.indices, t.dense_shape) 232 233 return values_clip 234 235 236@tf_export("linalg.global_norm", v1=["linalg.global_norm", "global_norm"]) 237@dispatch.add_dispatch_support 238@deprecation.deprecated_endpoints("global_norm") 239def global_norm(t_list, name=None): 240 """Computes the global norm of multiple tensors. 241 242 Given a tuple or list of tensors `t_list`, this operation returns the 243 global norm of the elements in all tensors in `t_list`. The global norm is 244 computed as: 245 246 `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))` 247 248 Any entries in `t_list` that are of type None are ignored. 249 250 Args: 251 t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None. 252 name: A name for the operation (optional). 253 254 Returns: 255 A 0-D (scalar) `Tensor` of type `float`. 256 257 Raises: 258 TypeError: If `t_list` is not a sequence. 259 """ 260 if (not isinstance(t_list, collections_abc.Sequence) or 261 isinstance(t_list, str)): 262 raise TypeError("`t_list` should be a sequence of tensors. Received " 263 f"{type(t_list)}.") 264 t_list = list(t_list) 265 with ops.name_scope(name, "global_norm", t_list) as name: 266 values = [ 267 ops.convert_to_tensor( 268 t.values if isinstance(t, indexed_slices.IndexedSlices) else t, 269 name="t_%d" % i) if t is not None else t 270 for i, t in enumerate(t_list) 271 ] 272 half_squared_norms = [] 273 for v in values: 274 if v is not None: 275 with ops.colocate_with(v): 276 half_squared_norms.append(gen_nn_ops.l2_loss(v)) 277 278 half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms)) 279 280 norm = math_ops.sqrt( 281 half_squared_norm * 282 constant_op.constant(2.0, dtype=half_squared_norm.dtype), 283 name="global_norm") 284 285 return norm 286 287 288@tf_export("clip_by_global_norm") 289@dispatch.add_dispatch_support 290def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): 291 """Clips values of multiple tensors by the ratio of the sum of their norms. 292 293 Given a tuple or list of tensors `t_list`, and a clipping ratio `clip_norm`, 294 this operation returns a list of clipped tensors `list_clipped` 295 and the global norm (`global_norm`) of all tensors in `t_list`. Optionally, 296 if you've already computed the global norm for `t_list`, you can specify 297 the global norm with `use_norm`. 298 299 To perform the clipping, the values `t_list[i]` are set to: 300 301 t_list[i] * clip_norm / max(global_norm, clip_norm) 302 303 where: 304 305 global_norm = sqrt(sum([l2norm(t)**2 for t in t_list])) 306 307 If `clip_norm > global_norm` then the entries in `t_list` remain as they are, 308 otherwise they're all shrunk by the global ratio. 309 310 If `global_norm == infinity` then the entries in `t_list` are all set to `NaN` 311 to signal that an error occurred. 312 313 Any of the entries of `t_list` that are of type `None` are ignored. 314 315 This is the correct way to perform gradient clipping (Pascanu et al., 2012). 316 317 However, it is slower than `clip_by_norm()` because all the parameters must be 318 ready before the clipping operation can be performed. 319 320 Args: 321 t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None. 322 clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio. 323 use_norm: A 0-D (scalar) `Tensor` of type `float` (optional). The global 324 norm to use. If not provided, `global_norm()` is used to compute the norm. 325 name: A name for the operation (optional). 326 327 Returns: 328 list_clipped: A list of `Tensors` of the same type as `list_t`. 329 global_norm: A 0-D (scalar) `Tensor` representing the global norm. 330 331 Raises: 332 TypeError: If `t_list` is not a sequence. 333 334 References: 335 On the difficulty of training Recurrent Neural Networks: 336 [Pascanu et al., 2012](http://proceedings.mlr.press/v28/pascanu13.html) 337 ([pdf](http://proceedings.mlr.press/v28/pascanu13.pdf)) 338 """ 339 if (not isinstance(t_list, collections_abc.Sequence) or 340 isinstance(t_list, str)): 341 raise TypeError("`t_list` should be a sequence of tensors. Received " 342 f"{type(t_list)}.") 343 t_list = list(t_list) 344 if use_norm is None: 345 use_norm = global_norm(t_list, name) 346 347 with ops.name_scope(name, "clip_by_global_norm", 348 t_list + [clip_norm]) as name: 349 # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm 350 scale_for_finite = clip_norm * math_ops.minimum( 351 1.0 / use_norm, 352 constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm) 353 # If use_norm is any finite number, this is a no-op. For inf/-inf/NaN, 354 # this will make scale NaN. 355 scale = scale_for_finite + (use_norm - use_norm) 356 357 values = [ 358 ops.convert_to_tensor( 359 t.values if isinstance(t, indexed_slices.IndexedSlices) else t, 360 name="t_%d" % i) if t is not None else t 361 for i, t in enumerate(t_list) 362 ] 363 364 values_clipped = [] 365 for i, v in enumerate(values): 366 if v is None: 367 values_clipped.append(None) 368 else: 369 with ops.colocate_with(v): 370 values_clipped.append( 371 array_ops.identity(v * scale, name="%s_%d" % (name, i))) 372 373 list_clipped = [ 374 indexed_slices.IndexedSlices(c_v, t.indices, t.dense_shape) 375 if isinstance(t, indexed_slices.IndexedSlices) else c_v 376 for (c_v, t) in zip(values_clipped, t_list) 377 ] 378 379 return list_clipped, use_norm 380 381 382@deprecation.deprecated( 383 date=None, 384 instructions="clip_by_average_norm is deprecated in TensorFlow 2.0. Please " 385 "use clip_by_norm(t, clip_norm * tf.cast(tf.size(t), tf.float32), name) " 386 "instead.") 387@tf_export(v1=["clip_by_average_norm"]) 388@dispatch.add_dispatch_support 389def clip_by_average_norm(t, clip_norm, name=None): 390 """Clips tensor values to a maximum average L2-norm. 391 392 Given a tensor `t`, and a maximum clip value `clip_norm`, this operation 393 normalizes `t` so that its average L2-norm is less than or equal to 394 `clip_norm`. Specifically, if the average L2-norm is already less than or 395 equal to `clip_norm`, then `t` is not modified. If the average L2-norm is 396 greater than `clip_norm`, then this operation returns a tensor of the same 397 type and shape as `t` with its values set to: 398 399 `t * clip_norm / l2norm_avg(t)` 400 401 In this case, the average L2-norm of the output tensor is `clip_norm`. 402 403 This operation is typically used to clip gradients before applying them with 404 an optimizer. 405 406 Args: 407 t: A `Tensor`. 408 clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value. 409 name: A name for the operation (optional). 410 411 Returns: 412 A clipped `Tensor`. 413 """ 414 with ops.name_scope(name, "clip_by_average_norm", [t, clip_norm]) as name: 415 t = ops.convert_to_tensor(t, name="t") 416 417 # Calculate L2-norm per element, clip elements by ratio of clip_norm to 418 # L2-norm per element 419 n_element = math_ops.cast(array_ops.size(t), dtypes.float32) 420 l2norm_inv = math_ops.rsqrt( 421 math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t)))) 422 tclip = array_ops.identity( 423 t * clip_norm * math_ops.minimum( 424 l2norm_inv * n_element, constant_op.constant(1.0) / clip_norm), 425 name=name) 426 427 return tclip 428