1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Decorator to overrides the gradient for a function.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.eager import backprop 21from tensorflow.python.eager import context 22from tensorflow.python.eager import tape as tape_lib 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import gen_array_ops 27from tensorflow.python.ops import handle_data_util 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import op_selector 30from tensorflow.python.ops import resource_variable_ops 31from tensorflow.python.ops import variable_scope 32from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.util import nest 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util import tf_inspect 37from tensorflow.python.util.tf_export import tf_export 38 39 40VAR_OP_TYPES = [ 41 "VariableV2", 42 "VarHandleOp", 43] 44 45 46@tf_export("custom_gradient") 47def custom_gradient(f=None): 48 """Decorator to define a function with a custom gradient. 49 50 This decorator allows fine grained control over the gradients of a sequence 51 for operations. This may be useful for multiple reasons, including providing 52 a more efficient or numerically stable gradient for a sequence of operations. 53 54 For example, consider the following function that commonly occurs in the 55 computation of cross entropy and log likelihoods: 56 57 ```python 58 def log1pexp(x): 59 return tf.math.log(1 + tf.exp(x)) 60 ``` 61 62 Due to numerical instability, the gradient of this function evaluated at x=100 63 is NaN. For example: 64 65 ```python 66 x = tf.constant(100.) 67 y = log1pexp(x) 68 dy_dx = tf.gradients(y, x) # Will be NaN when evaluated. 69 ``` 70 71 The gradient expression can be analytically simplified to provide numerical 72 stability: 73 74 ```python 75 @tf.custom_gradient 76 def log1pexp(x): 77 e = tf.exp(x) 78 def grad(upstream): 79 return upstream * (1 - 1 / (1 + e)) 80 return tf.math.log(1 + e), grad 81 ``` 82 83 With this definition, the gradient `dy_dx` at `x = 100` will be correctly 84 evaluated as 1.0. 85 86 The variable `upstream` is defined as the upstream gradient. i.e. the gradient 87 from all the layers or functions originating from this layer. The above 88 example has no upstream functions, therefore `upstream = dy/dy = 1.0`. 89 90 Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`, 91 `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By 92 chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * 93 dx_i/dx_i-1 * ... * dx_1/dx_0`. 94 95 In this case the gradient of our current function defined as 96 `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be 97 `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient 98 multiplied by the current gradient is then passed downstream. 99 100 In case the function takes multiple variables as input, the `grad` 101 function must also return the same number of variables. 102 We take the function `z = x * y` as an example. 103 104 >>> @tf.custom_gradient 105 ... def bar(x, y): 106 ... def grad(upstream): 107 ... dz_dx = y 108 ... dz_dy = x 109 ... return upstream * dz_dx, upstream * dz_dy 110 ... z = x * y 111 ... return z, grad 112 >>> x = tf.constant(2.0, dtype=tf.float32) 113 >>> y = tf.constant(3.0, dtype=tf.float32) 114 >>> with tf.GradientTape(persistent=True) as tape: 115 ... tape.watch(x) 116 ... tape.watch(y) 117 ... z = bar(x, y) 118 >>> z 119 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 120 >>> tape.gradient(z, x) 121 <tf.Tensor: shape=(), dtype=float32, numpy=3.0> 122 >>> tape.gradient(z, y) 123 <tf.Tensor: shape=(), dtype=float32, numpy=2.0> 124 125 Nesting custom gradients can lead to unintuitive results. The default 126 behavior does not correspond to n-th order derivatives. For example 127 128 ```python 129 @tf.custom_gradient 130 def op(x): 131 y = op1(x) 132 @tf.custom_gradient 133 def grad_fn(dy): 134 gdy = op2(x, y, dy) 135 def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. 136 return op3(x, y, dy, ddy) 137 return gdy, grad_grad_fn 138 return y, grad_fn 139 ``` 140 141 The function `grad_grad_fn` will be calculating the first order gradient 142 of `grad_fn` with respect to `dy`, which is used to generate forward-mode 143 gradient graphs from backward-mode gradient graphs, but is not the same as 144 the second order gradient of `op` with respect to `x`. 145 146 Instead, wrap nested `@tf.custom_gradients` in another function: 147 148 ```python 149 @tf.custom_gradient 150 def op_with_fused_backprop(x): 151 y, x_grad = fused_op(x) 152 def first_order_gradient(dy): 153 @tf.custom_gradient 154 def first_order_custom(unused_x): 155 def second_order_and_transpose(ddy): 156 return second_order_for_x(...), gradient_wrt_dy(...) 157 return x_grad, second_order_and_transpose 158 return dy * first_order_custom(x) 159 return y, first_order_gradient 160 ``` 161 162 Additional arguments to the inner `@tf.custom_gradient`-decorated function 163 control the expected return values of the innermost function. 164 165 The examples above illustrate how to specify custom gradients for functions 166 which do not read from variables. The following example uses variables, which 167 require special handling because they are effectively inputs of the forward 168 function. 169 170 >>> weights = tf.Variable(tf.ones([2])) # Trainable variable weights 171 >>> @tf.custom_gradient 172 ... def linear_poly(x): 173 ... # Creating polynomial 174 ... poly = weights[1] * x + weights[0] 175 ... 176 ... def grad_fn(dpoly, variables): 177 ... # dy/dx = weights[1] and we need to left multiply dpoly 178 ... grad_xs = dpoly * weights[1] # Scalar gradient 179 ... 180 ... grad_vars = [] # To store gradients of passed variables 181 ... assert variables is not None 182 ... assert len(variables) == 1 183 ... assert variables[0] is weights 184 ... # Manually computing dy/dweights 185 ... dy_dw = dpoly * tf.stack([x ** 1, x ** 0]) 186 ... grad_vars.append( 187 ... tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1) 188 ... ) 189 ... return grad_xs, grad_vars 190 ... return poly, grad_fn 191 >>> x = tf.constant([1., 2., 3.]) 192 >>> with tf.GradientTape(persistent=True) as tape: 193 ... tape.watch(x) 194 ... poly = linear_poly(x) 195 >>> poly # poly = x + 1 196 <tf.Tensor: shape=(3,), 197 dtype=float32, 198 numpy=array([2., 3., 4.], dtype=float32)> 199 >>> tape.gradient(poly, x) # conventional scalar gradient dy/dx 200 <tf.Tensor: shape=(3,), 201 dtype=float32, 202 numpy=array([1., 1., 1.], dtype=float32)> 203 >>> tape.gradient(poly, weights) 204 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)> 205 206 Above example illustrates usage of trainable variable `weights`. 207 In the example, the inner `grad_fn` accepts an extra `variables` input 208 parameter and also returns an extra `grad_vars` output. That extra argument 209 is passed if the forward function reads any variables. You need to 210 compute the gradient w.r.t. each of those `variables` and output it as a list 211 of `grad_vars`. Note here that default value of `variables` is set to `None` 212 when no variables are used in the forward function. 213 214 It should be noted `tf.GradientTape` is still watching the forward pass of a 215 `tf.custom_gradient`, and will use the ops it watches. As a consequence, 216 calling `tf.function` while the tape is still watching leads 217 to a gradient graph being built. If an op is used in `tf.function` without 218 registered gradient, a `LookupError` will be raised. 219 220 Users can insert `tf.stop_gradient` to customize this behavior. This 221 is demonstrated in the example below. `tf.random.shuffle` does not have a 222 registered gradient. As a result `tf.stop_gradient` is used to avoid the 223 `LookupError`. 224 225 ```python 226 x = tf.constant([0.3, 0.5], dtype=tf.float32) 227 228 @tf.custom_gradient 229 def test_func_with_stop_grad(x): 230 @tf.function 231 def _inner_func(): 232 # Avoid exception during the forward pass 233 return tf.stop_gradient(tf.random.shuffle(x)) 234 # return tf.random.shuffle(x) # This will raise 235 236 res = _inner_func() 237 def grad(upstream): 238 return upstream # Arbitrarily defined custom gradient 239 return res, grad 240 241 with tf.GradientTape() as g: 242 g.watch(x) 243 res = test_func_with_stop_grad(x) 244 245 g.gradient(res, x) 246 ``` 247 248 See also `tf.RegisterGradient` which registers a gradient function for a 249 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows 250 for fine grained control over the gradient computation of a sequence of 251 operations. 252 253 Note that if the decorated function uses `Variable`s, the enclosing variable 254 scope must be using `ResourceVariable`s. 255 256 Args: 257 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: 258 - `x` is a sequence of (nested structures of) `Tensor` inputs to the 259 function. 260 - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow 261 operations in `f` to `x`. 262 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns 263 a list of `Tensor`s the same size as (flattened) `x` - the derivatives 264 of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is 265 a sequence of `Tensor`s the same size as (flattened) `y` holding the 266 initial value gradients for each `Tensor` in `y`. 267 268 In a pure mathematical sense, a vector-argument vector-valued function 269 `f`'s derivatives should be its Jacobian matrix `J`. Here we are 270 expressing the Jacobian `J` as a function `grad_fn` which defines how 271 `J` will transform a vector `grad_ys` when left-multiplied with it 272 (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional 273 representation of a matrix is convenient to use for chain-rule 274 calculation (in e.g. the back-propagation algorithm). 275 276 If `f` uses `Variable`s (that are not part of the 277 inputs), i.e. through `get_variable`, then `grad_fn` should have 278 signature `g(*grad_ys, variables=None)`, where `variables` is a list of 279 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where 280 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` 281 with the derivatives of `Tensor`s in `y` with respect to the variables 282 (that is, grad_vars has one Tensor per variable in variables). 283 284 Returns: 285 A function `h(x)` which returns the same value as `f(x)[0]` and whose 286 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. 287 """ 288 289 if f is None: 290 return lambda f: custom_gradient(f=f) 291 292 @Bind.decorator 293 def decorated(wrapped, args, kwargs): 294 """Decorated function with custom gradient.""" 295 if context.executing_eagerly(): 296 return _eager_mode_decorator(wrapped, args, kwargs) 297 else: 298 return _graph_mode_decorator(wrapped, args, kwargs) 299 300 return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter 301 302 303class Bind(object): 304 """When called evaluates `d(f, args, kwargs)` but supports binding `f`. 305 306 >>> @Bind.decorator 307 ... def my_decorator(f, args, kwargs): 308 ... print("my_decorator called with", args, kwargs) 309 ... return f(*args, **kwargs) 310 311 >>> class Foo(object): 312 ... @my_decorator 313 ... def bar(self, a, b, c): 314 ... return a * b * c 315 316 >>> Foo.bar(None, 1, 2, c=3) 317 my_decorator called with (None, 1, 2) {'c': 3} 318 6 319 320 >>> foo = Foo() 321 >>> foo.bar(1, 2, c=3) 322 my_decorator called with (1, 2) {'c': 3} 323 6 324 """ 325 326 @classmethod 327 def decorator(cls, d): 328 return lambda f: Bind(f, d) 329 330 def __init__(self, f, d): 331 self._f = f 332 self._d = d 333 334 def __get__(self, instance, owner): 335 if instance is not None: 336 f = self._f.__get__(instance, owner) 337 return tf_decorator.make_decorator(f, Bind(f, self._d)) 338 else: 339 return self 340 341 def __call__(self, *a, **k): 342 return self._d(self._f, a, k) 343 344 345def get_variable_by_name(var_name): 346 """Given a variable name, retrieves a handle on the tensorflow Variable.""" 347 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 348 349 def _filter_fn(item): 350 try: 351 return var_name == item.op.name 352 except AttributeError: 353 # Collection items without operation are ignored. 354 return False 355 356 candidate_vars = list(filter(_filter_fn, global_vars)) 357 358 if len(candidate_vars) >= 1: 359 # Filter out non-trainable variables. 360 candidate_vars = [v for v in candidate_vars if v.trainable] 361 else: 362 raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) 363 364 if len(candidate_vars) == 1: 365 return candidate_vars[0] 366 elif len(candidate_vars) > 1: 367 raise ValueError( 368 "Unsuccessful at finding trainable variable {}. " 369 "Number of candidates: {}. " 370 "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) 371 else: 372 # The variable is not trainable. 373 return None 374 375 376def _get_dependent_variables(input_ops, output_ops): 377 """Finds variables involved in the subgraph between input_ops and output_ops. 378 379 Args: 380 input_ops: Flattened list of input ops 381 output_ops: Flattened list of output ops 382 383 Returns: 384 A list of variables 385 """ 386 387 # avoids the edge-case when input_ops == output_ops. 388 output_ops = nest.map_structure(gen_array_ops.identity, output_ops) 389 inbetween_ops = op_selector.get_backward_walk_ops( 390 seed_ops=output_ops, 391 stop_at_ts=input_ops, 392 inclusive=False, 393 only_differentiable=True) 394 var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) 395 var_names = (op.name for op in var_ops) 396 tf_vars = (get_variable_by_name(var_name) for var_name in var_names) 397 tf_vars = [v for v in tf_vars if v is not None] 398 return tf_vars 399 400 401def generate_name(): 402 return "CustomGradient-%s" % ops.uid() 403 404 405def _graph_mode_decorator(f, args, kwargs): 406 """Implement custom gradient decorator for graph mode.""" 407 # TODO(rsepassi): Add support for kwargs 408 if kwargs: 409 raise ValueError( 410 "The custom_gradient decorator currently supports keywords " 411 "arguments only when eager execution is enabled.") 412 name = generate_name() 413 args = nest.map_structure(ops.convert_to_tensor, args) 414 415 # Checking global and local variables attempts to ensure that no non-resource 416 # Variables are added to the graph. 417 current_var_scope = variable_scope.get_variable_scope() 418 before_vars = set([ 419 v.ref() for v in current_var_scope.global_variables() + 420 current_var_scope.local_variables() 421 ]) 422 with tape_lib.VariableWatcher() as variable_watcher: 423 result, grad_fn = f(*args) 424 425 args = nest.flatten(args) 426 flat_result = nest.flatten(result) 427 flat_result_len = len(flat_result) 428 429 after_vars = set([ 430 v.ref() for v in current_var_scope.global_variables() + 431 current_var_scope.local_variables() 432 ]) 433 new_vars = after_vars - before_vars 434 new_vars_list = [v.deref() for v in new_vars] 435 for v in new_vars_list: 436 if not resource_variable_ops.is_resource_variable(v): 437 raise TypeError( 438 "All variables used by a function wrapped with @custom_gradient must " 439 "be `ResourceVariable`s. Ensure that no `variable_scope` is created " 440 "with `use_resource=False`.") 441 442 # The variables that grad_fn needs to return gradients for are the set of 443 # variables used that are *not* part of the inputs. 444 variables_in_tape = frozenset([ 445 v.ref() for v in variable_watcher.watched_variables() 446 ]) 447 448 graphs = {getattr(o, "graph", None) for o in flat_result} 449 # Not all results may be tensors. However, we want to ensure all tensor 450 # outputs are from the same graph and get a list of captured inputs for 451 # variable search 452 graphs.discard(None) # Discard non-graph outputs 453 if graphs: 454 if len(graphs) > 1: 455 raise ValueError( 456 "All custom_gradient outputs should be from the same graph") 457 output_graph = graphs.pop() 458 filtered_input_tensors = [] 459 for i in args: 460 if i.graph == output_graph: 461 filtered_input_tensors.append(i) 462 else: 463 filtered_input_tensors = args 464 465 variables_in_subgraph = frozenset([ 466 v.ref() for v in _get_dependent_variables( 467 input_ops=filtered_input_tensors, output_ops=flat_result) 468 ]) 469 variables = sorted( 470 [v.deref() for v in variables_in_subgraph.union(variables_in_tape)], 471 key=lambda v: v.name) 472 473 grad_argspec = tf_inspect.getfullargspec(grad_fn) 474 variables_in_signature = ("variables" in grad_argspec.args or 475 "variables" in grad_argspec.kwonlyargs or 476 grad_argspec.varkw) 477 if variables and not variables_in_signature: 478 raise TypeError( 479 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 480 "since function uses variables: {}".format(variables)) 481 if variables_in_signature and not variables: 482 # User seems to intend to use variables but none were captured. 483 logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " 484 "no ResourceVariables were used on the forward pass.") 485 486 all_tensors = flat_result + args + variables 487 488 def tape_grad_fn(*result_grads): 489 """Custom grad fn wrapper.""" 490 result_grads = result_grads[:flat_result_len] 491 if variables: 492 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 493 if len(variable_grads) != len(variables): 494 raise ValueError("Must return gradient for each variable from " 495 "@custom_gradient grad_fn.") 496 else: 497 input_grads = grad_fn(*result_grads) 498 variable_grads = [] 499 500 # Need to return one value per input to the IdentityN, so pad the 501 # gradients of the inputs of the custom_gradient function with the 502 # gradients of the outputs as well. 503 input_grads = nest.flatten(input_grads) 504 return ([None] * flat_result_len) + input_grads + variable_grads 505 506 @ops.RegisterGradient(name) 507 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable 508 """Custom grad fn wrapper.""" 509 return tape_grad_fn(*result_grads) 510 511 original_tensors = all_tensors 512 with ops.get_default_graph().gradient_override_map({"IdentityN": name}): 513 all_tensors = array_ops.identity_n(all_tensors) 514 515 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] 516 517 # Propagate handle data for happier shape inference for resource variables. 518 for i, t in enumerate(original_tensors): 519 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): 520 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access 521 tape_lib.record_operation( 522 f.__name__, all_tensors, original_tensors, tape_grad_fn) 523 for ot, t in zip(original_tensors, all_tensors): 524 handle_data_util.copy_handle_data(ot, t) 525 return nest.pack_sequence_as( 526 structure=result, flat_sequence=all_tensors[:flat_result_len]) 527 528 529def _eager_mode_decorator(f, args, kwargs): 530 """Implement custom gradient decorator for eager mode.""" 531 with tape_lib.VariableWatcher() as variable_watcher: 532 result, grad_fn = f(*args, **kwargs) 533 args = nest.flatten(args) 534 all_inputs = list(args) + list(kwargs.values()) 535 # The variables that grad_fn needs to return gradients for are the set of 536 # variables used that are *not* part of the inputs. 537 variables = [ 538 v.deref() # pylint: disable=g-complex-comprehension 539 for v in set(v.ref() for v in variable_watcher.watched_variables()) 540 if all(v.deref() is not i for i in all_inputs) 541 ] 542 grad_argspec = tf_inspect.getfullargspec(grad_fn) 543 if (variables and ("variables" not in grad_argspec.args) and 544 ("variables" not in grad_argspec.kwonlyargs) and 545 not grad_argspec.varkw): 546 raise TypeError( 547 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 548 "since function uses variables: {}".format(variables)) 549 flat_result = nest.flatten(result) 550 # TODO(apassos) consider removing the identity below. 551 flat_result = [gen_array_ops.identity(x) for x in flat_result] 552 553 input_tensors = [ops.convert_to_tensor(x) for x 554 in list(args) + list(variables)] 555 556 recorded_inputs = input_tensors 557 arg_count = len(args) 558 559 def actual_grad_fn(*result_grads): 560 """Custom grad fn wrapper.""" 561 if variables: 562 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 563 if len(variable_grads) != len(variables): 564 raise ValueError("Must return gradient for each variable from " 565 "@custom_gradient grad_fn.") 566 else: 567 input_grads = grad_fn(*result_grads) 568 variable_grads = [] 569 flat_grads = nest.flatten(input_grads) 570 if len(flat_grads) != arg_count: 571 raise ValueError( 572 "custom_gradient function expected to return", arg_count, 573 "gradients but returned", len(flat_grads), "instead.") 574 return flat_grads + variable_grads 575 576 tape_lib.record_operation(f.__name__, flat_result, recorded_inputs, 577 actual_grad_fn) 578 flat_result = list(flat_result) 579 return nest.pack_sequence_as(result, flat_result) 580 581 582@tf_export("recompute_grad") 583def recompute_grad(f): 584 """An eager-compatible version of recompute_grad. 585 586 For f(*args, **kwargs), this supports gradients with respect to args or 587 kwargs, but kwargs are currently only supported in eager-mode. 588 Note that for keras layer and model objects, this is handled automatically. 589 590 Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not 591 be able to access the member variables of that object, because `g` returns 592 through the wrapper function `inner`. When recomputing gradients through 593 objects that inherit from keras, we suggest keeping a reference to the 594 underlying object around for the purpose of accessing these variables. 595 596 Args: 597 f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. 598 599 Returns: 600 A function `g` that wraps `f`, but which recomputes `f` on the backwards 601 pass of a gradient call. 602 """ 603 # TODO(cdfreeman) Add is_recomputing functionality from graph mode version 604 605 @custom_gradient 606 def inner(*args, **kwargs): 607 """Inner function closure for calculating gradients.""" 608 current_var_scope = variable_scope.get_variable_scope() 609 with tape_lib.stop_recording(): 610 result = f(*args, **kwargs) 611 612 def grad_wrapper(*wrapper_args, **grad_kwargs): 613 """Wrapper function to accomodate lack of kwargs in graph mode decorator.""" 614 615 @custom_gradient 616 def inner_recompute_grad(*dresult): 617 """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" 618 # Gradient calculation for reverse mode autodiff. 619 variables = grad_kwargs.get("variables") 620 with backprop.GradientTape() as t: 621 id_args = nest.map_structure(gen_array_ops.identity, args) 622 # Tuple `dresult` should contain at least one tensor. 623 assert len(dresult) >= 1 624 625 if not context.executing_eagerly(): 626 # XLA doesn't respect `tf.control_dependencies`. The code block 627 # below manually adds a data dependency to `dresult` to ensure 628 # recomputation of `f(*args, **kwargs)` happens after `dresult`. 629 630 # This works even if `dresult[0]` is a size 0 tensor as reduce_max 631 # of a size 0 tensor returns -inf. Use reshape here to avoid reading 632 # the entire `dresult[0]`. 633 elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1]) 634 # Cast elem to bool in case elem is NaN. 635 elem_bool = math_ops.cast(elem, dtypes.bool) 636 dresult_dep = array_ops.where_v2( 637 elem_bool == elem_bool, 0., float("nan")) # pylint: disable=comparison-with-itself 638 id_args = nest.map_structure( 639 lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args) 640 641 t.watch(id_args) 642 if variables is not None: 643 t.watch(variables) 644 with variable_scope.variable_scope(current_var_scope): 645 recomputed_result = f(*id_args, **kwargs) 646 kw_vars = [] 647 if variables is not None: 648 kw_vars = list(variables) 649 grads = t.gradient( 650 recomputed_result, 651 list(id_args) + kw_vars, 652 output_gradients=dresult, 653 unconnected_gradients=UnconnectedGradients.ZERO) 654 655 def transpose(*t_args, **t_kwargs): 656 """Gradient function calculation for forward mode autodiff.""" 657 # Just throw an error since gradients / activations are not stored on 658 # tape for recompute. 659 raise NotImplementedError( 660 "recompute_grad tried to transpose grad of {}. " 661 "Consider not using recompute_grad in forward mode" 662 "autodiff".format(f.__name__)) 663 664 return (grads[:len(id_args)], grads[len(id_args):]), transpose 665 666 return inner_recompute_grad(*wrapper_args) 667 668 return result, grad_wrapper 669 670 return inner 671 672 673@tf_export("grad_pass_through") 674def grad_pass_through(f): 675 """Creates a grad-pass-through op with the forward behavior provided in f. 676 677 Use this function to wrap any op, maintaining its behavior in the forward 678 pass, but replacing the original op in the backward graph with an identity. 679 For example: 680 681 ```python 682 x = tf.Variable(1.0, name="x") 683 z = tf.Variable(3.0, name="z") 684 685 with tf.GradientTape() as tape: 686 # y will evaluate to 9.0 687 y = tf.grad_pass_through(x.assign)(z**2) 688 # grads will evaluate to 6.0 689 grads = tape.gradient(y, z) 690 ``` 691 692 Another example is a 'differentiable' moving average approximation, where 693 gradients are allowed to flow into the last value fed to the moving average, 694 but the moving average is still used for the forward pass: 695 696 ```python 697 x = ... # Some scalar value 698 # A moving average object, we don't need to know how this is implemented 699 moving_average = MovingAverage() 700 with backprop.GradientTape() as tape: 701 # mavg_x will evaluate to the current running average value 702 mavg_x = tf.grad_pass_through(moving_average)(x) 703 grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 704 ``` 705 706 Args: 707 f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` 708 outputs. 709 710 Returns: 711 A function `h(x)` which returns the same values as `f(x)` and whose 712 gradients are the same as those of an identity function. 713 """ 714 @custom_gradient 715 def _grad_pass_through_op(*args, **kwargs): 716 def grad(*args, **kwargs): 717 variables = kwargs.get("variables") 718 if variables is not None: 719 # Variables involved in the wrapped op will not receive gradients. 720 return args, [None] * len(variables) 721 return args 722 return f(*args, **kwargs), grad 723 return tf_decorator.make_decorator(f, _grad_pass_through_op) 724