• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for forward-mode automatic differentiation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import threading
23
24from tensorflow.python import pywrap_tfe
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import backprop_util
27from tensorflow.python.eager import execute
28from tensorflow.python.eager import forwardprop_util
29from tensorflow.python.eager import function
30
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops.parallel_for import control_flow_ops
36from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import nest
39from tensorflow.python.util.tf_export import tf_export
40
41
42# Dictionary mapping from op names to special-cased jvp functions. Otherwise
43# backward functions are transposed on the tape.
44_SPECIAL_CASES = {}
45
46
47def _identity_jvp(attr_tuple, inputs, outputs, tangents):
48  # Special-cased mostly for resource handles, where creating ones Tensors from
49  # handle data for transposing the backward function on the tape is error-prone
50  # (even if we get good handle data, partially defined shapes are an issue).
51  del attr_tuple, inputs, outputs
52  return [array_ops.identity(t) for t in tangents]
53
54
55_SPECIAL_CASES["Identity"] = _identity_jvp
56
57
58def _read_variable_jvp(attr_tuple, inputs, outputs, tangents):
59  # Like for Identity, this special case means we don't need to create
60  # variable-shaped Tensors from resource handles.
61  del attr_tuple, inputs, outputs
62  return [array_ops.identity(t) for t in tangents]
63
64
65_SPECIAL_CASES["ReadVariableOp"] = _read_variable_jvp
66
67
68_TRACE_COUNT_CONSISTENCY_LOCK = threading.Lock()
69# Map from op names to number of traces of _jvp_helper. Used to cap the number
70# of traces due to shape differences while still specializing where possible.
71_TRACE_COUNT = {}
72
73
74def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
75  """Computes a Jacobian-vector product for an op.
76
77  Note that this function would be wasteful if executed eagerly. It runs the
78  backward gradient function and throws away the result just to record its
79  operations on a GradientTape. These unused ops are pruned away when this
80  function is traced.
81
82  Args:
83    op_name: A string, the type of operation being executed.
84    attr_tuple: Attributes of the operation.
85    inputs: A flat list of input Tensors to the operation.
86    outputs: A flat list of output Tensors from the operation.
87    tangents: A flat list of Tensors, same shape as `inputs`.
88
89  Returns:
90    A flat list of tangents corresponding to `outputs`.
91  """
92  with _TRACE_COUNT_CONSISTENCY_LOCK:
93    # Just make sure writes don't clobber each other's increments; reads in
94    # _jvp_dispatch do not lock.
95    _TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1
96
97  special_case = _SPECIAL_CASES.get(op_name, None)
98  if special_case is not None:
99    return special_case(attr_tuple, inputs, outputs, tangents)
100  if not outputs:
101    # tape.gradients([], inputs) doesn't make much sense
102    return []
103  # Generally inner GradientTapes won't function while outer accumulators are
104  # recording. We temporarily reset forwardprop state to allow GradientTapes to
105  # function here.
106  with forwardprop_util.push_forwardprop_state():
107    trainable_inputs = []
108    trainable_indices = []
109    nontrivial_tangents = []
110    for input_index, tensor in enumerate(inputs):
111      if backprop_util.IsTrainable(tensor):
112        trainable_inputs.append(tensor)
113        trainable_indices.append(input_index)
114        nontrivial_tangents.append(tangents[input_index])
115
116    with backprop.GradientTape() as transpose_tape:
117      with backprop.GradientTape() as backfunc_tape:
118        backfunc_tape.watch(trainable_inputs)
119        execute.record_gradient(op_name, inputs, attr_tuple, outputs)
120
121      forwardprop_aids = []
122      trainable_outputs = []
123      nontrivial_output_indices = []
124      for output_index, output in enumerate(outputs):
125        if backprop_util.IsTrainable(output):
126          forwardprop_aids.append(
127              array_ops.ones_like(output, name="unused_forwardprop_aid"))
128          trainable_outputs.append(output)
129          nontrivial_output_indices.append(output_index)
130
131      transpose_tape.watch(forwardprop_aids)
132      grads = backfunc_tape.gradient(
133          trainable_outputs,
134          trainable_inputs,
135          forwardprop_aids,
136          unconnected_gradients=UnconnectedGradients.ZERO)
137    nontrivial_output_tangents = transpose_tape.gradient(
138        grads, forwardprop_aids, output_gradients=nontrivial_tangents)
139    output_tangents = [None] * len(outputs)
140    for index, tangent in zip(nontrivial_output_indices,
141                              nontrivial_output_tangents):
142      output_tangents[index] = tangent
143    return output_tangents
144
145
146def _jvp_helper_wrapper(op_name, attr_tuple, inputs, outputs, tangents,
147                        use_batch):
148  """Computes a batch of Jacobian-vector product for an op.
149
150  Args:
151    op_name: A string, the type of operation being executed.
152    attr_tuple: Attributes of the operation.
153    inputs: A flat list of input Tensors to the operation.
154    outputs: A flat list of output Tensors from the operation.
155    tangents: A flat list of Tensors, compatible with shape `[None] +
156      input_shape`.
157    use_batch: A bool, True to vetorize over batch of tangents of shape `[None]
158      + input_shape`.
159
160  Returns:
161    A flat list of tangents compatible with `outputs`
162    or `[None] + output_shape`.
163
164  Raises:
165    ValueError: if tangent shapes are not compatible with input shapes.
166  """
167  if use_batch:
168    for primal, tangent in zip(inputs, tangents):
169      if not tangent.shape.is_compatible_with([None] + primal.shape):
170        raise ValueError("Tangent {} was expected to be of shape "
171                         "{} but is instead of shape {}".format(
172                             tangent, [None] + primal.shape, tangent.shape))
173
174    return control_flow_ops.vectorized_map(
175        functools.partial(_jvp_helper, op_name, attr_tuple, inputs, outputs),
176        tangents,
177    )
178  return _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents)
179
180
181# TODO(allenl): experimental_relax_shapes for gradients which rely on static
182# shape information are underspecialized. We may want hand-written forward
183# implementations, or a more satisfying story about how we re-specialize
184# gradients which were traced with relaxed shapes (e.g. use conds instead of
185# trace-time Python logic).
186#
187# Using function.defun rather than def_function.function avoids
188# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
189# eagerly (infinite recursion), and even if it did it would use extra memory and
190# run unnecessary computation. The function does not create variables, so the
191# two symbols are otherwise equivalent.
192_jvp_relaxed_shapes = function.defun(
193    _jvp_helper_wrapper, experimental_relax_shapes=True)
194_jvp_exact_shapes = function.defun(
195    _jvp_helper_wrapper, experimental_relax_shapes=False)
196
197# The maximum number of exact-shape traces to perform for a single op before
198# switching to shape relaxation.
199_TRACE_COUNT_LIMIT = 32
200
201
202def _jvp_dispatch(op_name,
203                  attr_tuple,
204                  inputs,
205                  outputs,
206                  tangents,
207                  use_batch=False):
208  """Determine which forwardprop function to call."""
209  # Note that this _TRACE_COUNT read races with writes. That's fine, it just
210  # means we may trace a few more exact shapes before moving on to relaxation.
211  if _TRACE_COUNT.get(op_name, 0) < _TRACE_COUNT_LIMIT:
212    return _jvp_exact_shapes(op_name, attr_tuple, inputs, outputs, tangents,
213                             use_batch)
214  return _jvp_relaxed_shapes(op_name, attr_tuple, inputs, outputs, tangents,
215                             use_batch)
216
217
218pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
219
220
221@tf_export("autodiff.ForwardAccumulator", v1=[])
222class ForwardAccumulator():
223  """Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff.
224
225  Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s)
226  using reverse-mode autodiff (backprop). Reverse mode is more attractive when
227  computing gradients of a scalar-valued function with respect to many inputs
228  (e.g. a neural network with many parameters and a scalar loss). Forward mode
229  works best on functions with many outputs and few inputs. Since it does not
230  hold on to intermediate activations, it is much more memory efficient than
231  backprop where it is applicable.
232
233  Consider a simple linear regression:
234
235  >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
236  >>> targets = tf.constant([[1.], [-1.]])
237  >>> dense = tf.keras.layers.Dense(1)
238  >>> dense.build([None, 2])
239  >>> with tf.autodiff.ForwardAccumulator(
240  ...    primals=dense.kernel,
241  ...    tangents=tf.constant([[1.], [0.]])) as acc:
242  ...   loss = tf.reduce_sum((dense(x) - targets) ** 2.)
243  >>> acc.jvp(loss)
244  <tf.Tensor: shape=(), dtype=float32, numpy=...>
245
246  The example has two variables containing parameters, `dense.kernel` (2
247  parameters) and `dense.bias` (1 parameter). Considering the training data `x`
248  as a constant, this means the Jacobian matrix for the function mapping from
249  parameters to loss has one row and three columns.
250
251  With forwardprop, we specify a length-three vector in advance which multiplies
252  the Jacobian. The `primals` constructor argument is the parameter (a
253  `tf.Tensor` or `tf.Variable`) we're specifying a vector for, and the
254  `tangents` argument is the "vector" in Jacobian-vector product. If our goal is
255  to compute the entire Jacobian matrix, forwardprop computes one column at a
256  time while backprop computes one row at a time. Since the Jacobian in the
257  linear regression example has only one row, backprop requires fewer
258  invocations:
259
260  >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
261  >>> targets = tf.constant([[1.], [-1.]])
262  >>> dense = tf.keras.layers.Dense(1)
263  >>> dense.build([None, 2])
264  >>> loss_fn = lambda: tf.reduce_sum((dense(x) - targets) ** 2.)
265  >>> kernel_fprop = []
266  >>> with tf.autodiff.ForwardAccumulator(
267  ...     dense.kernel, tf.constant([[1.], [0.]])) as acc:
268  ...   kernel_fprop.append(acc.jvp(loss_fn()))
269  >>> with tf.autodiff.ForwardAccumulator(
270  ...     dense.kernel, tf.constant([[0.], [1.]])) as acc:
271  ...   kernel_fprop.append(acc.jvp(loss_fn()))
272  >>> with tf.autodiff.ForwardAccumulator(dense.bias, tf.constant([1.])) as acc:
273  ...   bias_fprop = acc.jvp(loss_fn())
274  >>> with tf.GradientTape() as tape:
275  ...   loss = loss_fn()
276  >>> kernel_grad, bias_grad = tape.gradient(loss, (dense.kernel, dense.bias))
277  >>> np.testing.assert_allclose(
278  ...     kernel_grad, tf.stack(kernel_fprop)[:, tf.newaxis])
279  >>> np.testing.assert_allclose(bias_grad, bias_fprop[tf.newaxis])
280
281  Implicit in the `tape.gradient` call is a length-one vector which
282  left-multiplies the Jacobian, a vector-Jacobian product.
283
284  `ForwardAccumulator` maintains JVPs corresponding primal tensors it is
285  watching, derived from the original `primals` specified in the constructor. As
286  soon as a primal tensor is deleted, `ForwardAccumulator` deletes the
287  corresponding JVP.
288
289  `acc.jvp(x)` retrieves `acc`'s JVP corresponding to the primal tensor `x`. It
290  does not perform any computation. `acc.jvp` calls can be repeated as long as
291  `acc` is accessible, whether the context manager is active or not. New JVPs
292  are only computed while the context manager is active.
293
294  Note that `ForwardAccumulator`s are always applied in the order their context
295  managers were entered, so inner accumulators will not see JVP computation from
296  outer accumulators. Take higher-order JVPs from outer accumulators:
297
298  >>> primal = tf.constant(1.1)
299  >>> with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as outer:
300  ...   with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as inner:
301  ...     primal_out = primal ** tf.constant(3.5)
302  >>> inner_jvp = inner.jvp(primal_out)
303  >>> inner_jvp  # 3.5 * 1.1 ** 2.5
304  <tf.Tensor: shape=(), dtype=float32, numpy=4.4417057>
305  >>> outer.jvp(inner_jvp)  # 3.5 * 2.5 * 1.1 ** 1.5
306  <tf.Tensor: shape=(), dtype=float32, numpy=10.094786>
307
308  Reversing the collection in the last line to instead retrieve
309  `inner.jvp(outer.jvp(primal_out))` will not work.
310
311  Strict nesting also applies to combinations of `ForwardAccumulator` and
312  `tf.GradientTape`. More deeply nested `GradientTape` objects will ignore the
313  products of outer `ForwardAccumulator` objects. This allows (for example)
314  memory-efficient forward-over-backward computation of Hessian-vector products,
315  where the inner `GradientTape` would otherwise hold on to all intermediate
316  JVPs:
317
318  >>> v = tf.Variable([1., 2.])
319  >>> with tf.autodiff.ForwardAccumulator(
320  ...     v,
321  ...     # The "vector" in Hessian-vector product.
322  ...     tf.constant([1., 0.])) as acc:
323  ...   with tf.GradientTape() as tape:
324  ...     y = tf.reduce_sum(v ** 3.)
325  ...   backward = tape.gradient(y, v)
326  >>> backward  # gradient from backprop
327  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 12.], dtype=float32)>
328  >>> acc.jvp(backward)  # forward-over-backward Hessian-vector product
329  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)>
330  """
331
332  def __init__(self, primals, tangents):
333    """Specify tensors to watch and their Jacobian-vector products.
334
335    Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
336    (a Jacobian-vector product) for the function computed while this accumulator
337    is active. Since JVPs are computed in forward mode as the computation
338    happens, this vector must be supplied in advance.
339
340    Listing a single tensor multiple times in `primals` raises an
341    exception. Excluding a tensor from `primals` is equivalent to watching it
342    with a tangent tensor of zeros.
343
344    Args:
345      primals: A tensor or nested structure of tensors to watch.
346      tangents: A tensor or nested structure of tensors, with the same nesting
347        structure as `primals`, with each element being a vector with the same
348        size as the corresponding primal element.
349
350    Raises:
351      ValueError: If the same tensor or variable is specified multiple times in
352        `primals`.
353    """
354    self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False)
355    self._recording = False
356    primal_ids = set()
357    for primal in nest.flatten(primals):
358      if id(primal) in primal_ids:
359        raise ValueError(
360            "Tensor {} was specified as a primal multiple times. This may "
361            "indicate an error. If it was intended, please sum the "
362            "corresponding tangents.")
363      primal_ids.add(id(primal))
364    self._watch(primals, tangents)
365
366  def __enter__(self):
367    self._push_accumulator()
368    return self
369
370  def __exit__(self, typ, value, traceback):
371    if self._recording:
372      self._pop_accumulator()
373
374  def _push_accumulator(self):
375    if self._recording:
376      raise ValueError("Accumulator is already recording.")
377    pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
378    self._recording = True
379
380  def _pop_accumulator(self):
381    if not self._recording:
382      raise ValueError("Accumulator is not recording.")
383    pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
384    self._recording = False
385
386  def _watch(self, primals, tangents):
387    """Ensures that `primals` are being traced by this accumulator.
388
389    Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
390    (a Jacobian-vector product) for the function computed while this accumulator
391    is active. Since JVPs are computed in forward mode as the computation
392    happens, this vector must be supplied in advance.
393
394    Watching a single tensor multiple times sums each of its `tangents`. Any
395    un-watched tensor has zeros for its tangent vector.
396
397    Args:
398      primals: A Tensor or list of Tensors.
399      tangents: A Tensor or list of Tensors matching `primals`.
400    """
401
402    def _watch(primal, tangent):
403      if not primal.dtype.is_floating:
404        logging.log_first_n(
405            logging.WARN, "The dtype of the watched primal must be "
406            "floating (e.g. tf.float32), got %r", 5, primal.dtype)
407      tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype)
408      if hasattr(primal, "handle"):
409        # Run convert_to_tensor to get the captured handle from whichever
410        # function we're running if necessary.
411        primal = ops.convert_to_tensor(primal.handle)
412      pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, primal,
413                                                tangent)
414
415    nest.map_structure(_watch, primals, tangents, expand_composites=True)
416
417  def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
418    """Fetches the Jacobian-vector product computed for `primals`.
419
420    Note that this method performs no computation, and simply looks up a JVP
421    that was already computed (unlike backprop using a `tf.GradientTape`, where
422    the computation happens on the call to `tape.gradient`).
423
424    Args:
425      primals: A watched Tensor or structure of Tensors to fetch the JVPs for.
426      unconnected_gradients: A value which can either hold 'none' or 'zero' and
427        alters the value which will be returned if no JVP was computed for
428        `primals`. The possible values and effects are detailed in
429        'tf.UnconnectedGradients' and it defaults to 'none'.
430
431    Returns:
432      Tensors with the same shapes and dtypes as `primals`, or None if no JVP
433      is available.
434    """
435    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
436    if self._accumulator is None:
437      raise ValueError("Called jvp() without first tracing anything.")
438
439    def _fetch_jvp(tensor):
440      if hasattr(tensor, "handle"):
441        unwrapped_tensor = ops.convert_to_tensor(tensor.handle)
442      else:
443        unwrapped_tensor = tensor
444      result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
445                                                       unwrapped_tensor)
446      if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
447        result = array_ops.zeros_like(tensor)
448      return result
449
450    return nest.map_structure(_fetch_jvp, primals)
451
452  @classmethod
453  def _batch_accumulator(cls, primals, tangents):
454    """Factory constructor to test accumulator on batches of tangents.
455
456    Args:
457      primals: A tensor or nested structure of tensors to watch.
458      tangents: A tensor or nested structure of tensors, with the same nesting
459        structure as `primals`, with each element being a vector with compatible
460        shape `[None] + primal.shape` of the corresponding primal element.
461
462    Returns:
463      A batch accumulator object.
464    """
465    acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents)
466    acc._recording = False
467    acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True)
468    primal_ids = set()
469    for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)):
470      tangent.shape.assert_is_compatible_with(
471          tensor_shape.TensorShape([None]) + primal.shape)
472      if id(primal) in primal_ids:
473        raise ValueError(
474            "Tensor {} was specified as a primal multiple times. This may "
475            "indicate an error. If it was intended, please sum the "
476            "corresponding tangents.")
477      primal_ids.add(id(primal))
478    acc._watch(primals, tangents)
479    return acc
480