• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for forward-mode automatic differentiation."""
16
17import functools
18import threading
19
20from tensorflow.python import pywrap_tfe
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import backprop_util
23from tensorflow.python.eager import execute
24from tensorflow.python.eager import forwardprop_util
25from tensorflow.python.eager import function
26
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops.parallel_for import control_flow_ops
32from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import nest
35from tensorflow.python.util.tf_export import tf_export
36
37
38# Dictionary mapping from op names to special-cased jvp functions. Otherwise
39# backward functions are transposed on the tape.
40_SPECIAL_CASES = {}
41
42
43def _identity_jvp(attr_tuple, inputs, outputs, tangents):
44  # Special-cased mostly for resource handles, where creating ones Tensors from
45  # handle data for transposing the backward function on the tape is error-prone
46  # (even if we get good handle data, partially defined shapes are an issue).
47  del attr_tuple, inputs, outputs
48  return [array_ops.identity(t) for t in tangents]
49
50
51_SPECIAL_CASES["Identity"] = _identity_jvp
52
53
54def _read_variable_jvp(attr_tuple, inputs, outputs, tangents):
55  # Like for Identity, this special case means we don't need to create
56  # variable-shaped Tensors from resource handles.
57  del attr_tuple, inputs, outputs
58  return [array_ops.identity(t) for t in tangents]
59
60
61_SPECIAL_CASES["ReadVariableOp"] = _read_variable_jvp
62
63
64_TRACE_COUNT_CONSISTENCY_LOCK = threading.Lock()
65# Map from op names to number of traces of _jvp_helper. Used to cap the number
66# of traces due to shape differences while still specializing where possible.
67_TRACE_COUNT = {}
68
69
70def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
71  """Computes a Jacobian-vector product for an op.
72
73  Note that this function would be wasteful if executed eagerly. It runs the
74  backward gradient function and throws away the result just to record its
75  operations on a GradientTape. These unused ops are pruned away when this
76  function is traced.
77
78  Args:
79    op_name: A string, the type of operation being executed.
80    attr_tuple: Attributes of the operation.
81    inputs: A flat list of input Tensors to the operation.
82    outputs: A flat list of output Tensors from the operation.
83    tangents: A flat list of Tensors, same shape as `inputs`.
84
85  Returns:
86    A flat list of tangents corresponding to `outputs`.
87  """
88  with _TRACE_COUNT_CONSISTENCY_LOCK:
89    # Just make sure writes don't clobber each other's increments; reads in
90    # _jvp_dispatch do not lock.
91    _TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1
92
93  special_case = _SPECIAL_CASES.get(op_name, None)
94  if special_case is not None:
95    return special_case(attr_tuple, inputs, outputs, tangents)
96  if not outputs:
97    # tape.gradients([], inputs) doesn't make much sense
98    return []
99  # Generally inner GradientTapes won't function while outer accumulators are
100  # recording. We temporarily reset forwardprop state to allow GradientTapes to
101  # function here.
102  with forwardprop_util.push_forwardprop_state():
103    trainable_inputs = []
104    trainable_indices = []
105    nontrivial_tangents = []
106    for input_index, tensor in enumerate(inputs):
107      if backprop_util.IsTrainable(tensor):
108        trainable_inputs.append(tensor)
109        trainable_indices.append(input_index)
110        nontrivial_tangents.append(tangents[input_index])
111
112    with backprop.GradientTape() as transpose_tape:
113      with backprop.GradientTape() as backfunc_tape:
114        backfunc_tape.watch(trainable_inputs)
115        execute.record_gradient(op_name, inputs, attr_tuple, outputs)
116
117      forwardprop_aids = []
118      trainable_outputs = []
119      nontrivial_output_indices = []
120      for output_index, output in enumerate(outputs):
121        if backprop_util.IsTrainable(output):
122          forwardprop_aids.append(
123              array_ops.ones_like(output, name="unused_forwardprop_aid"))
124          trainable_outputs.append(output)
125          nontrivial_output_indices.append(output_index)
126
127      transpose_tape.watch(forwardprop_aids)
128      grads = backfunc_tape.gradient(
129          trainable_outputs,
130          trainable_inputs,
131          forwardprop_aids,
132          unconnected_gradients=UnconnectedGradients.ZERO)
133    nontrivial_output_tangents = transpose_tape.gradient(
134        grads, forwardprop_aids, output_gradients=nontrivial_tangents)
135    output_tangents = [None] * len(outputs)
136    for index, tangent in zip(nontrivial_output_indices,
137                              nontrivial_output_tangents):
138      output_tangents[index] = tangent
139    return output_tangents
140
141
142def _jvp_helper_wrapper(op_name, attr_tuple, inputs, outputs, tangents,
143                        use_batch):
144  """Computes a batch of Jacobian-vector product for an op.
145
146  Args:
147    op_name: A string, the type of operation being executed.
148    attr_tuple: Attributes of the operation.
149    inputs: A flat list of input Tensors to the operation.
150    outputs: A flat list of output Tensors from the operation.
151    tangents: A flat list of Tensors, compatible with shape `[None] +
152      input_shape`.
153    use_batch: A bool, True to vetorize over batch of tangents of shape `[None]
154      + input_shape`.
155
156  Returns:
157    A flat list of tangents compatible with `outputs`
158    or `[None] + output_shape`.
159
160  Raises:
161    ValueError: if tangent shapes are not compatible with input shapes.
162  """
163  if use_batch:
164    for primal, tangent in zip(inputs, tangents):
165      if not tangent.shape.is_compatible_with([None] + primal.shape):
166        raise ValueError("Tangent {} was expected to be of shape "
167                         "{} but is instead of shape {}".format(
168                             tangent, [None] + primal.shape, tangent.shape))
169
170    return control_flow_ops.vectorized_map(
171        functools.partial(_jvp_helper, op_name, attr_tuple, inputs, outputs),
172        tangents,
173    )
174  return _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents)
175
176
177# TODO(allenl): reduce_retracing for gradients which rely on static
178# shape information are underspecialized. We may want hand-written forward
179# implementations, or a more satisfying story about how we re-specialize
180# gradients which were traced with relaxed shapes (e.g. use conds instead of
181# trace-time Python logic).
182#
183# Using function.defun rather than def_function.function avoids
184# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
185# eagerly (infinite recursion), and even if it did it would use extra memory and
186# run unnecessary computation. The function does not create variables, so the
187# two symbols are otherwise equivalent.
188_jvp_relaxed_shapes = function.defun(
189    _jvp_helper_wrapper, reduce_retracing=True)
190_jvp_exact_shapes = function.defun(
191    _jvp_helper_wrapper, reduce_retracing=False)
192
193# The maximum number of exact-shape traces to perform for a single op before
194# switching to shape relaxation.
195_TRACE_COUNT_LIMIT = 32
196
197
198def _jvp_dispatch(op_name,
199                  attr_tuple,
200                  inputs,
201                  outputs,
202                  tangents,
203                  use_batch=False):
204  """Determine which forwardprop function to call."""
205  # Note that this _TRACE_COUNT read races with writes. That's fine, it just
206  # means we may trace a few more exact shapes before moving on to relaxation.
207  if _TRACE_COUNT.get(op_name, 0) < _TRACE_COUNT_LIMIT:
208    return _jvp_exact_shapes(op_name, attr_tuple, inputs, outputs, tangents,
209                             use_batch)
210  return _jvp_relaxed_shapes(op_name, attr_tuple, inputs, outputs, tangents,
211                             use_batch)
212
213
214pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
215
216
217@tf_export("autodiff.ForwardAccumulator", v1=[])
218class ForwardAccumulator():
219  """Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff.
220
221  Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s)
222  using reverse-mode autodiff (backprop). Reverse mode is more attractive when
223  computing gradients of a scalar-valued function with respect to many inputs
224  (e.g. a neural network with many parameters and a scalar loss). Forward mode
225  works best on functions with many outputs and few inputs. Since it does not
226  hold on to intermediate activations, it is much more memory efficient than
227  backprop where it is applicable.
228
229  Consider a simple linear regression:
230
231  >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
232  >>> targets = tf.constant([[1.], [-1.]])
233  >>> dense = tf.keras.layers.Dense(1)
234  >>> dense.build([None, 2])
235  >>> with tf.autodiff.ForwardAccumulator(
236  ...    primals=dense.kernel,
237  ...    tangents=tf.constant([[1.], [0.]])) as acc:
238  ...   loss = tf.reduce_sum((dense(x) - targets) ** 2.)
239  >>> acc.jvp(loss)
240  <tf.Tensor: shape=(), dtype=float32, numpy=...>
241
242  The example has two variables containing parameters, `dense.kernel` (2
243  parameters) and `dense.bias` (1 parameter). Considering the training data `x`
244  as a constant, this means the Jacobian matrix for the function mapping from
245  parameters to loss has one row and three columns.
246
247  With forwardprop, we specify a length-three vector in advance which multiplies
248  the Jacobian. The `primals` constructor argument is the parameter (a
249  `tf.Tensor` or `tf.Variable`) we're specifying a vector for, and the
250  `tangents` argument is the "vector" in Jacobian-vector product. If our goal is
251  to compute the entire Jacobian matrix, forwardprop computes one column at a
252  time while backprop computes one row at a time. Since the Jacobian in the
253  linear regression example has only one row, backprop requires fewer
254  invocations:
255
256  >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
257  >>> targets = tf.constant([[1.], [-1.]])
258  >>> dense = tf.keras.layers.Dense(1)
259  >>> dense.build([None, 2])
260  >>> loss_fn = lambda: tf.reduce_sum((dense(x) - targets) ** 2.)
261  >>> kernel_fprop = []
262  >>> with tf.autodiff.ForwardAccumulator(
263  ...     dense.kernel, tf.constant([[1.], [0.]])) as acc:
264  ...   kernel_fprop.append(acc.jvp(loss_fn()))
265  >>> with tf.autodiff.ForwardAccumulator(
266  ...     dense.kernel, tf.constant([[0.], [1.]])) as acc:
267  ...   kernel_fprop.append(acc.jvp(loss_fn()))
268  >>> with tf.autodiff.ForwardAccumulator(dense.bias, tf.constant([1.])) as acc:
269  ...   bias_fprop = acc.jvp(loss_fn())
270  >>> with tf.GradientTape() as tape:
271  ...   loss = loss_fn()
272  >>> kernel_grad, bias_grad = tape.gradient(loss, (dense.kernel, dense.bias))
273  >>> np.testing.assert_allclose(
274  ...     kernel_grad, tf.stack(kernel_fprop)[:, tf.newaxis])
275  >>> np.testing.assert_allclose(bias_grad, bias_fprop[tf.newaxis])
276
277  Implicit in the `tape.gradient` call is a length-one vector which
278  left-multiplies the Jacobian, a vector-Jacobian product.
279
280  `ForwardAccumulator` maintains JVPs corresponding primal tensors it is
281  watching, derived from the original `primals` specified in the constructor. As
282  soon as a primal tensor is deleted, `ForwardAccumulator` deletes the
283  corresponding JVP.
284
285  `acc.jvp(x)` retrieves `acc`'s JVP corresponding to the primal tensor `x`. It
286  does not perform any computation. `acc.jvp` calls can be repeated as long as
287  `acc` is accessible, whether the context manager is active or not. New JVPs
288  are only computed while the context manager is active.
289
290  Note that `ForwardAccumulator`s are always applied in the order their context
291  managers were entered, so inner accumulators will not see JVP computation from
292  outer accumulators. Take higher-order JVPs from outer accumulators:
293
294  >>> primal = tf.constant(1.1)
295  >>> with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as outer:
296  ...   with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as inner:
297  ...     primal_out = primal ** tf.constant(3.5)
298  >>> inner_jvp = inner.jvp(primal_out)
299  >>> inner_jvp  # 3.5 * 1.1 ** 2.5
300  <tf.Tensor: shape=(), dtype=float32, numpy=4.4417057>
301  >>> outer.jvp(inner_jvp)  # 3.5 * 2.5 * 1.1 ** 1.5
302  <tf.Tensor: shape=(), dtype=float32, numpy=10.094786>
303
304  Reversing the collection in the last line to instead retrieve
305  `inner.jvp(outer.jvp(primal_out))` will not work.
306
307  Strict nesting also applies to combinations of `ForwardAccumulator` and
308  `tf.GradientTape`. More deeply nested `GradientTape` objects will ignore the
309  products of outer `ForwardAccumulator` objects. This allows (for example)
310  memory-efficient forward-over-backward computation of Hessian-vector products,
311  where the inner `GradientTape` would otherwise hold on to all intermediate
312  JVPs:
313
314  >>> v = tf.Variable([1., 2.])
315  >>> with tf.autodiff.ForwardAccumulator(
316  ...     v,
317  ...     # The "vector" in Hessian-vector product.
318  ...     tf.constant([1., 0.])) as acc:
319  ...   with tf.GradientTape() as tape:
320  ...     y = tf.reduce_sum(v ** 3.)
321  ...   backward = tape.gradient(y, v)
322  >>> backward  # gradient from backprop
323  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 12.], dtype=float32)>
324  >>> acc.jvp(backward)  # forward-over-backward Hessian-vector product
325  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)>
326  """
327
328  def __init__(self, primals, tangents):
329    """Specify tensors to watch and their Jacobian-vector products.
330
331    Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
332    (a Jacobian-vector product) for the function computed while this accumulator
333    is active. Since JVPs are computed in forward mode as the computation
334    happens, this vector must be supplied in advance.
335
336    Listing a single tensor multiple times in `primals` raises an
337    exception. Excluding a tensor from `primals` is equivalent to watching it
338    with a tangent tensor of zeros.
339
340    Args:
341      primals: A tensor or nested structure of tensors to watch.
342      tangents: A tensor or nested structure of tensors, with the same nesting
343        structure as `primals`, with each element being a vector with the same
344        size as the corresponding primal element.
345
346    Raises:
347      ValueError: If the same tensor or variable is specified multiple times in
348        `primals`.
349    """
350    self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False)
351    self._recording = False
352    primal_ids = set()
353    for primal in nest.flatten(primals):
354      if id(primal) in primal_ids:
355        raise ValueError(
356            "Tensor {} was specified as a primal multiple times. This may "
357            "indicate an error. If it was intended, please sum the "
358            "corresponding tangents.")
359      primal_ids.add(id(primal))
360    self._watch(primals, tangents)
361
362  def __enter__(self):
363    self._push_accumulator()
364    return self
365
366  def __exit__(self, typ, value, traceback):
367    if self._recording:
368      self._pop_accumulator()
369
370  def _push_accumulator(self):
371    if self._recording:
372      raise ValueError("Accumulator is already recording.")
373    pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
374    self._recording = True
375
376  def _pop_accumulator(self):
377    if not self._recording:
378      raise ValueError("Accumulator is not recording.")
379    pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
380    self._recording = False
381
382  def _watch(self, primals, tangents):
383    """Ensures that `primals` are being traced by this accumulator.
384
385    Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
386    (a Jacobian-vector product) for the function computed while this accumulator
387    is active. Since JVPs are computed in forward mode as the computation
388    happens, this vector must be supplied in advance.
389
390    Watching a single tensor multiple times sums each of its `tangents`. Any
391    un-watched tensor has zeros for its tangent vector.
392
393    Args:
394      primals: A Tensor or list of Tensors.
395      tangents: A Tensor or list of Tensors matching `primals`.
396    """
397
398    def _watch(primal, tangent):
399      if not primal.dtype.is_floating:
400        logging.log_first_n(
401            logging.WARN, "The dtype of the watched primal must be "
402            "floating (e.g. tf.float32), got %r", 5, primal.dtype)
403      tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype)
404      if hasattr(primal, "handle"):
405        # Run convert_to_tensor to get the captured handle from whichever
406        # function we're running if necessary.
407        primal = ops.convert_to_tensor(primal.handle)
408      pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, primal,
409                                                tangent)
410
411    nest.map_structure(_watch, primals, tangents)
412
413  def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
414    """Fetches the Jacobian-vector product computed for `primals`.
415
416    Note that this method performs no computation, and simply looks up a JVP
417    that was already computed (unlike backprop using a `tf.GradientTape`, where
418    the computation happens on the call to `tape.gradient`).
419
420    Args:
421      primals: A watched Tensor or structure of Tensors to fetch the JVPs for.
422      unconnected_gradients: A value which can either hold 'none' or 'zero' and
423        alters the value which will be returned if no JVP was computed for
424        `primals`. The possible values and effects are detailed in
425        'tf.UnconnectedGradients' and it defaults to 'none'.
426
427    Returns:
428      Tensors with the same shapes and dtypes as `primals`, or None if no JVP
429      is available.
430    """
431    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
432    if self._accumulator is None:
433      raise ValueError("Called jvp() without first tracing anything.")
434
435    def _fetch_jvp(tensor):
436      if hasattr(tensor, "handle"):
437        unwrapped_tensor = ops.convert_to_tensor(tensor.handle)
438      else:
439        unwrapped_tensor = tensor
440      result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
441                                                       unwrapped_tensor)
442      if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
443        result = array_ops.zeros_like(tensor)
444      return result
445
446    return nest.map_structure(_fetch_jvp, primals)
447
448  @classmethod
449  def _batch_accumulator(cls, primals, tangents):
450    """Factory constructor to test accumulator on batches of tangents.
451
452    Args:
453      primals: A tensor or nested structure of tensors to watch.
454      tangents: A tensor or nested structure of tensors, with the same nesting
455        structure as `primals`, with each element being a vector with compatible
456        shape `[None] + primal.shape` of the corresponding primal element.
457
458    Returns:
459      A batch accumulator object.
460    """
461    acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents)
462    acc._recording = False
463    acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True)
464    primal_ids = set()
465    for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)):
466      tangent.shape.assert_is_compatible_with(
467          tensor_shape.TensorShape([None]) + primal.shape)
468      if id(primal) in primal_ids:
469        raise ValueError(
470            "Tensor {} was specified as a primal multiple times. This may "
471            "indicate an error. If it was intended, please sum the "
472            "corresponding tangents.")
473      primal_ids.add(id(primal))
474    acc._watch(primals, tangents)
475    return acc
476