• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import gc
22import weakref
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python import pywrap_tfe
28from tensorflow.python.distribute import mirrored_strategy
29from tensorflow.python.eager import backprop
30from tensorflow.python.eager import def_function
31from tensorflow.python.eager import forwardprop
32from tensorflow.python.eager import forwardprop_util
33from tensorflow.python.eager import tape as tape_lib
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import test_util
38from tensorflow.python.module import module
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import custom_gradient
41from tensorflow.python.ops import gradient_checker_v2
42from tensorflow.python.ops import map_fn
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import nn_impl
45from tensorflow.python.ops import nn_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.ops import variables
48from tensorflow.python.ops.parallel_for import control_flow_ops
49from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
50from tensorflow.python.platform import test
51from tensorflow.python.util import nest
52
53_X11_35_DERIVATIVES = [
54    1.1**3.5, 3.5 * 1.1**2.5, 3.5 * 2.5 * 1.1**1.5, 3.5 * 2.5 * 1.5 * 1.1**0.5
55]
56
57
58# TODO(allenl): Move this somewhere useful once forward gradients are stable.
59def _jvp(f, primals, tangents):
60  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
61  with forwardprop.ForwardAccumulator(primals, tangents) as acc:
62    primals_out = f(*primals)
63  return primals_out, acc.jvp(
64      primals_out, unconnected_gradients=UnconnectedGradients.ZERO)
65
66
67def _jacfwd(f, primals):
68  """Compute the jacobian of `f` at `primals` using forward-mode autodiff."""
69  jac_flat = []
70  flat_primals = nest.flatten(primals)
71  tangent_mask = [array_ops.zeros_like(primal) for primal in flat_primals]
72  for primal_index, primal in enumerate(flat_primals):
73    primal_vector = array_ops.reshape(primal, [-1])
74    primal_vector_length = array_ops.size(primal_vector)
75    jac_columns = []
76    for element_index in math_ops.range(primal_vector_length):
77      mask = array_ops.one_hot(element_index, primal_vector_length)
78      tangent_mask[primal_index] = array_ops.reshape(mask,
79                                                     array_ops.shape(primal))
80      jac_columns.append(
81          nest.map_structure(
82              functools.partial(array_ops.reshape, shape=[-1]),
83              _jvp(f, primals, nest.pack_sequence_as(primals,
84                                                     tangent_mask))[1]))
85    jac_flat.append(array_ops.stack(jac_columns, axis=1))
86    tangent_mask[primal_index] = array_ops.zeros_like(primal)
87  return nest.pack_sequence_as(primals, jac_flat)
88
89
90def _jvp_batch(f, primal, tangents):
91  tf_function = def_function.function(f)
92
93  return control_flow_ops.vectorized_map(
94      functools.partial(_jvp, tf_function, primal), tangents)
95
96
97def _jvp_batch_matmul(f, primals, tangent_batch):
98  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
99  jac_fwd = _jacfwd(f, primals)
100
101  def jac_mul(tangent):
102    flat_tangent = array_ops.reshape(tangent, shape=[-1])
103    tangent_vector = array_ops.expand_dims(flat_tangent, 1)
104    jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
105    return array_ops.reshape(jvp_vector, tangent.shape)
106
107  return control_flow_ops.vectorized_map(jac_mul, tangent_batch)
108
109
110def _grad(f, argnums=0):
111  """Return a function which computes the gradient of `f`."""
112
113  def _f(*params):
114    with backprop.GradientTape() as tape:
115      tape.watch(params)
116      primals_out = f(*params)
117    return tape.gradient(
118        primals_out,
119        params[argnums],
120        unconnected_gradients=UnconnectedGradients.ZERO)
121
122  return _f
123
124
125def _gradfwd(f, argnums=0, f_out_dtypes=dtypes.float32):
126  """Return a function which computes the gradient of `f` in forward mode."""
127
128  def _f(*params):
129
130    def _single_jvp(param_mask):
131      with forwardprop.ForwardAccumulator(
132          primals=[params[argnums]], tangents=param_mask) as acc:
133        primals_out = f(*params)
134      return acc.jvp(primals_out)
135
136    # Building up a function to run with pfor takes a bit too long since we're
137    # only running it a handful of times.
138    return _vectorize_parameters(
139        _single_jvp, [params[argnums]], use_pfor=False, dtype=f_out_dtypes)
140
141  return _f
142
143
144def _hvp(f, primals, tangents):
145  """Compute a forward-over-back Hessian-vector product."""
146  with forwardprop.ForwardAccumulator(primals, tangents) as acc:
147    with backprop.GradientTape() as tape:
148      tape.watch(primals)
149      f_out = f(*primals)
150      f_out.shape.assert_is_compatible_with([])
151    return acc.jvp(tape.gradient(f_out, primals))
152
153
154def _vectorize_parameters(f, params, use_pfor, dtype):
155  """Loop over `params`, providing a one-hot mask to `f` for each."""
156  parameter_sizes = [array_ops.size(param) for param in params]
157  total_size = math_ops.add_n(parameter_sizes)
158
159  def _wrapper(index):
160    full_onehot = array_ops.one_hot(index, total_size)
161    split_onehot = array_ops.split(full_onehot, parameter_sizes)
162    tangents = [
163        array_ops.reshape(v, array_ops.shape(param))
164        for param, v in zip(params, split_onehot)
165    ]
166    return f(tangents)
167
168  if use_pfor:
169    return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size))
170
171  return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype)
172
173
174def _forward_over_back_hessian(f, params, use_pfor, dtype=None):
175  """Computes the full Hessian matrix for the scalar-valued f(*params).
176
177  Args:
178    f: A function taking `params` and returning a scalar.
179    params: A possibly nested structure of tensors.
180    use_pfor: If true, uses `tf.vectorized_map` calls instead of looping.
181    dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes
182      (e.g. `tf.float32`) matching the structure of `f`'s returns.
183
184  Returns:
185    A possibly nested structure of matrix slices corresponding to `params`. Each
186    slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`)
187    in the corresponding element of `params` and `P` is the total number of
188    parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating
189    along the second axis.
190  """
191  return _vectorize_parameters(
192      functools.partial(_hvp, f, params),
193      params,
194      use_pfor=use_pfor,
195      dtype=dtype)
196
197
198def _test_gradients(testcase,
199                    f,
200                    primals,
201                    order,
202                    delta=1e-3,
203                    rtol=1e-2,
204                    atol=1e-6):
205  """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
206  if order < 1:
207    raise ValueError(
208        "`order` should be a positive integer, got '{}'.".format(order))
209  if order > 1:
210    _test_gradients(
211        testcase=testcase,
212        f=_grad(f),
213        primals=primals,
214        order=order - 1,
215        delta=delta,
216        rtol=rtol,
217        atol=atol)
218  sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
219      f, primals, delta=delta)
220  testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
221  sym_jac_fwd = _jacfwd(f, primals)
222  testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol)
223  # And the symbolic computations should be much closer.
224  testcase.assertAllClose(sym_jac_back, sym_jac_fwd)
225
226
227class ForwardpropTest(test.TestCase, parameterized.TestCase):
228
229  def testJVPFunction(self):
230    add_outputs = (constant_op.constant(4.),)
231    vp, = forwardprop._jvp_dispatch(
232        op_name="Add",
233        attr_tuple=(),
234        inputs=(constant_op.constant(1.), constant_op.constant(3.)),
235        outputs=add_outputs,
236        tangents=(
237            constant_op.constant(1.),
238            constant_op.constant(5.),
239        ))
240    self.assertAllClose(1. + 5., self.evaluate(vp))
241
242    mul_outputs = (constant_op.constant([20.]),)
243    vp, = forwardprop._jvp_dispatch(
244        op_name="Mul",
245        attr_tuple=(),
246        inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
247        outputs=mul_outputs,
248        tangents=(
249            constant_op.constant([2.]),
250            constant_op.constant([3.]),
251        ))
252    self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
253
254  def testJVPFunctionWithBatchOfTangents(self):
255    add_outputs = (constant_op.constant(4.),)
256    jvp_flat = forwardprop._jvp_dispatch(
257        op_name="Add",
258        attr_tuple=(),
259        inputs=(constant_op.constant(1.), constant_op.constant(3.)),
260        outputs=add_outputs,
261        tangents=(
262            constant_op.constant([1., 2., 3.]),
263            constant_op.constant([4., 5., 6.]),
264        ),
265        use_batch=True)
266
267    # Using evaluate and asserting with just a list works too
268    # but the output is more explicit this way
269    self.assertAllClose([constant_op.constant([1. + 4., 2. + 5., 3. + 6.])],
270                        jvp_flat)
271
272    mul_outputs = (constant_op.constant([20.]),)
273    jvp_flat = forwardprop._jvp_dispatch(
274        op_name="Mul",
275        attr_tuple=(),
276        inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
277        outputs=mul_outputs,
278        tangents=(
279            constant_op.constant([[1.], [0.], [1.]]),
280            constant_op.constant([[0.], [1.], [1.]]),
281        ),
282        use_batch=True)
283    self.assertAllClose([constant_op.constant([[5.], [4.], [5. + 4.]])],
284                        jvp_flat)
285
286  def testJVPFunctionRaisesError(self):
287    sum_outputs = (constant_op.constant(6.),)
288
289    with self.assertRaisesRegex(ValueError, r".*was expected to be of shape*"):
290      forwardprop._jvp_dispatch(
291          op_name="Add",
292          attr_tuple=(),
293          inputs=(constant_op.constant(2.), constant_op.constant(4.)),
294          outputs=sum_outputs,
295          tangents=(constant_op.constant([1., 2.]),
296                    constant_op.constant([[1.], [2.]])),
297          use_batch=True)
298
299  def testNonDifferentiableOpWithInputTangent(self):
300    x = constant_op.constant(1.)
301    with forwardprop.ForwardAccumulator(x, 2.) as acc1:
302      with forwardprop.ForwardAccumulator(x, 2.) as acc2:
303        y = array_ops.zeros_like(x)
304      self.assertIsNone(acc1.jvp(y))
305    self.assertIsNone(acc2.jvp(y))
306
307  def testRunFunctionsEagerly(self):
308    try:
309      original_setting = def_function.functions_run_eagerly()
310      def_function.run_functions_eagerly(True)
311      x = constant_op.constant(1.)
312      with forwardprop.ForwardAccumulator(x, 2.) as acc:
313        y = x * 3.
314      self.assertAllClose(6., acc.jvp(y))
315    finally:
316      def_function.run_functions_eagerly(original_setting)
317
318  def testJVPFunctionUsedByAccumulatorForOps(self):
319    previous_fn = forwardprop._jvp_dispatch
320    try:
321      x = constant_op.constant(1.)
322      with forwardprop.ForwardAccumulator(x, 2.) as acc:
323        y = x + x
324        pywrap_tfe.TFE_Py_RegisterJVPFunction(
325            lambda *args, **kwargs: [constant_op.constant(-15.)])
326        z = x + x
327      self.assertAllClose(4., acc.jvp(y))
328      self.assertAllClose(-15., acc.jvp(z))
329    finally:
330      pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
331
332  @test_util.assert_no_new_pyobjects_executing_eagerly
333  def testFunctionCacheLimited(self):
334    # Every time this test is executed, it will create a slightly larger Tensor
335    # and push it through Add's gradient. Since we check for new pyobjects after
336    # the warmup, retracing each time without cleaning up old traces fails the
337    # test. It works because of experimental_relax_shapes.
338    for _ in range(forwardprop._TRACE_COUNT_LIMIT):
339      execution_count = getattr(self, "_execution_count", 0)
340      self._execution_count = execution_count + 1
341      x = array_ops.zeros([execution_count])
342      with forwardprop.ForwardAccumulator(x, array_ops.ones_like(x)) as acc:
343        y = x + x
344      self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
345
346  def testVariableUnwatchedZero(self):
347    v = variables.Variable([[1.]])
348    x = constant_op.constant(1.)
349    xt = constant_op.constant(2.)
350    with forwardprop.ForwardAccumulator(x, xt) as acc:
351      pass
352    self.assertIsNone(acc.jvp(v))
353    self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero"))
354
355  @test_util.assert_no_new_pyobjects_executing_eagerly
356  def testFunctionReturnsResource(self):
357    v = variables.Variable([[1.]])
358    x = constant_op.constant(1.)
359    xt = constant_op.constant(2.)
360
361    @def_function.function
362    def f(a):
363      return a, v.handle
364
365    with forwardprop.ForwardAccumulator(x, xt) as acc:
366      y, _ = f(x)
367    self.assertAllClose(2., acc.jvp(y))
368
369  @test_util.assert_no_new_pyobjects_executing_eagerly
370  def testMultipleWatchesAdd(self):
371    x = constant_op.constant(-2.)
372    with self.assertRaisesRegex(ValueError, "multiple times"):
373      with forwardprop.ForwardAccumulator([x, x], [1., 2.]):
374        pass
375    with forwardprop.ForwardAccumulator([x], [3.]) as acc:
376      self.assertAllClose(3., acc.jvp(x))
377      acc._watch(x, constant_op.constant(10.))
378      self.assertAllClose(13., acc.jvp(x))
379      acc._watch(x, constant_op.constant(11.))
380      self.assertAllClose(24., acc.jvp(x))
381      y = constant_op.constant(3.) * x
382    self.assertAllClose(24., acc.jvp(x))
383    self.assertAllClose(24. * 3., acc.jvp(y))
384
385  @test_util.assert_no_new_pyobjects_executing_eagerly
386  def testReenter(self):
387    x = constant_op.constant(-2.)
388    with forwardprop.ForwardAccumulator(x, 1.5) as acc:
389      self.assertAllClose(1.5, acc.jvp(x))
390      y = 4. * x
391      self.assertAllClose(6., acc.jvp(y))
392      with self.assertRaisesRegex(ValueError, "already recording"):
393        with acc:
394          pass
395    z = 4. * x
396    self.assertIsNone(acc.jvp(z))
397    with acc:
398      yy = y * y
399    self.assertAllClose(6. * -8. * 2., acc.jvp(yy))
400
401  @test_util.assert_no_new_pyobjects_executing_eagerly
402  def testDeadTensorsJVPCleared(self):
403    x = array_ops.ones([100])
404    x_weak = weakref.ref(x)
405    grad_tensor = constant_op.constant(array_ops.zeros([100]))
406    grad_tensor_weak = weakref.ref(grad_tensor)
407    with forwardprop.ForwardAccumulator(x, grad_tensor) as acc:
408      derived_tensor = constant_op.constant(2.) * x
409      del grad_tensor
410      self.assertAllClose(array_ops.zeros([100]), acc.jvp(x))
411      del x
412      self.assertIsNone(x_weak())
413      self.assertIsNone(grad_tensor_weak())
414      derived_tensor_weak = weakref.ref(derived_tensor)
415      derived_tensor_grad = acc.jvp(derived_tensor)
416      derived_tensor_grad_weak = weakref.ref(derived_tensor_grad)
417      del derived_tensor
418      del derived_tensor_grad
419      self.assertIsNone(derived_tensor_weak())
420      self.assertIsNone(derived_tensor_grad_weak())
421
422  @test_util.assert_no_new_pyobjects_executing_eagerly
423  def testJVPManual(self):
424    primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),),
425                           (constant_op.constant(0.2),))
426    self.assertAllClose(math_ops.sin(0.1), primal)
427    self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent)
428
429  @test_util.assert_no_new_pyobjects_executing_eagerly
430  def testNumericHigherOrder(self):
431
432    def f(x):
433      pointwise = math_ops.sin(x) * math_ops.tan(x)
434      return math_ops.reduce_prod(
435          pointwise + math_ops.reduce_sum(pointwise), axis=1)
436
437    _test_gradients(
438        self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
439
440  @test_util.assert_no_new_pyobjects_executing_eagerly
441  def testCustomGradient(self):
442
443    @custom_gradient.custom_gradient
444    def f(x):
445
446      def grad(dy):
447        return dy * math_ops.cos(x)
448
449      return np.sin(x.numpy()), grad
450
451    _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
452
453  # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly
454  # fails around this test?
455  def testExceptionCustomGradientRecomputeGradForward(self):
456
457    @custom_gradient.recompute_grad
458    def f(x):
459      return math_ops.reduce_prod(math_ops.tanh(x)**2)
460
461    with self.assertRaisesRegex(NotImplementedError,
462                                "recompute_grad tried to transpose"):
463      primals = [constant_op.constant([1.])]
464      sym_jac_fwd = _jacfwd(f, primals)
465
466  def testExceptionInCustomGradientNotSwallowed(self):
467
468    @custom_gradient.custom_gradient
469    def f(unused_x):
470
471      def grad(unused_dy):
472        raise ValueError("test_error_string")
473
474      return 1., grad
475
476    c = constant_op.constant(1.)
477    d = constant_op.constant(2.)
478    with forwardprop.ForwardAccumulator(c, d):
479      with self.assertRaisesRegex(ValueError, "test_error_string"):
480        f(c)
481
482  @parameterized.named_parameters([("EluM5", -0.5, nn_ops.elu),
483                                   ("EluP5", [0.5], nn_ops.elu),
484                                   ("SwishP5", 0.5, nn_impl.swish),
485                                   ("SwishM5", [-0.5], nn_impl.swish)])
486  def testElementwiseNNOps(self, value, op_fn):
487    _test_gradients(self, op_fn, [constant_op.constant(value)], order=3)
488
489  def testFusedBatchNormGradsInference(self):
490
491    x_shape = [4, 10, 10, 2]
492    increment = 3. / math_ops.reduce_prod(
493        constant_op.constant(x_shape, dtype=dtypes.float32))
494    x = array_ops.reshape(math_ops.range(-2., 1., increment), x_shape)
495    scale = constant_op.constant([1., 1.1])
496    offset = constant_op.constant([-0.5, -0.6])
497    mean = constant_op.constant([-1.3, 1.4])
498    variance = constant_op.constant([0.7, 0.9])
499    epsilon = 0.001
500
501    def _bn_fused(x_arg, scale_arg, offset_arg):
502      return nn_impl.fused_batch_norm(
503          x_arg,
504          scale_arg,
505          offset_arg,
506          mean,
507          variance,
508          epsilon=epsilon,
509          is_training=False)[0]
510
511    _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-2)
512
513  def testPushPopAccumulatorState(self):
514    # Note that this example is somewhat contrived. push_forwardprop_state is
515    # probably only useful in practice for building functions that compute jvps
516    # alongside their usual outputs.
517    c = constant_op.constant(1.)
518    d = constant_op.constant(2.)
519    with forwardprop.ForwardAccumulator(c, d) as acc:
520
521      @custom_gradient.custom_gradient
522      def f(x):
523        y = math_ops.sin(x.numpy())
524
525        def grad(dy):
526          with forwardprop_util.push_forwardprop_state():
527            x_copy = constant_op.constant(x.numpy())
528            acc._watch(x_copy, dy)
529            y_copy = math_ops.sin(x_copy)
530          return dy * acc.jvp(y_copy)
531
532        return y, grad
533
534      output = f(c)
535      self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
536
537  @parameterized.named_parameters([
538      ("Order{}".format(order), order, expected)
539      for order, expected in enumerate(_X11_35_DERIVATIVES)
540  ])
541  @test_util.assert_no_new_pyobjects_executing_eagerly
542  def testHigherOrderPureForward(self, order, expected):
543
544    def _forwardgrad(f):
545
546      def _compute_forwardgrad(primal):
547        tangent = constant_op.constant(1.)
548        with forwardprop.ForwardAccumulator(primal, tangent) as acc:
549          primal_out = f(primal)
550        return acc.jvp(primal_out)
551
552      return _compute_forwardgrad
553
554    def _forward(x):
555      return x**3.5
556
557    f = _forward
558    primal = constant_op.constant(1.1)
559    for _ in range(order):
560      f = _forwardgrad(f)
561    self.assertAllClose(expected, f(primal))
562
563  @parameterized.named_parameters([("Function", def_function.function),
564                                   ("NoFunction", lambda f: f)])
565  def testGradPureForward(self, decorator):
566
567    @decorator
568    def f(x):
569      return x**3.5
570
571    primal = constant_op.constant(1.1)
572    with forwardprop.ForwardAccumulator(primal,
573                                        constant_op.constant(1.)) as outer_acc:
574      with forwardprop.ForwardAccumulator(primal,
575                                          constant_op.constant(1.)) as acc:
576        primal_out = f(primal)
577    inner_jvp = acc.jvp(primal_out)
578    outer_jvp = outer_acc.jvp(inner_jvp)
579    self.assertAllClose(1.1**3.5, primal_out)
580    self.assertAllClose(3.5 * 1.1**2.5, inner_jvp)
581    self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp)
582    self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
583
584  @test_util.assert_no_new_pyobjects_executing_eagerly
585  def testJVPPacking(self):
586    two = constant_op.constant(2.)
587    primal_in = constant_op.constant(1.)
588    inner_jvp = constant_op.constant(3.)
589    with forwardprop.ForwardAccumulator(
590        [primal_in, inner_jvp],
591        [constant_op.constant(2.),
592         constant_op.constant(4.)]) as outer_acc:
593      with forwardprop.ForwardAccumulator(primal_in, inner_jvp) as inner_acc:
594        packed_input_indices, packed_input_tangents = (
595            forwardprop_util.pack_tangents([primal_in]))
596        self.assertAllClose([3., 2., 4.], packed_input_tangents)
597        expected_indices = (
598            # inner_acc watches primal_in
599            (
600                (0, 1),),
601            # outer_acc watches primal_in and inner_jvp
602            ((0, 2), (1, 3)))
603        self.assertAllEqual(expected_indices, packed_input_indices)
604        primal_out = primal_in * two
605        self.assertAllClose(6., inner_acc.jvp(primal_out))
606        self.assertAllClose(4., outer_acc.jvp(primal_out))
607        self.assertAllClose(8., outer_acc.jvp(inner_acc.jvp(primal_out)))
608        packed_output_indices, packed_output_tangents = (
609            forwardprop_util.pack_tangents([primal_out]))
610        self.assertAllClose([6., 4., 8.], packed_output_tangents)
611        self.assertAllEqual(expected_indices, packed_output_indices)
612
613  def testFunctionGradInFunctionPureForward(self):
614
615    @def_function.function
616    def take_gradients():
617
618      @def_function.function
619      def f(x):
620        return x**3.5
621
622      primal = constant_op.constant(1.1)
623      with forwardprop.ForwardAccumulator(
624          primal, constant_op.constant(1.)) as outer_acc:
625        with forwardprop.ForwardAccumulator(primal,
626                                            constant_op.constant(1.)) as acc:
627          primal_out = f(primal)
628      inner_jvp = acc.jvp(primal_out)
629      outer_jvp = outer_acc.jvp(inner_jvp)
630      self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
631      return primal_out, inner_jvp, outer_jvp
632
633    primal_out, inner_jvp, outer_jvp = take_gradients()
634    self.assertAllClose(1.1**3.5, primal_out)
635    self.assertAllClose(3.5 * 1.1**2.5, inner_jvp)
636    self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp)
637
638  def testFunctionGrad(self):
639
640    @def_function.function
641    def f(x):
642      return math_ops.reduce_prod(math_ops.tanh(x)**2)
643
644    _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
645
646  def testReusingJVP(self):
647    m1 = random_ops.random_uniform((256, 2096))
648    m2 = array_ops.identity(m1)
649    tangent1 = random_ops.random_uniform((256, 2096))
650    tangent2 = random_ops.random_uniform((256, 2096))
651    matmul = def_function.function(math_ops.matmul)
652
653    with forwardprop.ForwardAccumulator(
654        primals=[m1, m2], tangents=[tangent1, tangent2]) as acc:
655      result1 = matmul(m1, m1, transpose_b=True)
656      result2 = matmul(m2, m2, transpose_b=True)
657
658    def _expected(mat, tangent):
659      return (math_ops.matmul(tangent, mat, transpose_b=True) +
660              math_ops.matmul(mat, tangent, transpose_b=True))
661
662    self.assertAllClose(result1, result2)
663    self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1))
664    self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2))
665
666  @test_util.assert_no_new_pyobjects_executing_eagerly
667  def testHVPMemory(self):
668
669    def fun(x):
670      return math_ops.reduce_prod(math_ops.tanh(x)**2)
671
672    primals = constant_op.constant([1., 2., 3.])
673    tangents = constant_op.constant([3., 4., 5.])
674    _hvp(fun, (primals,), (tangents,))
675
676  @test_util.assert_no_new_pyobjects_executing_eagerly
677  def testHVPCorrectness(self):
678
679    def fun(x):
680      return math_ops.reduce_prod(math_ops.tanh(x)**2)
681
682    primals = constant_op.constant([1., 2., 3.])
683    tangents = constant_op.constant([3., 4., 5.])
684    forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,))
685    forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,),
686                                                            (tangents,))
687
688    with backprop.GradientTape(persistent=True) as g:
689      g.watch(primals)
690      with backprop.GradientTape() as gg:
691        gg.watch(primals)
692        out = fun(primals)
693      grad = array_ops.unstack(gg.gradient(out, primals))
694    hessian = []
695    for i in range(3):
696      hessian.append(g.gradient(grad[i], primals))
697    hessian = array_ops.stack(hessian, axis=0)
698    backback_hvp = math_ops.tensordot(hessian, tangents, axes=1)
699
700    self.assertAllClose(backback_hvp, forwardback_hvp_eager)
701    self.assertAllClose(backback_hvp, forwardback_hvp_function)
702
703  @test_util.assert_no_new_pyobjects_executing_eagerly
704  def testShouldRecordAndStopRecord(self):
705    c = constant_op.constant(1.)
706    c_tangent = constant_op.constant(2.)
707    with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
708      with backprop.GradientTape() as tape:
709        self.assertFalse(tape_lib.should_record_backprop([c]))
710        self.assertEqual(1, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
711        tape.watch(c)
712        self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
713        self.assertTrue(tape_lib.should_record_backprop([c]))
714        with tape_lib.stop_recording():
715          self.assertEqual(0,
716                           pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
717          self.assertFalse(tape_lib.should_record_backprop([c]))
718          d = c * 2.
719        self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
720        self.assertTrue(tape_lib.should_record_backprop([c]))
721        self.assertFalse(tape_lib.should_record_backprop([d]))
722        self.assertIsNone(acc.jvp(d))
723      self.assertIsNone(tape.gradient(d, c))
724
725  @test_util.assert_no_new_pyobjects_executing_eagerly
726  def testRecordingSelectively(self):
727    c = constant_op.constant(1.)
728    c_tangent = constant_op.constant(2.)
729    with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
730      with backprop.GradientTape(persistent=True) as tape:
731        tape.watch(c)
732        with tape_lib.stop_recording():
733          two = constant_op.constant(2.)
734          d = c * two
735          three = constant_op.constant(3.)
736          e = c * three
737        self.assertIsNone(acc.jvp(d))
738        self.assertIsNone(acc.jvp(e))
739        self.assertIsNone(tape.gradient(d, c))
740        self.assertIsNone(tape.gradient(e, c))
741        tape_lib.record_operation_forwardprop_only(
742            "CustomForwardMul", [d], [c, two], lambda dd: (two * dd, c * dd),
743            None)
744        tape_lib.record_operation_backprop_only("CustomBackwardMul", [e],
745                                                [c, three], lambda de:
746                                                (three * de, c * de))
747        self.assertAllClose(4., acc.jvp(d))
748        self.assertIsNone(acc.jvp(e))
749        self.assertIsNone(tape.gradient(d, c))
750        self.assertAllClose(3., tape.gradient(e, c))
751
752  @test_util.assert_no_new_pyobjects_executing_eagerly
753  def testOpWithNoTrainableOutputs(self):
754    v = variables.Variable(1.)
755    with forwardprop.ForwardAccumulator(v, 11.):
756      v.assign_sub(0.5)
757      self.assertAllClose(0.5, self.evaluate(v))
758
759  # TODO(b/141025187): Add a no_new_pyobjects decorator.
760  def testVariableReadInFunction(self):
761    v = variables.Variable(1.)
762    with forwardprop.ForwardAccumulator(v, 11.) as acc:
763
764      @def_function.function
765      def f():
766        return v.read_value(), 2. * v.read_value()
767
768      result = f()
769      self.assertAllClose((1.0, 2.), result)
770      self.assertAllClose((11., 22.), acc.jvp(result))
771
772  @parameterized.named_parameters([("ForwardPropFirst", True),
773                                   ("TapeFirst", False)])
774  def testForwardOverBackwardMemoryEfficiency(self, forward_prop_first):
775    # Watching depends on nesting, not creation order
776    c = constant_op.constant(1.)
777    if forward_prop_first:
778      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
779      gradient_tape = backprop.GradientTape()
780    else:
781      gradient_tape = backprop.GradientTape()
782      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
783    try:
784      gc.disable()
785      with gradient_tape as tape:
786        # Adding and removing the tape multiple times in different nesting
787        # patterns does not affect watch ordering.
788        pass
789      with forward_accumulator as acc:
790        with gradient_tape as tape:
791          tape.watch(c)
792          d = math_ops.cos(c)
793          self.assertFalse(tape_lib.should_record_backprop((acc.jvp(d),)))
794          e = math_ops.cos(acc.jvp(d))
795          math_ops.cos(e)
796          weak_e = weakref.ref(e)
797          del e
798          self.assertIsNone(weak_e())
799        self.assertIsNone(tape.gradient(acc.jvp(d), c))
800    finally:
801      gc.enable()
802
803  @parameterized.named_parameters([("ForwardPropFirst", True),
804                                   ("TapeFirst", False)])
805  def testBackwardOverForward(self, forward_prop_first):
806    c = constant_op.constant(1.)
807    # Watching depends on nesting, not creation order
808    if forward_prop_first:
809      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
810      gradient_tape = backprop.GradientTape()
811    else:
812      gradient_tape = backprop.GradientTape()
813      forward_accumulator = forwardprop.ForwardAccumulator(c, .1)
814    with gradient_tape as tape:
815      with forward_accumulator as acc:
816        tape.watch(c)
817        d = math_ops.cos(c)
818        self.assertTrue(tape_lib.should_record_backprop((acc.jvp(d),)))
819      self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c))
820
821  @test_util.assert_no_new_pyobjects_executing_eagerly
822  def testRecordingWithJVPIndices(self):
823    c = constant_op.constant(1.)
824    with forwardprop.ForwardAccumulator(c, 10.) as acc:
825      packed_input_tangents = forwardprop_util.pack_tangents([c]).tangents
826      self.assertAllClose([10.], packed_input_tangents)
827      d = constant_op.constant(2.)
828      d_tangent = constant_op.constant(3.)
829      tape_lib.record_operation_forwardprop_only("FunctionWithInlineJVPs",
830                                                 [d] + [d_tangent],
831                                                 [c] + packed_input_tangents,
832                                                 None, (((0, 1),),))
833      self.assertAllClose(3., acc.jvp(d))
834
835  @test_util.assert_no_new_pyobjects_executing_eagerly
836  def testSpecialForwardFunctionUsed(self):
837    c = constant_op.constant(1.)
838    d = constant_op.constant(2.)
839    e = constant_op.constant(3.)
840    with forwardprop.ForwardAccumulator(c, 10.) as acc:
841      tape_lib.record_operation("ForwardIsSpecial", [d], [c], None,
842                                lambda jvp: [-2. * jvp])
843      self.assertAllClose(-20., acc.jvp(d))
844      tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: [])
845      tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None,
846                                lambda x: [x])
847      self.assertAllClose(-20., acc.jvp(e))
848
849  @test_util.assert_no_new_pyobjects_executing_eagerly
850  def testVariableWatched(self):
851    v = variables.Variable([1., 2., 3.])
852    with forwardprop.ForwardAccumulator(v, constant_op.constant([.1, -.2,
853                                                                 .3])) as acc:
854      self.assertAllClose([.1, -.2, .3], acc.jvp(v))
855      x = v * 2.
856      self.assertAllClose([.2, -.4, .6], acc.jvp(x))
857      x2 = v + .1
858      self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
859
860  def testUnconnectedGradients(self):
861    x = constant_op.constant(-1.)
862    with forwardprop.ForwardAccumulator(x, 0.1) as acc:
863      self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="zero"))
864      self.assertAllClose(0.1, acc.jvp(x, unconnected_gradients="none"))
865      y = constant_op.constant(-2.)
866      self.assertAllClose(0.0, acc.jvp(y, unconnected_gradients="zero"))
867      self.assertIsNone(acc.jvp(y, unconnected_gradients="none"))
868
869  # TODO(kkb): One weakref instance is created with warmup_iters=2,
870  # investigate.
871  @test_util.assert_no_new_pyobjects_executing_eagerly(warmup_iters=3)
872  def testVariableWatchedFunction(self):
873
874    class _Model(module.Module):
875
876      def __init__(self):
877        self._v = None
878
879      @def_function.function
880      def compute_jvps(self):
881        if self._v is None:
882          self._v = variables.Variable([1., 2., 3.])
883        with forwardprop.ForwardAccumulator(self._v,
884                                            constant_op.constant([.1, -.2,
885                                                                  .3])) as acc:
886          x = self._v * 2.
887          x2 = self._v + .1
888        return acc.jvp((self._v, x, x2))
889
890    model = _Model()
891    v_jvp, x_jvp, x2_jvp = model.compute_jvps()
892    self.assertAllClose([.1, -.2, .3], v_jvp)
893    self.assertAllClose([.2, -.4, .6], x_jvp)
894    self.assertAllClose([.1, -.2, .3], x2_jvp)
895
896  def testIndexSlicesGrad(self):
897    x = constant_op.constant([1.])
898
899    with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc:
900      y = array_ops.gather(x, 0)
901    self.assertAllClose(3., acc.jvp(y))
902
903  def testIndexSlicesGradInFunction(self):
904
905    @def_function.function
906    def f(a):
907      return array_ops.gather(a, 0)
908
909    x = constant_op.constant([1.])
910
911    with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc:
912      y = f(x)
913    self.assertAllClose(3., acc.jvp(y))
914
915  # NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
916  # test... could be something wrong with the test decorator, or some sort of
917  # nondeterministic caching.
918  def testMirroredVariableWatched(self):
919
920    def _replicated(input_tangent):
921      with forwardprop.ForwardAccumulator(v, input_tangent) as acc:
922        self.assertAllClose([.1, -.2, .3], acc.jvp(v))
923        x = v * 2.
924        self.assertAllClose([.2, -.4, .6], acc.jvp(x))
925        x2 = v + .1
926        self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
927
928    strategy = mirrored_strategy.MirroredStrategy()
929    with strategy.scope():
930      v = variables.Variable([1., 2., 3.])
931      strategy.run(_replicated, args=(constant_op.constant([.1, -.2, .3]),))
932
933  # TODO(b/141025187): Add a no_new_pyobjects decorator.
934  def testArgumentUnused(self):
935    v = constant_op.constant(1.)
936    with forwardprop.ForwardAccumulator(v, 11.) as acc:
937
938      @def_function.function
939      def _f(x):
940        del x
941        return constant_op.constant(1.)
942
943      result = _f(v)
944      self.assertAllClose(1.0, result)
945      self.assertIsNone(acc.jvp(result))
946
947
948@def_function.function
949def _has_loop(iters, y):
950  ret = 0.
951  for i in math_ops.range(iters):
952    ret += y * math_ops.cast(i, dtypes.float32)
953  return ret
954
955
956@def_function.function
957def _has_cond(k, y):
958  if k > 1:
959    ret = 3. * y
960  else:
961    ret = 0.
962  return ret
963
964
965@def_function.function
966def _fprop_while(iters, y):
967  with forwardprop.ForwardAccumulator(y, 1.) as acc:
968    ret = 0.
969    for i in math_ops.range(iters):
970      ret += y * math_ops.cast(i, dtypes.float32)
971  return acc.jvp(ret)
972
973
974@def_function.function
975def _fprop_cond(k, y):
976  with forwardprop.ForwardAccumulator(y, 1.) as acc:
977    if k > 1:
978      ret = 3. * y
979    else:
980      ret = 0.
981  return acc.jvp(ret)
982
983
984class ControlFlowTests(test.TestCase):
985
986  @test_util.assert_no_new_pyobjects_executing_eagerly
987  def testOfFunctionWhile(self):
988    y = constant_op.constant(1.)
989    with forwardprop.ForwardAccumulator(y, 1.) as acc:
990      self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y)))
991
992  @test_util.assert_no_new_pyobjects_executing_eagerly
993  def testOfFunctionCond(self):
994    y = constant_op.constant(1.)
995    with forwardprop.ForwardAccumulator(y, 1.) as acc:
996      self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y)))
997      self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y)))
998
999  @test_util.assert_no_new_pyobjects_executing_eagerly
1000  def testInFunctionWhile(self):
1001    self.assertAllClose(
1002        10., _fprop_while(constant_op.constant(5), constant_op.constant(1.)))
1003
1004  @test_util.assert_no_new_pyobjects_executing_eagerly
1005  def testInFunctionCond(self):
1006    self.assertAllClose(
1007        3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.)))
1008    self.assertAllClose(
1009        0., _fprop_cond(constant_op.constant(0), constant_op.constant(1.)))
1010
1011
1012class HessianTests(test.TestCase, parameterized.TestCase):
1013
1014  def testHessian1D(self):
1015    # Note: stolen from ops/gradients_test.py
1016    m = 4
1017    rng = np.random.RandomState([1, 2, 3])
1018    mat_value = rng.randn(m, m).astype("float32")
1019    x_value = rng.randn(m).astype("float32")
1020    hess_value = mat_value + mat_value.T
1021    mat = variables.Variable(mat_value)
1022
1023    def _f(x):
1024      return math_ops.reduce_sum(x[:, None] * mat * x[None, :])
1025
1026    hessian_eager, = _forward_over_back_hessian(
1027        _f, [constant_op.constant(x_value)],
1028        use_pfor=False,
1029        dtype=[dtypes.float32])
1030    self.assertAllClose(hess_value, hessian_eager)
1031    hessian_function, = def_function.function(_forward_over_back_hessian)(
1032        _f, [constant_op.constant(x_value)],
1033        use_pfor=False,
1034        dtype=[dtypes.float32])
1035    self.assertAllClose(hess_value, hessian_function)
1036    hessian_pfor, = def_function.function(_forward_over_back_hessian)(
1037        _f, [constant_op.constant(x_value)],
1038        use_pfor=True,
1039        dtype=[dtypes.float32])
1040    self.assertAllClose(hess_value, hessian_pfor)
1041
1042
1043class BatchTests(test.TestCase, parameterized.TestCase):
1044
1045  @parameterized.parameters([(math_ops.sin, (2, 3), 5),
1046                             (math_ops.sin, (2, 3, 4), 10)])
1047  def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
1048    primals = [random_ops.random_uniform(primal_shape)]
1049    tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])]
1050    self.assertAllClose(
1051        _jvp_batch(f, primals, tangent_batch)[1],
1052        _jvp_batch_matmul(f, primals, *tangent_batch))
1053
1054  def testBatchCorrectness(self):
1055    x = constant_op.constant(2.0)
1056    y = constant_op.constant(5.0)
1057    tangents = (
1058        constant_op.constant([1., 0., 1.]),
1059        constant_op.constant([0., 1., 1.]),
1060    )
1061    with forwardprop.ForwardAccumulator._batch_accumulator((x, y),
1062                                                           tangents) as acc:
1063      z = x * y
1064    self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0]))
1065
1066  @parameterized.named_parameters([("ForwardPropFirst", True),
1067                                   ("TapeFirst", False)])
1068  def testBatchBackwardOverForward(self, forward_prop_first):
1069    x = constant_op.constant(1.)
1070    tangents = random_ops.random_normal(shape=[10], seed=1)
1071    expected = [-t * math_ops.cos(1.) for t in tangents]
1072    if forward_prop_first:
1073      batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
1074      gradient_tape = backprop.GradientTape(persistent=True)
1075    else:
1076      gradient_tape = backprop.GradientTape(persistent=True)
1077      batch_acc = forwardprop.ForwardAccumulator._batch_accumulator(x, tangents)
1078    with gradient_tape as tape:
1079      with batch_acc as acc:
1080        tape.watch(x)
1081        y = math_ops.cos(x)
1082        self.assertTrue(tape_lib.should_record_backprop((acc.jvp(y),)))
1083        jvps = acc.jvp(y)
1084      d2y_dx2 = [tape.gradient(dy_dx, x) for dy_dx in jvps]
1085    self.assertAllClose(expected, d2y_dx2)
1086
1087
1088if __name__ == "__main__":
1089  # TODO(allenl): Also test with 1.x-style graph mode.
1090  ops.enable_eager_execution()
1091  test.main()
1092