1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements the graph generation for computation of gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_grad # pylint: disable=unused-import 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import check_ops # pylint: disable=unused-import 26from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import gradients_util 29from tensorflow.python.ops import image_grad # pylint: disable=unused-import 30from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import 31from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import 32from tensorflow.python.ops import logging_ops # pylint: disable=unused-import 33from tensorflow.python.ops import manip_grad # pylint: disable=unused-import 34from tensorflow.python.ops import math_grad # pylint: disable=unused-import 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import optional_grad # pylint: disable=unused-import 37from tensorflow.python.ops import random_grad # pylint: disable=unused-import 38from tensorflow.python.ops import tensor_array_ops 39from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 40from tensorflow.python.util.tf_export import tf_export 41 42 43@tf_export(v1=["gradients"]) 44def gradients(ys, 45 xs, 46 grad_ys=None, 47 name="gradients", 48 colocate_gradients_with_ops=False, 49 gate_gradients=False, 50 aggregation_method=None, 51 stop_gradients=None, 52 unconnected_gradients=UnconnectedGradients.NONE): 53 """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. 54 55 `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` 56 is a list of `Tensor`, holding the gradients received by the 57 `ys`. The list must be the same length as `ys`. 58 59 `gradients()` adds ops to the graph to output the derivatives of `ys` with 60 respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where 61 each tensor is the `sum(dy/dx)` for y in `ys`. 62 63 `grad_ys` is a list of tensors of the same length as `ys` that holds 64 the initial gradients for each y in `ys`. When `grad_ys` is None, 65 we fill in a tensor of '1's of the shape of y for each y in `ys`. A 66 user can provide their own initial `grad_ys` to compute the 67 derivatives using a different initial gradient for each y (e.g., if 68 one wanted to weight the gradient differently for each value in 69 each y). 70 71 `stop_gradients` is a `Tensor` or a list of tensors to be considered constant 72 with respect to all `xs`. These tensors will not be backpropagated through, 73 as though they had been explicitly disconnected using `stop_gradient`. Among 74 other things, this allows computation of partial derivatives as opposed to 75 total derivatives. For example: 76 77 ```python 78 a = tf.constant(0.) 79 b = 2 * a 80 g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) 81 ``` 82 83 Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the 84 total derivatives `tf.gradients(a + b, [a, b])`, which take into account the 85 influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is 86 equivalent to: 87 88 ```python 89 a = tf.stop_gradient(tf.constant(0.)) 90 b = tf.stop_gradient(2 * a) 91 g = tf.gradients(a + b, [a, b]) 92 ``` 93 94 `stop_gradients` provides a way of stopping gradient after the graph has 95 already been constructed, as compared to `tf.stop_gradient` which is used 96 during graph construction. When the two approaches are combined, 97 backpropagation stops at both `tf.stop_gradient` nodes and nodes in 98 `stop_gradients`, whichever is encountered first. 99 100 All integer tensors are considered constant with respect to all `xs`, as if 101 they were included in `stop_gradients`. 102 103 `unconnected_gradients` determines the value returned for each x in xs if it 104 is unconnected in the graph to ys. By default this is None to safeguard 105 against errors. MAthematically these gradients are zero which can be requested 106 using the `'zero'` option. `tf.UnconnectedGradients` provides the 107 following options and behaviors: 108 109 ```python 110 a = tf.ones([1, 2]) 111 b = tf.ones([3, 1]) 112 g1 = tf.gradients([b], [a], unnconnected_gradients='none') 113 sess.run(g1) # [None] 114 115 g2 = tf.gradients([b], [a], unconnected_gradients='zero') 116 sess.run(g2) # [array([[0., 0.]], dtype=float32)] 117 ``` 118 119 120 Args: 121 ys: A `Tensor` or list of tensors to be differentiated. 122 xs: A `Tensor` or list of tensors to be used for differentiation. 123 grad_ys: Optional. A `Tensor` or list of tensors the same size as 124 `ys` and holding the gradients computed for each y in `ys`. 125 name: Optional name to use for grouping all the gradient ops together. 126 defaults to 'gradients'. 127 colocate_gradients_with_ops: If True, try colocating gradients with 128 the corresponding op. 129 gate_gradients: If True, add a tuple around the gradients returned 130 for an operations. This avoids some race conditions. 131 aggregation_method: Specifies the method used to combine gradient terms. 132 Accepted values are constants defined in the class `AggregationMethod`. 133 stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate 134 through. 135 unconnected_gradients: Optional. Specifies the gradient value returned when 136 the given input tensors are unconnected. Accepted values are constants 137 defined in the class `tf.UnconnectedGradients` and the default value is 138 `none`. 139 140 Returns: 141 A list of `sum(dy/dx)` for each x in `xs`. 142 143 Raises: 144 LookupError: if one of the operations between `x` and `y` does not 145 have a registered gradient function. 146 ValueError: if the arguments are invalid. 147 RuntimeError: if called in Eager mode. 148 149 """ 150 # Creating the gradient graph for control flow mutates Operations. 151 # _mutation_lock ensures a Session.run call cannot occur between creating and 152 # mutating new ops. 153 # pylint: disable=protected-access 154 with ops.get_default_graph()._mutation_lock(): 155 return gradients_util._GradientsHelper( 156 ys, xs, grad_ys, name, colocate_gradients_with_ops, 157 gate_gradients, aggregation_method, stop_gradients, 158 unconnected_gradients) 159 # pylint: enable=protected-access 160 161 162@tf_export("gradients", v1=[]) 163def gradients_v2(ys, # pylint: disable=invalid-name 164 xs, 165 grad_ys=None, 166 name="gradients", 167 gate_gradients=False, 168 aggregation_method=None, 169 stop_gradients=None, 170 unconnected_gradients=UnconnectedGradients.NONE): 171 """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. 172 173 `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` 174 is a list of `Tensor`, holding the gradients received by the 175 `ys`. The list must be the same length as `ys`. 176 177 `gradients()` adds ops to the graph to output the derivatives of `ys` with 178 respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where 179 each tensor is the `sum(dy/dx)` for y in `ys`. 180 181 `grad_ys` is a list of tensors of the same length as `ys` that holds 182 the initial gradients for each y in `ys`. When `grad_ys` is None, 183 we fill in a tensor of '1's of the shape of y for each y in `ys`. A 184 user can provide their own initial `grad_ys` to compute the 185 derivatives using a different initial gradient for each y (e.g., if 186 one wanted to weight the gradient differently for each value in 187 each y). 188 189 `stop_gradients` is a `Tensor` or a list of tensors to be considered constant 190 with respect to all `xs`. These tensors will not be backpropagated through, 191 as though they had been explicitly disconnected using `stop_gradient`. Among 192 other things, this allows computation of partial derivatives as opposed to 193 total derivatives. For example: 194 195 ```python 196 a = tf.constant(0.) 197 b = 2 * a 198 g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) 199 ``` 200 201 Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the 202 total derivatives `tf.gradients(a + b, [a, b])`, which take into account the 203 influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is 204 equivalent to: 205 206 ```python 207 a = tf.stop_gradient(tf.constant(0.)) 208 b = tf.stop_gradient(2 * a) 209 g = tf.gradients(a + b, [a, b]) 210 ``` 211 212 `stop_gradients` provides a way of stopping gradient after the graph has 213 already been constructed, as compared to `tf.stop_gradient` which is used 214 during graph construction. When the two approaches are combined, 215 backpropagation stops at both `tf.stop_gradient` nodes and nodes in 216 `stop_gradients`, whichever is encountered first. 217 218 All integer tensors are considered constant with respect to all `xs`, as if 219 they were included in `stop_gradients`. 220 221 `unconnected_gradients` determines the value returned for each x in xs if it 222 is unconnected in the graph to ys. By default this is None to safeguard 223 against errors. Mathematically these gradients are zero which can be requested 224 using the `'zero'` option. `tf.UnconnectedGradients` provides the 225 following options and behaviors: 226 227 ```python 228 a = tf.ones([1, 2]) 229 b = tf.ones([3, 1]) 230 g1 = tf.gradients([b], [a], unnconnected_gradients='none') 231 sess.run(g1) # [None] 232 233 g2 = tf.gradients([b], [a], unconnected_gradients='zero') 234 sess.run(g2) # [array([[0., 0.]], dtype=float32)] 235 ``` 236 237 238 Args: 239 ys: A `Tensor` or list of tensors to be differentiated. 240 xs: A `Tensor` or list of tensors to be used for differentiation. 241 grad_ys: Optional. A `Tensor` or list of tensors the same size as 242 `ys` and holding the gradients computed for each y in `ys`. 243 name: Optional name to use for grouping all the gradient ops together. 244 defaults to 'gradients'. 245 gate_gradients: If True, add a tuple around the gradients returned 246 for an operations. This avoids some race conditions. 247 aggregation_method: Specifies the method used to combine gradient terms. 248 Accepted values are constants defined in the class `AggregationMethod`. 249 stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate 250 through. 251 unconnected_gradients: Optional. Specifies the gradient value returned when 252 the given input tensors are unconnected. Accepted values are constants 253 defined in the class `tf.UnconnectedGradients` and the default value is 254 `none`. 255 256 Returns: 257 A list of `sum(dy/dx)` for each x in `xs`. 258 259 Raises: 260 LookupError: if one of the operations between `x` and `y` does not 261 have a registered gradient function. 262 ValueError: if the arguments are invalid. 263 RuntimeError: if called in Eager mode. 264 265 """ 266 # Creating the gradient graph for control flow mutates Operations. 267 # _mutation_lock ensures a Session.run call cannot occur between creating and 268 # mutating new ops. 269 # pylint: disable=protected-access 270 with ops.get_default_graph()._mutation_lock(): 271 return gradients_util._GradientsHelper( 272 ys, xs, grad_ys, name, True, gate_gradients, 273 aggregation_method, stop_gradients, 274 unconnected_gradients) 275 # pylint: enable=protected-access 276 277 278# TODO(vrv): Make this available when we want to make it public. 279def _hessian_vector_product(ys, xs, v): 280 """Multiply the Hessian of `ys` wrt `xs` by `v`. 281 282 This is an efficient construction that uses a backprop-like approach 283 to compute the product between the Hessian and another vector. The 284 Hessian is usually too large to be explicitly computed or even 285 represented, but this method allows us to at least multiply by it 286 for the same big-O cost as backprop. 287 288 Implicit Hessian-vector products are the main practical, scalable way 289 of using second derivatives with neural networks. They allow us to 290 do things like construct Krylov subspaces and approximate conjugate 291 gradient descent. 292 293 Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y, 294 x, v)` will return an expression that evaluates to the same values 295 as (A + A.T) `v`. 296 297 Args: 298 ys: A scalar value, or a tensor or list of tensors to be summed to 299 yield a scalar. 300 xs: A list of tensors that we should construct the Hessian over. 301 v: A list of tensors, with the same shapes as xs, that we want to 302 multiply by the Hessian. 303 304 Returns: 305 A list of tensors (or if the list would be length 1, a single tensor) 306 containing the product between the Hessian and `v`. 307 308 Raises: 309 ValueError: `xs` and `v` have different length. 310 311 """ 312 313 # Validate the input 314 length = len(xs) 315 if len(v) != length: 316 raise ValueError("xs and v must have the same length.") 317 318 # First backprop 319 grads = gradients(ys, xs) 320 321 assert len(grads) == length 322 elemwise_products = [ 323 math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem)) 324 for grad_elem, v_elem in zip(grads, v) 325 if grad_elem is not None 326 ] 327 328 # Second backprop 329 return gradients(elemwise_products, xs) 330 331 332@tf_export(v1=["hessians"]) 333def hessians(ys, 334 xs, 335 name="hessians", 336 colocate_gradients_with_ops=False, 337 gate_gradients=False, 338 aggregation_method=None): 339 """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`. 340 341 `hessians()` adds ops to the graph to output the Hessian matrix of `ys` 342 with respect to `xs`. It returns a list of `Tensor` of length `len(xs)` 343 where each tensor is the Hessian of `sum(ys)`. 344 345 The Hessian is a matrix of second-order partial derivatives of a scalar 346 tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details). 347 348 Args: 349 ys: A `Tensor` or list of tensors to be differentiated. 350 xs: A `Tensor` or list of tensors to be used for differentiation. 351 name: Optional name to use for grouping all the gradient ops together. 352 defaults to 'hessians'. 353 colocate_gradients_with_ops: See `gradients()` documentation for details. 354 gate_gradients: See `gradients()` documentation for details. 355 aggregation_method: See `gradients()` documentation for details. 356 357 Returns: 358 A list of Hessian matrices of `sum(ys)` for each `x` in `xs`. 359 360 Raises: 361 LookupError: if one of the operations between `xs` and `ys` does not 362 have a registered gradient function. 363 """ 364 xs = gradients_util._AsList(xs) # pylint: disable=protected-access 365 kwargs = { 366 "colocate_gradients_with_ops": colocate_gradients_with_ops, 367 "gate_gradients": gate_gradients, 368 "aggregation_method": aggregation_method 369 } 370 # Compute first-order derivatives and iterate for each x in xs. 371 hessians = [] 372 _gradients = gradients(ys, xs, **kwargs) 373 for gradient, x in zip(_gradients, xs): 374 # change shape to one-dimension without graph branching 375 gradient = array_ops.reshape(gradient, [-1]) 376 377 # Declare an iterator and tensor array loop variables for the gradients. 378 n = array_ops.size(x) 379 loop_vars = [ 380 array_ops.constant(0, dtypes.int32), 381 tensor_array_ops.TensorArray(x.dtype, n) 382 ] 383 # Iterate over all elements of the gradient and compute second order 384 # derivatives. 385 _, hessian = control_flow_ops.while_loop( 386 lambda j, _: j < n, 387 lambda j, result: (j + 1, 388 result.write(j, gradients(gradient[j], x)[0])), 389 loop_vars 390 ) 391 392 _shape = array_ops.shape(x) 393 _reshaped_hessian = array_ops.reshape(hessian.stack(), 394 array_ops.concat((_shape, _shape), 0)) 395 hessians.append(_reshaped_hessian) 396 return hessians 397 398 399@tf_export("hessians", v1=[]) 400def HessiansV2(ys, 401 xs, 402 gate_gradients=False, 403 aggregation_method=None, 404 name="hessians"): 405 return hessians(ys, xs, name=name, gate_gradients=gate_gradients, 406 aggregation_method=aggregation_method) 407 408 409HessiansV2.__doc__ = hessians.__doc__ 410