1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Decorator to overrides the gradient for a function.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.eager import backprop 21from tensorflow.python.eager import context 22from tensorflow.python.eager import tape as tape_lib 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import gen_array_ops 27from tensorflow.python.ops import handle_data_util 28from tensorflow.python.ops import op_selector 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.util import nest 34from tensorflow.python.util import tf_decorator 35from tensorflow.python.util import tf_inspect 36from tensorflow.python.util.tf_export import tf_export 37 38 39VAR_OP_TYPES = [ 40 "VariableV2", 41 "VarHandleOp", 42] 43 44 45# TODO(allenl): Remove this alias and migrate callers. 46copy_handle_data = handle_data_util.copy_handle_data 47 48 49@tf_export("custom_gradient") 50def custom_gradient(f=None): 51 """Decorator to define a function with a custom gradient. 52 53 This decorator allows fine grained control over the gradients of a sequence 54 for operations. This may be useful for multiple reasons, including providing 55 a more efficient or numerically stable gradient for a sequence of operations. 56 57 For example, consider the following function that commonly occurs in the 58 computation of cross entropy and log likelihoods: 59 60 ```python 61 def log1pexp(x): 62 return tf.math.log(1 + tf.exp(x)) 63 ``` 64 65 Due to numerical instability, the gradient of this function evaluated at x=100 66 is NaN. For example: 67 68 ```python 69 x = tf.constant(100.) 70 y = log1pexp(x) 71 dy = tf.gradients(y, x) # Will be NaN when evaluated. 72 ``` 73 74 The gradient expression can be analytically simplified to provide numerical 75 stability: 76 77 ```python 78 @tf.custom_gradient 79 def log1pexp(x): 80 e = tf.exp(x) 81 def grad(dy): 82 return dy * (1 - 1 / (1 + e)) 83 return tf.math.log(1 + e), grad 84 ``` 85 86 With this definition, the gradient at x=100 will be correctly evaluated as 87 1.0. 88 89 The variable `dy` is defined as the upstream gradient. i.e. the gradient from 90 all the layers or functions originating from this layer. 91 92 By chain rule we know that 93 `dy/dx = dy/dx_0 * dx_0/dx_1 * ... * dx_i/dx_i+1 * ... * dx_n/dx` 94 95 In this case the gradient of our current function defined as 96 `dx_i/dx_i+1 = (1 - 1 / (1 + e))`. The upstream gradient `dy` would be 97 `dx_i+1/dx_i+2 * dx_i+2/dx_i+3 * ... * dx_n/dx`. The upstream gradient 98 multiplied by the current gradient is then passed downstream. 99 100 In case the function takes multiple variables as input, the `grad` 101 function must also return the same number of variables. 102 We take the function `z = x * y` as an example. 103 104 >>> @tf.custom_gradient 105 ... def bar(x, y): 106 ... def grad(upstream): 107 ... dz_dx = y 108 ... dz_dy = x 109 ... return upstream * dz_dx, upstream * dz_dy 110 ... z = x * y 111 ... return z, grad 112 >>> x = tf.constant(2.0, dtype=tf.float32) 113 >>> y = tf.constant(3.0, dtype=tf.float32) 114 >>> with tf.GradientTape(persistent=True) as tape: 115 ... tape.watch(x) 116 ... tape.watch(y) 117 ... z = bar(x, y) 118 >>> z 119 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 120 >>> tape.gradient(z, x) 121 <tf.Tensor: shape=(), dtype=float32, numpy=3.0> 122 >>> tape.gradient(z, y) 123 <tf.Tensor: shape=(), dtype=float32, numpy=2.0> 124 125 Nesting custom gradients can lead to unintuitive results. The default 126 behavior does not correspond to n-th order derivatives. For example 127 128 ```python 129 @tf.custom_gradient 130 def op(x): 131 y = op1(x) 132 @tf.custom_gradient 133 def grad_fn(dy): 134 gdy = op2(x, y, dy) 135 def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. 136 return op3(x, y, dy, ddy) 137 return gdy, grad_grad_fn 138 return y, grad_fn 139 ``` 140 141 The function `grad_grad_fn` will be calculating the first order gradient 142 of `grad_fn` with respect to `dy`, which is used to generate forward-mode 143 gradient graphs from backward-mode gradient graphs, but is not the same as 144 the second order gradient of `op` with respect to `x`. 145 146 Instead, wrap nested `@tf.custom_gradients` in another function: 147 148 ```python 149 @tf.custom_gradient 150 def op_with_fused_backprop(x): 151 y, x_grad = fused_op(x) 152 def first_order_gradient(dy): 153 @tf.custom_gradient 154 def first_order_custom(unused_x): 155 def second_order_and_transpose(ddy): 156 return second_order_for_x(...), gradient_wrt_dy(...) 157 return x_grad, second_order_and_transpose 158 return dy * first_order_custom(x) 159 return y, first_order_gradient 160 ``` 161 162 Additional arguments to the inner `@tf.custom_gradient`-decorated function 163 control the expected return values of the innermost function. 164 165 See also `tf.RegisterGradient` which registers a gradient function for a 166 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows 167 for fine grained control over the gradient computation of a sequence of 168 operations. 169 170 Note that if the decorated function uses `Variable`s, the enclosing variable 171 scope must be using `ResourceVariable`s. 172 173 Args: 174 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: 175 - `x` is a sequence of (nested structures of) `Tensor` inputs to the 176 function. 177 - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow 178 operations in `f` to `x`. 179 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns 180 a list of `Tensor`s the same size as (flattened) `x` - the derivatives 181 of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is 182 a sequence of `Tensor`s the same size as (flattened) `y` holding the 183 initial value gradients for each `Tensor` in `y`. 184 185 In a pure mathematical sense, a vector-argument vector-valued function 186 `f`'s derivatives should be its Jacobian matrix `J`. Here we are 187 expressing the Jacobian `J` as a function `grad_fn` which defines how 188 `J` will transform a vector `grad_ys` when left-multiplied with it 189 (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional 190 representation of a matrix is convenient to use for chain-rule 191 calculation (in e.g. the back-propagation algorithm). 192 193 If `f` uses `Variable`s (that are not part of the 194 inputs), i.e. through `get_variable`, then `grad_fn` should have 195 signature `g(*grad_ys, variables=None)`, where `variables` is a list of 196 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where 197 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` 198 with the derivatives of `Tensor`s in `y` with respect to the variables 199 (that is, grad_vars has one Tensor per variable in variables). 200 201 Returns: 202 A function `h(x)` which returns the same value as `f(x)[0]` and whose 203 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. 204 """ 205 206 if f is None: 207 return lambda f: custom_gradient(f=f) 208 209 @Bind.decorator 210 def decorated(wrapped, args, kwargs): 211 """Decorated function with custom gradient.""" 212 # raise ValueError("PW: trap") 213 214 if context.executing_eagerly(): 215 return _eager_mode_decorator(wrapped, args, kwargs) 216 else: 217 return _graph_mode_decorator(wrapped, args, kwargs) 218 219 return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter 220 221 222class Bind(object): 223 """When called evaluates `d(f, args, kwargs)` but supports binding `f`. 224 225 >>> @Bind.decorator 226 ... def my_decorator(f, args, kwargs): 227 ... print("my_decorator called with", args, kwargs) 228 ... return f(*args, **kwargs) 229 230 >>> class Foo(object): 231 ... @my_decorator 232 ... def bar(self, a, b, c): 233 ... return a * b * c 234 235 >>> Foo.bar(None, 1, 2, c=3) 236 my_decorator called with (None, 1, 2) {'c': 3} 237 6 238 239 >>> foo = Foo() 240 >>> foo.bar(1, 2, c=3) 241 my_decorator called with (1, 2) {'c': 3} 242 6 243 """ 244 245 @classmethod 246 def decorator(cls, d): 247 return lambda f: Bind(f, d) 248 249 def __init__(self, f, d): 250 self._f = f 251 self._d = d 252 253 def __get__(self, instance, owner): 254 if instance is not None: 255 f = self._f.__get__(instance, owner) 256 return tf_decorator.make_decorator(f, Bind(f, self._d)) 257 else: 258 return self 259 260 def __call__(self, *a, **k): 261 return self._d(self._f, a, k) 262 263 264def get_variable_by_name(var_name): 265 """Given a variable name, retrieves a handle on the tensorflow Variable.""" 266 267 candidate_vars = ops.get_collection( 268 ops.GraphKeys.GLOBAL_VARIABLES, scope="{}:0".format(var_name)) 269 if len(candidate_vars) >= 1: 270 # Filter out non-trainable variables. 271 candidate_vars = [v for v in candidate_vars if v.trainable] 272 else: 273 raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) 274 275 if len(candidate_vars) == 1: 276 return candidate_vars[0] 277 elif len(candidate_vars) > 1: 278 raise ValueError( 279 "Unsuccessful at finding trainable variable {}. " 280 "Number of candidates: {}. " 281 "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) 282 else: 283 # The variable is not trainable. 284 return None 285 286 287def _get_dependent_variables(input_ops, output_ops): 288 """Finds variables involved in the subgraph between input_ops and output_ops. 289 290 Args: 291 input_ops: Flattened list of input ops 292 output_ops: Flattened list of output ops 293 294 Returns: 295 A list of variables 296 """ 297 298 # avoids the edge-case when input_ops == output_ops. 299 output_ops = nest.map_structure(gen_array_ops.identity, output_ops) 300 inbetween_ops = op_selector.get_backward_walk_ops( 301 seed_ops=output_ops, 302 stop_at_ts=input_ops, 303 inclusive=False, 304 only_differentiable=True) 305 var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) 306 var_names = (op.name for op in var_ops) 307 tf_vars = (get_variable_by_name(var_name) for var_name in var_names) 308 tf_vars = [v for v in tf_vars if v is not None] 309 return tf_vars 310 311 312def _graph_mode_decorator(f, args, kwargs): 313 """Implement custom gradient decorator for graph mode.""" 314 # TODO(rsepassi): Add support for kwargs 315 if kwargs: 316 raise ValueError( 317 "The custom_gradient decorator currently supports keywords " 318 "arguments only when eager execution is enabled.") 319 name = "CustomGradient-%s" % ops.uid() 320 args = nest.map_structure(ops.convert_to_tensor, args) 321 322 # Checking global and local variables attempts to ensure that no non-resource 323 # Variables are added to the graph. 324 current_var_scope = variable_scope.get_variable_scope() 325 before_vars = set([ 326 v.ref() for v in current_var_scope.global_variables() + 327 current_var_scope.local_variables() 328 ]) 329 with tape_lib.VariableWatcher() as variable_watcher: 330 result, grad_fn = f(*args) 331 332 args = nest.flatten(args) 333 flat_result = nest.flatten(result) 334 flat_result_len = len(flat_result) 335 336 after_vars = set([ 337 v.ref() for v in current_var_scope.global_variables() + 338 current_var_scope.local_variables() 339 ]) 340 new_vars = after_vars - before_vars 341 new_vars_list = [v.deref() for v in new_vars] 342 for v in new_vars_list: 343 if not resource_variable_ops.is_resource_variable(v): 344 raise TypeError( 345 "All variables used by a function wrapped with @custom_gradient must " 346 "be `ResourceVariable`s. Ensure that no `variable_scope` is created " 347 "with `use_resource=False`.") 348 349 # The variables that grad_fn needs to return gradients for are the set of 350 # variables used that are *not* part of the inputs. 351 variables_in_tape = frozenset([ 352 v.ref() for v in variable_watcher.watched_variables() 353 ]) 354 355 graphs = {getattr(o, "graph", None) for o in flat_result} 356 # Not all results may be tensors. However, we want to ensure all tensor 357 # outputs are from the same graph and get a list of captured inputs for 358 # variable search 359 graphs.discard(None) # Discard non-graph outputs 360 if graphs: 361 if len(graphs) > 1: 362 raise ValueError( 363 "All custom_gradient outputs should be from the same graph") 364 output_graph = graphs.pop() 365 filtered_input_tensors = [] 366 for i in args: 367 if i.graph == output_graph: 368 filtered_input_tensors.append(i) 369 else: 370 filtered_input_tensors = args 371 372 variables_in_subgraph = frozenset([ 373 v.ref() for v in _get_dependent_variables( 374 input_ops=filtered_input_tensors, output_ops=flat_result) 375 ]) 376 variables = list( 377 [v.deref() for v in variables_in_subgraph.union(variables_in_tape)]) 378 379 grad_argspec = tf_inspect.getfullargspec(grad_fn) 380 variables_in_signature = ("variables" in grad_argspec.args or 381 "variables" in grad_argspec.kwonlyargs or 382 grad_argspec.varkw) 383 if variables and not variables_in_signature: 384 raise TypeError( 385 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 386 "since function uses variables: {}".format(variables)) 387 if variables_in_signature and not variables: 388 # User seems to intend to use variables but none were captured. 389 logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " 390 "no ResourceVariables were used on the forward pass.") 391 392 all_tensors = flat_result + args + variables 393 394 def tape_grad_fn(*result_grads): 395 """Custom grad fn wrapper.""" 396 result_grads = result_grads[:flat_result_len] 397 if variables: 398 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 399 if len(variable_grads) != len(variables): 400 raise ValueError("Must return gradient for each variable from " 401 "@custom_gradient grad_fn.") 402 else: 403 input_grads = grad_fn(*result_grads) 404 variable_grads = [] 405 406 # Need to return one value per input to the IdentityN, so pad the 407 # gradients of the inputs of the custom_gradient function with the 408 # gradients of the outputs as well. 409 input_grads = nest.flatten(input_grads) 410 return ([None] * flat_result_len) + input_grads + variable_grads 411 412 @ops.RegisterGradient(name) 413 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable 414 """Custom grad fn wrapper.""" 415 return tape_grad_fn(*result_grads) 416 417 original_tensors = all_tensors 418 with ops.get_default_graph().gradient_override_map({"IdentityN": name}): 419 all_tensors = array_ops.identity_n(all_tensors) 420 421 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] 422 423 # Propagate handle data for happier shape inference for resource variables. 424 for i, t in enumerate(original_tensors): 425 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): 426 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access 427 tape_lib.record_operation( 428 f.__name__, all_tensors, original_tensors, tape_grad_fn) 429 for ot, t in zip(original_tensors, all_tensors): 430 copy_handle_data(ot, t) 431 return nest.pack_sequence_as( 432 structure=result, flat_sequence=all_tensors[:flat_result_len]) 433 434 435def _eager_mode_decorator(f, args, kwargs): 436 """Implement custom gradient decorator for eager mode.""" 437 with tape_lib.VariableWatcher() as variable_watcher: 438 result, grad_fn = f(*args, **kwargs) 439 args = nest.flatten(args) 440 all_inputs = list(args) + list(kwargs.values()) 441 # The variables that grad_fn needs to return gradients for are the set of 442 # variables used that are *not* part of the inputs. 443 variables = [ 444 v.deref() # pylint: disable=g-complex-comprehension 445 for v in set(v.ref() for v in variable_watcher.watched_variables()) 446 if all(v.deref() is not i for i in all_inputs) 447 ] 448 grad_argspec = tf_inspect.getfullargspec(grad_fn) 449 if (variables and ("variables" not in grad_argspec.args) and 450 ("variables" not in grad_argspec.kwonlyargs) and 451 not grad_argspec.varkw): 452 raise TypeError( 453 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 454 "since function uses variables: {}".format(variables)) 455 flat_result = nest.flatten(result) 456 # TODO(apassos) consider removing the identity below. 457 flat_result = [gen_array_ops.identity(x) for x in flat_result] 458 459 input_tensors = [ops.convert_to_tensor(x) for x 460 in list(args) + list(variables)] 461 462 recorded_inputs = input_tensors 463 arg_count = len(args) 464 465 def actual_grad_fn(*result_grads): 466 """Custom grad fn wrapper.""" 467 if variables: 468 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 469 if len(variable_grads) != len(variables): 470 raise ValueError("Must return gradient for each variable from " 471 "@custom_gradient grad_fn.") 472 else: 473 input_grads = grad_fn(*result_grads) 474 variable_grads = [] 475 flat_grads = nest.flatten(input_grads) 476 if len(flat_grads) != arg_count: 477 raise ValueError( 478 "custom_gradient function expected to return", arg_count, 479 "gradients but returned", len(flat_grads), "instead.") 480 return flat_grads + variable_grads 481 482 tape_lib.record_operation(f.__name__, flat_result, recorded_inputs, 483 actual_grad_fn) 484 flat_result = list(flat_result) 485 return nest.pack_sequence_as(result, flat_result) 486 487 488@tf_export("recompute_grad") 489def recompute_grad(f): 490 """An eager-compatible version of recompute_grad. 491 492 For f(*args, **kwargs), this supports gradients with respect to args or 493 kwargs, but kwargs are currently only supported in eager-mode. 494 Note that for keras layer and model objects, this is handled automatically. 495 496 Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not 497 be able to access the member variables of that object, because `g` returns 498 through the wrapper function `inner`. When recomputing gradients through 499 objects that inherit from keras, we suggest keeping a reference to the 500 underlying object around for the purpose of accessing these variables. 501 502 Args: 503 f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. 504 505 Returns: 506 A function `g` that wraps `f`, but which recomputes `f` on the backwards 507 pass of a gradient call. 508 """ 509 # TODO(cdfreeman) Add is_recomputing functionality from graph mode version 510 511 @custom_gradient 512 def inner(*args, **kwargs): 513 """Inner function closure for calculating gradients.""" 514 current_var_scope = variable_scope.get_variable_scope() 515 with tape_lib.stop_recording(): 516 result = f(*args, **kwargs) 517 518 def grad_wrapper(*wrapper_args, **grad_kwargs): 519 """Wrapper function to accomodate lack of kwargs in graph mode decorator.""" 520 521 @custom_gradient 522 def inner_recompute_grad(*dresult): 523 """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" 524 # Gradient calculation for reverse mode autodiff. 525 variables = grad_kwargs.get("variables") 526 with backprop.GradientTape() as t: 527 id_args = nest.map_structure(gen_array_ops.identity, args) 528 t.watch(id_args) 529 if variables is not None: 530 t.watch(variables) 531 with ops.control_dependencies(dresult): 532 with variable_scope.variable_scope(current_var_scope): 533 result = f(*id_args, **kwargs) 534 kw_vars = [] 535 if variables is not None: 536 kw_vars = list(variables) 537 grads = t.gradient( 538 result, 539 list(id_args) + kw_vars, 540 output_gradients=dresult, 541 unconnected_gradients=UnconnectedGradients.ZERO) 542 543 def transpose(*t_args, **t_kwargs): 544 """Gradient function calculation for forward mode autodiff.""" 545 # Just throw an error since gradients / activations are not stored on 546 # tape for recompute. 547 raise NotImplementedError( 548 "recompute_grad tried to transpose grad of {}. " 549 "Consider not using recompute_grad in forward mode" 550 "autodiff".format(f.__name__)) 551 552 return (grads[:len(id_args)], grads[len(id_args):]), transpose 553 554 return inner_recompute_grad(*wrapper_args) 555 556 return result, grad_wrapper 557 558 return inner 559 560 561@tf_export("grad_pass_through") 562def grad_pass_through(f): 563 """Creates a grad-pass-through op with the forward behavior provided in f. 564 565 Use this function to wrap any op, maintaining its behavior in the forward 566 pass, but replacing the original op in the backward graph with an identity. 567 For example: 568 569 ```python 570 x = tf.Variable(1.0, name="x") 571 z = tf.Variable(3.0, name="z") 572 573 with tf.GradientTape() as tape: 574 # y will evaluate to 9.0 575 y = tf.grad_pass_through(x.assign)(z**2) 576 # grads will evaluate to 6.0 577 grads = tape.gradient(y, z) 578 ``` 579 580 Another example is a 'differentiable' moving average approximation, where 581 gradients are allowed to flow into the last value fed to the moving average, 582 but the moving average is still used for the forward pass: 583 584 ```python 585 x = ... # Some scalar value 586 # A moving average object, we don't need to know how this is implemented 587 moving_average = MovingAverage() 588 with backprop.GradientTape() as tape: 589 # mavg_x will evaluate to the current running average value 590 mavg_x = tf.grad_pass_through(moving_average)(x) 591 grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 592 ``` 593 594 Args: 595 f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` 596 outputs. 597 598 Returns: 599 A function `h(x)` which returns the same values as `f(x)` and whose 600 gradients are the same as those of an identity function. 601 """ 602 @custom_gradient 603 def _grad_pass_through_op(*args, **kwargs): 604 def grad(*args, **kwargs): 605 variables = kwargs.get("variables") 606 if variables is not None: 607 # Variables involved in the wrapped op will not receive gradients. 608 return args, [None] * len(variables) 609 return args 610 return f(*args, **kwargs), grad 611 return tf_decorator.make_decorator(f, _grad_pass_through_op) 612