• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code for backpropagation using the tape utilities."""
16
17# TODO(b/159343581): Properly support CompositeTensor in all functions in this
18# file.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import functools
25import operator
26import sys
27
28import six
29
30from tensorflow.python import pywrap_tfe
31from tensorflow.python.eager import backprop_util
32from tensorflow.python.eager import context
33from tensorflow.python.eager import execute
34from tensorflow.python.eager import imperative_grad
35from tensorflow.python.eager import tape
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import check_ops
43from tensorflow.python.ops import control_flow_util
44from tensorflow.python.ops import default_gradient
45from tensorflow.python.ops import gen_array_ops
46from tensorflow.python.ops import gen_math_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import resource_variable_ops
49from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.util import _pywrap_utils
52from tensorflow.python.util import nest
53from tensorflow.python.util import tf_contextlib
54from tensorflow.python.util import tf_inspect
55from tensorflow.python.util.lazy_loader import LazyLoader
56from tensorflow.python.util.tf_export import tf_export
57
58
59# Note that we need to lazy load the following two modules to avoid creating
60# circular dependencies.
61# TODO(b/119775953): fix the circular dependencies.
62pfor_ops = LazyLoader(
63    "pfor_ops", globals(),
64    "tensorflow.python.ops.parallel_for.control_flow_ops")
65
66function = LazyLoader("function", globals(),
67                      "tensorflow.python.eager.function")
68
69_op_attr_type_cache = {}
70
71
72def op_attr_type(op_type, attr_name):
73  try:
74    return _op_attr_type_cache[(op_type, attr_name)]
75  except KeyError:
76    context.ensure_initialized()
77    h = context.context()._handle  # pylint: disable=protected-access
78    attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
79  _op_attr_type_cache[(op_type, attr_name)] = attr_type
80  return attr_type
81
82
83def make_attr(attr_type, value):
84  # pybind11 enums do not return the raw value like SWIG enums do. They are
85  # useful when comparing amongst each other but not direct integers as we are
86  # doing in most tests.
87  # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
88  # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
89  # from integer value to class.
90  if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
91    return dtypes.as_dtype(value)
92  if attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
93    return [dtypes.as_dtype(v) for v in value]
94  if attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
95    return tensor_shape.as_shape(value).as_proto()
96  if attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
97    return [tensor_shape.as_shape(v).as_proto() for v in value]
98  if isinstance(value, str):
99    return value.encode()
100  return value
101
102
103class _MockOp(object):
104  """Pretends to be a tf.Operation for the gradient functions."""
105
106  def __init__(self, attrs, inputs, outputs, typ, skip_input_indices):
107    self.attrs = attrs
108    self.inputs = inputs
109    self.outputs = outputs
110    self.type = typ
111    self.skip_input_indices = skip_input_indices
112
113  def get_attr(self, attr):
114    typ = op_attr_type(self.type, attr)
115    for i in range(0, len(self.attrs), 2):
116      if self.attrs[i] == attr:
117        return make_attr(typ, self.attrs[i + 1])
118    raise KeyError(attr)
119
120  def _get_control_flow_context(self):
121    raise NotImplementedError(
122        "tf.GradientTape.gradients() does not support graph control flow "
123        "operations like tf.cond or tf.while at this time. Use tf.gradients() "
124        "instead. If you need this feature, please file a feature request at "
125        "https://github.com/tensorflow/tensorflow/issues/new"
126    )
127
128
129def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
130                       out_grads, skip_input_indices, forward_pass_name_scope):
131  """Calls the gradient function of the op.
132
133  Args:
134    op_name: the name of the op to be differentiated.
135    attr_tuple: the attrs, as a tuple.
136    num_inputs: the number of inputs to the op.
137    inputs: inputs to the original operation.
138    outputs: outputs to the original operation.
139    out_grads: gradients of the operation wrt its outputs.
140    skip_input_indices: a tuple that is passed to the gradient function,
141      indicating which inputs to skip calculating the gradient for
142    forward_pass_name_scope: the namescope of the op in the forward pass.
143
144  Returns:
145    The gradients with respect to the inputs of the function, as a list.
146  """
147  mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices)
148  grad_fn = ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
149  if grad_fn is None:
150    return [None] * num_inputs
151
152  # This does not work with v1 TensorArrays.
153  if ops.executing_eagerly_outside_functions(
154  ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
155    gradient_name_scope = "gradient_tape/"
156    if forward_pass_name_scope:
157      gradient_name_scope += forward_pass_name_scope + "/"
158    with ops.name_scope(gradient_name_scope):
159      return grad_fn(mock_op, *out_grads)
160  else:
161    return grad_fn(mock_op, *out_grads)
162
163
164pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
165
166
167def _must_record_gradient():
168  return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
169
170
171@tf_export("__internal__.record_gradient", v1=[])
172def record_gradient(op_name, inputs, attrs, outputs):
173  """Explicitly record the gradient for a given op.
174
175  Args:
176    op_name: The op name as listed in the `OpDef` for the op.
177    inputs: A list of tensor inputs to the op.
178    attrs: The op attributes as a flattened list of alternating attribute names
179      and attribute values.
180    outputs: A list of tensor outputs from the op.
181  """
182  pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, outputs,
183                                   ops.get_name_scope())
184
185
186execute.must_record_gradient = _must_record_gradient
187execute.record_gradient = record_gradient
188
189
190def implicit_val_and_grad(f):
191  """Returns a function which differentiates f with respect to variables.
192
193  The wrapped function returns the value and the gradient of f when called with
194  the same arguments. The gradient is with respect to all trainable TFE
195  variables accessed by `f`.
196
197  This function is useful when the exact set of variables to differentiate with
198  is not known ahead of time.
199
200  Example:
201
202  ```python
203  dense_layer = tf.compat.v1.layers.Dense(1)
204  def loss(x, y):
205    return tf.reduce_sum(tf.square(dense_layer(x) - y))
206
207  # Obtain the gradient function.
208  val_grad_fn = tfe.implicit_value_and_gradients(loss)
209
210  # Invoke the gradient function with concrete values of x and y.
211  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
212  y = tf.constant([[10.0], [20.0]])
213  value, grads_and_vars = val_grad_fn(x, y)
214  print('Value of loss: %s' % value)
215
216  # Apply the gradients to Variables.
217  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
218  optimizer.apply_gradients(grads_and_vars)
219  ```
220
221  Args:
222    f: function to be differentiated. If `f` returns a scalar, this scalar will
223      be differentiated. If `f` returns a tensor or list of tensors, by default
224      a scalar will be computed by adding all their values to produce a single
225      scalar.
226
227  Returns:
228    A function which, when called, returns a tuple pair.
229    Its first element is the value to which the function evaluates.
230    Its second element is list of (gradient, variable) pairs.
231
232  Raises:
233    ValueError: if `f` returns None.
234  """
235  # TODO(cais): Remove calls to tf.constant() once the gradients functions
236  # accept lists and np.ndarrays.
237
238  def grad_fn(*args, **kwds):
239    """Computes the gradient of the wrapped function."""
240    this_tape = tape.push_new_tape()
241    try:
242      end_node = f(*args, **kwds)
243      if end_node is None:
244        raise ValueError("Cannot differentiate a function that returns None; "
245                         "did you forget to return a value from {}?".format(
246                             f.__name__))
247    finally:
248      tape.pop_tape(this_tape)
249    # Note: variables are returned in construction order. This ensures unique
250    # order across executions.
251    variables = this_tape.watched_variables()
252    if not variables:
253      raise ValueError("No trainable variables were accessed while the "
254                       "function was being computed.")
255
256    sources = [v.handle for v in variables]
257    for s in sources:
258      if getattr(s, "is_packed", False):
259        raise ValueError(
260            "GradientTape.gradient is not supported on packed EagerTensors yet."
261        )
262    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
263                                           sources)
264    return end_node, list(zip(grad, variables))
265
266  return grad_fn
267
268
269def implicit_grad(f):
270  """Returns a function which differentiates f with respect to variables.
271
272  The wrapped function returns the gradient of f when called with the same
273  arguments. The gradient is with respect to all trainable TFE variables
274  accessed by `f`.
275
276  This function is useful when the exact set of variables to differentiate with
277  is not known ahead of time.
278
279  Example:
280
281  ```python
282  dense_layer = tf.compat.v1.layers.Dense(1)
283  def loss(x, y):
284    return tf.reduce_sum(tf.square(dense_layer(x) - y))
285
286  # Obtain the gradient function.
287  grad_fn = tfe.implicit_gradients(loss)
288
289  # Invoke the gradient function with concrete values of x and y.
290  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
291  y = tf.constant([[10.0], [20.0]])
292  grads_and_vars = grad_fn(x, y)
293
294  # Apply the gradients to Variables.
295  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
296  optimizer.apply_gradients(grads_and_vars)
297  ```
298
299  Args:
300    f: function to be differentiated. If `f` returns a scalar, this scalar will
301      be differentiated. If `f` returns a tensor or list of tensors, by default
302      a scalar will be computed by adding all their values to produce a single
303      scalar.
304
305  Returns:
306    A function which, when called, returns a list of (gradient, variable) pairs.
307  """
308  # TODO(cais): Remove calls to tf.constant() once the gradients functions
309  # accept lists and np.ndarrays.
310
311  def grad_fn(*args, **kwds):
312    """Computes the gradient of the wrapped function."""
313    return implicit_val_and_grad(f)(*args, **kwds)[1]
314
315  return grad_fn
316
317
318def _get_arg_spec(f, params, param_args):
319  """The positions of the parameters of f to be differentiated in param_args."""
320  try:
321    args = tf_inspect.getfullargspec(f).args
322  except TypeError as e:
323    # TypeError can happen when f is a callable object.
324    if params is None:
325      return range(len(param_args))
326    elif all(isinstance(x, int) for x in params):
327      return params
328    raise ValueError("Either callable provided is not a function or could not "
329                     "inspect its arguments by name: %s. Original error: %s"
330                     % (f, e))
331  if params is None:
332    if not args:
333      return range(len(param_args))
334    if args[0] == "self":
335      return range(len(args) - 1)
336    else:
337      return range(len(args))
338  elif all(isinstance(x, six.string_types) for x in params):
339    return [args.index(n) for n in params]
340  elif all(isinstance(x, int) for x in params):
341    return params
342  else:
343    raise ValueError(
344        "params must be all strings or all integers; got %s." % params)
345
346
347def gradients_function(f, params=None):
348  """Returns a function which differentiates f with respect to params.
349
350  Example:
351  ```python
352  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
353  # Therefore, the 1st order derivatives are:
354  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
355  #   df / dy = x ^ 3 - 2 * x * y
356  # The 2nd order derivatives with respect to x is:
357  #   d^2 f / (dx)^2 = 6 * x * y
358  def f(x, y):
359    return x * x * x * y - x * y * y
360
361  # Obtain a function that returns 1st order gradients.
362  grad_fn = tfe.gradients_function(f)
363
364  x = 2.0
365  y = 3.0
366
367  # Invoke the 1st order gradient function.
368  x_grad, y_grad = grad_fn(x, y)
369  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
370  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
371
372  # Obtain a function that returns the 2nd order gradient with respect to x.
373  gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0])
374
375  # Invoke the 2nd order gradient function.
376  x_gradgrad = gradgrad_fn(x, y)[0]
377  assert x_gradgrad.numpy() == 6 * 2 * 3
378
379  # To obtain a callable that returns the gradient(s) of `f` with respect to a
380  # subset of its inputs, use the `params` keyword argument with
381  # `gradients_function()`.
382  ygrad_fn = tfe.gradients_function(f, params=[1])
383
384  (y_grad,) = ygrad_fn(x, y)
385  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
386  ```
387
388  Note that only tensors with real or complex dtypes are differentiable.
389
390  Args:
391    f: function to be differentiated. If `f` returns a scalar, this scalar will
392      be differentiated. If `f` returns a tensor or list of tensors, by default
393      a scalar will be computed by adding all their values to produce a single
394      scalar. If desired, the tensors can be elementwise multiplied by the
395      tensors passed as the `dy` keyword argument to the returned gradient
396      function.
397    params: list of parameter names of f or list of integers indexing the
398      parameters with respect to which we'll differentiate. Passing None
399      differentiates with respect to all parameters.
400
401  Returns:
402    function which, when called, returns the value of f and the gradient
403    of `f` with respect to all of `params`. The function takes an extra optional
404    keyword argument `dy`. Setting it allows computation of vector jacobian
405    products for vectors other than the vector of ones.
406
407  Raises:
408    ValueError: if the params are not all strings or all integers.
409  """
410
411  def decorated(*args, **kwds):
412    """Computes the gradient of the decorated function."""
413
414    _, grad = val_and_grad_function(f, params=params)(*args, **kwds)
415    return grad
416
417  return decorated
418
419
420def _ensure_unique_tensor_objects(parameter_positions, args):
421  """Make each of the parameter_positions in args a unique ops.Tensor object.
422
423  Ensure that each parameter is treated independently.
424  For example:
425
426  def f(x, y): return x * y
427  g = gradients_function(f)
428  one = tf.constant(1.)
429
430  g(one, one) should return [1., 1.]
431  (even though the two arguments are the same Tensor object).
432
433  Args:
434    parameter_positions: List of indices into args defining the arguments to
435      differentiate against.
436    args: A list of arguments to the function to be differentiated.
437
438  Returns:
439    args, possibly edited in-place.
440  """
441  s = set()
442  for (i, t) in enumerate(args):
443    if i in parameter_positions:
444      tid = ops.tensor_id(t)
445      if tid in s:
446        args[i] = gen_array_ops.identity(args[i])
447      else:
448        s.add(tid)
449  return args
450
451
452def val_and_grad_function(f, params=None):
453  """Returns a function that computes f and its derivative w.r.t. params.
454
455  Example:
456  ```python
457  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
458  # Therefore, the 1st order derivatives are:
459  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
460  #   df / dy = x ^ 3 - 2 * x * y
461  def f(x, y):
462    return x * x * x * y - x * y * y
463
464  # Obtain a function that returns the function value and the 1st order
465  # gradients.
466  val_grads_fn = tfe.value_and_gradients_function(f)
467
468  x = 2.0
469  y = 3.0
470
471  # Invoke the value-and-gradients function.
472  f_val, (x_grad, y_grad) = val_grads_fn(x, y)
473  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
474  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
475  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
476
477  # To obtain a callable that returns the value of `f` and the gradient(s) of
478  # `f` with respect to a subset of its inputs, use the `params` keyword
479  # argument with `value_and_gradients_function()`.
480  val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1])
481
482  f_val, (y_grad,) = val_ygrad_fn(x, y)
483  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
484  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
485  ```
486
487  Args:
488    f: function to be differentiated. If `f` returns a scalar, this scalar will
489      be differentiated. If `f` returns a tensor or list of tensors, by default
490      a scalar will be computed by adding all their values to produce a single
491      scalar. If desired, the tensors can be elementwise multiplied by the
492      tensors passed as the `dy` keyword argument to the returned gradient
493      function.
494    params: list of parameter names of f or list of integers indexing the
495      parameters with respect to which we'll differentiate. Passing `None`
496      differentiates with respect to all parameters.
497
498  Returns:
499    function which, when called, returns the value of f and the gradient
500    of f with respect to all of `params`. The function takes an extra optional
501    keyword argument "dy". Setting it allows computation of vector jacobian
502    products for vectors other than the vector of ones.
503
504  Raises:
505    ValueError: if the params are not all strings or all integers.
506  """
507
508  def decorated(*args, **kwds):
509    """Computes the value and gradient of the decorated function."""
510    dy = kwds.pop("dy", None)
511    if kwds:
512      raise ValueError("Functions to be differentiated cannot "
513                       "receive keyword arguments.")
514    val, vjp = make_vjp(f, params)(*args, **kwds)
515    return val, vjp(dy=dy)
516
517  return decorated
518
519
520def make_vjp(f, params=None, persistent=True):
521  """Returns a function that computes f and its vjp w.r.t.
522
523  params.
524
525  The term "vjp" here is an abbreviation for vector-jacobian product.
526
527  Args:
528    f: the function to be differentiated.
529    params: the parameters (numbers or names) to differentiate with respect to.
530      A value of None will differentiate with respect to all parameters.
531    persistent: Boolean controlling whether the VJP function can be re-used.
532      Must be True or False.
533
534  Returns:
535    A function, which when called, returns a tuple (value, vjp), where:
536    - value is the result of calling f.
537    - vjp is a function, which takes a vector as an argument and
538      returns the product of that vector with the Jacobian of f.
539      Providing no argument to vjp is equivalent to providing a
540      vector of ones.
541
542    For example,
543    ```python
544    def f(x):
545      return x * x
546
547    wrapped_fn = tfe.make_vjp(f)
548    result, vjp = wrapped_fn(tf.constant(3.0))
549    # result is 9.0
550    vjp()  # the vjp function returns 6.0
551
552  Raises:
553    ValueError: if `f` returns None.
554  """
555
556  def decorated(*args, **kwds):
557    """Computes the value and gradient of the decorated function."""
558    parameter_positions = _get_arg_spec(f, params, args)
559    assert not kwds, "The gradient function can't take keyword arguments."
560    this_tape = tape.push_new_tape(persistent=persistent)
561    try:
562      sources = []
563      args = [
564          ops.convert_to_tensor(arg) if i in parameter_positions else arg
565          for i, arg in enumerate(args)
566      ]
567      args = _ensure_unique_tensor_objects(parameter_positions, args)
568      for i in parameter_positions:
569        if getattr(args[i], "is_packed", False):
570          raise ValueError(
571              "GradientTape.gradient is not supported on packed EagerTensors"
572              "yet.")
573        sources.append(args[i])
574        tape.watch(this_tape, args[i])
575      result = f(*args)
576      if result is None:
577        raise ValueError("Cannot differentiate a function that returns None; "
578                         "did you forget to return a value from {}?".format(
579                             f.__name__))
580      flat_result = nest.flatten(result)
581      flat_result = [gen_array_ops.identity(x) for x in flat_result]
582      result = nest.pack_sequence_as(result, flat_result)
583    finally:
584      tape.pop_tape(this_tape)
585    def vjp(dy=None):
586      if dy is not None:
587        dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
588      return imperative_grad.imperative_grad(
589          this_tape, nest.flatten(result), sources, output_gradients=dy)
590
591    return result, vjp
592
593  return decorated
594
595
596def flatten_nested_indexed_slices(grad):
597  assert isinstance(grad, ops.IndexedSlices)
598  if isinstance(grad.values, ops.Tensor):
599    return grad
600  else:
601    assert isinstance(grad.values, ops.IndexedSlices)
602    g = flatten_nested_indexed_slices(grad.values)
603    return ops.IndexedSlices(g.values, array_ops.gather(grad.indices,
604                                                        g.indices),
605                             g.dense_shape)
606
607
608def aggregate_indexed_slices_gradients(grads):
609  """Aggregates gradients containing `IndexedSlices`s."""
610  if len(grads) < 1:
611    return None
612  if len(grads) == 1:
613    return grads[0]
614  grads = [g for g in grads if g is not None]
615  # If any gradient is a `Tensor`, sum them up and return a dense tensor
616  # object.
617  if any(isinstance(g, ops.Tensor) for g in grads):
618    return math_ops.add_n(grads)
619
620  # The following `_as_indexed_slices_list` casts ids of IndexedSlices into
621  # int64. It is to make sure the inputs of `concat` all have same the data
622  # type.
623  grads = math_ops._as_indexed_slices_list(grads)  # pylint: disable=protected-access
624
625  grads = [flatten_nested_indexed_slices(x) for x in grads]
626  # Form IndexedSlices out of the concatenated values and indices.
627  concat_grad = ops.IndexedSlices(
628      array_ops.concat([x.values for x in grads], axis=0),
629      array_ops.concat([x.indices for x in grads], axis=0),
630      grads[0].dense_shape)
631
632  return concat_grad
633
634
635def _aggregate_grads(gradients):
636  """Aggregate gradients from multiple sources.
637
638  Args:
639    gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
640
641  Returns:
642    If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
643    Otherwise returns an aggregated 'IndexedSlices'.
644  """
645  assert gradients, "No gradients to aggregate"
646
647  if len(gradients) == 1:
648    return gradients[0]
649  if all(isinstance(g, ops.Tensor) for g in gradients):
650    return gen_math_ops.add_n(gradients)
651  else:
652    assert all(isinstance(g, (ops.Tensor, ops.IndexedSlices))
653               for g in gradients)
654    return aggregate_indexed_slices_gradients(gradients)
655
656
657def _num_elements(grad):
658  """The number of elements in the `grad` tensor."""
659  if isinstance(grad, ops.Tensor):
660    shape_tuple = grad._shape_tuple()  # pylint: disable=protected-access
661  elif isinstance(grad, ops.IndexedSlices):
662    shape_tuple = grad.values._shape_tuple()  # pylint: disable=protected-access
663  else:
664    raise ValueError("`grad` not a Tensor or IndexedSlices.")
665  if shape_tuple is None or None in shape_tuple:
666    return 0
667  return functools.reduce(operator.mul, shape_tuple, 1)
668
669
670def _fast_fill(value, shape, dtype):
671  return array_ops.fill(
672      constant_op.constant(shape, dtype=dtypes.int32),
673      constant_op.constant(value, dtype=dtype))
674
675
676def _zeros(shape, dtype):
677  """Helper to return (possibly cached) zero tensors in eager mode."""
678  # Note: variants will use _zeros_like
679  if dtype == dtypes.string or dtype == dtypes.resource:
680    return None
681
682  ctx = context.context()
683  if not ctx.executing_eagerly():
684    return array_ops.zeros(shape, dtype)
685
686  device = ctx.device_name
687
688  if tensor_util.is_tf_type(shape):
689    shape_key = shape.ref()
690  else:
691    shape_key = shape
692  cache_key = shape_key, dtype, device
693  cached = ctx.zeros_cache().get(cache_key)
694  if cached is None:
695    if dtypes.as_dtype(dtype).is_bool:
696      value = False
697    else:
698      value = 0
699    cached = _fast_fill(value, shape, dtype)
700    ctx.zeros_cache().put(cache_key, cached)
701  return cached
702
703
704def _ones(shape, dtype):
705  as_dtype = dtypes.as_dtype(dtype)
706  if as_dtype == dtypes.string:
707    return None
708
709  if not context.executing_eagerly():
710    return array_ops.ones(shape, dtype)
711
712  if as_dtype.is_bool:
713    value = True
714  else:
715    value = 1
716
717  if shape == ():  # pylint: disable=g-explicit-bool-comparison
718    return constant_op.constant(value, dtype=dtype)
719  return _fast_fill(value, shape, dtype)
720
721
722_default_vspace = imperative_grad.VSpace(
723    num_elements_fn=_num_elements,
724    aggregate_fn=_aggregate_grads,
725    zeros_fn=_zeros,
726    ones_fn=_ones,
727    zeros_like_fn=default_gradient.zeros_like,
728    ones_like_fn=default_gradient.ones_like,
729    graph_shape_fn=gen_array_ops.shape)
730pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
731
732
733def _handle_or_self(x):
734  """Unwrap resource variable/ndarray to return tensors."""
735  if resource_variable_ops.is_resource_variable(x):
736    return x.handle
737  return x
738
739
740@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"])
741class GradientTape(object):
742  """Record operations for automatic differentiation.
743
744  Operations are recorded if they are executed within this context manager and
745  at least one of their inputs is being "watched".
746
747  Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`,
748  where `trainable=True` is default in both cases) are automatically watched.
749  Tensors can be manually watched by invoking the `watch` method on this context
750  manager.
751
752  For example, consider the function `y = x * x`. The gradient at `x = 3.0` can
753  be computed as:
754
755  >>> x = tf.constant(3.0)
756  >>> with tf.GradientTape() as g:
757  ...   g.watch(x)
758  ...   y = x * x
759  >>> dy_dx = g.gradient(y, x)
760  >>> print(dy_dx)
761  tf.Tensor(6.0, shape=(), dtype=float32)
762
763  GradientTapes can be nested to compute higher-order derivatives. For example,
764
765  >>> x = tf.constant(5.0)
766  >>> with tf.GradientTape() as g:
767  ...   g.watch(x)
768  ...   with tf.GradientTape() as gg:
769  ...     gg.watch(x)
770  ...     y = x * x
771  ...   dy_dx = gg.gradient(y, x)  # dy_dx = 2 * x
772  >>> d2y_dx2 = g.gradient(dy_dx, x)  # d2y_dx2 = 2
773  >>> print(dy_dx)
774  tf.Tensor(10.0, shape=(), dtype=float32)
775  >>> print(d2y_dx2)
776  tf.Tensor(2.0, shape=(), dtype=float32)
777
778  By default, the resources held by a GradientTape are released as soon as
779  GradientTape.gradient() method is called. To compute multiple gradients over
780  the same computation, create a persistent gradient tape. This allows multiple
781  calls to the gradient() method as resources are released when the tape object
782  is garbage collected. For example:
783
784  >>> x = tf.constant(3.0)
785  >>> with tf.GradientTape(persistent=True) as g:
786  ...   g.watch(x)
787  ...   y = x * x
788  ...   z = y * y
789  >>> dz_dx = g.gradient(z, x)  # (4*x^3 at x = 3)
790  >>> print(dz_dx)
791  tf.Tensor(108.0, shape=(), dtype=float32)
792  >>> dy_dx = g.gradient(y, x)
793  >>> print(dy_dx)
794  tf.Tensor(6.0, shape=(), dtype=float32)
795
796  By default GradientTape will automatically watch any trainable variables that
797  are accessed inside the context. If you want fine grained control over which
798  variables are watched you can disable automatic tracking by passing
799  `watch_accessed_variables=False` to the tape constructor:
800
801  >>> x = tf.Variable(2.0)
802  >>> w = tf.Variable(5.0)
803  >>> with tf.GradientTape(
804  ...     watch_accessed_variables=False, persistent=True) as tape:
805  ...   tape.watch(x)
806  ...   y = x ** 2  # Gradients will be available for `x`.
807  ...   z = w ** 3  # No gradients will be available as `w` isn't being watched.
808  >>> dy_dx = tape.gradient(y, x)
809  >>> print(dy_dx)
810  tf.Tensor(4.0, shape=(), dtype=float32)
811  >>> # No gradients will be available as `w` isn't being watched.
812  >>> dz_dy = tape.gradient(z, w)
813  >>> print(dz_dy)
814  None
815
816  Note that when using models you should ensure that your variables exist when
817  using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
818  first iteration not have any gradients:
819
820  ```python
821  a = tf.keras.layers.Dense(32)
822  b = tf.keras.layers.Dense(32)
823
824  with tf.GradientTape(watch_accessed_variables=False) as tape:
825    tape.watch(a.variables)  # Since `a.build` has not been called at this point
826                             # `a.variables` will return an empty list and the
827                             # tape will not be watching anything.
828    result = b(a(inputs))
829    tape.gradient(result, a.variables)  # The result of this computation will be
830                                        # a list of `None`s since a's variables
831                                        # are not being watched.
832  ```
833
834  Note that only tensors with real or complex dtypes are differentiable.
835  """
836
837  def __init__(self, persistent=False, watch_accessed_variables=True):
838    """Creates a new GradientTape.
839
840    Args:
841      persistent: Boolean controlling whether a persistent gradient tape
842        is created. False by default, which means at most one call can
843        be made to the gradient() method on this object.
844      watch_accessed_variables: Boolean controlling whether the tape will
845        automatically `watch` any (trainable) variables accessed while the tape
846        is active. Defaults to True meaning gradients can be requested from any
847        result computed in the tape derived from reading a trainable `Variable`.
848        If False users must explicitly `watch` any `Variable`s they want to
849        request gradients from.
850    """
851    self._tape = None
852    self._persistent = persistent
853    self._watch_accessed_variables = watch_accessed_variables
854    self._watched_variables = ()
855    self._recording = False
856
857  def __enter__(self):
858    """Enters a context inside which operations are recorded on this tape."""
859    self._push_tape()
860    return self
861
862  def __exit__(self, typ, value, traceback):
863    """Exits the recording context, no further operations are traced."""
864    if self._recording:
865      self._pop_tape()
866
867  def _push_tape(self):
868    """Pushes a new tape onto the tape stack."""
869    if self._recording:
870      raise ValueError("Tape is still recording, This can happen if you try to "
871                       "re-enter an already-active tape.")
872    if self._tape is None:
873      self._tape = tape.push_new_tape(
874          persistent=self._persistent,
875          watch_accessed_variables=self._watch_accessed_variables)
876    else:
877      tape.push_tape(self._tape)
878    self._recording = True
879
880  def _pop_tape(self):
881    if not self._recording:
882      raise ValueError("Tape is not recording.")
883    tape.pop_tape(self._tape)
884    self._recording = False
885
886  @tf_contextlib.contextmanager
887  def _ensure_recording(self):
888    """Ensures that this tape is recording."""
889    if not self._recording:
890      try:
891        self._push_tape()
892        yield
893      finally:
894        self._pop_tape()
895    else:
896      yield
897
898  def watch(self, tensor):
899    """Ensures that `tensor` is being traced by this tape.
900
901    Args:
902      tensor: a Tensor or list of Tensors.
903
904    Raises:
905      ValueError: if it encounters something that is not a tensor.
906    """
907    for t in nest.flatten(tensor, expand_composites=True):
908      if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
909        raise ValueError("Passed in object of type {}, not tf.Tensor".format(
910            type(t)))
911      if not backprop_util.IsTrainable(t):
912        logging.log_first_n(
913            logging.WARN, "The dtype of the watched tensor must be "
914            "floating (e.g. tf.float32), got %r", 5, t.dtype)
915      if hasattr(t, "handle"):
916        # There are many variable-like objects, all of them currently have
917        # `handle` attribute that points to a tensor. If this changes, internals
918        # of watch_variable need to change as well.
919        tape.watch_variable(self._tape, t)
920      else:
921        tape.watch(self._tape, t)
922
923  @tf_contextlib.contextmanager
924  def stop_recording(self):
925    """Temporarily stops recording operations on this tape.
926
927    Operations executed while this context manager is active will not be
928    recorded on the tape. This is useful for reducing the memory used by tracing
929    all computations.
930
931    For example:
932
933    >>> x = tf.constant(4.0)
934    >>> with tf.GradientTape() as tape:
935    ...   with tape.stop_recording():
936    ...     y = x ** 2
937    >>> dy_dx = tape.gradient(y, x)
938    >>> print(dy_dx)
939    None
940
941    Yields:
942      None
943    Raises:
944      RuntimeError: if the tape is not currently recording.
945    """
946    if self._tape is None:
947      raise RuntimeError(
948          "Trying to stop recording a tape which is not recording.")
949    self._pop_tape()
950    try:
951      yield
952    finally:
953      self._push_tape()
954
955  def reset(self):
956    """Clears all information stored in this tape.
957
958    Equivalent to exiting and reentering the tape context manager with a new
959    tape. For example, the two following code blocks are equivalent:
960
961    ```
962    with tf.GradientTape() as t:
963      loss = loss_fn()
964    with tf.GradientTape() as t:
965      loss += other_loss_fn()
966    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
967
968
969    # The following is equivalent to the above
970    with tf.GradientTape() as t:
971      loss = loss_fn()
972      t.reset()
973      loss += other_loss_fn()
974    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
975    ```
976
977    This is useful if you don't want to exit the context manager for the tape,
978    or can't because the desired reset point is inside a control flow construct:
979
980    ```
981    with tf.GradientTape() as t:
982      loss = ...
983      if loss > k:
984        t.reset()
985    ```
986    """
987    self._pop_tape()
988    self._tape = None
989    self._push_tape()
990
991  def watched_variables(self):
992    """Returns variables watched by this tape in order of construction."""
993    if self._tape is not None:
994      self._watched_variables = self._tape.watched_variables()
995    return self._watched_variables
996
997  def gradient(self,
998               target,
999               sources,
1000               output_gradients=None,
1001               unconnected_gradients=UnconnectedGradients.NONE):
1002    """Computes the gradient using operations recorded in context of this tape.
1003
1004    Note: Unless you set `persistent=True` a GradientTape can only be used to
1005    compute one set of gradients (or jacobians).
1006
1007    Args:
1008      target: a list or nested structure of Tensors or Variables to be
1009        differentiated.
1010      sources: a list or nested structure of Tensors or Variables. `target`
1011        will be differentiated against elements in `sources`.
1012      output_gradients: a list of gradients, one for each element of
1013        target. Defaults to None.
1014      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1015        alters the value which will be returned if the target and sources are
1016        unconnected. The possible values and effects are detailed in
1017        'UnconnectedGradients' and it defaults to 'none'.
1018
1019    Returns:
1020      a list or nested structure of Tensors (or IndexedSlices, or None),
1021      one for each element in `sources`. Returned structure is the same as
1022      the structure of `sources`.
1023
1024    Raises:
1025      RuntimeError: If called on a used, non-persistent tape.
1026      RuntimeError: If called inside the context of the tape.
1027      TypeError: If the target is a None object.
1028      ValueError: If the target is a variable or if unconnected gradients is
1029       called with an unknown value.
1030    """
1031    if self._tape is None:
1032      raise RuntimeError("A non-persistent GradientTape can only be used to "
1033                         "compute one set of gradients (or jacobians)")
1034    if self._recording:
1035      if not self._persistent:
1036        self._pop_tape()
1037      else:
1038        logging.log_first_n(
1039            logging.WARN, "Calling GradientTape.gradient on a persistent "
1040            "tape inside its context is significantly less "
1041            "efficient than calling it outside the context (it "
1042            "causes the gradient ops to be recorded on the "
1043            "tape, leading to increased CPU and memory usage). "
1044            "Only call GradientTape.gradient inside the "
1045            "context if you actually want to trace the "
1046            "gradient in order to compute higher order "
1047            "derivatives.", 1)
1048
1049    if target is None:
1050      raise TypeError("Target should be a list or nested structure"
1051                      " of Tensors or Variables to be differentiated,"
1052                      " but recieved %r" % (target))
1053
1054    flat_targets = []
1055    for t in nest.flatten(target):
1056      if not backprop_util.IsTrainable(t):
1057        logging.vlog(
1058            logging.WARN, "The dtype of the target tensor must be "
1059            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1060            "got %r", t.dtype)
1061      if resource_variable_ops.is_resource_variable(t):
1062        with self:
1063          t = ops.convert_to_tensor(t)
1064      flat_targets.append(t)
1065
1066    flat_sources = nest.flatten(sources)
1067    flat_sources_raw = flat_sources
1068    flat_sources = [_handle_or_self(x) for x in flat_sources]
1069    for t in flat_sources_raw:
1070      if not backprop_util.IsTrainable(t):
1071        logging.vlog(
1072            logging.WARN, "The dtype of the source tensor must be "
1073            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1074            "got %r", t.dtype)
1075      if getattr(t, "is_packed", False):
1076        raise ValueError(
1077            "GradientTape.gradient is not supported on packed EagerTensors yet."
1078        )
1079
1080    if output_gradients is not None:
1081      output_gradients = [None if x is None else ops.convert_to_tensor(x)
1082                          for x in nest.flatten(output_gradients)]
1083
1084    flat_grad = imperative_grad.imperative_grad(
1085        self._tape,
1086        flat_targets,
1087        flat_sources,
1088        output_gradients=output_gradients,
1089        sources_raw=flat_sources_raw,
1090        unconnected_gradients=unconnected_gradients)
1091
1092    if not self._persistent:
1093      # Keep track of watched variables before setting tape to None
1094      self._watched_variables = self._tape.watched_variables()
1095      self._tape = None
1096
1097    grad = nest.pack_sequence_as(sources, flat_grad)
1098    return grad
1099
1100  def jacobian(self,
1101               target,
1102               sources,
1103               unconnected_gradients=UnconnectedGradients.NONE,
1104               parallel_iterations=None,
1105               experimental_use_pfor=True):
1106    """Computes the jacobian using operations recorded in context of this tape.
1107
1108    Note: Unless you set `persistent=True` a GradientTape can only be used to
1109    compute one set of gradients (or jacobians).
1110
1111    Note: By default the jacobian implementation uses parallel for (pfor), which
1112    creates a tf.function under the hood for each jacobian call. For better
1113    performance, and to avoid recompilation and vectorization rewrites on each
1114    call, enclose GradientTape code in @tf.function.
1115
1116    See[wikipedia
1117    article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1118    for the definition of a Jacobian.
1119
1120    Example usage:
1121
1122    ```python
1123    with tf.GradientTape() as g:
1124      x  = tf.constant([1.0, 2.0])
1125      g.watch(x)
1126      y = x * x
1127    jacobian = g.jacobian(y, x)
1128    # jacobian value is [[2., 0.], [0., 4.]]
1129    ```
1130
1131    Args:
1132      target: Tensor to be differentiated.
1133      sources: a list or nested structure of Tensors or Variables. `target`
1134        will be differentiated against elements in `sources`.
1135      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1136        alters the value which will be returned if the target and sources are
1137        unconnected. The possible values and effects are detailed in
1138        'UnconnectedGradients' and it defaults to 'none'.
1139      parallel_iterations: A knob to control how many iterations are dispatched
1140        in parallel. This knob can be used to control the total memory usage.
1141      experimental_use_pfor: If true, vectorizes the jacobian computation. Else
1142        falls back to a sequential while_loop. Vectorization can sometimes fail
1143        or lead to excessive memory usage. This option can be used to disable
1144        vectorization in such cases.
1145
1146    Returns:
1147      A list or nested structure of Tensors (or None), one for each element in
1148      `sources`. Returned structure is the same as the structure of `sources`.
1149      Note if any gradient is sparse (IndexedSlices), jacobian function
1150      currently makes it dense and returns a Tensor instead. This may change in
1151      the future.
1152
1153
1154    Raises:
1155      RuntimeError: If called on a used, non-persistent tape.
1156      RuntimeError: If called on a non-persistent tape with eager execution
1157        enabled and without enabling experimental_use_pfor.
1158      ValueError: If vectorization of jacobian computation fails.
1159    """
1160    if self._tape is None:
1161      raise RuntimeError("A non-persistent GradientTape can only be used to "
1162                         "compute one set of gradients (or jacobians)")
1163
1164    flat_sources = nest.flatten(sources)
1165    target_static_shape = target.shape
1166    target_shape = array_ops.shape(target)
1167    # Note that we push and pop the tape here and below. This is needed since we
1168    # need gradients through the enclosed operations.
1169    with self._ensure_recording():
1170      target = array_ops.reshape(target, [-1])
1171
1172    def loop_fn(i):
1173      with self._ensure_recording():
1174        y = array_ops.gather(target, i)
1175      return self.gradient(y, flat_sources,
1176                           unconnected_gradients=unconnected_gradients)
1177
1178    try:
1179      target_size = int(target.shape[0])
1180    except TypeError:
1181      target_size = array_ops.shape(target)[0]
1182
1183    if experimental_use_pfor:
1184      try:
1185        output = pfor_ops.pfor(loop_fn, target_size,
1186                               parallel_iterations=parallel_iterations)
1187      except ValueError as err:
1188        six.reraise(
1189            ValueError,
1190            ValueError(
1191                str(err) + "\nEncountered an exception while vectorizing the "
1192                "jacobian computation. Vectorization can be disabled by setting"
1193                " experimental_use_pfor to False."),
1194            sys.exc_info()[2])
1195    else:
1196      if context.executing_eagerly() and not self._persistent:
1197        raise RuntimeError(
1198            "GradientTape must be created with persistent=True"
1199            " to compute the jacobian with eager execution enabled and with "
1200            " experimental_use_pfor set to False.")
1201      output = pfor_ops.for_loop(
1202          loop_fn, [target.dtype] * len(flat_sources), target_size,
1203          parallel_iterations=parallel_iterations)
1204
1205    for i, out in enumerate(output):
1206      if out is not None:
1207        new_shape = array_ops.concat(
1208            [target_shape, array_ops.shape(out)[1:]], axis=0)
1209        out = array_ops.reshape(out, new_shape)
1210        if context.executing_eagerly():
1211          out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
1212      output[i] = out
1213
1214    return nest.pack_sequence_as(sources, output)
1215
1216  def batch_jacobian(self,
1217                     target,
1218                     source,
1219                     unconnected_gradients=UnconnectedGradients.NONE,
1220                     parallel_iterations=None,
1221                     experimental_use_pfor=True):
1222    """Computes and stacks per-example jacobians.
1223
1224    See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1225    for the definition of a Jacobian. This function is essentially an efficient
1226    implementation of the following:
1227
1228    `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
1229
1230    Note that compared to `GradientTape.jacobian` which computes gradient of
1231    each output value w.r.t each input value, this function is useful when
1232    `target[i,...]` is independent of `source[j,...]` for `j != i`. This
1233    assumption allows more efficient computation as compared to
1234    `GradientTape.jacobian`. The output, as well as intermediate activations,
1235    are lower dimensional and avoid a bunch of redundant zeros which would
1236    result in the jacobian computation given the independence assumption.
1237
1238    Note: Unless you set `persistent=True` a GradientTape can only be used to
1239    compute one set of gradients (or jacobians).
1240
1241    Note: By default the batch_jacobian implementation uses parallel for (pfor),
1242    which creates a tf.function under the hood for each batch_jacobian call.
1243    For better performance, and to avoid recompilation and vectorization
1244    rewrites on each call, enclose GradientTape code in @tf.function.
1245
1246
1247    Example usage:
1248
1249    ```python
1250    with tf.GradientTape() as g:
1251      x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
1252      g.watch(x)
1253      y = x * x
1254    batch_jacobian = g.batch_jacobian(y, x)
1255    # batch_jacobian is [[[2,  0], [0,  4]], [[6,  0], [0,  8]]]
1256    ```
1257
1258    Args:
1259      target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n].
1260        `target[i,...]` should only depend on `source[i,...]`.
1261      source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m].
1262      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1263        alters the value which will be returned if the target and sources are
1264        unconnected. The possible values and effects are detailed in
1265        'UnconnectedGradients' and it defaults to 'none'.
1266      parallel_iterations: A knob to control how many iterations are dispatched
1267        in parallel. This knob can be used to control the total memory usage.
1268      experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
1269        uses a tf.while_loop.
1270
1271    Returns:
1272      A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
1273      is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked
1274      per-example jacobians.
1275
1276    Raises:
1277      RuntimeError: If called on a used, non-persistent tape.
1278      RuntimeError: If called on a non-persistent tape with eager execution
1279        enabled and without enabling experimental_use_pfor.
1280      ValueError: If vectorization of jacobian computation fails or if first
1281        dimension of `target` and `source` do not match.
1282    """
1283    if self._tape is None:
1284      raise RuntimeError("A non-persistent GradientTape can only be used to"
1285                         "compute one set of gradients (or jacobians)")
1286    target_shape = target.shape
1287    if target_shape.rank is None:
1288      dim = tensor_shape.Dimension(None)
1289    else:
1290      dim = target_shape.dims[0]
1291    if not (target_shape.with_rank_at_least(2) and
1292            source.shape.with_rank_at_least(2) and
1293            dim.is_compatible_with(source.shape[0])):
1294      raise ValueError(
1295          "Need first dimension of target shape (%s) and "
1296          "source shape (%s) to match." % (target.shape, source.shape))
1297    if target_shape.is_fully_defined():
1298      batch_size = int(target_shape[0])
1299      target_row_size = target_shape.num_elements() // batch_size
1300    else:
1301      target_shape = array_ops.shape(target)
1302      batch_size = target_shape[0]
1303      target_row_size = array_ops.size(target) // batch_size
1304    source_shape = array_ops.shape(source)
1305    # Flatten target to 2-D.
1306    # Note that we push and pop the tape here and below. This is needed since we
1307    # need gradients through the enclosed operations.
1308    with self._ensure_recording():
1309      with ops.control_dependencies(
1310          [check_ops.assert_equal(batch_size, source_shape[0])]):
1311        target = array_ops.reshape(target, [batch_size, target_row_size])
1312
1313    run_once = False
1314
1315    def loop_fn(i):
1316      nonlocal run_once
1317      if run_once and not self._persistent:
1318        if parallel_iterations is not None:
1319          raise RuntimeError(
1320              "GradientTape must be created with persistent=True"
1321              " to compute the batch_jacobian with parallel_iterations.")
1322        else:
1323          raise RuntimeError(
1324              "GradientTape must be created with persistent=True"
1325              " to compute the batch_jacobian.")
1326      run_once = True
1327
1328      with self._ensure_recording():
1329        y = array_ops.gather(target, i, axis=1)
1330      return self.gradient(y, source,
1331                           unconnected_gradients=unconnected_gradients)
1332
1333    if experimental_use_pfor:
1334      try:
1335        output = pfor_ops.pfor(loop_fn, target_row_size,
1336                               parallel_iterations=parallel_iterations)
1337      except ValueError as err:
1338        six.reraise(
1339            ValueError,
1340            ValueError(
1341                str(err) + "\nEncountered an exception while vectorizing the "
1342                "batch_jacobian computation. Vectorization can be disabled by "
1343                "setting experimental_use_pfor to False."),
1344            sys.exc_info()[2])
1345    else:
1346      if context.executing_eagerly() and not self._persistent:
1347        raise RuntimeError(
1348            "GradientTape must be created with persistent=True"
1349            " to compute the batch_jacobian with eager execution enabled and "
1350            " with experimental_use_pfor set to False.")
1351      output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
1352                                 parallel_iterations=parallel_iterations)
1353    new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
1354    if output is None:
1355      # Note that this block is returning zeros when it could use `None` to
1356      # represent unconnected gradients. This is to maintain compatibility with
1357      # the previous behavior, which ignored `unconnected_gradients`.
1358      output = array_ops.zeros(new_shape, target.dtype)
1359      return output
1360    else:
1361      output = array_ops.reshape(output,
1362                                 [target_row_size, batch_size, -1])
1363      output = array_ops.transpose(output, [1, 0, 2])
1364
1365      output = array_ops.reshape(output, new_shape)
1366      return output
1367