• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.eager import backprop
21from tensorflow.python.eager import context
22from tensorflow.python.eager import tape as tape_lib
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_array_ops
27from tensorflow.python.ops import handle_data_util
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import op_selector
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util import tf_inspect
37from tensorflow.python.util.tf_export import tf_export
38
39
40VAR_OP_TYPES = [
41    "VariableV2",
42    "VarHandleOp",
43]
44
45
46@tf_export("custom_gradient")
47def custom_gradient(f=None):
48  """Decorator to define a function with a custom gradient.
49
50  This decorator allows fine grained control over the gradients of a sequence
51  for operations.  This may be useful for multiple reasons, including providing
52  a more efficient or numerically stable gradient for a sequence of operations.
53
54  For example, consider the following function that commonly occurs in the
55  computation of cross entropy and log likelihoods:
56
57  ```python
58  def log1pexp(x):
59    return tf.math.log(1 + tf.exp(x))
60  ```
61
62  Due to numerical instability, the gradient of this function evaluated at x=100
63  is NaN.  For example:
64
65  ```python
66  x = tf.constant(100.)
67  y = log1pexp(x)
68  dy_dx = tf.gradients(y, x) # Will be NaN when evaluated.
69  ```
70
71  The gradient expression can be analytically simplified to provide numerical
72  stability:
73
74  ```python
75  @tf.custom_gradient
76  def log1pexp(x):
77    e = tf.exp(x)
78    def grad(upstream):
79      return upstream * (1 - 1 / (1 + e))
80    return tf.math.log(1 + e), grad
81  ```
82
83  With this definition, the gradient `dy_dx` at `x = 100` will be correctly
84  evaluated as 1.0.
85
86  The variable `upstream` is defined as the upstream gradient. i.e. the gradient
87  from all the layers or functions originating from this layer. The above
88  example has no upstream functions, therefore `upstream = dy/dy = 1.0`.
89
90  Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`,
91  `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By
92  chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... *
93  dx_i/dx_i-1 * ... * dx_1/dx_0`.
94
95  In this case the gradient of our current function defined as
96  `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be
97  `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient
98  multiplied by the current gradient is then passed downstream.
99
100  In case the function takes multiple variables as input, the `grad`
101  function must also return  the same number of variables.
102  We take the function `z = x * y` as an example.
103
104  >>> @tf.custom_gradient
105  ... def bar(x, y):
106  ...   def grad(upstream):
107  ...     dz_dx = y
108  ...     dz_dy = x
109  ...     return upstream * dz_dx, upstream * dz_dy
110  ...   z = x * y
111  ...   return z, grad
112  >>> x = tf.constant(2.0, dtype=tf.float32)
113  >>> y = tf.constant(3.0, dtype=tf.float32)
114  >>> with tf.GradientTape(persistent=True) as tape:
115  ...   tape.watch(x)
116  ...   tape.watch(y)
117  ...   z = bar(x, y)
118  >>> z
119  <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
120  >>> tape.gradient(z, x)
121  <tf.Tensor: shape=(), dtype=float32, numpy=3.0>
122  >>> tape.gradient(z, y)
123  <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
124
125  Nesting custom gradients can lead to unintuitive results. The default
126  behavior does not correspond to n-th order derivatives. For example
127
128  ```python
129  @tf.custom_gradient
130  def op(x):
131    y = op1(x)
132    @tf.custom_gradient
133    def grad_fn(dy):
134      gdy = op2(x, y, dy)
135      def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
136        return op3(x, y, dy, ddy)
137      return gdy, grad_grad_fn
138    return y, grad_fn
139  ```
140
141  The function `grad_grad_fn` will be calculating the first order gradient
142  of `grad_fn` with respect to `dy`, which is used to generate forward-mode
143  gradient graphs from backward-mode gradient graphs, but is not the same as
144  the second order gradient of `op` with respect to `x`.
145
146  Instead, wrap nested `@tf.custom_gradients` in another function:
147
148  ```python
149  @tf.custom_gradient
150  def op_with_fused_backprop(x):
151    y, x_grad = fused_op(x)
152    def first_order_gradient(dy):
153      @tf.custom_gradient
154      def first_order_custom(unused_x):
155        def second_order_and_transpose(ddy):
156          return second_order_for_x(...), gradient_wrt_dy(...)
157        return x_grad, second_order_and_transpose
158      return dy * first_order_custom(x)
159    return y, first_order_gradient
160  ```
161
162  Additional arguments to the inner `@tf.custom_gradient`-decorated function
163  control the expected return values of the innermost function.
164
165  The examples above illustrate how to specify custom gradients for functions
166  which do not read from variables. The following example uses variables, which
167  require special handling because they are effectively inputs of the forward
168  function.
169
170  >>> weights = tf.Variable(tf.ones([2]))  # Trainable variable weights
171  >>> @tf.custom_gradient
172  ... def linear_poly(x):
173  ...   # Creating polynomial
174  ...   poly = weights[1] * x + weights[0]
175  ...
176  ...   def grad_fn(dpoly, variables):
177  ...     # dy/dx = weights[1] and we need to left multiply dpoly
178  ...     grad_xs = dpoly * weights[1]  # Scalar gradient
179  ...
180  ...     grad_vars = []  # To store gradients of passed variables
181  ...     assert variables is not None
182  ...     assert len(variables) == 1
183  ...     assert variables[0] is weights
184  ...     # Manually computing dy/dweights
185  ...     dy_dw = dpoly * tf.stack([x ** 1, x ** 0])
186  ...     grad_vars.append(
187  ...         tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1)
188  ...     )
189  ...     return grad_xs, grad_vars
190  ...   return poly, grad_fn
191  >>> x = tf.constant([1., 2., 3.])
192  >>> with tf.GradientTape(persistent=True) as tape:
193  ...   tape.watch(x)
194  ...   poly = linear_poly(x)
195  >>> poly # poly = x + 1
196  <tf.Tensor: shape=(3,),
197    dtype=float32,
198    numpy=array([2., 3., 4.], dtype=float32)>
199  >>> tape.gradient(poly, x)  # conventional scalar gradient dy/dx
200  <tf.Tensor: shape=(3,),
201    dtype=float32,
202    numpy=array([1., 1., 1.], dtype=float32)>
203  >>> tape.gradient(poly, weights)
204  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)>
205
206  Above example illustrates usage of trainable variable `weights`.
207  In the example, the inner `grad_fn` accepts an extra `variables` input
208  parameter and also returns an extra `grad_vars` output. That extra argument
209  is passed if the forward function reads any variables. You need to
210  compute the gradient w.r.t. each of those `variables` and output it as a list
211  of `grad_vars`. Note here that default value of `variables` is set to `None`
212  when no variables are used in the forward function.
213
214  It should be noted `tf.GradientTape` is still watching the forward pass of a
215  `tf.custom_gradient`, and will use the ops it watches. As a consequence,
216  calling `tf.function` while the tape is still watching leads
217  to a gradient graph being built. If an op is used in `tf.function` without
218  registered gradient, a `LookupError` will be raised.
219
220  Users can insert `tf.stop_gradient` to customize this behavior. This
221  is demonstrated in the example below. `tf.random.shuffle` does not have a
222  registered gradient. As a result `tf.stop_gradient` is used to avoid the
223  `LookupError`.
224
225  ```python
226  x = tf.constant([0.3, 0.5], dtype=tf.float32)
227
228  @tf.custom_gradient
229  def test_func_with_stop_grad(x):
230    @tf.function
231    def _inner_func():
232      # Avoid exception during the forward pass
233      return tf.stop_gradient(tf.random.shuffle(x))
234      # return tf.random.shuffle(x)  # This will raise
235
236    res = _inner_func()
237    def grad(upstream):
238      return upstream  # Arbitrarily defined custom gradient
239    return res, grad
240
241  with tf.GradientTape() as g:
242    g.watch(x)
243    res = test_func_with_stop_grad(x)
244
245  g.gradient(res, x)
246  ```
247
248  See also `tf.RegisterGradient` which registers a gradient function for a
249  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
250  for fine grained control over the gradient computation of a sequence of
251  operations.
252
253  Note that if the decorated function uses `Variable`s, the enclosing variable
254  scope must be using `ResourceVariable`s.
255
256  Args:
257    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
258       - `x` is a sequence of (nested structures of) `Tensor` inputs to the
259         function.
260       - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
261         operations in `f` to `x`.
262       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
263         a list of `Tensor`s the same size as (flattened) `x` - the derivatives
264         of `Tensor`s in `y` with respect to the `Tensor`s in `x`.  `grad_ys` is
265         a sequence of `Tensor`s the same size as (flattened) `y` holding the
266         initial value gradients for each `Tensor` in `y`.
267
268         In a pure mathematical sense, a vector-argument vector-valued function
269         `f`'s derivatives should be its Jacobian matrix `J`. Here we are
270         expressing the Jacobian `J` as a function `grad_fn` which defines how
271         `J` will transform a vector `grad_ys` when left-multiplied with it
272         (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
273         representation of a matrix is convenient to use for chain-rule
274         calculation (in e.g. the back-propagation algorithm).
275
276         If `f` uses `Variable`s (that are not part of the
277         inputs), i.e. through `get_variable`, then `grad_fn` should have
278         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
279         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
280         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
281         with the derivatives of `Tensor`s in `y` with respect to the variables
282         (that is, grad_vars has one Tensor per variable in variables).
283
284  Returns:
285    A function `h(x)` which returns the same value as `f(x)[0]` and whose
286    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
287  """
288
289  if f is None:
290    return lambda f: custom_gradient(f=f)
291
292  @Bind.decorator
293  def decorated(wrapped, args, kwargs):
294    """Decorated function with custom gradient."""
295    if context.executing_eagerly():
296      return _eager_mode_decorator(wrapped, args, kwargs)
297    else:
298      return _graph_mode_decorator(wrapped, args, kwargs)
299
300  return tf_decorator.make_decorator(f, decorated(f))  # pylint: disable=no-value-for-parameter
301
302
303class Bind(object):
304  """When called evaluates `d(f, args, kwargs)` but supports binding `f`.
305
306  >>> @Bind.decorator
307  ... def my_decorator(f, args, kwargs):
308  ...   print("my_decorator called with", args, kwargs)
309  ...   return f(*args, **kwargs)
310
311  >>> class Foo(object):
312  ...   @my_decorator
313  ...   def bar(self, a, b, c):
314  ...     return a * b * c
315
316  >>> Foo.bar(None, 1, 2, c=3)
317  my_decorator called with (None, 1, 2) {'c': 3}
318  6
319
320  >>> foo = Foo()
321  >>> foo.bar(1, 2, c=3)
322  my_decorator called with (1, 2) {'c': 3}
323  6
324  """
325
326  @classmethod
327  def decorator(cls, d):
328    return lambda f: Bind(f, d)
329
330  def __init__(self, f, d):
331    self._f = f
332    self._d = d
333
334  def __get__(self, instance, owner):
335    if instance is not None:
336      f = self._f.__get__(instance, owner)
337      return tf_decorator.make_decorator(f, Bind(f, self._d))
338    else:
339      return self
340
341  def __call__(self, *a, **k):
342    return self._d(self._f, a, k)
343
344
345def get_variable_by_name(var_name):
346  """Given a variable name, retrieves a handle on the tensorflow Variable."""
347  global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
348
349  def _filter_fn(item):
350    try:
351      return var_name == item.op.name
352    except AttributeError:
353      # Collection items without operation are ignored.
354      return False
355
356  candidate_vars = list(filter(_filter_fn, global_vars))
357
358  if len(candidate_vars) >= 1:
359    # Filter out non-trainable variables.
360    candidate_vars = [v for v in candidate_vars if v.trainable]
361  else:
362    raise ValueError("Unsuccessful at finding variable {}.".format(var_name))
363
364  if len(candidate_vars) == 1:
365    return candidate_vars[0]
366  elif len(candidate_vars) > 1:
367    raise ValueError(
368        "Unsuccessful at finding trainable variable {}. "
369        "Number of candidates: {}. "
370        "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars))
371  else:
372    # The variable is not trainable.
373    return None
374
375
376def _get_dependent_variables(input_ops, output_ops):
377  """Finds variables involved in the subgraph between input_ops and output_ops.
378
379  Args:
380    input_ops: Flattened list of input ops
381    output_ops: Flattened list of output ops
382
383  Returns:
384    A list of variables
385  """
386
387  # avoids the edge-case when input_ops == output_ops.
388  output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
389  inbetween_ops = op_selector.get_backward_walk_ops(
390      seed_ops=output_ops,
391      stop_at_ts=input_ops,
392      inclusive=False,
393      only_differentiable=True)
394  var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
395  var_names = (op.name for op in var_ops)
396  tf_vars = (get_variable_by_name(var_name) for var_name in var_names)
397  tf_vars = [v for v in tf_vars if v is not None]
398  return tf_vars
399
400
401def generate_name():
402  return "CustomGradient-%s" % ops.uid()
403
404
405def _graph_mode_decorator(f, args, kwargs):
406  """Implement custom gradient decorator for graph mode."""
407  # TODO(rsepassi): Add support for kwargs
408  if kwargs:
409    raise ValueError(
410        "The custom_gradient decorator currently supports keywords "
411        "arguments only when eager execution is enabled.")
412  name = generate_name()
413  args = nest.map_structure(ops.convert_to_tensor, args)
414
415  # Checking global and local variables attempts to ensure that no non-resource
416  # Variables are added to the graph.
417  current_var_scope = variable_scope.get_variable_scope()
418  before_vars = set([
419      v.ref() for v in current_var_scope.global_variables() +
420      current_var_scope.local_variables()
421  ])
422  with tape_lib.VariableWatcher() as variable_watcher:
423    result, grad_fn = f(*args)
424
425  args = nest.flatten(args)
426  flat_result = nest.flatten(result)
427  flat_result_len = len(flat_result)
428
429  after_vars = set([
430      v.ref() for v in current_var_scope.global_variables() +
431      current_var_scope.local_variables()
432  ])
433  new_vars = after_vars - before_vars
434  new_vars_list = [v.deref() for v in new_vars]
435  for v in new_vars_list:
436    if not resource_variable_ops.is_resource_variable(v):
437      raise TypeError(
438          "All variables used by a function wrapped with @custom_gradient must "
439          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
440          "with `use_resource=False`.")
441
442  # The variables that grad_fn needs to return gradients for are the set of
443  # variables used that are *not* part of the inputs.
444  variables_in_tape = frozenset([
445      v.ref() for v in variable_watcher.watched_variables()
446  ])
447
448  graphs = {getattr(o, "graph", None) for o in flat_result}
449  # Not all results may be tensors. However, we want to ensure all tensor
450  # outputs are from the same graph and get a list of captured inputs for
451  # variable search
452  graphs.discard(None)  # Discard non-graph outputs
453  if graphs:
454    if len(graphs) > 1:
455      raise ValueError(
456          "All custom_gradient outputs should be from the same graph")
457    output_graph = graphs.pop()
458    filtered_input_tensors = []
459    for i in args:
460      if i.graph == output_graph:
461        filtered_input_tensors.append(i)
462  else:
463    filtered_input_tensors = args
464
465  variables_in_subgraph = frozenset([
466      v.ref() for v in _get_dependent_variables(
467          input_ops=filtered_input_tensors, output_ops=flat_result)
468  ])
469  variables = sorted(
470      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)],
471      key=lambda v: v.name)
472
473  grad_argspec = tf_inspect.getfullargspec(grad_fn)
474  variables_in_signature = ("variables" in grad_argspec.args or
475                            "variables" in grad_argspec.kwonlyargs or
476                            grad_argspec.varkw)
477  if variables and not variables_in_signature:
478    raise TypeError(
479        "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
480        "since function uses variables: {}".format(variables))
481  if variables_in_signature and not variables:
482    # User seems to intend to use variables but none were captured.
483    logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
484                 "no ResourceVariables were used on the forward pass.")
485
486  all_tensors = flat_result + args + variables
487
488  def tape_grad_fn(*result_grads):
489    """Custom grad fn wrapper."""
490    result_grads = result_grads[:flat_result_len]
491    if variables:
492      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
493      if len(variable_grads) != len(variables):
494        raise ValueError("Must return gradient for each variable from "
495                         "@custom_gradient grad_fn.")
496    else:
497      input_grads = grad_fn(*result_grads)
498      variable_grads = []
499
500    # Need to return one value per input to the IdentityN, so pad the
501    # gradients of the inputs of the custom_gradient function with the
502    # gradients of the outputs as well.
503    input_grads = nest.flatten(input_grads)
504    return ([None] * flat_result_len) + input_grads + variable_grads
505
506  @ops.RegisterGradient(name)
507  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
508    """Custom grad fn wrapper."""
509    return tape_grad_fn(*result_grads)
510
511  original_tensors = all_tensors
512  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
513    all_tensors = array_ops.identity_n(all_tensors)
514
515  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
516
517  # Propagate handle data for happier shape inference for resource variables.
518  for i, t in enumerate(original_tensors):
519    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
520      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
521  tape_lib.record_operation(
522      f.__name__, all_tensors, original_tensors, tape_grad_fn)
523  for ot, t in zip(original_tensors, all_tensors):
524    handle_data_util.copy_handle_data(ot, t)
525  return nest.pack_sequence_as(
526      structure=result, flat_sequence=all_tensors[:flat_result_len])
527
528
529def _eager_mode_decorator(f, args, kwargs):
530  """Implement custom gradient decorator for eager mode."""
531  with tape_lib.VariableWatcher() as variable_watcher:
532    result, grad_fn = f(*args, **kwargs)
533  args = nest.flatten(args)
534  all_inputs = list(args) + list(kwargs.values())
535  # The variables that grad_fn needs to return gradients for are the set of
536  # variables used that are *not* part of the inputs.
537  variables = [
538      v.deref()  # pylint: disable=g-complex-comprehension
539      for v in set(v.ref() for v in variable_watcher.watched_variables())
540      if all(v.deref() is not i for i in all_inputs)
541  ]
542  grad_argspec = tf_inspect.getfullargspec(grad_fn)
543  if (variables and ("variables" not in grad_argspec.args) and
544      ("variables" not in grad_argspec.kwonlyargs) and
545      not grad_argspec.varkw):
546    raise TypeError(
547        "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
548        "since function uses variables: {}".format(variables))
549  flat_result = nest.flatten(result)
550  # TODO(apassos) consider removing the identity below.
551  flat_result = [gen_array_ops.identity(x) for x in flat_result]
552
553  input_tensors = [ops.convert_to_tensor(x) for x
554                   in list(args) + list(variables)]
555
556  recorded_inputs = input_tensors
557  arg_count = len(args)
558
559  def actual_grad_fn(*result_grads):
560    """Custom grad fn wrapper."""
561    if variables:
562      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
563      if len(variable_grads) != len(variables):
564        raise ValueError("Must return gradient for each variable from "
565                         "@custom_gradient grad_fn.")
566    else:
567      input_grads = grad_fn(*result_grads)
568      variable_grads = []
569    flat_grads = nest.flatten(input_grads)
570    if len(flat_grads) != arg_count:
571      raise ValueError(
572          "custom_gradient function expected to return", arg_count,
573          "gradients but returned", len(flat_grads), "instead.")
574    return flat_grads + variable_grads
575
576  tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
577                            actual_grad_fn)
578  flat_result = list(flat_result)
579  return nest.pack_sequence_as(result, flat_result)
580
581
582@tf_export("recompute_grad")
583def recompute_grad(f):
584  """An eager-compatible version of recompute_grad.
585
586  For f(*args, **kwargs), this supports gradients with respect to args or
587  kwargs, but kwargs are currently only supported in eager-mode.
588  Note that for keras layer and model objects, this is handled automatically.
589
590  Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
591  be able to access the member variables of that object, because `g` returns
592  through the wrapper function `inner`.  When recomputing gradients through
593  objects that inherit from keras, we suggest keeping a reference to the
594  underlying object around for the purpose of accessing these variables.
595
596  Args:
597    f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
598
599  Returns:
600    A function `g` that wraps `f`, but which recomputes `f` on the backwards
601    pass of a gradient call.
602  """
603  # TODO(cdfreeman) Add is_recomputing functionality from graph mode version
604
605  @custom_gradient
606  def inner(*args, **kwargs):
607    """Inner function closure for calculating gradients."""
608    current_var_scope = variable_scope.get_variable_scope()
609    with tape_lib.stop_recording():
610      result = f(*args, **kwargs)
611
612    def grad_wrapper(*wrapper_args, **grad_kwargs):
613      """Wrapper function to accomodate lack of kwargs in graph mode decorator."""
614
615      @custom_gradient
616      def inner_recompute_grad(*dresult):
617        """Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
618        # Gradient calculation for reverse mode autodiff.
619        variables = grad_kwargs.get("variables")
620        with backprop.GradientTape() as t:
621          id_args = nest.map_structure(gen_array_ops.identity, args)
622          # Tuple `dresult` should contain at least one tensor.
623          assert len(dresult) >= 1
624
625          if not context.executing_eagerly():
626            # XLA doesn't respect `tf.control_dependencies`. The code block
627            # below manually adds a data dependency to `dresult` to ensure
628            # recomputation of `f(*args, **kwargs)` happens after `dresult`.
629
630            # This works even if `dresult[0]` is a size 0 tensor as reduce_max
631            # of a size 0 tensor returns -inf. Use reshape here to avoid reading
632            # the entire `dresult[0]`.
633            elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1])
634            # Cast elem to bool in case elem is NaN.
635            elem_bool = math_ops.cast(elem, dtypes.bool)
636            dresult_dep = array_ops.where_v2(
637                elem_bool == elem_bool, 0., float("nan"))  # pylint: disable=comparison-with-itself
638            id_args = nest.map_structure(
639                lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args)
640
641          t.watch(id_args)
642          if variables is not None:
643            t.watch(variables)
644          with variable_scope.variable_scope(current_var_scope):
645            recomputed_result = f(*id_args, **kwargs)
646        kw_vars = []
647        if variables is not None:
648          kw_vars = list(variables)
649        grads = t.gradient(
650            recomputed_result,
651            list(id_args) + kw_vars,
652            output_gradients=dresult,
653            unconnected_gradients=UnconnectedGradients.ZERO)
654
655        def transpose(*t_args, **t_kwargs):
656          """Gradient function calculation for forward mode autodiff."""
657          # Just throw an error since gradients / activations are not stored on
658          # tape for recompute.
659          raise NotImplementedError(
660              "recompute_grad tried to transpose grad of {}. "
661              "Consider not using recompute_grad in forward mode"
662              "autodiff".format(f.__name__))
663
664        return (grads[:len(id_args)], grads[len(id_args):]), transpose
665
666      return inner_recompute_grad(*wrapper_args)
667
668    return result, grad_wrapper
669
670  return inner
671
672
673@tf_export("grad_pass_through")
674def grad_pass_through(f):
675  """Creates a grad-pass-through op with the forward behavior provided in f.
676
677  Use this function to wrap any op, maintaining its behavior in the forward
678  pass, but replacing the original op in the backward graph with an identity.
679  For example:
680
681  ```python
682  x = tf.Variable(1.0, name="x")
683  z = tf.Variable(3.0, name="z")
684
685  with tf.GradientTape() as tape:
686    # y will evaluate to 9.0
687    y = tf.grad_pass_through(x.assign)(z**2)
688  # grads will evaluate to 6.0
689  grads = tape.gradient(y, z)
690  ```
691
692  Another example is a 'differentiable' moving average approximation, where
693  gradients are allowed to flow into the last value fed to the moving average,
694  but the moving average is still used for the forward pass:
695
696  ```python
697  x = ... # Some scalar value
698  # A moving average object, we don't need to know how this is implemented
699  moving_average = MovingAverage()
700  with backprop.GradientTape() as tape:
701    # mavg_x will evaluate to the current running average value
702    mavg_x = tf.grad_pass_through(moving_average)(x)
703  grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
704  ```
705
706  Args:
707    f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
708      outputs.
709
710  Returns:
711    A function `h(x)` which returns the same values as `f(x)` and whose
712    gradients are the same as those of an identity function.
713  """
714  @custom_gradient
715  def _grad_pass_through_op(*args, **kwargs):
716    def grad(*args, **kwargs):
717      variables = kwargs.get("variables")
718      if variables is not None:
719        # Variables involved in the wrapped op will not receive gradients.
720        return args, [None] * len(variables)
721      return args
722    return f(*args, **kwargs), grad
723  return tf_decorator.make_decorator(f, _grad_pass_through_op)
724