• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code for backpropagation using the tape utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import operator
23import sys
24
25import six
26
27from tensorflow.python import pywrap_tfe
28from tensorflow.python import _pywrap_utils
29from tensorflow.python.eager import backprop_util
30from tensorflow.python.eager import context
31from tensorflow.python.eager import execute
32from tensorflow.python.eager import imperative_grad
33from tensorflow.python.eager import tape
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import check_ops
41from tensorflow.python.ops import default_gradient
42from tensorflow.python.ops import gen_array_ops
43from tensorflow.python.ops import gen_math_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import resource_variable_ops
46from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49from tensorflow.python.util import tf_contextlib
50from tensorflow.python.util import tf_inspect
51from tensorflow.python.util.lazy_loader import LazyLoader
52from tensorflow.python.util.tf_export import tf_export
53
54
55# Note that we need to lazy load the following two modules to avoid creating
56# circular dependencies.
57# TODO(b/119775953): fix the circular dependencies.
58pfor_ops = LazyLoader(
59    "pfor_ops", globals(),
60    "tensorflow.python.ops.parallel_for.control_flow_ops")
61
62function = LazyLoader("function", globals(),
63                      "tensorflow.python.eager.function")
64
65_op_attr_type_cache = {}
66
67
68def op_attr_type(op_type, attr_name):
69  try:
70    return _op_attr_type_cache[(op_type, attr_name)]
71  except KeyError:
72    context.ensure_initialized()
73    h = context.context()._handle  # pylint: disable=protected-access
74    attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
75  _op_attr_type_cache[(op_type, attr_name)] = attr_type
76  return attr_type
77
78
79def make_attr(attr_type, value):
80  # pybind11 enums do not return the raw value like SWIG enums do. They are
81  # useful when comparing amongst each other but not direct integers as we are
82  # doing in most tests.
83  # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
84  # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
85  # from integer value to class.
86  if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
87    return dtypes.as_dtype(value)
88  elif attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
89    return [dtypes.as_dtype(v) for v in value]
90  elif attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
91    return tensor_shape.as_shape(value).as_proto()
92  elif attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
93    return [tensor_shape.as_shape(v).as_proto() for v in value]
94  elif isinstance(value, str):
95    return value.encode()
96  return value
97
98
99class _MockOp(object):
100  """Pretends to be a tf.Operation for the gradient functions."""
101
102  def __init__(self, attrs, inputs, outputs, typ, skip_input_indices):
103    self.attrs = attrs
104    self.inputs = inputs
105    self.outputs = outputs
106    self.type = typ
107    self.skip_input_indices = skip_input_indices
108
109  def get_attr(self, attr):
110    typ = op_attr_type(self.type, attr)
111    for i in range(0, len(self.attrs), 2):
112      if self.attrs[i] == attr:
113        return make_attr(typ, self.attrs[i + 1])
114    raise KeyError(attr)
115
116  def _get_control_flow_context(self):
117    raise NotImplementedError(
118        "tf.GradientTape.gradients() does not support graph control flow "
119        "operations like tf.cond or tf.while at this time. Use tf.gradients() "
120        "instead. If you need this feature, please file a feature request at "
121        "https://github.com/tensorflow/tensorflow/issues/new"
122    )
123
124
125def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
126                       out_grads, skip_input_indices):
127  """Calls the gradient function of the op.
128
129  Args:
130    op_name: the name of the op to be differentiated.
131    attr_tuple: the attrs, as a tuple.
132    num_inputs: the number of inputs to the op.
133    inputs: inputs to the original operation.
134    outputs: outputs to the original operation.
135    out_grads: gradients of the operation wrt its outputs.
136    skip_input_indices: a tuple that is passed to the gradient function,
137      indicating which inputs to skip calculating the gradient for
138
139  Returns:
140    The gradients with respect to the inputs of the function, as a list.
141  """
142  mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices)
143  grad_fn = ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
144  if grad_fn is None:
145    return [None] * num_inputs
146
147  return grad_fn(mock_op, *out_grads)
148
149
150pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
151
152
153def _must_record_gradient():
154  return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
155
156
157def _record_gradient(op_name, inputs, attrs, results):
158  return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
159
160
161execute.must_record_gradient = _must_record_gradient
162execute.record_gradient = _record_gradient
163
164
165def implicit_val_and_grad(f):
166  """Returns a function which differentiates f with respect to variables.
167
168  The wrapped function returns the value and the gradient of f when called with
169  the same arguments. The gradient is with respect to all trainable TFE
170  variables accessed by `f`.
171
172  This function is useful when the exact set of variables to differentiate with
173  is not known ahead of time.
174
175  Example:
176
177  ```python
178  dense_layer = tf.compat.v1.layers.Dense(1)
179  def loss(x, y):
180    return tf.reduce_sum(tf.square(dense_layer(x) - y))
181
182  # Obtain the gradient function.
183  val_grad_fn = tfe.implicit_value_and_gradients(loss)
184
185  # Invoke the gradient function with concrete values of x and y.
186  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
187  y = tf.constant([[10.0], [20.0]])
188  value, grads_and_vars = val_grad_fn(x, y)
189  print('Value of loss: %s' % value)
190
191  # Apply the gradients to Variables.
192  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
193  optimizer.apply_gradients(grads_and_vars)
194  ```
195
196  Args:
197    f: function to be differentiated. If `f` returns a scalar, this scalar will
198      be differentiated. If `f` returns a tensor or list of tensors, by default
199      a scalar will be computed by adding all their values to produce a single
200      scalar.
201
202  Returns:
203    A function which, when called, returns a tuple pair.
204    Its first element is the value to which the function evaluates.
205    Its second element is list of (gradient, variable) pairs.
206
207  Raises:
208    ValueError: if `f` returns None.
209  """
210  # TODO(cais): Remove calls to tf.constant() once the gradients functions
211  # accept lists and np.ndarrays.
212
213  def grad_fn(*args, **kwds):
214    """Computes the gradient of the wrapped function."""
215    this_tape = tape.push_new_tape()
216    try:
217      end_node = f(*args, **kwds)
218      if end_node is None:
219        raise ValueError("Cannot differentiate a function that returns None; "
220                         "did you forget to return a value from {}?".format(
221                             f.__name__))
222    finally:
223      tape.pop_tape(this_tape)
224    # Note: variables are returned in construction order. This ensures unique
225    # order across executions.
226    variables = this_tape.watched_variables()
227    if not variables:
228      raise ValueError("No trainable variables were accessed while the "
229                       "function was being computed.")
230
231    sources = [v.handle for v in variables]
232    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
233                                           sources)
234    return end_node, list(zip(grad, variables))
235
236  return grad_fn
237
238
239def implicit_grad(f):
240  """Returns a function which differentiates f with respect to variables.
241
242  The wrapped function returns the gradient of f when called with the same
243  arguments. The gradient is with respect to all trainable TFE variables
244  accessed by `f`.
245
246  This function is useful when the exact set of variables to differentiate with
247  is not known ahead of time.
248
249  Example:
250
251  ```python
252  dense_layer = tf.compat.v1.layers.Dense(1)
253  def loss(x, y):
254    return tf.reduce_sum(tf.square(dense_layer(x) - y))
255
256  # Obtain the gradient function.
257  grad_fn = tfe.implicit_gradients(loss)
258
259  # Invoke the gradient function with concrete values of x and y.
260  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
261  y = tf.constant([[10.0], [20.0]])
262  grads_and_vars = grad_fn(x, y)
263
264  # Apply the gradients to Variables.
265  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
266  optimizer.apply_gradients(grads_and_vars)
267  ```
268
269  Args:
270    f: function to be differentiated. If `f` returns a scalar, this scalar will
271      be differentiated. If `f` returns a tensor or list of tensors, by default
272      a scalar will be computed by adding all their values to produce a single
273      scalar.
274
275  Returns:
276    A function which, when called, returns a list of (gradient, variable) pairs.
277  """
278  # TODO(cais): Remove calls to tf.constant() once the gradients functions
279  # accept lists and np.ndarrays.
280
281  def grad_fn(*args, **kwds):
282    """Computes the gradient of the wrapped function."""
283    return implicit_val_and_grad(f)(*args, **kwds)[1]
284
285  return grad_fn
286
287
288def _get_arg_spec(f, params, param_args):
289  """The positions of the parameters of f to be differentiated in param_args."""
290  try:
291    args = tf_inspect.getfullargspec(f).args
292  except TypeError as e:
293    # TypeError can happen when f is a callable object.
294    if params is None:
295      return range(len(param_args))
296    elif all(isinstance(x, int) for x in params):
297      return params
298    raise ValueError("Either callable provided is not a function or could not "
299                     "inspect its arguments by name: %s. Original error: %s"
300                     % (f, e))
301  if params is None:
302    if not args:
303      return range(len(param_args))
304    if args[0] == "self":
305      return range(len(args) - 1)
306    else:
307      return range(len(args))
308  elif all(isinstance(x, six.string_types) for x in params):
309    return [args.index(n) for n in params]
310  elif all(isinstance(x, int) for x in params):
311    return params
312  else:
313    raise ValueError(
314        "params must be all strings or all integers; got %s." % params)
315
316
317def gradients_function(f, params=None):
318  """Returns a function which differentiates f with respect to params.
319
320  Example:
321  ```python
322  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
323  # Therefore, the 1st order derivatives are:
324  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
325  #   df / dy = x ^ 3 - 2 * x * y
326  # The 2nd order derivatives with respect to x is:
327  #   d^2 f / (dx)^2 = 6 * x * y
328  def f(x, y):
329    return x * x * x * y - x * y * y
330
331  # Obtain a function that returns 1st order gradients.
332  grad_fn = tfe.gradients_function(f)
333
334  x = 2.0
335  y = 3.0
336
337  # Invoke the 1st order gradient function.
338  x_grad, y_grad = grad_fn(x, y)
339  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
340  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
341
342  # Obtain a function that returns the 2nd order gradient with respect to x.
343  gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0])
344
345  # Invoke the 2nd order gradient function.
346  x_gradgrad = gradgrad_fn(x, y)[0]
347  assert x_gradgrad.numpy() == 6 * 2 * 3
348
349  # To obtain a callable that returns the gradient(s) of `f` with respect to a
350  # subset of its inputs, use the `params` keyword argument with
351  # `gradients_function()`.
352  ygrad_fn = tfe.gradients_function(f, params=[1])
353
354  (y_grad,) = ygrad_fn(x, y)
355  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
356  ```
357
358  Note that only tensors with real or complex dtypes are differentiable.
359
360  Args:
361    f: function to be differentiated. If `f` returns a scalar, this scalar will
362      be differentiated. If `f` returns a tensor or list of tensors, by default
363      a scalar will be computed by adding all their values to produce a single
364      scalar. If desired, the tensors can be elementwise multiplied by the
365      tensors passed as the `dy` keyword argument to the returned gradient
366      function.
367    params: list of parameter names of f or list of integers indexing the
368      parameters with respect to which we'll differentiate. Passing None
369      differentiates with respect to all parameters.
370
371  Returns:
372    function which, when called, returns the value of f and the gradient
373    of `f` with respect to all of `params`. The function takes an extra optional
374    keyword argument `dy`. Setting it allows computation of vector jacobian
375    products for vectors other than the vector of ones.
376
377  Raises:
378    ValueError: if the params are not all strings or all integers.
379  """
380
381  def decorated(*args, **kwds):
382    """Computes the gradient of the decorated function."""
383
384    _, grad = val_and_grad_function(f, params=params)(*args, **kwds)
385    return grad
386
387  return decorated
388
389
390def _ensure_unique_tensor_objects(parameter_positions, args):
391  """Make each of the parameter_positions in args a unique ops.Tensor object.
392
393  Ensure that each parameter is treated independently.
394  For example:
395
396  def f(x, y): return x * y
397  g = gradients_function(f)
398  one = tf.constant(1.)
399
400  g(one, one) should return [1., 1.]
401  (even though the two arguments are the same Tensor object).
402
403  Args:
404    parameter_positions: List of indices into args defining the arguments to
405      differentiate against.
406    args: A list of arguments to the function to be differentiated.
407
408  Returns:
409    args, possibly edited in-place.
410  """
411  s = set()
412  for (i, t) in enumerate(args):
413    if i in parameter_positions:
414      tid = ops.tensor_id(t)
415      if tid in s:
416        args[i] = gen_array_ops.identity(args[i])
417      else:
418        s.add(tid)
419  return args
420
421
422def val_and_grad_function(f, params=None):
423  """Returns a function that computes f and its derivative w.r.t. params.
424
425  Example:
426  ```python
427  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
428  # Therefore, the 1st order derivatives are:
429  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
430  #   df / dy = x ^ 3 - 2 * x * y
431  def f(x, y):
432    return x * x * x * y - x * y * y
433
434  # Obtain a function that returns the function value and the 1st order
435  # gradients.
436  val_grads_fn = tfe.value_and_gradients_function(f)
437
438  x = 2.0
439  y = 3.0
440
441  # Invoke the value-and-gradients function.
442  f_val, (x_grad, y_grad) = val_grads_fn(x, y)
443  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
444  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
445  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
446
447  # To obtain a callable that returns the value of `f` and the gradient(s) of
448  # `f` with respect to a subset of its inputs, use the `params` keyword
449  # argument with `value_and_gradients_function()`.
450  val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1])
451
452  f_val, (y_grad,) = val_ygrad_fn(x, y)
453  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
454  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
455  ```
456
457  Args:
458    f: function to be differentiated. If `f` returns a scalar, this scalar will
459      be differentiated. If `f` returns a tensor or list of tensors, by default
460      a scalar will be computed by adding all their values to produce a single
461      scalar. If desired, the tensors can be elementwise multiplied by the
462      tensors passed as the `dy` keyword argument to the returned gradient
463      function.
464    params: list of parameter names of f or list of integers indexing the
465      parameters with respect to which we'll differentiate. Passing `None`
466      differentiates with respect to all parameters.
467
468  Returns:
469    function which, when called, returns the value of f and the gradient
470    of f with respect to all of `params`. The function takes an extra optional
471    keyword argument "dy". Setting it allows computation of vector jacobian
472    products for vectors other than the vector of ones.
473
474  Raises:
475    ValueError: if the params are not all strings or all integers.
476  """
477
478  def decorated(*args, **kwds):
479    """Computes the value and gradient of the decorated function."""
480    dy = kwds.pop("dy", None)
481    if kwds:
482      raise ValueError("Functions to be differentiated cannot "
483                       "receive keyword arguments.")
484    val, vjp = make_vjp(f, params)(*args, **kwds)
485    return val, vjp(dy=dy)
486
487  return decorated
488
489
490def make_vjp(f, params=None, persistent=True):
491  """Returns a function that computes f and its vjp w.r.t.
492
493  params.
494
495  The term "vjp" here is an abbreviation for vector-jacobian product.
496
497  Args:
498    f: the function to be differentiated.
499    params: the parameters (numbers or names) to differentiate with respect to.
500      A value of None will differentiate with respect to all parameters.
501    persistent: Boolean controlling whether the VJP function can be re-used.
502      Must be True or False.
503
504  Returns:
505    A function, which when called, returns a tuple (value, vjp), where:
506    - value is the result of calling f.
507    - vjp is a function, which takes a vector as an argument and
508      returns the product of that vector with the Jacobian of f.
509      Providing no argument to vjp is equivalent to providing a
510      vector of ones.
511
512    For example,
513    ```python
514    def f(x):
515      return x * x
516
517    wrapped_fn = tfe.make_vjp(f)
518    result, vjp = wrapped_fn(tf.constant(3.0))
519    # result is 9.0
520    vjp()  # the vjp function rturns 6.0
521
522  Raises:
523    ValueError: if `f` returns None.
524  """
525
526  def decorated(*args, **kwds):
527    """Computes the value and gradient of the decorated function."""
528    parameter_positions = _get_arg_spec(f, params, args)
529    assert not kwds, "The gradient function can't take keyword arguments."
530    this_tape = tape.push_new_tape(persistent=persistent)
531    try:
532      sources = []
533      args = [
534          ops.convert_to_tensor(arg) if i in parameter_positions else arg
535          for i, arg in enumerate(args)
536      ]
537      args = _ensure_unique_tensor_objects(parameter_positions, args)
538      for i in parameter_positions:
539        sources.append(args[i])
540        tape.watch(this_tape, args[i])
541      result = f(*args)
542      if result is None:
543        raise ValueError("Cannot differentiate a function that returns None; "
544                         "did you forget to return a value from {}?".format(
545                             f.__name__))
546      flat_result = nest.flatten(result)
547      flat_result = [gen_array_ops.identity(x) for x in flat_result]
548      result = nest.pack_sequence_as(result, flat_result)
549    finally:
550      tape.pop_tape(this_tape)
551    def vjp(dy=None):
552      if dy is not None:
553        dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
554      return imperative_grad.imperative_grad(
555          this_tape, nest.flatten(result), sources, output_gradients=dy)
556
557    return result, vjp
558
559  return decorated
560
561
562def flatten_nested_indexed_slices(grad):
563  assert isinstance(grad, ops.IndexedSlices)
564  if isinstance(grad.values, ops.Tensor):
565    return grad
566  else:
567    assert isinstance(grad.values, ops.IndexedSlices)
568    g = flatten_nested_indexed_slices(grad.values)
569    return ops.IndexedSlices(g.values, array_ops.gather(grad.indices,
570                                                        g.indices),
571                             g.dense_shape)
572
573
574def aggregate_indexed_slices_gradients(grads):
575  """Aggregates gradients containing `IndexedSlices`s."""
576  if len(grads) < 1:
577    return None
578  elif len(grads) == 1:
579    return grads[0]
580  else:
581    grads = [g for g in grads if g is not None]
582    # If any gradient is a `Tensor`, sum them up and return a dense tensor
583    # object.
584    if any(isinstance(g, ops.Tensor) for g in grads):
585      return math_ops.add_n(grads)
586
587    # The following `_as_indexed_slices_list` casts ids of IndexedSlices into
588    # int64. It is to make sure the inputs of `concat` all have same the data
589    # type.
590    grads = math_ops._as_indexed_slices_list(grads)  # pylint: disable=protected-access
591
592    grads = [flatten_nested_indexed_slices(x) for x in grads]
593    # Form IndexedSlices out of the concatenated values and indices.
594    concat_grad = ops.IndexedSlices(
595        array_ops.concat([x.values for x in grads], axis=0),
596        array_ops.concat([x.indices for x in grads], axis=0),
597        grads[0].dense_shape)
598
599    return concat_grad
600
601
602def _aggregate_grads(gradients):
603  """Aggregate gradients from multiple sources.
604
605  Args:
606    gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
607
608  Returns:
609    If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
610    Otherwise returns an aggregated 'IndexedSlices'.
611  """
612  assert gradients, "No gradients to aggregate"
613
614  if len(gradients) == 1:
615    return gradients[0]
616  if all(isinstance(g, ops.Tensor) for g in gradients):
617    return gen_math_ops.add_n(gradients)
618  else:
619    assert all(isinstance(g, (ops.Tensor, ops.IndexedSlices))
620               for g in gradients)
621    return aggregate_indexed_slices_gradients(gradients)
622
623
624def _num_elements(grad):
625  """The number of elements in the `grad` tensor."""
626  if isinstance(grad, ops.Tensor):
627    shape_tuple = grad._shape_tuple()  # pylint: disable=protected-access
628    if shape_tuple is None or None in shape_tuple:
629      return 0
630    return functools.reduce(operator.mul, shape_tuple, 1)
631  if isinstance(grad, ops.IndexedSlices):
632    return functools.reduce(operator.mul, grad.values._shape_tuple(), 1)  # pylint: disable=protected-access
633  raise ValueError("`grad` not a Tensor or IndexedSlices.")
634
635
636def _fast_fill(value, shape, dtype):
637  return array_ops.fill(
638      constant_op.constant(shape, dtype=dtypes.int32),
639      constant_op.constant(value, dtype=dtype))
640
641
642def _zeros(shape, dtype):
643  """Helper to return (possibly cached) zero tensors in eager mode."""
644  # Note: variants will use _zeros_like
645  if dtype == dtypes.string or dtype == dtypes.resource:
646    return None
647
648  ctx = context.context()
649  if not ctx.executing_eagerly():
650    return array_ops.zeros(shape, dtype)
651
652  device = ctx.device_name
653
654  if tensor_util.is_tensor(shape):
655    shape_key = shape.experimental_ref()
656  else:
657    shape_key = shape
658  cache_key = shape_key, dtype, device
659  cached = ctx.zeros_cache().get(cache_key)
660  if cached is None:
661    if dtypes.as_dtype(dtype).is_bool:
662      value = False
663    else:
664      value = 0
665    cached = _fast_fill(value, shape, dtype)
666    ctx.zeros_cache().put(cache_key, cached)
667  return cached
668
669
670def _ones(shape, dtype):
671  as_dtype = dtypes.as_dtype(dtype)
672  if as_dtype == dtypes.string:
673    return None
674
675  if not context.context().executing_eagerly():
676    return array_ops.ones(shape, dtype)
677
678  if as_dtype.is_bool:
679    value = True
680  else:
681    value = 1
682
683  if shape == ():  # pylint: disable=g-explicit-bool-comparison
684    return constant_op.constant(value, dtype=dtype)
685  return _fast_fill(value, shape, dtype)
686
687
688_default_vspace = imperative_grad.VSpace(
689    num_elements_fn=_num_elements,
690    aggregate_fn=_aggregate_grads,
691    zeros_fn=_zeros,
692    ones_fn=_ones,
693    zeros_like_fn=default_gradient.zeros_like,
694    ones_like_fn=default_gradient.ones_like,
695    graph_shape_fn=gen_array_ops.shape)
696pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
697
698
699def _handle_or_self(x):
700  """If x is ResourceVariable, return its handle, else x."""
701  if resource_variable_ops.is_resource_variable(x):
702    x = x.handle
703  return x
704
705
706@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"])
707class GradientTape(object):
708  """Record operations for automatic differentiation.
709
710  Operations are recorded if they are executed within this context manager and
711  at least one of their inputs is being "watched".
712
713  Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`,
714  where `trainable=True` is default in both cases) are automatically watched.
715  Tensors can be manually watched by invoking the `watch` method on this context
716  manager.
717
718  For example, consider the function `y = x * x`. The gradient at `x = 3.0` can
719  be computed as:
720
721  ```python
722  x = tf.constant(3.0)
723  with tf.GradientTape() as g:
724    g.watch(x)
725    y = x * x
726  dy_dx = g.gradient(y, x) # Will compute to 6.0
727  ```
728
729  GradientTapes can be nested to compute higher-order derivatives. For example,
730
731  ```python
732  x = tf.constant(3.0)
733  with tf.GradientTape() as g:
734    g.watch(x)
735    with tf.GradientTape() as gg:
736      gg.watch(x)
737      y = x * x
738    dy_dx = gg.gradient(y, x)     # Will compute to 6.0
739  d2y_dx2 = g.gradient(dy_dx, x)  # Will compute to 2.0
740  ```
741
742  By default, the resources held by a GradientTape are released as soon as
743  GradientTape.gradient() method is called. To compute multiple gradients over
744  the same computation, create a persistent gradient tape. This allows multiple
745  calls to the gradient() method as resources are released when the tape object
746  is garbage collected. For example:
747
748  ```python
749  x = tf.constant(3.0)
750  with tf.GradientTape(persistent=True) as g:
751    g.watch(x)
752    y = x * x
753    z = y * y
754  dz_dx = g.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
755  dy_dx = g.gradient(y, x)  # 6.0
756  del g  # Drop the reference to the tape
757  ```
758
759  By default GradientTape will automatically watch any trainable variables that
760  are accessed inside the context. If you want fine grained control over which
761  variables are watched you can disable automatic tracking by passing
762  `watch_accessed_variables=False` to the tape constructor:
763
764  ```python
765  with tf.GradientTape(watch_accessed_variables=False) as tape:
766    tape.watch(variable_a)
767    y = variable_a ** 2  # Gradients will be available for `variable_a`.
768    z = variable_b ** 3  # No gradients will be available since `variable_b` is
769                         # not being watched.
770  ```
771
772  Note that when using models you should ensure that your variables exist when
773  using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
774  first iteration not have any gradients:
775
776  ```python
777  a = tf.keras.layers.Dense(32)
778  b = tf.keras.layers.Dense(32)
779
780  with tf.GradientTape(watch_accessed_variables=False) as tape:
781    tape.watch(a.variables)  # Since `a.build` has not been called at this point
782                             # `a.variables` will return an empty list and the
783                             # tape will not be watching anything.
784    result = b(a(inputs))
785    tape.gradient(result, a.variables)  # The result of this computation will be
786                                        # a list of `None`s since a's variables
787                                        # are not being watched.
788  ```
789
790  Note that only tensors with real or complex dtypes are differentiable.
791  """
792
793  def __init__(self, persistent=False, watch_accessed_variables=True):
794    """Creates a new GradientTape.
795
796    Args:
797      persistent: Boolean controlling whether a persistent gradient tape
798        is created. False by default, which means at most one call can
799        be made to the gradient() method on this object.
800      watch_accessed_variables: Boolean controlling whether the tape will
801        automatically `watch` any (trainable) variables accessed while the tape
802        is active. Defaults to True meaning gradients can be requested from any
803        result computed in the tape derived from reading a trainable `Variable`.
804        If False users must explicitly `watch` any `Variable`s they want to
805        request gradients from.
806    """
807    self._tape = None
808    self._persistent = persistent
809    self._watch_accessed_variables = watch_accessed_variables
810    self._watched_variables = ()
811    self._recording = False
812    self._created_eagerly = context.executing_eagerly()
813    if self._created_eagerly:
814      context.ensure_initialized()
815      context.context().start_step()
816
817  def __enter__(self):
818    """Enters a context inside which operations are recorded on this tape."""
819    self._push_tape()
820    return self
821
822  def __exit__(self, typ, value, traceback):
823    """Exits the recording context, no further operations are traced."""
824    if self._recording:
825      self._pop_tape()
826
827  def _push_tape(self):
828    """Pushes a new tape onto the tape stack."""
829    if self._recording:
830      raise ValueError("Tape is still recording, This can happen if you try to "
831                       "re-enter an already-active tape.")
832    if self._tape is None:
833      self._tape = tape.push_new_tape(
834          persistent=self._persistent,
835          watch_accessed_variables=self._watch_accessed_variables)
836    else:
837      tape.push_tape(self._tape)
838    self._recording = True
839
840  def _pop_tape(self):
841    if not self._recording:
842      raise ValueError("Tape is not recording.")
843    tape.pop_tape(self._tape)
844    self._recording = False
845
846  def __del__(self):
847    if self._created_eagerly:
848      try:
849        context.context().end_step()
850      except AttributeError:
851        pass
852      except TypeError:
853        pass
854
855  def watch(self, tensor):
856    """Ensures that `tensor` is being traced by this tape.
857
858    Args:
859      tensor: a Tensor or list of Tensors.
860
861    Raises:
862      ValueError: if it encounters something that is not a tensor.
863    """
864    for t in nest.flatten(tensor):
865      if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
866        raise ValueError("Passed in object of type {}, not tf.Tensor".format(
867            type(t)))
868      if not backprop_util.IsTrainable(t):
869        logging.log_first_n(
870            logging.WARN, "The dtype of the watched tensor must be "
871            "floating (e.g. tf.float32), got %r", 5, t.dtype)
872      if hasattr(t, "handle"):
873        # There are many variable-like objects, all of them currently have
874        # `handle` attribute that points to a tensor. If this changes, internals
875        # of watch_variable need to change as well.
876        tape.watch_variable(self._tape, t)
877      else:
878        tape.watch(self._tape, t)
879
880  @tf_contextlib.contextmanager
881  def stop_recording(self):
882    """Temporarily stops recording operations on this tape.
883
884    Operations executed while this context manager is active will not be
885    recorded on the tape. This is useful for reducing the memory used by tracing
886    all computations.
887
888    For example:
889
890    ```
891      with tf.GradientTape(persistent=True) as t:
892        loss = compute_loss(model)
893        with t.stop_recording():
894          # The gradient computation below is not traced, saving memory.
895          grads = t.gradient(loss, model.variables)
896    ```
897
898    Yields:
899      None
900    Raises:
901      RuntimeError: if the tape is not currently recording.
902    """
903    if self._tape is None:
904      raise RuntimeError(
905          "Trying to stop recording a tape which is not recording.")
906    self._pop_tape()
907    try:
908      yield
909    finally:
910      self._push_tape()
911
912  def reset(self):
913    """Clears all information stored in this tape.
914
915    Equivalent to exiting and reentering the tape context manager with a new
916    tape. For example, the two following code blocks are equivalent:
917
918    ```
919    with tf.GradientTape() as t:
920      loss = loss_fn()
921    with tf.GradientTape() as t:
922      loss += other_loss_fn()
923    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
924
925
926    # The following is equivalent to the above
927    with tf.GradientTape() as t:
928      loss = loss_fn()
929      t.reset()
930      loss += other_loss_fn()
931    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
932    ```
933
934    This is useful if you don't want to exit the context manager for the tape,
935    or can't because the desired reset point is inside a control flow construct:
936
937    ```
938    with tf.GradientTape() as t:
939      loss = ...
940      if loss > k:
941        t.reset()
942    ```
943    """
944    self._pop_tape()
945    self._tape = None
946    self._push_tape()
947
948  def watched_variables(self):
949    """Returns variables watched by this tape in order of construction."""
950    if self._tape is not None:
951      self._watched_variables = self._tape.watched_variables()
952    return self._watched_variables
953
954  def gradient(self,
955               target,
956               sources,
957               output_gradients=None,
958               unconnected_gradients=UnconnectedGradients.NONE):
959    """Computes the gradient using operations recorded in context of this tape.
960
961    Args:
962      target: a list or nested structure of Tensors or Variables to be
963        differentiated.
964      sources: a list or nested structure of Tensors or Variables. `target`
965        will be differentiated against elements in `sources`.
966      output_gradients: a list of gradients, one for each element of
967        target. Defaults to None.
968      unconnected_gradients: a value which can either hold 'none' or 'zero' and
969        alters the value which will be returned if the target and sources are
970        unconnected. The possible values and effects are detailed in
971        'UnconnectedGradients' and it defaults to 'none'.
972
973    Returns:
974      a list or nested structure of Tensors (or IndexedSlices, or None),
975      one for each element in `sources`. Returned structure is the same as
976      the structure of `sources`.
977
978    Raises:
979      RuntimeError: if called inside the context of the tape, or if called more
980       than once on a non-persistent tape.
981      ValueError: if the target is a variable or if unconnected gradients is
982       called with an unknown value.
983    """
984    if self._tape is None:
985      raise RuntimeError("GradientTape.gradient can only be called once on "
986                         "non-persistent tapes.")
987    if self._recording:
988      if not self._persistent:
989        self._pop_tape()
990      else:
991        logging.log_first_n(
992            logging.WARN, "Calling GradientTape.gradient on a persistent "
993            "tape inside its context is significantly less "
994            "efficient than calling it outside the context (it "
995            "causes the gradient ops to be recorded on the "
996            "tape, leading to increased CPU and memory usage). "
997            "Only call GradientTape.gradient inside the "
998            "context if you actually want to trace the "
999            "gradient in order to compute higher order "
1000            "derivatives.", 1)
1001
1002    flat_targets = []
1003    for t in nest.flatten(target):
1004      if not backprop_util.IsTrainable(t):
1005        logging.vlog(
1006            logging.WARN, "The dtype of the target tensor must be "
1007            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1008            "got %r", t.dtype)
1009      if resource_variable_ops.is_resource_variable(t):
1010        with self:
1011          t = ops.convert_to_tensor(t)
1012      flat_targets.append(t)
1013
1014    flat_sources = nest.flatten(sources)
1015    flat_sources_raw = flat_sources
1016    flat_sources = [_handle_or_self(x) for x in flat_sources]
1017    for t in flat_sources_raw:
1018      if not backprop_util.IsTrainable(t):
1019        logging.vlog(
1020            logging.WARN, "The dtype of the source tensor must be "
1021            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1022            "got %r", t.dtype)
1023
1024    if output_gradients is not None:
1025      output_gradients = [None if x is None else ops.convert_to_tensor(x)
1026                          for x in nest.flatten(output_gradients)]
1027
1028    flat_grad = imperative_grad.imperative_grad(
1029        self._tape,
1030        flat_targets,
1031        flat_sources,
1032        output_gradients=output_gradients,
1033        sources_raw=flat_sources_raw,
1034        unconnected_gradients=unconnected_gradients)
1035
1036    if not self._persistent:
1037      # Keep track of watched variables before setting tape to None
1038      self._watched_variables = self._tape.watched_variables()
1039      self._tape = None
1040
1041    grad = nest.pack_sequence_as(sources, flat_grad)
1042    return grad
1043
1044  def jacobian(self,
1045               target,
1046               sources,
1047               unconnected_gradients=UnconnectedGradients.NONE,
1048               parallel_iterations=None,
1049               experimental_use_pfor=True):
1050    """Computes the jacobian using operations recorded in context of this tape.
1051
1052    See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) for the
1053    definition of a Jacobian.
1054
1055    Example usage:
1056
1057    ```python
1058    with tf.GradientTape() as g:
1059      x  = tf.constant([1.0, 2.0])
1060      g.watch(x)
1061      y = x * x
1062    jacobian = g.jacobian(y, x)
1063    # jacobian value is [[2., 0.], [0., 4.]]
1064    ```
1065
1066    Args:
1067      target: Tensor to be differentiated.
1068      sources: a list or nested structure of Tensors or Variables. `target`
1069        will be differentiated against elements in `sources`.
1070      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1071        alters the value which will be returned if the target and sources are
1072        unconnected. The possible values and effects are detailed in
1073        'UnconnectedGradients' and it defaults to 'none'.
1074      parallel_iterations: A knob to control how many iterations are dispatched
1075        in parallel. This knob can be used to control the total memory usage.
1076      experimental_use_pfor: If true, vectorizes the jacobian computation. Else
1077        falls back to a sequential while_loop. Vectorization can sometimes fail
1078        or lead to excessive memory usage. This option can be used to disable
1079        vectorization in such cases.
1080
1081    Returns:
1082      A list or nested structure of Tensors (or None), one for each element in
1083      `sources`. Returned structure is the same as the structure of `sources`.
1084      Note if any gradient is sparse (IndexedSlices), jacobian function
1085      currently makes it dense and returns a Tensor instead. This may change in
1086      the future.
1087
1088
1089    Raises:
1090      RuntimeError: If called on a non-persistent tape with eager execution
1091        enabled and without enabling experimental_use_pfor.
1092      ValueError: If vectorization of jacobian computation fails.
1093    """
1094    flat_sources = nest.flatten(sources)
1095    target_static_shape = target.shape
1096    target_shape = array_ops.shape(target)
1097    # Note that we push and pop the tape here and below. This is needed since we
1098    # need gradients through the enclosed operations.
1099    self._push_tape()
1100    target = array_ops.reshape(target, [-1])
1101    self._pop_tape()
1102
1103    def loop_fn(i):
1104      self._push_tape()
1105      y = array_ops.gather(target, i)
1106      self._pop_tape()
1107      return self.gradient(y, flat_sources,
1108                           unconnected_gradients=unconnected_gradients)
1109
1110    try:
1111      target_size = int(target.shape[0])
1112    except TypeError:
1113      target_size = array_ops.shape(target)[0]
1114
1115    if experimental_use_pfor:
1116      try:
1117        output = pfor_ops.pfor(loop_fn, target_size,
1118                               parallel_iterations=parallel_iterations)
1119      except ValueError as err:
1120        six.reraise(
1121            ValueError,
1122            ValueError(
1123                str(err) + "\nEncountered an exception while vectorizing the "
1124                "jacobian computation. Vectorization can be disabled by setting"
1125                " experimental_use_pfor to False."),
1126            sys.exc_info()[2])
1127    else:
1128      if context.executing_eagerly() and not self._persistent:
1129        raise RuntimeError(
1130            "GradientTape must be created with persistent=True"
1131            " to compute the jacobian with eager execution enabled and with "
1132            " experimental_use_pfor set to False.")
1133      output = pfor_ops.for_loop(
1134          loop_fn, [target.dtype] * len(flat_sources), target_size,
1135          parallel_iterations=parallel_iterations)
1136
1137    for i, out in enumerate(output):
1138      if out is not None:
1139        new_shape = array_ops.concat(
1140            [target_shape, array_ops.shape(out)[1:]], axis=0)
1141        out = array_ops.reshape(out, new_shape)
1142        if context.executing_eagerly():
1143          out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
1144      output[i] = out
1145
1146    return nest.pack_sequence_as(sources, output)
1147
1148  def batch_jacobian(self,
1149                     target,
1150                     source,
1151                     unconnected_gradients=UnconnectedGradients.NONE,
1152                     parallel_iterations=None,
1153                     experimental_use_pfor=True):
1154    """Computes and stacks per-example jacobians.
1155
1156    See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) for the
1157    definition of a Jacobian. This function is essentially an efficient
1158    implementation of the following:
1159
1160    `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
1161
1162    Note that compared to `GradientTape.jacobian` which computes gradient of
1163    each output value w.r.t each input value, this function is useful when
1164    `target[i,...]` is independent of `source[j,...]` for `j != i`. This
1165    assumption allows more efficient computation as compared to
1166    `GradientTape.jacobian`. The output, as well as intermediate activations,
1167    are lower dimensional and avoid a bunch of redundant zeros which would
1168    result in the jacobian computation given the independence assumption.
1169
1170    Example usage:
1171
1172    ```python
1173    with tf.GradientTape() as g:
1174      x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
1175      g.watch(x)
1176      y = x * x
1177    batch_jacobian = g.batch_jacobian(y, x)
1178    # batch_jacobian is [[[2,  0], [0,  4]], [[6,  0], [0,  8]]]
1179    ```
1180
1181    Args:
1182      target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n].
1183        `target[i,...]` should only depend on `source[i,...]`.
1184      source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m].
1185      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1186        alters the value which will be returned if the target and sources are
1187        unconnected. The possible values and effects are detailed in
1188        'UnconnectedGradients' and it defaults to 'none'.
1189      parallel_iterations: A knob to control how many iterations are dispatched
1190        in parallel. This knob can be used to control the total memory usage.
1191      experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
1192        uses a tf.while_loop.
1193
1194    Returns:
1195      A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
1196      is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked
1197      per-example jacobians.
1198
1199    Raises:
1200      RuntimeError: If called on a non-persistent tape with eager execution
1201        enabled and without enabling experimental_use_pfor.
1202      ValueError: If vectorization of jacobian computation fails or if first
1203        dimension of `target` and `source` do not match.
1204    """
1205    target_shape = target.shape
1206    if target_shape.rank is None:
1207      dim = tensor_shape.Dimension(None)
1208    else:
1209      dim = target_shape.dims[0]
1210    if not (target_shape.with_rank_at_least(2) and
1211            source.shape.with_rank_at_least(2) and
1212            dim.is_compatible_with(source.shape[0])):
1213      raise ValueError(
1214          "Need first dimension of target shape (%s) and "
1215          "source shape (%s) to match." % (target.shape, source.shape))
1216    if target_shape.is_fully_defined():
1217      batch_size = int(target_shape[0])
1218      target_row_size = target_shape.num_elements() // batch_size
1219    else:
1220      target_shape = array_ops.shape(target)
1221      batch_size = target_shape[0]
1222      target_row_size = array_ops.size(target) // batch_size
1223    source_shape = array_ops.shape(source)
1224    # Flatten target to 2-D.
1225    # Note that we push and pop the tape here and below. This is needed since we
1226    # need gradients through the enclosed operations.
1227    self._push_tape()
1228    with ops.control_dependencies(
1229        [check_ops.assert_equal(batch_size, source_shape[0])]):
1230      target = array_ops.reshape(target, [batch_size, target_row_size])
1231    self._pop_tape()
1232
1233    def loop_fn(i):
1234      self._push_tape()
1235      y = array_ops.gather(target, i, axis=1)
1236      self._pop_tape()
1237      return self.gradient(y, source,
1238                           unconnected_gradients=unconnected_gradients)
1239
1240    if experimental_use_pfor:
1241      try:
1242        output = pfor_ops.pfor(loop_fn, target_row_size,
1243                               parallel_iterations=parallel_iterations)
1244      except ValueError as err:
1245        six.reraise(
1246            ValueError,
1247            ValueError(
1248                str(err) + "\nEncountered an exception while vectorizing the "
1249                "batch_jacobian computation. Vectorization can be disabled by "
1250                "setting experimental_use_pfor to False."),
1251            sys.exc_info()[2])
1252    else:
1253      if context.executing_eagerly() and not self._persistent:
1254        raise RuntimeError(
1255            "GradientTape must be created with persistent=True"
1256            " to compute the batch_jacobian with eager execution enabled and "
1257            " with experimental_use_pfor set to False.")
1258      output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
1259                                 parallel_iterations=parallel_iterations)
1260    new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
1261    if output is None:
1262      return array_ops.zeros(new_shape)
1263    else:
1264      output = array_ops.reshape(output,
1265                                 [target_row_size, batch_size, -1])
1266      output = array_ops.transpose(output, [1, 0, 2])
1267      return array_ops.reshape(output, new_shape)
1268