1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Decorator to overrides the gradient for a function.""" 16 17from tensorflow.python.eager import backprop 18from tensorflow.python.eager import context 19from tensorflow.python.eager import tape as tape_lib 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import gen_array_ops 24from tensorflow.python.ops import handle_data_util 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import op_selector 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import nest 32from tensorflow.python.util import tf_decorator 33from tensorflow.python.util import tf_inspect 34from tensorflow.python.util.tf_export import tf_export 35 36 37VAR_OP_TYPES = [ 38 "VariableV2", 39 "VarHandleOp", 40] 41 42 43@tf_export("custom_gradient") 44def custom_gradient(f=None): 45 """Decorator to define a function with a custom gradient. 46 47 This decorator allows fine grained control over the gradients of a sequence 48 for operations. This may be useful for multiple reasons, including providing 49 a more efficient or numerically stable gradient for a sequence of operations. 50 51 For example, consider the following function that commonly occurs in the 52 computation of cross entropy and log likelihoods: 53 54 ```python 55 def log1pexp(x): 56 return tf.math.log(1 + tf.exp(x)) 57 ``` 58 59 Due to numerical instability, the gradient of this function evaluated at x=100 60 is NaN. For example: 61 62 ```python 63 x = tf.constant(100.) 64 y = log1pexp(x) 65 dy_dx = tf.gradients(y, x) # Will be NaN when evaluated. 66 ``` 67 68 The gradient expression can be analytically simplified to provide numerical 69 stability: 70 71 ```python 72 @tf.custom_gradient 73 def log1pexp(x): 74 e = tf.exp(x) 75 def grad(upstream): 76 return upstream * (1 - 1 / (1 + e)) 77 return tf.math.log(1 + e), grad 78 ``` 79 80 With this definition, the gradient `dy_dx` at `x = 100` will be correctly 81 evaluated as 1.0. 82 83 The variable `upstream` is defined as the upstream gradient. i.e. the gradient 84 from all the layers or functions originating from this layer. The above 85 example has no upstream functions, therefore `upstream = dy/dy = 1.0`. 86 87 Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`, 88 `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By 89 chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * 90 dx_i/dx_i-1 * ... * dx_1/dx_0`. 91 92 In this case the gradient of our current function defined as 93 `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be 94 `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient 95 multiplied by the current gradient is then passed downstream. 96 97 In case the function takes multiple variables as input, the `grad` 98 function must also return the same number of variables. 99 We take the function `z = x * y` as an example. 100 101 >>> @tf.custom_gradient 102 ... def bar(x, y): 103 ... def grad(upstream): 104 ... dz_dx = y 105 ... dz_dy = x 106 ... return upstream * dz_dx, upstream * dz_dy 107 ... z = x * y 108 ... return z, grad 109 >>> x = tf.constant(2.0, dtype=tf.float32) 110 >>> y = tf.constant(3.0, dtype=tf.float32) 111 >>> with tf.GradientTape(persistent=True) as tape: 112 ... tape.watch(x) 113 ... tape.watch(y) 114 ... z = bar(x, y) 115 >>> z 116 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 117 >>> tape.gradient(z, x) 118 <tf.Tensor: shape=(), dtype=float32, numpy=3.0> 119 >>> tape.gradient(z, y) 120 <tf.Tensor: shape=(), dtype=float32, numpy=2.0> 121 122 Nesting custom gradients can lead to unintuitive results. The default 123 behavior does not correspond to n-th order derivatives. For example 124 125 ```python 126 @tf.custom_gradient 127 def op(x): 128 y = op1(x) 129 @tf.custom_gradient 130 def grad_fn(dy): 131 gdy = op2(x, y, dy) 132 def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. 133 return op3(x, y, dy, ddy) 134 return gdy, grad_grad_fn 135 return y, grad_fn 136 ``` 137 138 The function `grad_grad_fn` will be calculating the first order gradient 139 of `grad_fn` with respect to `dy`, which is used to generate forward-mode 140 gradient graphs from backward-mode gradient graphs, but is not the same as 141 the second order gradient of `op` with respect to `x`. 142 143 Instead, wrap nested `@tf.custom_gradients` in another function: 144 145 ```python 146 @tf.custom_gradient 147 def op_with_fused_backprop(x): 148 y, x_grad = fused_op(x) 149 def first_order_gradient(dy): 150 @tf.custom_gradient 151 def first_order_custom(unused_x): 152 def second_order_and_transpose(ddy): 153 return second_order_for_x(...), gradient_wrt_dy(...) 154 return x_grad, second_order_and_transpose 155 return dy * first_order_custom(x) 156 return y, first_order_gradient 157 ``` 158 159 Additional arguments to the inner `@tf.custom_gradient`-decorated function 160 control the expected return values of the innermost function. 161 162 The examples above illustrate how to specify custom gradients for functions 163 which do not read from variables. The following example uses variables, which 164 require special handling because they are effectively inputs of the forward 165 function. 166 167 >>> weights = tf.Variable(tf.ones([2])) # Trainable variable weights 168 >>> @tf.custom_gradient 169 ... def linear_poly(x): 170 ... # Creating polynomial 171 ... poly = weights[1] * x + weights[0] 172 ... 173 ... def grad_fn(dpoly, variables): 174 ... # dy/dx = weights[1] and we need to left multiply dpoly 175 ... grad_xs = dpoly * weights[1] # Scalar gradient 176 ... 177 ... grad_vars = [] # To store gradients of passed variables 178 ... assert variables is not None 179 ... assert len(variables) == 1 180 ... assert variables[0] is weights 181 ... # Manually computing dy/dweights 182 ... dy_dw = dpoly * tf.stack([x ** 1, x ** 0]) 183 ... grad_vars.append( 184 ... tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1) 185 ... ) 186 ... return grad_xs, grad_vars 187 ... return poly, grad_fn 188 >>> x = tf.constant([1., 2., 3.]) 189 >>> with tf.GradientTape(persistent=True) as tape: 190 ... tape.watch(x) 191 ... poly = linear_poly(x) 192 >>> poly # poly = x + 1 193 <tf.Tensor: shape=(3,), 194 dtype=float32, 195 numpy=array([2., 3., 4.], dtype=float32)> 196 >>> tape.gradient(poly, x) # conventional scalar gradient dy/dx 197 <tf.Tensor: shape=(3,), 198 dtype=float32, 199 numpy=array([1., 1., 1.], dtype=float32)> 200 >>> tape.gradient(poly, weights) 201 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)> 202 203 Above example illustrates usage of trainable variable `weights`. 204 In the example, the inner `grad_fn` accepts an extra `variables` input 205 parameter and also returns an extra `grad_vars` output. That extra argument 206 is passed if the forward function reads any variables. You need to 207 compute the gradient w.r.t. each of those `variables` and output it as a list 208 of `grad_vars`. Note here that default value of `variables` is set to `None` 209 when no variables are used in the forward function. 210 211 It should be noted `tf.GradientTape` is still watching the forward pass of a 212 `tf.custom_gradient`, and will use the ops it watches. As a consequence, 213 calling `tf.function` while the tape is still watching leads 214 to a gradient graph being built. If an op is used in `tf.function` without 215 registered gradient, a `LookupError` will be raised. 216 217 Users can insert `tf.stop_gradient` to customize this behavior. This 218 is demonstrated in the example below. `tf.random.shuffle` does not have a 219 registered gradient. As a result `tf.stop_gradient` is used to avoid the 220 `LookupError`. 221 222 ```python 223 x = tf.constant([0.3, 0.5], dtype=tf.float32) 224 225 @tf.custom_gradient 226 def test_func_with_stop_grad(x): 227 @tf.function 228 def _inner_func(): 229 # Avoid exception during the forward pass 230 return tf.stop_gradient(tf.random.shuffle(x)) 231 # return tf.random.shuffle(x) # This will raise 232 233 res = _inner_func() 234 def grad(upstream): 235 return upstream # Arbitrarily defined custom gradient 236 return res, grad 237 238 with tf.GradientTape() as g: 239 g.watch(x) 240 res = test_func_with_stop_grad(x) 241 242 g.gradient(res, x) 243 ``` 244 245 See also `tf.RegisterGradient` which registers a gradient function for a 246 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows 247 for fine grained control over the gradient computation of a sequence of 248 operations. 249 250 Note that if the decorated function uses `Variable`s, the enclosing variable 251 scope must be using 252 [ResourceVariables](https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables). 253 254 Args: 255 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: 256 - `x` is a sequence of (nested structures of) `Tensor` inputs to the 257 function. 258 - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow 259 operations in `f` to `x`. 260 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns 261 a list of `Tensor`s the same size as (flattened) `x` - the derivatives 262 of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is 263 a sequence of `Tensor`s the same size as (flattened) `y` holding the 264 initial value gradients for each `Tensor` in `y`. 265 266 In a pure mathematical sense, a vector-argument vector-valued function 267 `f`'s derivatives should be its Jacobian matrix `J`. Here we are 268 expressing the Jacobian `J` as a function `grad_fn` which defines how 269 `J` will transform a vector `grad_ys` when left-multiplied with it 270 (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional 271 representation of a matrix is convenient to use for chain-rule 272 calculation (in e.g. the back-propagation algorithm). 273 274 If `f` uses `Variable`s (that are not part of the 275 inputs), i.e. through `get_variable`, then `grad_fn` should have 276 signature `g(*grad_ys, variables=None)`, where `variables` is a list of 277 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where 278 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` 279 with the derivatives of `Tensor`s in `y` with respect to the variables 280 (that is, grad_vars has one Tensor per variable in variables). 281 282 Returns: 283 A function `h(x)` which returns the same value as `f(x)[0]` and whose 284 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. 285 """ 286 287 if f is None: 288 return lambda f: custom_gradient(f=f) 289 290 @Bind.decorator 291 def decorated(wrapped, args, kwargs): 292 """Decorated function with custom gradient.""" 293 if context.executing_eagerly(): 294 return _eager_mode_decorator(wrapped, args, kwargs) 295 else: 296 return _graph_mode_decorator(wrapped, args, kwargs) 297 298 return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter 299 300 301class Bind: 302 """When called evaluates `d(f, args, kwargs)` but supports binding `f`. 303 304 >>> @Bind.decorator 305 ... def my_decorator(f, args, kwargs): 306 ... print("my_decorator called with", args, kwargs) 307 ... return f(*args, **kwargs) 308 309 >>> class Foo: 310 ... @my_decorator 311 ... def bar(self, a, b, c): 312 ... return a * b * c 313 314 >>> Foo.bar(None, 1, 2, c=3) 315 my_decorator called with (None, 1, 2) {'c': 3} 316 6 317 318 >>> foo = Foo() 319 >>> foo.bar(1, 2, c=3) 320 my_decorator called with (1, 2) {'c': 3} 321 6 322 """ 323 324 @classmethod 325 def decorator(cls, d): 326 return lambda f: Bind(f, d) 327 328 def __init__(self, f, d): 329 self._f = f 330 self._d = d 331 332 def __get__(self, instance, owner): 333 if instance is not None: 334 f = self._f.__get__(instance, owner) 335 return tf_decorator.make_decorator(f, Bind(f, self._d)) 336 else: 337 return self 338 339 def __call__(self, *a, **k): 340 return self._d(self._f, a, k) 341 342 343def get_variable_by_name(var_name): 344 """Given a variable name, retrieves a handle on the tensorflow Variable.""" 345 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 346 347 def _filter_fn(item): 348 try: 349 return var_name == item.op.name 350 except AttributeError: 351 # Collection items without operation are ignored. 352 return False 353 354 candidate_vars = list(filter(_filter_fn, global_vars)) 355 356 if len(candidate_vars) >= 1: 357 # Filter out non-trainable variables. 358 candidate_vars = [v for v in candidate_vars if v.trainable] 359 else: 360 raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) 361 362 if len(candidate_vars) == 1: 363 return candidate_vars[0] 364 elif len(candidate_vars) > 1: 365 raise ValueError( 366 "Unsuccessful at finding trainable variable {}. " 367 "Number of candidates: {}. " 368 "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) 369 else: 370 # The variable is not trainable. 371 return None 372 373 374def _get_dependent_variables(input_ops, output_ops): 375 """Finds variables involved in the subgraph between input_ops and output_ops. 376 377 Args: 378 input_ops: Flattened list of input ops 379 output_ops: Flattened list of output ops 380 381 Returns: 382 A list of variables 383 """ 384 385 # avoids the edge-case when input_ops == output_ops. 386 output_ops = nest.map_structure(gen_array_ops.identity, output_ops) 387 inbetween_ops = op_selector.get_backward_walk_ops( 388 seed_ops=output_ops, 389 stop_at_ts=input_ops, 390 inclusive=False, 391 only_differentiable=True) 392 var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) 393 var_names = (op.name for op in var_ops) 394 tf_vars = (get_variable_by_name(var_name) for var_name in var_names) 395 tf_vars = [v for v in tf_vars if v is not None] 396 return tf_vars 397 398 399def generate_name(): 400 return "CustomGradient-%s" % ops.uid() 401 402 403def _graph_mode_decorator(f, args, kwargs): 404 """Implement custom gradient decorator for graph mode.""" 405 # TODO(rsepassi): Add support for kwargs 406 if kwargs: 407 raise ValueError( 408 "The custom_gradient decorator currently supports keywords " 409 "arguments only when eager execution is enabled.") 410 name = generate_name() 411 args = nest.map_structure(ops.convert_to_tensor, args) 412 413 # Checking global and local variables attempts to ensure that no non-resource 414 # Variables are added to the graph. 415 current_var_scope = variable_scope.get_variable_scope() 416 before_vars = set([ 417 v.ref() for v in current_var_scope.global_variables() + 418 current_var_scope.local_variables() 419 ]) 420 with tape_lib.VariableWatcher() as variable_watcher: 421 result, grad_fn = f(*args) 422 423 args = nest.flatten(args) 424 flat_result = nest.flatten(result) 425 flat_result_len = len(flat_result) 426 427 after_vars = set([ 428 v.ref() for v in current_var_scope.global_variables() + 429 current_var_scope.local_variables() 430 ]) 431 new_vars = after_vars - before_vars 432 new_vars_list = [v.deref() for v in new_vars] 433 for v in new_vars_list: 434 if not resource_variable_ops.is_resource_variable(v): 435 raise TypeError( 436 "All variables used by a function wrapped with @custom_gradient must " 437 "be `ResourceVariable`s. Ensure that no `variable_scope` is created " 438 "with `use_resource=False`.") 439 440 # The variables that grad_fn needs to return gradients for are the set of 441 # variables used that are *not* part of the inputs. 442 variables_in_tape = frozenset([ 443 v.ref() for v in variable_watcher.watched_variables() 444 ]) 445 446 graphs = {getattr(o, "graph", None) for o in flat_result} 447 # Not all results may be tensors. However, we want to ensure all tensor 448 # outputs are from the same graph and get a list of captured inputs for 449 # variable search 450 graphs.discard(None) # Discard non-graph outputs 451 if graphs: 452 if len(graphs) > 1: 453 raise ValueError( 454 "All custom_gradient outputs should be from the same graph") 455 output_graph = graphs.pop() 456 filtered_input_tensors = [] 457 for i in args: 458 if i.graph == output_graph: 459 filtered_input_tensors.append(i) 460 else: 461 filtered_input_tensors = args 462 463 variables_in_subgraph = frozenset([ 464 v.ref() for v in _get_dependent_variables( 465 input_ops=filtered_input_tensors, output_ops=flat_result) 466 ]) 467 variables = sorted( 468 [v.deref() for v in variables_in_subgraph.union(variables_in_tape)], 469 key=lambda v: v.name) 470 471 grad_argspec = tf_inspect.getfullargspec(grad_fn) 472 variables_in_signature = ("variables" in grad_argspec.args or 473 "variables" in grad_argspec.kwonlyargs or 474 grad_argspec.varkw) 475 if variables and not variables_in_signature: 476 raise TypeError( 477 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 478 "since function uses variables: {}".format(variables)) 479 if variables_in_signature and not variables: 480 # User seems to intend to use variables but none were captured. 481 logging.vlog( 482 1, "@custom_gradient grad_fn has 'variables' in signature, " 483 "but no ResourceVariables were used on the forward pass.") 484 485 all_tensors = flat_result + args + variables 486 487 def tape_grad_fn(*result_grads): 488 """Custom grad fn wrapper.""" 489 result_grads = result_grads[:flat_result_len] 490 if variables: 491 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 492 if len(variable_grads) != len(variables): 493 raise ValueError("Must return gradient for each variable from " 494 "@custom_gradient grad_fn.") 495 else: 496 input_grads = grad_fn(*result_grads) 497 variable_grads = [] 498 499 # Need to return one value per input to the IdentityN, so pad the 500 # gradients of the inputs of the custom_gradient function with the 501 # gradients of the outputs as well. 502 input_grads = nest.flatten(input_grads) 503 return ([None] * flat_result_len) + input_grads + variable_grads 504 505 @ops.RegisterGradient(name) 506 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable 507 """Custom grad fn wrapper.""" 508 return tape_grad_fn(*result_grads) 509 510 original_tensors = all_tensors 511 with ops.get_default_graph().gradient_override_map({"IdentityN": name}): 512 all_tensors = array_ops.identity_n(all_tensors) 513 514 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] 515 516 # Propagate handle data for happier shape inference for resource variables. 517 for i, t in enumerate(original_tensors): 518 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): 519 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access 520 tape_lib.record_operation( 521 f.__name__, all_tensors, original_tensors, tape_grad_fn) 522 for ot, t in zip(original_tensors, all_tensors): 523 handle_data_util.copy_handle_data(ot, t) 524 return nest.pack_sequence_as( 525 structure=result, flat_sequence=all_tensors[:flat_result_len]) 526 527 528def _eager_mode_decorator(f, args, kwargs): 529 """Implement custom gradient decorator for eager mode.""" 530 with tape_lib.VariableWatcher() as variable_watcher: 531 result, grad_fn = f(*args, **kwargs) 532 args = nest.flatten(args) 533 all_inputs = list(args) + list(kwargs.values()) 534 # The variables that grad_fn needs to return gradients for are the set of 535 # variables used that are *not* part of the inputs. 536 variables = [ 537 v.deref() # pylint: disable=g-complex-comprehension 538 for v in set(v.ref() for v in variable_watcher.watched_variables()) 539 if all(v.deref() is not i for i in all_inputs) 540 ] 541 grad_argspec = tf_inspect.getfullargspec(grad_fn) 542 if (variables and ("variables" not in grad_argspec.args) and 543 ("variables" not in grad_argspec.kwonlyargs) and 544 not grad_argspec.varkw): 545 raise TypeError( 546 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 547 "since function uses variables: {}".format(variables)) 548 flat_result = nest.flatten(result) 549 # TODO(apassos) consider removing the identity below. 550 flat_result = [gen_array_ops.identity(x) for x in flat_result] 551 552 input_tensors = [ops.convert_to_tensor(x) for x 553 in list(args) + list(variables)] 554 555 recorded_inputs = input_tensors 556 arg_count = len(args) 557 558 def actual_grad_fn(*result_grads): 559 """Custom grad fn wrapper.""" 560 if variables: 561 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 562 if len(variable_grads) != len(variables): 563 raise ValueError("Must return gradient for each variable from " 564 "@custom_gradient grad_fn.") 565 else: 566 input_grads = grad_fn(*result_grads) 567 variable_grads = [] 568 flat_grads = nest.flatten(input_grads) 569 if len(flat_grads) != arg_count: 570 raise ValueError( 571 "custom_gradient function expected to return", arg_count, 572 "gradients but returned", len(flat_grads), "instead.") 573 return flat_grads + variable_grads 574 575 tape_lib.record_operation(f.__name__, flat_result, recorded_inputs, 576 actual_grad_fn) 577 flat_result = list(flat_result) 578 return nest.pack_sequence_as(result, flat_result) 579 580 581@tf_export("recompute_grad") 582def recompute_grad(f): 583 """Defines a function as a recompute-checkpoint for the tape auto-diff. 584 585 Tape checkpointing is a technique to reduce the memory consumption of the 586 auto-diff tape: 587 588 - Without tape checkpointing operations and intermediate values are 589 recorded to the tape for use in the backward pass. 590 591 - With tape checkpointing, only the function call and its inputs are 592 recorded. During back-propagation the `recompute_grad` custom gradient 593 (`tf.custom_gradient`) recomputes the function under a localized Tape object. 594 This recomputation of the function during backpropagation performs redundant 595 calculation, but reduces the overall memory usage of the Tape. 596 597 >>> y = tf.Variable(1.0) 598 599 >>> def my_function(x): 600 ... tf.print('running') 601 ... z = x*y 602 ... return z 603 604 >>> my_function_recompute = tf.recompute_grad(my_function) 605 606 >>> with tf.GradientTape() as tape: 607 ... r = tf.constant(1.0) 608 ... for i in range(4): 609 ... r = my_function_recompute(r) 610 running 611 running 612 running 613 running 614 615 >>> grad = tape.gradient(r, [y]) 616 running 617 running 618 running 619 running 620 621 Without `recompute_grad`, the tape contains all intermitate steps, and no 622 recomputation is performed. 623 624 >>> with tf.GradientTape() as tape: 625 ... r = tf.constant(1.0) 626 ... for i in range(4): 627 ... r = my_function(r) 628 running 629 running 630 running 631 running 632 633 >>> grad = tape.gradient(r, [y]) 634 635 636 If `f` was a `tf.keras` `Model` or `Layer` object, methods and attributes 637 such as `f.variables` are not available on the returned function `g`. 638 Either keep a reference of `f` , or use `g.__wrapped__` for accessing 639 these variables and methods. 640 641 642 >>> def print_running_and_return(x): 643 ... tf.print("running") 644 ... return x 645 646 >>> model = tf.keras.Sequential([ 647 ... tf.keras.layers.Lambda(print_running_and_return), 648 ... tf.keras.layers.Dense(2) 649 ... ]) 650 651 >>> model_recompute = tf.recompute_grad(model) 652 653 >>> with tf.GradientTape(persistent=True) as tape: 654 ... r = tf.constant([[1,2]]) 655 ... for i in range(4): 656 ... r = model_recompute(r) 657 running 658 running 659 running 660 running 661 662 >>> grad = tape.gradient(r, model.variables) 663 running 664 running 665 running 666 running 667 668 Alternatively, use the `__wrapped__` attribute to access the original 669 model object. 670 671 >>> grad = tape.gradient(r, model_recompute.__wrapped__.variables) 672 running 673 running 674 running 675 running 676 677 678 Args: 679 f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. 680 681 Returns: 682 A function `g` wrapping `f` that defines a custom gradient, which recomputes 683 `f` on the backwards pass of a gradient call. 684 """ 685 # TODO(cdfreeman) Add is_recomputing functionality from graph mode version 686 687 @custom_gradient 688 def inner(*args, **kwargs): 689 """Inner function closure for calculating gradients.""" 690 current_var_scope = variable_scope.get_variable_scope() 691 with tape_lib.stop_recording(): 692 result = f(*args, **kwargs) 693 694 def grad_wrapper(*wrapper_args, variables=None): 695 """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient.""" 696 697 @custom_gradient 698 def inner_recompute_grad(*dresult): 699 """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" 700 # Gradient calculation for reverse mode autodiff. 701 with backprop.GradientTape() as t: 702 id_args = nest.map_structure(gen_array_ops.identity, args) 703 # Tuple `dresult` should contain at least one tensor. 704 assert len(dresult) >= 1 705 706 if not context.executing_eagerly(): 707 # XLA doesn't respect `tf.control_dependencies`. The code block 708 # below manually adds a data dependency to `dresult` to ensure 709 # recomputation of `f(*args, **kwargs)` happens after `dresult`. 710 711 # This works even if `dresult[0]` is a size 0 tensor as reduce_max 712 # of a size 0 tensor returns -inf. Use reshape here to avoid reading 713 # the entire `dresult[0]`. 714 elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1]) 715 # Cast elem to bool in case elem is NaN. 716 elem_bool = math_ops.cast(elem, dtypes.bool) 717 dresult_dep = array_ops.where_v2( 718 elem_bool == elem_bool, 0., float("nan")) # pylint: disable=comparison-with-itself 719 id_args = nest.map_structure( 720 lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args) 721 722 t.watch(id_args) 723 if variables is not None: 724 t.watch(variables) 725 with variable_scope.variable_scope(current_var_scope): 726 recomputed_result = f(*id_args, **kwargs) 727 kw_vars = [] 728 if variables is not None: 729 kw_vars = list(variables) 730 grads = t.gradient( 731 recomputed_result, 732 list(id_args) + kw_vars, 733 output_gradients=dresult, 734 unconnected_gradients=UnconnectedGradients.ZERO) 735 736 def transpose(*t_args, **t_kwargs): 737 """Gradient function calculation for forward mode autodiff.""" 738 # Just throw an error since gradients / activations are not stored on 739 # tape for recompute. 740 raise NotImplementedError( 741 "recompute_grad tried to transpose grad of {}. " 742 "Consider not using recompute_grad in forward mode" 743 "autodiff".format(f.__name__)) 744 745 return (grads[:len(id_args)], grads[len(id_args):]), transpose 746 747 return inner_recompute_grad(*wrapper_args) 748 749 return result, grad_wrapper 750 751 return tf_decorator.make_decorator(f, inner) 752 753 754@tf_export("grad_pass_through") 755def grad_pass_through(f): 756 """Creates a grad-pass-through op with the forward behavior provided in f. 757 758 Use this function to wrap any op, maintaining its behavior in the forward 759 pass, but replacing the original op in the backward graph with an identity. 760 For example: 761 762 ```python 763 x = tf.Variable(1.0, name="x") 764 z = tf.Variable(3.0, name="z") 765 766 with tf.GradientTape() as tape: 767 # y will evaluate to 9.0 768 y = tf.grad_pass_through(x.assign)(z**2) 769 # grads will evaluate to 6.0 770 grads = tape.gradient(y, z) 771 ``` 772 773 Another example is a 'differentiable' moving average approximation, where 774 gradients are allowed to flow into the last value fed to the moving average, 775 but the moving average is still used for the forward pass: 776 777 ```python 778 x = ... # Some scalar value 779 # A moving average object, we don't need to know how this is implemented 780 moving_average = MovingAverage() 781 with backprop.GradientTape() as tape: 782 # mavg_x will evaluate to the current running average value 783 mavg_x = tf.grad_pass_through(moving_average)(x) 784 grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 785 ``` 786 787 Args: 788 f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` 789 outputs. 790 791 Returns: 792 A function `h(x)` which returns the same values as `f(x)` and whose 793 gradients are the same as those of an identity function. 794 """ 795 @custom_gradient 796 def _grad_pass_through_op(*args, **kwargs): 797 def grad(*args, **kwargs): 798 variables = kwargs.get("variables") 799 if variables is not None: 800 # Variables involved in the wrapped op will not receive gradients. 801 return args, [None] * len(variables) 802 return args 803 return f(*args, **kwargs), grad 804 return tf_decorator.make_decorator(f, _grad_pass_through_op) 805