• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
16
17from tensorflow.python.eager import backprop
18from tensorflow.python.eager import context
19from tensorflow.python.eager import tape as tape_lib
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import gen_array_ops
24from tensorflow.python.ops import handle_data_util
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import op_selector
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import nest
32from tensorflow.python.util import tf_decorator
33from tensorflow.python.util import tf_inspect
34from tensorflow.python.util.tf_export import tf_export
35
36
37VAR_OP_TYPES = [
38    "VariableV2",
39    "VarHandleOp",
40]
41
42
43@tf_export("custom_gradient")
44def custom_gradient(f=None):
45  """Decorator to define a function with a custom gradient.
46
47  This decorator allows fine grained control over the gradients of a sequence
48  for operations.  This may be useful for multiple reasons, including providing
49  a more efficient or numerically stable gradient for a sequence of operations.
50
51  For example, consider the following function that commonly occurs in the
52  computation of cross entropy and log likelihoods:
53
54  ```python
55  def log1pexp(x):
56    return tf.math.log(1 + tf.exp(x))
57  ```
58
59  Due to numerical instability, the gradient of this function evaluated at x=100
60  is NaN.  For example:
61
62  ```python
63  x = tf.constant(100.)
64  y = log1pexp(x)
65  dy_dx = tf.gradients(y, x) # Will be NaN when evaluated.
66  ```
67
68  The gradient expression can be analytically simplified to provide numerical
69  stability:
70
71  ```python
72  @tf.custom_gradient
73  def log1pexp(x):
74    e = tf.exp(x)
75    def grad(upstream):
76      return upstream * (1 - 1 / (1 + e))
77    return tf.math.log(1 + e), grad
78  ```
79
80  With this definition, the gradient `dy_dx` at `x = 100` will be correctly
81  evaluated as 1.0.
82
83  The variable `upstream` is defined as the upstream gradient. i.e. the gradient
84  from all the layers or functions originating from this layer. The above
85  example has no upstream functions, therefore `upstream = dy/dy = 1.0`.
86
87  Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`,
88  `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By
89  chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... *
90  dx_i/dx_i-1 * ... * dx_1/dx_0`.
91
92  In this case the gradient of our current function defined as
93  `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be
94  `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient
95  multiplied by the current gradient is then passed downstream.
96
97  In case the function takes multiple variables as input, the `grad`
98  function must also return  the same number of variables.
99  We take the function `z = x * y` as an example.
100
101  >>> @tf.custom_gradient
102  ... def bar(x, y):
103  ...   def grad(upstream):
104  ...     dz_dx = y
105  ...     dz_dy = x
106  ...     return upstream * dz_dx, upstream * dz_dy
107  ...   z = x * y
108  ...   return z, grad
109  >>> x = tf.constant(2.0, dtype=tf.float32)
110  >>> y = tf.constant(3.0, dtype=tf.float32)
111  >>> with tf.GradientTape(persistent=True) as tape:
112  ...   tape.watch(x)
113  ...   tape.watch(y)
114  ...   z = bar(x, y)
115  >>> z
116  <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
117  >>> tape.gradient(z, x)
118  <tf.Tensor: shape=(), dtype=float32, numpy=3.0>
119  >>> tape.gradient(z, y)
120  <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
121
122  Nesting custom gradients can lead to unintuitive results. The default
123  behavior does not correspond to n-th order derivatives. For example
124
125  ```python
126  @tf.custom_gradient
127  def op(x):
128    y = op1(x)
129    @tf.custom_gradient
130    def grad_fn(dy):
131      gdy = op2(x, y, dy)
132      def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
133        return op3(x, y, dy, ddy)
134      return gdy, grad_grad_fn
135    return y, grad_fn
136  ```
137
138  The function `grad_grad_fn` will be calculating the first order gradient
139  of `grad_fn` with respect to `dy`, which is used to generate forward-mode
140  gradient graphs from backward-mode gradient graphs, but is not the same as
141  the second order gradient of `op` with respect to `x`.
142
143  Instead, wrap nested `@tf.custom_gradients` in another function:
144
145  ```python
146  @tf.custom_gradient
147  def op_with_fused_backprop(x):
148    y, x_grad = fused_op(x)
149    def first_order_gradient(dy):
150      @tf.custom_gradient
151      def first_order_custom(unused_x):
152        def second_order_and_transpose(ddy):
153          return second_order_for_x(...), gradient_wrt_dy(...)
154        return x_grad, second_order_and_transpose
155      return dy * first_order_custom(x)
156    return y, first_order_gradient
157  ```
158
159  Additional arguments to the inner `@tf.custom_gradient`-decorated function
160  control the expected return values of the innermost function.
161
162  The examples above illustrate how to specify custom gradients for functions
163  which do not read from variables. The following example uses variables, which
164  require special handling because they are effectively inputs of the forward
165  function.
166
167  >>> weights = tf.Variable(tf.ones([2]))  # Trainable variable weights
168  >>> @tf.custom_gradient
169  ... def linear_poly(x):
170  ...   # Creating polynomial
171  ...   poly = weights[1] * x + weights[0]
172  ...
173  ...   def grad_fn(dpoly, variables):
174  ...     # dy/dx = weights[1] and we need to left multiply dpoly
175  ...     grad_xs = dpoly * weights[1]  # Scalar gradient
176  ...
177  ...     grad_vars = []  # To store gradients of passed variables
178  ...     assert variables is not None
179  ...     assert len(variables) == 1
180  ...     assert variables[0] is weights
181  ...     # Manually computing dy/dweights
182  ...     dy_dw = dpoly * tf.stack([x ** 1, x ** 0])
183  ...     grad_vars.append(
184  ...         tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1)
185  ...     )
186  ...     return grad_xs, grad_vars
187  ...   return poly, grad_fn
188  >>> x = tf.constant([1., 2., 3.])
189  >>> with tf.GradientTape(persistent=True) as tape:
190  ...   tape.watch(x)
191  ...   poly = linear_poly(x)
192  >>> poly # poly = x + 1
193  <tf.Tensor: shape=(3,),
194    dtype=float32,
195    numpy=array([2., 3., 4.], dtype=float32)>
196  >>> tape.gradient(poly, x)  # conventional scalar gradient dy/dx
197  <tf.Tensor: shape=(3,),
198    dtype=float32,
199    numpy=array([1., 1., 1.], dtype=float32)>
200  >>> tape.gradient(poly, weights)
201  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)>
202
203  Above example illustrates usage of trainable variable `weights`.
204  In the example, the inner `grad_fn` accepts an extra `variables` input
205  parameter and also returns an extra `grad_vars` output. That extra argument
206  is passed if the forward function reads any variables. You need to
207  compute the gradient w.r.t. each of those `variables` and output it as a list
208  of `grad_vars`. Note here that default value of `variables` is set to `None`
209  when no variables are used in the forward function.
210
211  It should be noted `tf.GradientTape` is still watching the forward pass of a
212  `tf.custom_gradient`, and will use the ops it watches. As a consequence,
213  calling `tf.function` while the tape is still watching leads
214  to a gradient graph being built. If an op is used in `tf.function` without
215  registered gradient, a `LookupError` will be raised.
216
217  Users can insert `tf.stop_gradient` to customize this behavior. This
218  is demonstrated in the example below. `tf.random.shuffle` does not have a
219  registered gradient. As a result `tf.stop_gradient` is used to avoid the
220  `LookupError`.
221
222  ```python
223  x = tf.constant([0.3, 0.5], dtype=tf.float32)
224
225  @tf.custom_gradient
226  def test_func_with_stop_grad(x):
227    @tf.function
228    def _inner_func():
229      # Avoid exception during the forward pass
230      return tf.stop_gradient(tf.random.shuffle(x))
231      # return tf.random.shuffle(x)  # This will raise
232
233    res = _inner_func()
234    def grad(upstream):
235      return upstream  # Arbitrarily defined custom gradient
236    return res, grad
237
238  with tf.GradientTape() as g:
239    g.watch(x)
240    res = test_func_with_stop_grad(x)
241
242  g.gradient(res, x)
243  ```
244
245  See also `tf.RegisterGradient` which registers a gradient function for a
246  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
247  for fine grained control over the gradient computation of a sequence of
248  operations.
249
250  Note that if the decorated function uses `Variable`s, the enclosing variable
251  scope must be using
252  [ResourceVariables](https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables).
253
254  Args:
255    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
256       - `x` is a sequence of (nested structures of) `Tensor` inputs to the
257         function.
258       - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
259         operations in `f` to `x`.
260       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
261         a list of `Tensor`s the same size as (flattened) `x` - the derivatives
262         of `Tensor`s in `y` with respect to the `Tensor`s in `x`.  `grad_ys` is
263         a sequence of `Tensor`s the same size as (flattened) `y` holding the
264         initial value gradients for each `Tensor` in `y`.
265
266         In a pure mathematical sense, a vector-argument vector-valued function
267         `f`'s derivatives should be its Jacobian matrix `J`. Here we are
268         expressing the Jacobian `J` as a function `grad_fn` which defines how
269         `J` will transform a vector `grad_ys` when left-multiplied with it
270         (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
271         representation of a matrix is convenient to use for chain-rule
272         calculation (in e.g. the back-propagation algorithm).
273
274         If `f` uses `Variable`s (that are not part of the
275         inputs), i.e. through `get_variable`, then `grad_fn` should have
276         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
277         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
278         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
279         with the derivatives of `Tensor`s in `y` with respect to the variables
280         (that is, grad_vars has one Tensor per variable in variables).
281
282  Returns:
283    A function `h(x)` which returns the same value as `f(x)[0]` and whose
284    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
285  """
286
287  if f is None:
288    return lambda f: custom_gradient(f=f)
289
290  @Bind.decorator
291  def decorated(wrapped, args, kwargs):
292    """Decorated function with custom gradient."""
293    if context.executing_eagerly():
294      return _eager_mode_decorator(wrapped, args, kwargs)
295    else:
296      return _graph_mode_decorator(wrapped, args, kwargs)
297
298  return tf_decorator.make_decorator(f, decorated(f))  # pylint: disable=no-value-for-parameter
299
300
301class Bind:
302  """When called evaluates `d(f, args, kwargs)` but supports binding `f`.
303
304  >>> @Bind.decorator
305  ... def my_decorator(f, args, kwargs):
306  ...   print("my_decorator called with", args, kwargs)
307  ...   return f(*args, **kwargs)
308
309  >>> class Foo:
310  ...   @my_decorator
311  ...   def bar(self, a, b, c):
312  ...     return a * b * c
313
314  >>> Foo.bar(None, 1, 2, c=3)
315  my_decorator called with (None, 1, 2) {'c': 3}
316  6
317
318  >>> foo = Foo()
319  >>> foo.bar(1, 2, c=3)
320  my_decorator called with (1, 2) {'c': 3}
321  6
322  """
323
324  @classmethod
325  def decorator(cls, d):
326    return lambda f: Bind(f, d)
327
328  def __init__(self, f, d):
329    self._f = f
330    self._d = d
331
332  def __get__(self, instance, owner):
333    if instance is not None:
334      f = self._f.__get__(instance, owner)
335      return tf_decorator.make_decorator(f, Bind(f, self._d))
336    else:
337      return self
338
339  def __call__(self, *a, **k):
340    return self._d(self._f, a, k)
341
342
343def get_variable_by_name(var_name):
344  """Given a variable name, retrieves a handle on the tensorflow Variable."""
345  global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
346
347  def _filter_fn(item):
348    try:
349      return var_name == item.op.name
350    except AttributeError:
351      # Collection items without operation are ignored.
352      return False
353
354  candidate_vars = list(filter(_filter_fn, global_vars))
355
356  if len(candidate_vars) >= 1:
357    # Filter out non-trainable variables.
358    candidate_vars = [v for v in candidate_vars if v.trainable]
359  else:
360    raise ValueError("Unsuccessful at finding variable {}.".format(var_name))
361
362  if len(candidate_vars) == 1:
363    return candidate_vars[0]
364  elif len(candidate_vars) > 1:
365    raise ValueError(
366        "Unsuccessful at finding trainable variable {}. "
367        "Number of candidates: {}. "
368        "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars))
369  else:
370    # The variable is not trainable.
371    return None
372
373
374def _get_dependent_variables(input_ops, output_ops):
375  """Finds variables involved in the subgraph between input_ops and output_ops.
376
377  Args:
378    input_ops: Flattened list of input ops
379    output_ops: Flattened list of output ops
380
381  Returns:
382    A list of variables
383  """
384
385  # avoids the edge-case when input_ops == output_ops.
386  output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
387  inbetween_ops = op_selector.get_backward_walk_ops(
388      seed_ops=output_ops,
389      stop_at_ts=input_ops,
390      inclusive=False,
391      only_differentiable=True)
392  var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
393  var_names = (op.name for op in var_ops)
394  tf_vars = (get_variable_by_name(var_name) for var_name in var_names)
395  tf_vars = [v for v in tf_vars if v is not None]
396  return tf_vars
397
398
399def generate_name():
400  return "CustomGradient-%s" % ops.uid()
401
402
403def _graph_mode_decorator(f, args, kwargs):
404  """Implement custom gradient decorator for graph mode."""
405  # TODO(rsepassi): Add support for kwargs
406  if kwargs:
407    raise ValueError(
408        "The custom_gradient decorator currently supports keywords "
409        "arguments only when eager execution is enabled.")
410  name = generate_name()
411  args = nest.map_structure(ops.convert_to_tensor, args)
412
413  # Checking global and local variables attempts to ensure that no non-resource
414  # Variables are added to the graph.
415  current_var_scope = variable_scope.get_variable_scope()
416  before_vars = set([
417      v.ref() for v in current_var_scope.global_variables() +
418      current_var_scope.local_variables()
419  ])
420  with tape_lib.VariableWatcher() as variable_watcher:
421    result, grad_fn = f(*args)
422
423  args = nest.flatten(args)
424  flat_result = nest.flatten(result)
425  flat_result_len = len(flat_result)
426
427  after_vars = set([
428      v.ref() for v in current_var_scope.global_variables() +
429      current_var_scope.local_variables()
430  ])
431  new_vars = after_vars - before_vars
432  new_vars_list = [v.deref() for v in new_vars]
433  for v in new_vars_list:
434    if not resource_variable_ops.is_resource_variable(v):
435      raise TypeError(
436          "All variables used by a function wrapped with @custom_gradient must "
437          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
438          "with `use_resource=False`.")
439
440  # The variables that grad_fn needs to return gradients for are the set of
441  # variables used that are *not* part of the inputs.
442  variables_in_tape = frozenset([
443      v.ref() for v in variable_watcher.watched_variables()
444  ])
445
446  graphs = {getattr(o, "graph", None) for o in flat_result}
447  # Not all results may be tensors. However, we want to ensure all tensor
448  # outputs are from the same graph and get a list of captured inputs for
449  # variable search
450  graphs.discard(None)  # Discard non-graph outputs
451  if graphs:
452    if len(graphs) > 1:
453      raise ValueError(
454          "All custom_gradient outputs should be from the same graph")
455    output_graph = graphs.pop()
456    filtered_input_tensors = []
457    for i in args:
458      if i.graph == output_graph:
459        filtered_input_tensors.append(i)
460  else:
461    filtered_input_tensors = args
462
463  variables_in_subgraph = frozenset([
464      v.ref() for v in _get_dependent_variables(
465          input_ops=filtered_input_tensors, output_ops=flat_result)
466  ])
467  variables = sorted(
468      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)],
469      key=lambda v: v.name)
470
471  grad_argspec = tf_inspect.getfullargspec(grad_fn)
472  variables_in_signature = ("variables" in grad_argspec.args or
473                            "variables" in grad_argspec.kwonlyargs or
474                            grad_argspec.varkw)
475  if variables and not variables_in_signature:
476    raise TypeError(
477        "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
478        "since function uses variables: {}".format(variables))
479  if variables_in_signature and not variables:
480    # User seems to intend to use variables but none were captured.
481    logging.vlog(
482        1, "@custom_gradient grad_fn has 'variables' in signature, "
483        "but no ResourceVariables were used on the forward pass.")
484
485  all_tensors = flat_result + args + variables
486
487  def tape_grad_fn(*result_grads):
488    """Custom grad fn wrapper."""
489    result_grads = result_grads[:flat_result_len]
490    if variables:
491      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
492      if len(variable_grads) != len(variables):
493        raise ValueError("Must return gradient for each variable from "
494                         "@custom_gradient grad_fn.")
495    else:
496      input_grads = grad_fn(*result_grads)
497      variable_grads = []
498
499    # Need to return one value per input to the IdentityN, so pad the
500    # gradients of the inputs of the custom_gradient function with the
501    # gradients of the outputs as well.
502    input_grads = nest.flatten(input_grads)
503    return ([None] * flat_result_len) + input_grads + variable_grads
504
505  @ops.RegisterGradient(name)
506  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
507    """Custom grad fn wrapper."""
508    return tape_grad_fn(*result_grads)
509
510  original_tensors = all_tensors
511  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
512    all_tensors = array_ops.identity_n(all_tensors)
513
514  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
515
516  # Propagate handle data for happier shape inference for resource variables.
517  for i, t in enumerate(original_tensors):
518    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
519      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
520  tape_lib.record_operation(
521      f.__name__, all_tensors, original_tensors, tape_grad_fn)
522  for ot, t in zip(original_tensors, all_tensors):
523    handle_data_util.copy_handle_data(ot, t)
524  return nest.pack_sequence_as(
525      structure=result, flat_sequence=all_tensors[:flat_result_len])
526
527
528def _eager_mode_decorator(f, args, kwargs):
529  """Implement custom gradient decorator for eager mode."""
530  with tape_lib.VariableWatcher() as variable_watcher:
531    result, grad_fn = f(*args, **kwargs)
532  args = nest.flatten(args)
533  all_inputs = list(args) + list(kwargs.values())
534  # The variables that grad_fn needs to return gradients for are the set of
535  # variables used that are *not* part of the inputs.
536  variables = [
537      v.deref()  # pylint: disable=g-complex-comprehension
538      for v in set(v.ref() for v in variable_watcher.watched_variables())
539      if all(v.deref() is not i for i in all_inputs)
540  ]
541  grad_argspec = tf_inspect.getfullargspec(grad_fn)
542  if (variables and ("variables" not in grad_argspec.args) and
543      ("variables" not in grad_argspec.kwonlyargs) and
544      not grad_argspec.varkw):
545    raise TypeError(
546        "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
547        "since function uses variables: {}".format(variables))
548  flat_result = nest.flatten(result)
549  # TODO(apassos) consider removing the identity below.
550  flat_result = [gen_array_ops.identity(x) for x in flat_result]
551
552  input_tensors = [ops.convert_to_tensor(x) for x
553                   in list(args) + list(variables)]
554
555  recorded_inputs = input_tensors
556  arg_count = len(args)
557
558  def actual_grad_fn(*result_grads):
559    """Custom grad fn wrapper."""
560    if variables:
561      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
562      if len(variable_grads) != len(variables):
563        raise ValueError("Must return gradient for each variable from "
564                         "@custom_gradient grad_fn.")
565    else:
566      input_grads = grad_fn(*result_grads)
567      variable_grads = []
568    flat_grads = nest.flatten(input_grads)
569    if len(flat_grads) != arg_count:
570      raise ValueError(
571          "custom_gradient function expected to return", arg_count,
572          "gradients but returned", len(flat_grads), "instead.")
573    return flat_grads + variable_grads
574
575  tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
576                            actual_grad_fn)
577  flat_result = list(flat_result)
578  return nest.pack_sequence_as(result, flat_result)
579
580
581@tf_export("recompute_grad")
582def recompute_grad(f):
583  """Defines a function as a recompute-checkpoint for the tape auto-diff.
584
585  Tape checkpointing is a technique to reduce the memory consumption of the
586  auto-diff tape:
587
588  - Without tape checkpointing operations and intermediate values are
589  recorded to the tape for use in the backward pass.
590
591  - With tape checkpointing, only the function call and its inputs are
592  recorded. During back-propagation the `recompute_grad` custom gradient
593  (`tf.custom_gradient`) recomputes the function under a localized Tape object.
594  This recomputation of the function during backpropagation performs redundant
595  calculation, but reduces the overall memory usage of the Tape.
596
597  >>> y = tf.Variable(1.0)
598
599  >>> def my_function(x):
600  ...   tf.print('running')
601  ...   z = x*y
602  ...   return z
603
604  >>> my_function_recompute = tf.recompute_grad(my_function)
605
606  >>> with tf.GradientTape() as tape:
607  ...   r = tf.constant(1.0)
608  ...   for i in range(4):
609  ...     r = my_function_recompute(r)
610  running
611  running
612  running
613  running
614
615  >>> grad = tape.gradient(r, [y])
616  running
617  running
618  running
619  running
620
621  Without `recompute_grad`, the tape contains all intermitate steps, and no
622  recomputation is performed.
623
624  >>> with tf.GradientTape() as tape:
625  ...   r = tf.constant(1.0)
626  ...   for i in range(4):
627  ...     r = my_function(r)
628  running
629  running
630  running
631  running
632
633  >>> grad = tape.gradient(r, [y])
634
635
636  If `f` was a `tf.keras` `Model` or `Layer` object, methods and attributes
637  such as `f.variables` are not available on the returned function `g`.
638  Either keep a reference of `f` , or use `g.__wrapped__` for accessing
639  these variables and methods.
640
641
642  >>> def print_running_and_return(x):
643  ...   tf.print("running")
644  ...   return x
645
646  >>> model = tf.keras.Sequential([
647  ...   tf.keras.layers.Lambda(print_running_and_return),
648  ...   tf.keras.layers.Dense(2)
649  ... ])
650
651  >>> model_recompute = tf.recompute_grad(model)
652
653  >>> with tf.GradientTape(persistent=True) as tape:
654  ...   r = tf.constant([[1,2]])
655  ...   for i in range(4):
656  ...     r = model_recompute(r)
657  running
658  running
659  running
660  running
661
662  >>> grad = tape.gradient(r, model.variables)
663  running
664  running
665  running
666  running
667
668  Alternatively, use the `__wrapped__` attribute to access the original
669  model object.
670
671  >>> grad = tape.gradient(r, model_recompute.__wrapped__.variables)
672  running
673  running
674  running
675  running
676
677
678  Args:
679    f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
680
681  Returns:
682    A function `g` wrapping `f` that defines a custom gradient, which recomputes
683    `f` on the backwards pass of a gradient call.
684  """
685  # TODO(cdfreeman) Add is_recomputing functionality from graph mode version
686
687  @custom_gradient
688  def inner(*args, **kwargs):
689    """Inner function closure for calculating gradients."""
690    current_var_scope = variable_scope.get_variable_scope()
691    with tape_lib.stop_recording():
692      result = f(*args, **kwargs)
693
694    def grad_wrapper(*wrapper_args, variables=None):
695      """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient."""
696
697      @custom_gradient
698      def inner_recompute_grad(*dresult):
699        """Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
700        # Gradient calculation for reverse mode autodiff.
701        with backprop.GradientTape() as t:
702          id_args = nest.map_structure(gen_array_ops.identity, args)
703          # Tuple `dresult` should contain at least one tensor.
704          assert len(dresult) >= 1
705
706          if not context.executing_eagerly():
707            # XLA doesn't respect `tf.control_dependencies`. The code block
708            # below manually adds a data dependency to `dresult` to ensure
709            # recomputation of `f(*args, **kwargs)` happens after `dresult`.
710
711            # This works even if `dresult[0]` is a size 0 tensor as reduce_max
712            # of a size 0 tensor returns -inf. Use reshape here to avoid reading
713            # the entire `dresult[0]`.
714            elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1])
715            # Cast elem to bool in case elem is NaN.
716            elem_bool = math_ops.cast(elem, dtypes.bool)
717            dresult_dep = array_ops.where_v2(
718                elem_bool == elem_bool, 0., float("nan"))  # pylint: disable=comparison-with-itself
719            id_args = nest.map_structure(
720                lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args)
721
722          t.watch(id_args)
723          if variables is not None:
724            t.watch(variables)
725          with variable_scope.variable_scope(current_var_scope):
726            recomputed_result = f(*id_args, **kwargs)
727        kw_vars = []
728        if variables is not None:
729          kw_vars = list(variables)
730        grads = t.gradient(
731            recomputed_result,
732            list(id_args) + kw_vars,
733            output_gradients=dresult,
734            unconnected_gradients=UnconnectedGradients.ZERO)
735
736        def transpose(*t_args, **t_kwargs):
737          """Gradient function calculation for forward mode autodiff."""
738          # Just throw an error since gradients / activations are not stored on
739          # tape for recompute.
740          raise NotImplementedError(
741              "recompute_grad tried to transpose grad of {}. "
742              "Consider not using recompute_grad in forward mode"
743              "autodiff".format(f.__name__))
744
745        return (grads[:len(id_args)], grads[len(id_args):]), transpose
746
747      return inner_recompute_grad(*wrapper_args)
748
749    return result, grad_wrapper
750
751  return tf_decorator.make_decorator(f, inner)
752
753
754@tf_export("grad_pass_through")
755def grad_pass_through(f):
756  """Creates a grad-pass-through op with the forward behavior provided in f.
757
758  Use this function to wrap any op, maintaining its behavior in the forward
759  pass, but replacing the original op in the backward graph with an identity.
760  For example:
761
762  ```python
763  x = tf.Variable(1.0, name="x")
764  z = tf.Variable(3.0, name="z")
765
766  with tf.GradientTape() as tape:
767    # y will evaluate to 9.0
768    y = tf.grad_pass_through(x.assign)(z**2)
769  # grads will evaluate to 6.0
770  grads = tape.gradient(y, z)
771  ```
772
773  Another example is a 'differentiable' moving average approximation, where
774  gradients are allowed to flow into the last value fed to the moving average,
775  but the moving average is still used for the forward pass:
776
777  ```python
778  x = ... # Some scalar value
779  # A moving average object, we don't need to know how this is implemented
780  moving_average = MovingAverage()
781  with backprop.GradientTape() as tape:
782    # mavg_x will evaluate to the current running average value
783    mavg_x = tf.grad_pass_through(moving_average)(x)
784  grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
785  ```
786
787  Args:
788    f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
789      outputs.
790
791  Returns:
792    A function `h(x)` which returns the same values as `f(x)` and whose
793    gradients are the same as those of an identity function.
794  """
795  @custom_gradient
796  def _grad_pass_through_op(*args, **kwargs):
797    def grad(*args, **kwargs):
798      variables = kwargs.get("variables")
799      if variables is not None:
800        # Variables involved in the wrapped op will not receive gradients.
801        return args, [None] * len(variables)
802      return args
803    return f(*args, **kwargs), grad
804  return tf_decorator.make_decorator(f, _grad_pass_through_op)
805